Commit 659976a2 authored by Findit's avatar Findit

Revert "[Dolphin] Refactor in preparation for query-based predictions."

This reverts commit 6d39702e.

Reason for revert:

Findit (https://goo.gl/kROfz5) identified CL at revision 630682 as the
culprit for failures in the build cycles as shown on:
https://findit-for-me.appspot.com/waterfall/culprit?key=ag9zfmZpbmRpdC1mb3ItbWVyRAsSDVdmU3VzcGVjdGVkQ0wiMWNocm9taXVtLzZkMzk3MDJlNzQ4YWUzMzg0MmU1MzlhZmNmNTVlYjI2MjM1ODJkODAM

Sample Failed Build: https://ci.chromium.org/buildbot/chromium.chromiumos/linux-chromeos-rel/19798

Sample Failed Step: non_single_process_mash_unit_tests

Original change's description:
> [Dolphin] Refactor in preparation for query-based predictions.
> 
>  - Extra methods added to RecurrenceRanker and RecurrencePredictor to accept queries.
> 
>  - Storage of targets has been moved from the RecurrenceRanker to individual
>    RecurrencePredictors. As a result, RecurrencePredictor must now handle targets
>    renames and removes themselves. Methods for this have been added to their API.
> 
>    This is necessary because query-based predictions will require a different data
>    structure to store target + query than zero-state predictions, which only require
>    a target. In order to keep the RecurrenceRanker usable for both tasks, choice of
>    data structure needs to moved into the RecurrencePredictors, where individual
>    predictors can use what's suitable.
> 
>  - The original FrecencyPredictor has been renamed ZeroStateFrecencyPredictor.
> 
>  - Some extra tests added.
> 
> Bug: 921444
> Change-Id: I95a6dca135928726c779261ec6b1663b1023bf1f
> Reviewed-on: https://chromium-review.googlesource.com/c/1459856
> Reviewed-by: Jia Meng <jiameng@chromium.org>
> Commit-Queue: Tony Yeoman <tby@chromium.org>
> Cr-Commit-Position: refs/heads/master@{#630682}

Change-Id: I9ac4e9efc7abf31d64e30cd9e3f5a1fa55bd9778
No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Bug: 921444
Reviewed-on: https://chromium-review.googlesource.com/c/1462426
Cr-Commit-Position: refs/heads/master@{#630687}
parent f56337f3
......@@ -217,6 +217,8 @@ DEFINE_EQUIVTO_PROTO_LITE_1(FakeAppLaunchPredictorProto, rank_result);
DEFINE_EQUIVTO_PROTO_LITE_1(FakePredictorProto, counts);
DEFINE_EQUIVTO_PROTO_LITE_DEFAULT(FrecencyPredictorProto);
DEFINE_EQUIVTO_PROTO_LITE_5(FrecencyStoreProto,
values,
value_limit,
......@@ -238,9 +240,12 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable,
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto,
fake_predictor,
zero_state_frecency_predictor);
frecency_predictor);
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrenceRankerProto, config_hash, predictor);
DEFINE_EQUIVTO_PROTO_LITE_3(RecurrenceRankerProto,
config_hash,
predictor,
targets);
DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto,
num_of_trains,
......@@ -249,9 +254,6 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto,
DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score,
num_of_trains_at_last_update,
last_score);
DEFINE_EQUIVTO_PROTO_LITE_1(ZeroStateFrecencyPredictorProto, targets);
} // namespace internal
template <typename Proto>
......
......@@ -22,13 +22,7 @@ static constexpr float kMinScoreBeforeDelete =
} // namespace
FrecencyStore::FrecencyStore(int value_limit, float decay_coeff)
: value_limit_(value_limit), decay_coeff_(decay_coeff) {
if (decay_coeff <= 0.0f || decay_coeff >= 1.0f) {
LOG(ERROR) << "FrecencyStore decay_coeff has invalid value: " << decay_coeff
<< ", resetting to default.";
decay_coeff_ = 0.75f;
}
}
: value_limit_(value_limit), decay_coeff_(decay_coeff) {}
FrecencyStore::~FrecencyStore() {}
......
......@@ -7,42 +7,24 @@
#include <cmath>
#include "base/logging.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.pb.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
namespace app_list {
FakePredictor::FakePredictor(FakePredictorConfig config) {}
FakePredictor::FakePredictor() = default;
FakePredictor::~FakePredictor() = default;
const char FakePredictor::kPredictorName[] = "FakePredictor";
const char* FakePredictor::GetPredictorName() const {
return kPredictorName;
}
void FakePredictor::Train(const std::string& target, const std::string& query) {
void FakePredictor::Train(unsigned int target) {
counts_[target] += 1.0f;
}
base::flat_map<std::string, float> FakePredictor::Rank(
const std::string& query) {
base::flat_map<unsigned int, float> FakePredictor::Rank() {
return counts_;
}
void FakePredictor::Rename(const std::string& target,
const std::string& new_target) {
auto it = counts_.find(target);
if (it != counts_.end()) {
counts_[new_target] = it->second;
counts_.erase(it);
}
}
void FakePredictor::Remove(const std::string& target) {
auto it = counts_.find(target);
if (it != counts_.end())
counts_.erase(it);
const char FakePredictor::kPredictorName[] = "FakePredictor";
const char* FakePredictor::GetPredictorName() const {
return kPredictorName;
}
void FakePredictor::ToProto(RecurrencePredictorProto* proto) const {
......@@ -55,72 +37,36 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
if (!proto.has_fake_predictor())
return;
auto predictor = proto.fake_predictor();
for (const auto& pair : predictor.counts())
auto fake_predictor_proto = proto.fake_predictor();
for (const auto& pair : fake_predictor_proto.counts()) {
counts_[pair.first] = pair.second;
}
ZeroStateFrecencyPredictor::ZeroStateFrecencyPredictor(
ZeroStateFrecencyPredictorConfig config)
: targets_(std::make_unique<FrecencyStore>(config.target_limit(),
config.decay_coeff())) {}
ZeroStateFrecencyPredictor::~ZeroStateFrecencyPredictor() = default;
const char ZeroStateFrecencyPredictor::kPredictorName[] =
"ZeroStateFrecencyPredictor";
const char* ZeroStateFrecencyPredictor::GetPredictorName() const {
return kPredictorName;
}
void ZeroStateFrecencyPredictor::Train(const std::string& target,
const std::string& query) {
if (!query.empty()) {
NOTREACHED();
LOG(ERROR) << "ZeroStateFrecencyPredictor was passed a query.";
return;
}
targets_->Update(target);
}
base::flat_map<std::string, float> ZeroStateFrecencyPredictor::Rank(
const std::string& query) {
if (!query.empty()) {
NOTREACHED();
LOG(ERROR) << "ZeroStateFrecencyPredictor was passed a query.";
return {};
}
base::flat_map<std::string, float> ranks;
for (const auto& target : targets_->GetAll())
ranks[target.first] = target.second.last_score;
return ranks;
}
void ZeroStateFrecencyPredictor::Rename(const std::string& target,
const std::string& new_target) {
targets_->Rename(target, new_target);
}
void ZeroStateFrecencyPredictor::Remove(const std::string& target) {
targets_->Remove(target);
// |FrecencyPredictor| uses the |RecurrenceRanker|'s store of targets for
// ranking, which is updated on calls to |RecurrenceRanker::Record|. So, there
// is no training work for the predictor itself to do.
void FrecencyPredictor::Train(unsigned int target) {}
// |FrecencyPredictor::Rank| is special-cased inside |RecurrenceRanker::Record|
// for efficiency reasons, so as not to do uneccessary conversion of targets to
// ids and back.
base::flat_map<unsigned int, float> FrecencyPredictor::Rank() {
NOTREACHED();
return {};
}
void ZeroStateFrecencyPredictor::ToProto(
RecurrencePredictorProto* proto) const {
auto* targets =
proto->mutable_zero_state_frecency_predictor()->mutable_targets();
targets_->ToProto(targets);
const char FrecencyPredictor::kPredictorName[] = "FrecencyPredictor";
const char* FrecencyPredictor::GetPredictorName() const {
return kPredictorName;
}
void ZeroStateFrecencyPredictor::FromProto(
const RecurrencePredictorProto& proto) {
if (!proto.has_zero_state_frecency_predictor())
return;
// Empty as all data used by the frecency predictor is serialised with
// |RecurrenceRanker|.
void FrecencyPredictor::ToProto(RecurrencePredictorProto* proto) const {}
const auto& predictor = proto.zero_state_frecency_predictor();
if (predictor.has_targets())
targets_->FromProto(predictor.targets());
}
// Empty as all data used by the frecency predictor is serialised with
// |RecurrenceRanker|.
void FrecencyPredictor::FromProto(const RecurrencePredictorProto& proto) {}
} // namespace app_list
......@@ -5,23 +5,15 @@
#ifndef CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_RECURRENCE_PREDICTOR_H_
#define CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_RECURRENCE_PREDICTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/macros.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_config.pb.h"
namespace app_list {
using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig;
using ZeroStateFrecencyPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateFrecencyPredictorConfig;
class FrecencyStore;
// |RecurrencePredictor| is the interface for all predictors used by
// |RecurrenceRanker| to drive rankings. If a predictor has some form of
// serialisation, it should have a corresponding proto in
......@@ -30,21 +22,13 @@ class RecurrencePredictor {
public:
virtual ~RecurrencePredictor() = default;
// Train the predictor on an occurrence of |target| coinciding with |query|.
// The predictor will collect its own contextual information, eg. time of day,
// as part of training. Zero-state scenarios should use an empty string for
// |query|.
virtual void Train(const std::string& target, const std::string& query) = 0;
// Return a map of all known targets to their scores for the given query
// under this predictor. Scores must be within the range [0,1]. Zero-state
// scenarios should use an empty string for |query|.
virtual base::flat_map<std::string, float> Rank(const std::string& query) = 0;
// Rename a target, while keeping learned information on it.
virtual void Rename(const std::string& target,
const std::string& new_target) = 0;
// Remove a target entirely.
virtual void Remove(const std::string& target) = 0;
// Train the predictor on an occurrence of |target|. The predictor will
// collect its own contextual information, eg. time of day, as part of
// training.
virtual void Train(unsigned int target) = 0;
// Return a map of all known targets to their scores under this predictor.
// Scores must be within the range [0,1].
virtual base::flat_map<unsigned int, float> Rank() = 0;
virtual void ToProto(RecurrencePredictorProto* proto) const = 0;
virtual void FromProto(const RecurrencePredictorProto& proto) = 0;
......@@ -52,57 +36,51 @@ class RecurrencePredictor {
};
// |FakePredictor| is a simple 'predictor' used for testing. |Rank| returns the
// numbers of times each target has been trained on, and ignores the query
// altogether.
// numbers of times each target has been trained on.
//
// WARNING: this breaks the guarantees on the range of values a score can take,
// so should not be used for anything except testing.
class FakePredictor : public RecurrencePredictor {
public:
explicit FakePredictor(FakePredictorConfig config);
FakePredictor();
~FakePredictor() override;
// RecurrencePredictor:
void Train(const std::string& target, const std::string& query) override;
base::flat_map<std::string, float> Rank(const std::string& query) override;
void Rename(const std::string& target,
const std::string& new_target) override;
void Remove(const std::string& target) override;
// RecurrencePredictor overrides:
void Train(unsigned int target) override;
base::flat_map<unsigned int, float> Rank() override;
const char* GetPredictorName() const override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
static const char kPredictorName[];
private:
base::flat_map<std::string, float> counts_;
base::flat_map<unsigned int, float> counts_;
DISALLOW_COPY_AND_ASSIGN(FakePredictor);
};
// |ZeroStateFrecencyPredictor| ranks targets according to their frecency, and
// can only be used for zero-state predictions, that is, an empty query string.
class ZeroStateFrecencyPredictor : public RecurrencePredictor {
// |FrecencyPredictor| simply returns targets in frecency order. To do this, it
// piggybacks off the existing targets FrecencyStore used in |RecurrenceRanker|.
// For efficiency reasons the ranker itself has a special case that handles the
// logic of |FrecencyPredictor::Rank|. This leaves |FrecencyPredictor| as a
// mostly empty class.
class FrecencyPredictor : public RecurrencePredictor {
public:
explicit ZeroStateFrecencyPredictor(ZeroStateFrecencyPredictorConfig config);
~ZeroStateFrecencyPredictor() override;
// RecurrencePredictor:
void Train(const std::string& target, const std::string& query) override;
base::flat_map<std::string, float> Rank(const std::string& query) override;
void Rename(const std::string& target,
const std::string& new_target) override;
void Remove(const std::string& target) override;
FrecencyPredictor() = default;
~FrecencyPredictor() override = default;
// RecurrencePredictor overrides:
void Train(unsigned int target) override;
base::flat_map<unsigned int, float> Rank() override;
const char* GetPredictorName() const override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
static const char kPredictorName[];
private:
std::unique_ptr<FrecencyStore> targets_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateFrecencyPredictor);
DISALLOW_COPY_AND_ASSIGN(FrecencyPredictor);
};
} // namespace app_list
......
......@@ -4,27 +4,23 @@
syntax = "proto2";
import "frecency_store.proto";
option optimize_for = LITE_RUNTIME;
package app_list;
// Fake predictor used for testing.
message FakePredictorProto {
// Maps targets to their score.
map<string, float> counts = 1;
map<uint32, float> counts = 1;
}
// Zero-state frecency predictor.
message ZeroStateFrecencyPredictorProto {
optional FrecencyStoreProto targets = 1;
}
// Frecency predictor. Uses the targets stored in the ranker, so serialises
// nothing of its own.
message FrecencyPredictorProto {}
// Represents the serialisation of one particular predictor.
message RecurrencePredictorProto {
oneof predictor {
FakePredictorProto fake_predictor = 1;
ZeroStateFrecencyPredictorProto zero_state_frecency_predictor = 2;
FrecencyPredictorProto frecency_predictor = 2;
}
}
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.h"
#include <memory>
#include <vector>
#include "ash/public/cpp/app_list/app_list_features.h"
#include "base/files/scoped_temp_dir.h"
#include "base/hash.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/scoped_task_environment.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_predictor_test_util.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_config.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::ElementsAre;
using testing::FloatEq;
using testing::IsSupersetOf;
using testing::NiceMock;
using testing::Pair;
using testing::Return;
using testing::StrEq;
using testing::UnorderedElementsAre;
namespace app_list {
class ZeroStateFrecencyPredictorTest : public testing::Test {
protected:
void SetUp() override {
Test::SetUp();
config_.set_target_limit(100u);
config_.set_decay_coeff(0.5f);
predictor_ = std::make_unique<ZeroStateFrecencyPredictor>(config_);
}
ZeroStateFrecencyPredictorConfig config_;
std::unique_ptr<ZeroStateFrecencyPredictor> predictor_;
};
TEST_F(ZeroStateFrecencyPredictorTest, RankWithNoTargets) {
EXPECT_TRUE(predictor_->Rank("").empty());
}
TEST_F(ZeroStateFrecencyPredictorTest, RecordAndRank) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("C", "");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.125f)),
Pair("B", FloatEq(0.25f)),
Pair("C", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, Rename) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("B", "");
predictor_->Rename("B", "A");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.75f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, RenameNonexistentTarget) {
predictor_->Train("A", "");
predictor_->Rename("B", "C");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, Remove) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Remove("B");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.25f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, RemoveNonexistentTarget) {
predictor_->Train("A", "");
predictor_->Remove("B");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, TargetLimitExceeded) {
ZeroStateFrecencyPredictorConfig config;
config.set_target_limit(5u);
config.set_decay_coeff(0.9999f);
ZeroStateFrecencyPredictor predictor(config);
// Insert many more targets than the target limit.
for (int i = 0; i < 50; i++) {
predictor.Train(std::to_string(i), "");
}
// Check that some of the values have been deleted, and the most recent ones
// remain. We check loose bounds on these requirements, to prevent the test
// from being tied to implementation details of the |FrecencyStore| cleanup
// logic. See |FrecencyStoreTest::CleanupOnOverflow| for a corresponding, more
// precise test.
auto ranks = predictor.Rank("");
EXPECT_LE(ranks.size(), 10ul);
EXPECT_THAT(
ranks, testing::IsSupersetOf({Pair("45", _), Pair("46", _), Pair("47", _),
Pair("48", _), Pair("49", _)}));
}
TEST_F(ZeroStateFrecencyPredictorTest, ToAndFromProto) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("C", "");
const auto ranks = predictor_->Rank("");
RecurrencePredictorProto proto;
predictor_->ToProto(&proto);
predictor_->FromProto(proto);
EXPECT_EQ(ranks, predictor_->Rank(""));
}
TEST_F(ZeroStateFrecencyPredictorTest, FailsWithQuery) {
ASSERT_DEATH(predictor_->Rank("query"), "");
}
} // namespace app_list
......@@ -74,12 +74,11 @@ std::unique_ptr<RecurrenceRankerProto> LoadProtoFromDisk(
// Returns a new, configured instance of the predictor defined in |config|.
std::unique_ptr<RecurrencePredictor> MakePredictor(
const RecurrenceRankerConfigProto& config) {
RecurrenceRankerConfigProto config) {
if (config.has_frecency_predictor())
return std::make_unique<FrecencyPredictor>();
if (config.has_fake_predictor())
return std::make_unique<FakePredictor>(config.fake_predictor());
if (config.has_zero_state_frecency_predictor())
return std::make_unique<ZeroStateFrecencyPredictor>(
config.zero_state_frecency_predictor());
return std::make_unique<FakePredictor>();
NOTREACHED();
return nullptr;
......@@ -101,12 +100,13 @@ RecurrenceRanker::RecurrenceRanker(const base::FilePath& filepath,
{base::TaskPriority::BEST_EFFORT, base::MayBlock(),
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN});
targets_ = std::make_unique<FrecencyStore>(config.target_limit(),
config.target_decay_coeff());
if (is_ephemeral_user_) {
// Ephemeral users have no persistent storage, so we don't try and load the
// proto from disk. Instead, we fall back on using a frecency predictor,
// which is still useful with only data from the current session.
predictor_ = std::make_unique<ZeroStateFrecencyPredictor>(
config.fallback_predictor());
predictor_ = std::make_unique<FrecencyPredictor>();
} else {
predictor_ = MakePredictor(config);
......@@ -132,21 +132,25 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete(
return;
}
if (proto->has_targets())
targets_->FromProto(proto->targets());
if (proto->has_predictor())
predictor_->FromProto(proto->predictor());
load_from_disk_completed_ = true;
}
void RecurrenceRanker::Record(const std::string& target) {
Record(target, "");
}
void RecurrenceRanker::Record(const std::string& target,
const std::string& query) {
if (!load_from_disk_completed_)
return;
predictor_->Train(target, query);
targets_->Update(target);
// It might be possible that, despite just being updated, the target was
// removed from the store. Only train if the target is still valid.
Optional<unsigned int> id = targets_->GetId(target);
if (id.has_value())
predictor_->Train(id.value());
MaybeSave();
}
......@@ -155,7 +159,7 @@ void RecurrenceRanker::Rename(const std::string& target,
if (!load_from_disk_completed_)
return;
predictor_->Rename(target, new_target);
targets_->Rename(target, new_target);
MaybeSave();
}
......@@ -163,33 +167,44 @@ void RecurrenceRanker::Remove(const std::string& target) {
if (!load_from_disk_completed_)
return;
predictor_->Remove(target);
targets_->Remove(target);
MaybeSave();
}
base::flat_map<std::string, float> RecurrenceRanker::Rank() {
return Rank("");
}
base::flat_map<std::string, float> RecurrenceRanker::Rank(
const std::string& query) {
if (!load_from_disk_completed_)
return {};
return predictor_->Rank(query);
}
// Special case for a frecency predictor. Because this is simply a wrapper
// around the |RecurrenceRanker|'s targets store, we can directly return the
// contents of the store and avoid an uneccessary iteration through targets.
if (predictor_->GetPredictorName() == FrecencyPredictor::kPredictorName) {
base::flat_map<std::string, float> ranks;
for (const auto& pair : targets_->GetAll())
ranks[pair.first] = pair.second.last_score;
return ranks;
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) {
return RankTopN(n, "");
const base::flat_map<unsigned int, float> id_ranks = predictor_->Rank();
const base::flat_map<std::string, FrecencyStore::ValueData>& targets =
targets_->GetAll();
base::flat_map<std::string, float> ranks;
for (const auto& pair : targets) {
const auto& data = pair.second;
const auto it = id_ranks.find(data.id);
if (it == id_ranks.end())
continue;
ranks[pair.first] = it->second;
}
return ranks;
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(
int n,
const std::string& query) {
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) {
if (!load_from_disk_completed_)
return {};
base::flat_map<std::string, float> ranks = Rank(query);
base::flat_map<std::string, float> ranks = Rank();
std::vector<std::pair<std::string, float>> sorted_ranks(ranks.begin(),
ranks.end());
std::sort(sorted_ranks.begin(), sorted_ranks.end(),
......@@ -221,6 +236,7 @@ void RecurrenceRanker::MaybeSave() {
void RecurrenceRanker::ToProto(RecurrenceRankerProto* proto) {
proto->set_config_hash(config_hash_);
predictor_->ToProto(proto->mutable_predictor());
targets_->ToProto(proto->mutable_targets());
}
void RecurrenceRanker::ForceSaveOnNextUpdateForTesting() {
......
......@@ -21,6 +21,7 @@
namespace app_list {
class FrecencyStore;
class RecurrencePredictor;
class RecurrenceRankerProto;
......@@ -39,11 +40,8 @@ class RecurrenceRanker {
~RecurrenceRanker();
// Record the use of a given target, and train the predictor on it. The one
// argument version is a shortcut for an empty query string, useful for
// zero-state scenarios.
// Record the use of a given target, and train the predictor on it.
void Record(const std::string& target);
void Record(const std::string& target, const std::string& query);
// Rename a target, while keeping learned information on it.
void Rename(const std::string& target, const std::string& new_target);
// Remove a target entirely.
......@@ -52,19 +50,14 @@ class RecurrenceRanker {
// Returns a map of target to score.
// - Higher scores are better.
// - Score are guaranteed to be in the range [0,1].
// The zero-argument version is a shortcut for an empty query string.
base::flat_map<std::string, float> Rank();
base::flat_map<std::string, float> Rank(const std::string& query);
// Returns a sorted vector of <target, score> pairs.
// - Higher scores are better.
// - Score are guaranteed to be in the range [0,1].
// - Pairs are sorted in descending order of score.
// - At most n results will be returned.
// The one-argument version is a shortcut for an empty query string.
std::vector<std::pair<std::string, float>> RankTopN(int n);
std::vector<std::pair<std::string, float>> RankTopN(int n,
const std::string& query);
private:
FRIEND_TEST_ALL_PREFIXES(RecurrenceRankerTest,
......@@ -89,6 +82,8 @@ class RecurrenceRanker {
// Internal predictor that drives ranking.
std::unique_ptr<RecurrencePredictor> predictor_;
// A store of all possible targets to return.
std::unique_ptr<FrecencyStore> targets_;
// Where to save the ranker.
const base::FilePath proto_filepath_;
......
......@@ -4,6 +4,7 @@
syntax = "proto2";
import "frecency_store.proto";
import "recurrence_predictor.proto";
option optimize_for = LITE_RUNTIME;
......@@ -17,4 +18,6 @@ message RecurrenceRankerProto {
optional uint32 config_hash = 1;
// Serialisation of the predictor used by the ranker.
optional RecurrencePredictorProto predictor = 2;
// Serialisation of the store of targets.
optional FrecencyStoreProto targets = 3;
}
......@@ -13,33 +13,23 @@ package app_list;
// Warning: this cannot contain any map fields, as they cannot be relied upon
// for a consistent hash.
message RecurrenceRankerConfigProto {
// Fields with IDs 1 (target_limit) and 2 (target_decay_coeff) have been
// deleted.
reserved 1;
reserved 2;
// The soft-maximum number of targets that are stored within the ranker.
required uint32 target_limit = 1;
// The frecency parameter used to control the frequency-recency tradeoff that
// determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency.
required float target_decay_coeff = 2;
required uint32 min_seconds_between_saves = 3;
// Config for a fake predictor, used for testing.
message FakePredictorConfig {}
// Config for a frecency predictor.
message ZeroStateFrecencyPredictorConfig {
// The soft-maximum number of targets that are stored within the predictor.
required uint32 target_limit = 201;
// The frecency parameter used to control the frequency-recency tradeoff
// that determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency.
required float decay_coeff = 202;
}
message FrecencyPredictorConfig {}
// The choice of which kind of predictor to use, and its configuration.
oneof predictor_config {
FakePredictorConfig fake_predictor = 10001;
ZeroStateFrecencyPredictorConfig zero_state_frecency_predictor = 10002;
FakePredictorConfig fake_predictor = 4;
FrecencyPredictorConfig frecency_predictor = 5;
}
// Configuration for a frecency predictor used as a fallback if the user is
// ephemeral.
required ZeroStateFrecencyPredictorConfig fallback_predictor = 11000;
}
......@@ -36,14 +36,13 @@ class RecurrenceRankerTest : public testing::Test {
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
ranker_filepath_ = temp_dir_.GetPath().AppendASCII("recurrence_ranker");
auto* fallback = config_.mutable_fallback_predictor();
fallback->set_target_limit(0u);
fallback->set_decay_coeff(0.0f);
config_.set_target_limit(1000u);
config_.set_target_decay_coeff(0.75f);
config_.set_min_seconds_between_saves(120);
// Even if empty, the setting of the oneof |predictor_config| in
// |RecurrenceRankerConfigProto| is used to determine which predictor is
// constructed.
config_.mutable_fake_predictor();
config_.set_min_seconds_between_saves(5);
ranker_ =
std::make_unique<RecurrenceRanker>(ranker_filepath_, config_, false);
......@@ -58,9 +57,17 @@ class RecurrenceRankerTest : public testing::Test {
auto* counts =
proto.mutable_predictor()->mutable_fake_predictor()->mutable_counts();
(*counts)["A"] = 1.0f;
(*counts)["B"] = 2.0f;
(*counts)["C"] = 1.0f;
(*counts)[0u] = 1.0f;
(*counts)[1u] = 2.0f;
(*counts)[2u] = 1.0f;
FrecencyStore store(1000, 0.75f);
store.Update("A");
store.Update("B");
store.Update("B");
store.Update("C");
FrecencyStoreProto targets_proto;
store.ToProto(proto.mutable_targets());
return proto;
}
......@@ -74,7 +81,7 @@ class RecurrenceRankerTest : public testing::Test {
base::FilePath ranker_filepath_;
};
TEST_F(RecurrenceRankerTest, Record) {
TEST_F(RecurrenceRankerTest, CheckRecord) {
ranker_->Record("A");
ranker_->Record("B");
ranker_->Record("B");
......@@ -83,7 +90,7 @@ TEST_F(RecurrenceRankerTest, Record) {
Pair("B", FloatEq(2.0f))));
}
TEST_F(RecurrenceRankerTest, Rename) {
TEST_F(RecurrenceRankerTest, CheckRename) {
ranker_->Record("A");
ranker_->Record("B");
ranker_->Record("B");
......@@ -92,7 +99,7 @@ TEST_F(RecurrenceRankerTest, Rename) {
EXPECT_THAT(ranker_->Rank(), ElementsAre(Pair("A", FloatEq(2.0f))));
}
TEST_F(RecurrenceRankerTest, Remove) {
TEST_F(RecurrenceRankerTest, CheckRemove) {
ranker_->Record("A");
ranker_->Record("B");
ranker_->Record("B");
......@@ -206,14 +213,10 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) {
// Construct a second ranker with a slightly different config.
RecurrenceRankerConfigProto other_config;
auto* fallback = other_config.mutable_fallback_predictor();
fallback->set_target_limit(0u);
fallback->set_decay_coeff(0.0f);
other_config.set_target_limit(1000u);
other_config.set_target_decay_coeff(0.76f);
other_config.set_min_seconds_between_saves(120);
other_config.mutable_fake_predictor();
// This is different.
other_config.set_min_seconds_between_saves(
config_.min_seconds_between_saves() + 1);
RecurrenceRanker other_ranker(ranker_filepath_, other_config, false);
Wait();
......@@ -226,36 +229,31 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) {
}
TEST_F(RecurrenceRankerTest, EphemeralUsersUseFrecencyPredictor) {
RecurrenceRanker ephemeral_ranker(ranker_filepath_, config_, true);
auto ephemeral_ranker =
std::make_unique<RecurrenceRanker>(ranker_filepath_, config_, true);
Wait();
EXPECT_THAT(ephemeral_ranker.GetPredictorNameForTesting(),
StrEq(ZeroStateFrecencyPredictor::kPredictorName));
EXPECT_THAT(ephemeral_ranker->GetPredictorNameForTesting(),
StrEq(FrecencyPredictor().GetPredictorName()));
}
TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) {
TEST_F(RecurrenceRankerTest, CheckFrecencyPredictor) {
RecurrenceRankerConfigProto config;
config.set_min_seconds_between_saves(5);
auto* predictor = config.mutable_zero_state_frecency_predictor();
predictor->set_target_limit(100u);
predictor->set_decay_coeff(0.5f);
auto* fallback = config.mutable_fallback_predictor();
fallback->set_target_limit(100u);
fallback->set_decay_coeff(0.5f);
RecurrenceRanker ranker(ranker_filepath_, config, false);
config.set_target_limit(1000u);
config.set_target_decay_coeff(0.5f);
config.set_min_seconds_between_saves(120);
config.mutable_frecency_predictor();
RecurrenceRanker frecency_ranker(ranker_filepath_, config, false);
Wait();
ranker.Record("A");
ranker.Record("D");
ranker.Record("C");
ranker.Record("E");
ranker.Rename("D", "B");
ranker.Remove("E");
ranker.Rename("E", "A");
EXPECT_THAT(ranker.Rank(), UnorderedElementsAre(Pair("A", FloatEq(0.0625f)),
Pair("B", FloatEq(0.125f)),
Pair("C", FloatEq(0.25f))));
frecency_ranker.Record("A");
frecency_ranker.Record("B");
frecency_ranker.Record("C");
EXPECT_THAT(frecency_ranker.Rank(),
UnorderedElementsAre(Pair("A", FloatEq(0.125f)),
Pair("B", FloatEq(0.25f)),
Pair("C", FloatEq(0.5f))));
}
} // namespace app_list
......@@ -4632,7 +4632,6 @@ test("unit_tests") {
"../browser/ui/app_list/search/search_result_ranker/app_launch_predictor_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/app_search_result_ranker_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/frecency_store_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/recurrence_predictor_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/recurrence_ranker_unittest.cc",
"../browser/ui/app_list/search/settings_shortcut/settings_shortcut_provider_unittest.cc",
"../browser/ui/app_list/search/settings_shortcut/settings_shortcut_result_unittest.cc",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment