Commit 92647523 authored by tby's avatar tby Committed by Commit Bot

[Dolphin] Refactor to make predictors ID-based rather than string-based.

This CL makes modifications to the architecture of the RecurrenceRanker
and RecurrencePredictor to:

 1. Allow for better API support for targets and conditions.
 2. Store all target and condition strings in the Ranker, and only pass
    IDs to the predictors.
 3. Modify the ZeroStateFrecencyPredictor to work with IDs instead of
    strings.

This will have several follow-up CLs:

 1. UMA logging for various error cases, eg. proto serialisation
    errors.
 2. Cleanup logic for the ZeroStateFrecencyPredictor.
 3. A general conditional predictor for use with query-based
    predictions.

Change-Id: I86e38c28738adf7d5115d1dcd3b5953fab739ebf
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1534583Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Commit-Queue: Charles . <charleszhao@chromium.org>
Cr-Commit-Position: refs/heads/master@{#659268}
parent e95b9a30
......@@ -250,7 +250,13 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score,
num_of_trains_at_last_update,
last_score);
DEFINE_EQUIVTO_PROTO_LITE_1(ZeroStateFrecencyPredictorProto, targets);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto,
targets,
num_updates);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto_TargetData,
last_score,
last_num_updates);
} // namespace internal
......
......@@ -4,7 +4,10 @@
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.h"
#include <algorithm>
#include <cmath>
#include <limits>
#include <utility>
#include "base/logging.h"
#include "base/stl_util.h"
......@@ -32,19 +35,26 @@ FrecencyStore::FrecencyStore(int value_limit, float decay_coeff)
FrecencyStore::~FrecencyStore() {}
void FrecencyStore::Update(const std::string& value) {
if (!values_.contains(value)) {
values_[value] = {next_id_, 0.0f, 0u};
++next_id_;
unsigned int FrecencyStore::Update(const std::string& value) {
if (static_cast<unsigned int>(values_.size()) >= 2 * value_limit_)
Cleanup();
ValueData* data;
auto it = values_.find(value);
if (it != values_.end()) {
data = &it->second;
} else {
auto ret = values_.insert(std::pair<std::string, FrecencyStore::ValueData>(
value, {next_id_++, 0.0f, 0u}));
DCHECK(ret.second);
data = &ret.first->second;
}
++num_updates_;
ValueData& data = values_[value];
DecayScore(&data);
data.last_score += 1.0f - decay_coeff_;
DecayScore(data);
data->last_score += 1.0f - decay_coeff_;
if (static_cast<unsigned int>(values_.size()) >= 2 * value_limit_)
Cleanup();
return data->id;
}
void FrecencyStore::Rename(const std::string& value,
......@@ -114,12 +124,6 @@ void FrecencyStore::FromProto(const FrecencyStoreProto& proto) {
void FrecencyStore::DecayScore(ValueData* data) {
int time_since_update = num_updates_ - data->last_num_updates;
if (time_since_update < 0) {
// |num_updates_| has overflowed. Fix our calculation of
// |time_since_update|.
time_since_update = num_updates_ + (std::numeric_limits<int>::max() -
data->last_num_updates);
}
if (time_since_update > 0) {
data->last_score *= std::pow(decay_coeff_, time_since_update);
......
......@@ -38,13 +38,15 @@ class FrecencyStore {
int32_t last_num_updates;
};
// Record the use of a value.
void Update(const std::string& value);
// Record the use of a value. Returns its ID.
unsigned int Update(const std::string& value);
// Change one value to another but retain its original ID and score.
void Rename(const std::string& value, const std::string& new_value);
// Remove a value from the store entirely.
// Remove a value and its associated ID from the store entirely.
void Remove(const std::string& value);
// Returns the ID for the given value. If the value is not in the store,
// return base::nullopt.
base::Optional<unsigned int> GetId(const std::string& value);
// Return all stored value data. This ensures all scores have been correctly
// updated, and none of the scores are below the |min_score_| threshold.
......
......@@ -132,14 +132,15 @@ TEST(FrecencyStoreTest, CleanupOnOverflow) {
FrecencyStore store(5, 0.9999f);
// |value_limit_| is 5, so cleanups should occur at 10, 20, ..., 50 values.
for (int i = 0; i < 50; i++) {
for (int i = 0; i <= 50; i++) {
store.Update(std::to_string(i));
}
// A cleanup just happened, so we should have only 45-49 stored.
EXPECT_THAT(store.GetAll(),
UnorderedElementsAre(Pair("45", _), Pair("46", _), Pair("47", _),
Pair("48", _), Pair("49", _)));
// A cleanup just happened, so we should have only 45-50 stored. This is six
// values because the cleanup happens before inserting the new value.
EXPECT_THAT(store.GetAll(), UnorderedElementsAre(
Pair("45", _), Pair("46", _), Pair("47", _),
Pair("48", _), Pair("49", _), Pair("50", _)));
}
TEST(FrecencyStoreTest, RenameValue) {
......
......@@ -7,12 +7,30 @@
#include <cmath>
#include "base/logging.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.pb.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
namespace app_list {
void RecurrencePredictor::Train(unsigned int target) {
NOTREACHED();
}
void RecurrencePredictor::Train(unsigned int target, unsigned int condition) {
NOTREACHED();
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank() {
NOTREACHED();
return {};
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank(
unsigned int condition) {
NOTREACHED();
return {};
}
FakePredictor::FakePredictor(FakePredictorConfig config) {}
FakePredictor::~FakePredictor() = default;
......@@ -21,30 +39,14 @@ const char* FakePredictor::GetPredictorName() const {
return kPredictorName;
}
void FakePredictor::Train(const std::string& target, const std::string& query) {
void FakePredictor::Train(unsigned int target) {
counts_[target] += 1.0f;
}
base::flat_map<std::string, float> FakePredictor::Rank(
const std::string& query) {
base::flat_map<unsigned int, float> FakePredictor::Rank() {
return counts_;
}
void FakePredictor::Rename(const std::string& target,
const std::string& new_target) {
auto it = counts_.find(target);
if (it != counts_.end()) {
counts_[new_target] = it->second;
counts_.erase(it);
}
}
void FakePredictor::Remove(const std::string& target) {
auto it = counts_.find(target);
if (it != counts_.end())
counts_.erase(it);
}
void FakePredictor::ToProto(RecurrencePredictorProto* proto) const {
auto* counts = proto->mutable_fake_predictor()->mutable_counts();
for (auto& pair : counts_)
......@@ -54,16 +56,27 @@ void FakePredictor::ToProto(RecurrencePredictorProto* proto) const {
void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
if (!proto.has_fake_predictor())
return;
auto predictor = proto.fake_predictor();
for (const auto& pair : predictor.counts())
counts_[pair.first] = pair.second;
}
DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config) {}
DefaultPredictor::~DefaultPredictor() {}
const char DefaultPredictor::kPredictorName[] = "DefaultPredictor";
const char* DefaultPredictor::GetPredictorName() const {
return kPredictorName;
}
void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {}
void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {}
ZeroStateFrecencyPredictor::ZeroStateFrecencyPredictor(
ZeroStateFrecencyPredictorConfig config)
: targets_(std::make_unique<FrecencyStore>(config.target_limit(),
config.decay_coeff())) {}
: decay_coeff_(config.decay_coeff()) {}
ZeroStateFrecencyPredictor::~ZeroStateFrecencyPredictor() = default;
const char ZeroStateFrecencyPredictor::kPredictorName[] =
......@@ -72,55 +85,59 @@ const char* ZeroStateFrecencyPredictor::GetPredictorName() const {
return kPredictorName;
}
void ZeroStateFrecencyPredictor::Train(const std::string& target,
const std::string& query) {
if (!query.empty()) {
NOTREACHED();
LOG(ERROR) << "ZeroStateFrecencyPredictor was passed a query.";
return;
}
targets_->Update(target);
void ZeroStateFrecencyPredictor::Train(unsigned int target) {
++num_updates_;
TargetData& data = targets_[target];
DecayScore(&data);
data.last_score += 1.0f - decay_coeff_;
}
base::flat_map<std::string, float> ZeroStateFrecencyPredictor::Rank(
const std::string& query) {
if (!query.empty()) {
NOTREACHED();
LOG(ERROR) << "ZeroStateFrecencyPredictor was passed a query.";
return {};
base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() {
base::flat_map<unsigned int, float> result;
for (auto& pair : targets_) {
DecayScore(&pair.second);
result[pair.first] = pair.second.last_score;
}
base::flat_map<std::string, float> ranks;
for (const auto& target : targets_->GetAll())
ranks[target.first] = target.second.last_score;
return ranks;
}
void ZeroStateFrecencyPredictor::Rename(const std::string& target,
const std::string& new_target) {
targets_->Rename(target, new_target);
}
void ZeroStateFrecencyPredictor::Remove(const std::string& target) {
targets_->Remove(target);
return result;
}
void ZeroStateFrecencyPredictor::ToProto(
RecurrencePredictorProto* proto) const {
auto* targets =
proto->mutable_zero_state_frecency_predictor()->mutable_targets();
targets_->ToProto(targets);
auto* predictor = proto->mutable_zero_state_frecency_predictor();
predictor->set_num_updates(num_updates_);
for (const auto& pair : targets_) {
auto* target_data = predictor->add_targets();
target_data->set_id(pair.first);
target_data->set_last_score(pair.second.last_score);
target_data->set_last_num_updates(pair.second.last_num_updates);
}
}
void ZeroStateFrecencyPredictor::FromProto(
const RecurrencePredictorProto& proto) {
if (!proto.has_zero_state_frecency_predictor())
return;
const auto& predictor = proto.zero_state_frecency_predictor();
if (predictor.has_targets())
targets_->FromProto(predictor.targets());
num_updates_ = predictor.num_updates();
base::flat_map<unsigned int, TargetData> targets;
for (const auto& target_data : predictor.targets()) {
targets[target_data.id()] = {target_data.last_score(),
target_data.last_num_updates()};
}
targets_.swap(targets);
}
void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) {
int time_since_update = num_updates_ - data->last_num_updates;
if (time_since_update > 0) {
data->last_score *= std::pow(decay_coeff_, time_since_update);
data->last_num_updates = num_updates_;
}
}
} // namespace app_list
......@@ -10,18 +10,20 @@
#include <vector>
#include "base/containers/flat_map.h"
#include "base/logging.h"
#include "base/macros.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_config.pb.h"
namespace app_list {
using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig;
using DefaultPredictorConfig =
RecurrenceRankerConfigProto::DefaultPredictorConfig;
using ZeroStateFrecencyPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateFrecencyPredictorConfig;
class FrecencyStore;
// |RecurrencePredictor| is the interface for all predictors used by
// |RecurrenceRanker| to drive rankings. If a predictor has some form of
// serialisation, it should have a corresponding proto in
......@@ -30,30 +32,27 @@ class RecurrencePredictor {
public:
virtual ~RecurrencePredictor() = default;
// Train the predictor on an occurrence of |target| coinciding with |query|.
// The predictor will collect its own contextual information, eg. time of day,
// as part of training. Zero-state scenarios should use an empty string for
// |query|.
virtual void Train(const std::string& target, const std::string& query) = 0;
// Return a map of all known targets to their scores for the given query
// under this predictor. Scores must be within the range [0,1]. Zero-state
// scenarios should use an empty string for |query|.
virtual base::flat_map<std::string, float> Rank(const std::string& query) = 0;
// Train the predictor on an occurrence of |target| coinciding with
// |condition|. The predictor will collect its own contextual information, eg.
// time of day, as part of training. Zero-state predictors should use the
// one-argument version.
virtual void Train(unsigned int target);
virtual void Train(unsigned int target, unsigned int condition);
// Rename a target, while keeping learned information on it.
virtual void Rename(const std::string& target,
const std::string& new_target) = 0;
// Remove a target entirely.
virtual void Remove(const std::string& target) = 0;
// Return a map of all known targets to their scores for the given condition
// under this predictor. Scores must be within the range [0,1]. Zero-state
// predictors should use the zero-argument version.
virtual base::flat_map<unsigned int, float> Rank();
virtual base::flat_map<unsigned int, float> Rank(unsigned int condition);
virtual void ToProto(RecurrencePredictorProto* proto) const = 0;
virtual void FromProto(const RecurrencePredictorProto& proto) = 0;
virtual const char* GetPredictorName() const = 0;
};
// |FakePredictor| is a simple 'predictor' used for testing. |Rank| returns the
// numbers of times each target has been trained on, and ignores the query
// altogether.
// FakePredictor is a simple 'predictor' used for testing. Rank() returns the
// numbers of times each target has been trained on, and does not handle
// conditions.
//
// WARNING: this breaks the guarantees on the range of values a score can take,
// so should not be used for anything except testing.
......@@ -63,11 +62,8 @@ class FakePredictor : public RecurrencePredictor {
~FakePredictor() override;
// RecurrencePredictor:
void Train(const std::string& target, const std::string& query) override;
base::flat_map<std::string, float> Rank(const std::string& query) override;
void Rename(const std::string& target,
const std::string& new_target) override;
void Remove(const std::string& target) override;
void Train(unsigned int target) override;
base::flat_map<unsigned int, float> Rank() override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
......@@ -75,24 +71,51 @@ class FakePredictor : public RecurrencePredictor {
static const char kPredictorName[];
private:
base::flat_map<std::string, float> counts_;
base::flat_map<unsigned int, float> counts_;
DISALLOW_COPY_AND_ASSIGN(FakePredictor);
};
// |ZeroStateFrecencyPredictor| ranks targets according to their frecency, and
// can only be used for zero-state predictions, that is, an empty query string.
// DefaultPredictor does no work on its own. Using this predictor makes the
// RecurrenceRanker return the scores of its FrecencyStore instead of using a
// predictor.
class DefaultPredictor : public RecurrencePredictor {
public:
explicit DefaultPredictor(const DefaultPredictorConfig& config);
~DefaultPredictor() override;
// RecurrencePredictor:
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
static const char kPredictorName[];
private:
DISALLOW_COPY_AND_ASSIGN(DefaultPredictor);
};
// ZeroStateFrecencyPredictor ranks targets according to their frecency, and
// can only be used for zero-state predictions. This predictor allows for
// frecency-based rankings with different configuration to that of the ranker's
// FrecencyStore. If frecency-based rankings with the same configuration as the
// store are needed, the DefaultPredictor should be used instead.
class ZeroStateFrecencyPredictor : public RecurrencePredictor {
public:
explicit ZeroStateFrecencyPredictor(ZeroStateFrecencyPredictorConfig config);
~ZeroStateFrecencyPredictor() override;
// Records all information about a target: its id and score, along with the
// number of updates that had occurred when the score was last calculated.
// This is used for further score updates.
struct TargetData {
float last_score = 0.0f;
int32_t last_num_updates = 0;
};
// RecurrencePredictor:
void Train(const std::string& target, const std::string& query) override;
base::flat_map<std::string, float> Rank(const std::string& query) override;
void Rename(const std::string& target,
const std::string& new_target) override;
void Remove(const std::string& target) override;
void Train(unsigned int target) override;
base::flat_map<unsigned int, float> Rank() override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
......@@ -100,7 +123,22 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor {
static const char kPredictorName[];
private:
std::unique_ptr<FrecencyStore> targets_;
// Decay the given target's score according to how many training steps have
// occurred since last update.
void DecayScore(TargetData* score);
// Decay the scores of all targets.
void DecayAllScores();
// Controls how quickly scores decay, in other words controls the trade-off
// between frequency and recency.
float decay_coeff_;
// Number of times the store has been updated.
unsigned int num_updates_ = 0;
// This stores all the data of the frecency predictor.
// TODO(tby): benchmark which map is best in practice for our use.
base::flat_map<unsigned int, ZeroStateFrecencyPredictor::TargetData> targets_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateFrecencyPredictor);
};
......
......@@ -4,21 +4,34 @@
syntax = "proto2";
import "frecency_store.proto";
option optimize_for = LITE_RUNTIME;
package app_list;
// Fake predictor used for testing.
message FakePredictorProto {
// Maps targets to their score.
map<string, float> counts = 1;
// Maps target IDs to scores.
map<uint32, float> counts = 1;
}
// Zero-state frecency predictor.
message ZeroStateFrecencyPredictorProto {
optional FrecencyStoreProto targets = 1;
// Field 1 (targets) has been deleted.
reserved 1;
// Records all data relating to a particular stored target, corresponding
// exactly to the ZeroStateFrecencyPredictor::ValueData struct.
message TargetData {
required uint32 id = 1;
// The last calculated score associated with a value.
required float last_score = 2;
// The model's number of updates when the score was last calculated.
required uint32 last_num_updates = 3;
}
repeated TargetData targets = 4;
required uint32 num_updates = 5;
}
// Represents the serialisation of one particular predictor.
......
......@@ -19,13 +19,8 @@
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::ElementsAre;
using testing::FloatEq;
using testing::IsSupersetOf;
using testing::NiceMock;
using testing::Pair;
using testing::Return;
using testing::StrEq;
using testing::UnorderedElementsAre;
namespace app_list {
......@@ -35,7 +30,6 @@ class ZeroStateFrecencyPredictorTest : public testing::Test {
void SetUp() override {
Test::SetUp();
config_.set_target_limit(100u);
config_.set_decay_coeff(0.5f);
predictor_ = std::make_unique<ZeroStateFrecencyPredictor>(config_);
}
......@@ -45,90 +39,50 @@ class ZeroStateFrecencyPredictorTest : public testing::Test {
};
TEST_F(ZeroStateFrecencyPredictorTest, RankWithNoTargets) {
EXPECT_TRUE(predictor_->Rank("").empty());
EXPECT_TRUE(predictor_->Rank().empty());
}
TEST_F(ZeroStateFrecencyPredictorTest, RecordAndRank) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("C", "");
TEST_F(ZeroStateFrecencyPredictorTest, RecordAndRankSimple) {
predictor_->Train(2u);
predictor_->Train(4u);
predictor_->Train(6u);
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.125f)),
Pair("B", FloatEq(0.25f)),
Pair("C", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, Rename) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("B", "");
predictor_->Rename("B", "A");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.75f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, RenameNonexistentTarget) {
predictor_->Train("A", "");
predictor_->Rename("B", "C");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, Remove) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Remove("B");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.25f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, RemoveNonexistentTarget) {
predictor_->Train("A", "");
predictor_->Remove("B");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.5f))));
EXPECT_THAT(
predictor_->Rank(),
UnorderedElementsAre(Pair(2u, FloatEq(0.125f)), Pair(4u, FloatEq(0.25f)),
Pair(6u, FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, TargetLimitExceeded) {
ZeroStateFrecencyPredictorConfig config;
config.set_target_limit(5u);
config.set_decay_coeff(0.9999f);
ZeroStateFrecencyPredictor predictor(config);
// Insert many more targets than the target limit.
for (int i = 0; i < 50; i++) {
predictor.Train(std::to_string(i), "");
TEST_F(ZeroStateFrecencyPredictorTest, RecordAndRankComplex) {
predictor_->Train(2u);
predictor_->Train(4u);
predictor_->Train(6u);
predictor_->Train(4u);
predictor_->Train(2u);
// Ranks should be deterministic.
for (int i = 0; i < 3; ++i) {
EXPECT_THAT(predictor_->Rank(),
UnorderedElementsAre(Pair(2u, FloatEq(0.53125f)),
Pair(4u, FloatEq(0.3125f)),
Pair(6u, FloatEq(0.125f))));
}
// Check that some of the values have been deleted, and the most recent ones
// remain. We check loose bounds on these requirements, to prevent the test
// from being tied to implementation details of the |FrecencyStore| cleanup
// logic. See |FrecencyStoreTest::CleanupOnOverflow| for a corresponding, more
// precise test.
auto ranks = predictor.Rank("");
EXPECT_LE(ranks.size(), 10ul);
EXPECT_THAT(
ranks, testing::IsSupersetOf({Pair("45", _), Pair("46", _), Pair("47", _),
Pair("48", _), Pair("49", _)}));
}
TEST_F(ZeroStateFrecencyPredictorTest, ToAndFromProto) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("C", "");
const auto ranks = predictor_->Rank("");
predictor_->Train(1u);
predictor_->Train(3u);
predictor_->Train(5u);
RecurrencePredictorProto proto;
predictor_->ToProto(&proto);
predictor_->FromProto(proto);
EXPECT_EQ(ranks, predictor_->Rank(""));
ZeroStateFrecencyPredictor new_predictor(config_);
new_predictor.FromProto(proto);
EXPECT_TRUE(proto.has_zero_state_frecency_predictor());
EXPECT_EQ(proto.zero_state_frecency_predictor().num_updates(), 3u);
EXPECT_EQ(predictor_->Rank(), new_predictor.Rank());
}
} // namespace app_list
......@@ -78,6 +78,8 @@ std::unique_ptr<RecurrencePredictor> MakePredictor(
const RecurrenceRankerConfigProto& config) {
if (config.has_fake_predictor())
return std::make_unique<FakePredictor>(config.fake_predictor());
if (config.has_default_predictor())
return std::make_unique<DefaultPredictor>(config.default_predictor());
if (config.has_zero_state_frecency_predictor())
return std::make_unique<ZeroStateFrecencyPredictor>(
config.zero_state_frecency_predictor());
......@@ -86,6 +88,48 @@ std::unique_ptr<RecurrencePredictor> MakePredictor(
return nullptr;
}
std::vector<std::pair<std::string, float>> SortAndTruncateRanks(
int n,
const base::flat_map<std::string, float>& ranks) {
std::vector<std::pair<std::string, float>> sorted_ranks(ranks.begin(),
ranks.end());
std::sort(sorted_ranks.begin(), sorted_ranks.end(),
[](const std::pair<std::string, float>& a,
const std::pair<std::string, float>& b) {
return a.second > b.second;
});
// vector::resize simply truncates the array if there are more than n
// elements. Note this is still O(N).
if (sorted_ranks.size() > static_cast<unsigned long>(n))
sorted_ranks.resize(n);
return sorted_ranks;
}
base::flat_map<std::string, float> ZipTargetsWithScores(
const base::flat_map<std::string, FrecencyStore::ValueData>& target_to_id,
const base::flat_map<unsigned int, float>& id_to_score) {
base::flat_map<std::string, float> target_to_score;
for (const auto& pair : target_to_id) {
DCHECK(pair.second.last_num_updates ==
target_to_id.begin()->second.last_num_updates);
const auto& it = id_to_score.find(pair.second.id);
if (it != id_to_score.end()) {
target_to_score[pair.first] = it->second;
}
}
return target_to_score;
}
base::flat_map<std::string, float> GetScoresFromFrecencyStore(
const base::flat_map<std::string, FrecencyStore::ValueData>& target_to_id) {
base::flat_map<std::string, float> target_to_score;
for (const auto& pair : target_to_id)
target_to_score[pair.first] = pair.second.last_score;
return target_to_score;
}
} // namespace
RecurrenceRanker::RecurrenceRanker(const base::FilePath& filepath,
......@@ -102,12 +146,16 @@ RecurrenceRanker::RecurrenceRanker(const base::FilePath& filepath,
{base::TaskPriority::BEST_EFFORT, base::MayBlock(),
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN});
targets_ = std::make_unique<FrecencyStore>(config.target_limit(),
config.target_decay());
conditions_ = std::make_unique<FrecencyStore>(config.condition_limit(),
config.condition_decay());
if (is_ephemeral_user_) {
// Ephemeral users have no persistent storage, so we don't try and load the
// proto from disk. Instead, we fall back on using a frecency predictor,
// which is still useful with only data from the current session.
predictor_ = std::make_unique<ZeroStateFrecencyPredictor>(
config.fallback_predictor());
// proto from disk. Instead, we fall back on using a default (frecency)
// predictor, which is still useful with only data from the current session.
predictor_ = std::make_unique<DefaultPredictor>(config.default_predictor());
} else {
predictor_ = MakePredictor(config);
......@@ -135,75 +183,123 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete(
if (proto->has_predictor())
predictor_->FromProto(proto->predictor());
if (proto->has_targets())
targets_->FromProto(proto->targets());
if (proto->has_conditions())
conditions_->FromProto(proto->conditions());
load_from_disk_completed_ = true;
}
void RecurrenceRanker::Record(const std::string& target) {
Record(target, "");
if (!load_from_disk_completed_)
return;
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
targets_->Update(target);
} else {
predictor_->Train(targets_->Update(target));
}
MaybeSave();
}
void RecurrenceRanker::Record(const std::string& target,
const std::string& query) {
const std::string& condition) {
if (!load_from_disk_completed_)
return;
predictor_->Train(target, query);
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
NOTREACHED();
} else {
predictor_->Train(targets_->Update(target), conditions_->Update(condition));
}
MaybeSave();
}
void RecurrenceRanker::Rename(const std::string& target,
const std::string& new_target) {
void RecurrenceRanker::RenameTarget(const std::string& target,
const std::string& new_target) {
if (!load_from_disk_completed_)
return;
predictor_->Rename(target, new_target);
targets_->Rename(target, new_target);
MaybeSave();
}
void RecurrenceRanker::Remove(const std::string& target) {
void RecurrenceRanker::RemoveTarget(const std::string& target) {
// TODO(tby): Find a solution to the edge case of a removal before disk
// loading is complete, resulting in the remove getting dropped.
if (!load_from_disk_completed_)
return;
predictor_->Remove(target);
targets_->Remove(target);
MaybeSave();
}
void RecurrenceRanker::RenameCondition(const std::string& condition,
const std::string& new_condition) {
if (!load_from_disk_completed_)
return;
conditions_->Rename(condition, new_condition);
MaybeSave();
}
void RecurrenceRanker::RemoveCondition(const std::string& condition) {
if (!load_from_disk_completed_)
return;
conditions_->Remove(condition);
MaybeSave();
}
base::flat_map<std::string, float> RecurrenceRanker::Rank() {
return Rank("");
if (!load_from_disk_completed_)
return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName)
return GetScoresFromFrecencyStore(targets_->GetAll());
return ZipTargetsWithScores(targets_->GetAll(), predictor_->Rank());
}
base::flat_map<std::string, float> RecurrenceRanker::Rank(
const std::string& query) {
const std::string& condition) {
if (!load_from_disk_completed_)
return {};
return predictor_->Rank(query);
base::Optional<unsigned int> condition_id = conditions_->GetId(condition);
if (condition_id == base::nullopt)
return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
NOTREACHED();
return {};
}
return ZipTargetsWithScores(targets_->GetAll(),
predictor_->Rank(condition_id.value()));
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) {
return RankTopN(n, "");
if (!load_from_disk_completed_)
return {};
return SortAndTruncateRanks(n, Rank());
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(
int n,
const std::string& query) {
const std::string& condition) {
if (!load_from_disk_completed_)
return {};
base::flat_map<std::string, float> ranks = Rank(query);
std::vector<std::pair<std::string, float>> sorted_ranks(ranks.begin(),
ranks.end());
std::sort(sorted_ranks.begin(), sorted_ranks.end(),
[](const std::pair<std::string, float>& a,
const std::pair<std::string, float>& b) {
return a.second > b.second;
});
// vector::resize simply truncates the array if there are more than n
// elements. Note this is still O(N).
if (sorted_ranks.size() > static_cast<unsigned long>(n))
sorted_ranks.resize(n);
return sorted_ranks;
return SortAndTruncateRanks(n, Rank(condition));
}
void RecurrenceRanker::MaybeSave() {
......@@ -222,6 +318,8 @@ void RecurrenceRanker::MaybeSave() {
void RecurrenceRanker::ToProto(RecurrenceRankerProto* proto) {
proto->set_config_hash(config_hash_);
predictor_->ToProto(proto->mutable_predictor());
targets_->ToProto(proto->mutable_targets());
conditions_->ToProto(proto->mutable_conditions());
}
void RecurrenceRanker::ForceSaveOnNextUpdateForTesting() {
......
......@@ -21,50 +21,64 @@
namespace app_list {
class FrecencyStore;
class RecurrencePredictor;
class RecurrenceRankerProto;
// |RecurrenceRanker| is the public interface of the ranking system. The methods
// of interest are:
// - Record, Rename, and Remove, for modifying the targets stored in the
// ranker.
// - Rank, for retrieving rankings under the current conditions.
// The class can be configured to use different predictors, via
// |app_list_features|.
// |RecurrenceRanker| is the public interface of the ranking system.
class RecurrenceRanker {
public:
RecurrenceRanker(const base::FilePath& filepath,
const RecurrenceRankerConfigProto& config,
bool is_ephemeral_user);
~RecurrenceRanker();
// Record the use of a given target, and train the predictor on it. The one
// argument version is a shortcut for an empty query string, useful for
// zero-state scenarios.
// Record the use of a given target, and train the predictor on it. The
// one-argument version should be used for zero-state predictions, and the
// two-argument version for condition-based predictions. This may save to
// disk, but is not guaranteed to.
void Record(const std::string& target);
void Record(const std::string& target, const std::string& query);
// Rename a target, while keeping learned information on it.
void Rename(const std::string& target, const std::string& new_target);
// Remove a target entirely.
void Remove(const std::string& target);
void Record(const std::string& target, const std::string& condition);
// Rename a target, while keeping learned information on it. This may save to
// disk, but is not guaranteed to.
// TODO(921444): Provide a mechanism to force save to disk.
void RenameTarget(const std::string& target, const std::string& new_target);
void RenameCondition(const std::string& condition,
const std::string& new_condition);
// Remove a target or condition entirely. This may save to disk, but is not
// guaranteed to. If the intention of this removal is to removal all knowledge
// of, for example, a sensitive target, then a ForceSaveToDisk call should be
// made after removal.
// TODO(921444): Provide a mechanism to force save to disk.
void RemoveTarget(const std::string& target);
void RemoveCondition(const std::string& condition);
// Returns a map of target to score.
// - Higher scores are better.
// - Score are guaranteed to be in the range [0,1].
// The zero-argument version is a shortcut for an empty query string.
// The zero-argument version should be used for zero-state predictions, and
// the one-argument version for condition-based predictions.
base::flat_map<std::string, float> Rank();
base::flat_map<std::string, float> Rank(const std::string& query);
base::flat_map<std::string, float> Rank(const std::string& condition);
// Returns a sorted vector of <target, score> pairs.
// - Higher scores are better.
// - Score are guaranteed to be in the range [0,1].
// - Pairs are sorted in descending order of score.
// - At most n results will be returned.
// The one-argument version is a shortcut for an empty query string.
// The zero-argument version should be used for zero-state predictions, and
// the one-argument version for condition-based predictions.
std::vector<std::pair<std::string, float>> RankTopN(int n);
std::vector<std::pair<std::string, float>> RankTopN(int n,
const std::string& query);
std::vector<std::pair<std::string, float>> RankTopN(
int n,
const std::string& condition);
// TODO(921444): Create a system for cleaning up internal predictor state that
// is stored indepent of the target/condition frecency stores.
const char* GetPredictorNameForTesting() const;
private:
FRIEND_TEST_ALL_PREFIXES(RecurrenceRankerTest,
......@@ -84,12 +98,17 @@ class RecurrenceRanker {
void ToProto(RecurrenceRankerProto* proto);
void FromProto(const RecurrenceRankerProto& proto);
const char* GetPredictorNameForTesting() const;
void ForceSaveOnNextUpdateForTesting();
// Internal predictor that drives ranking.
std::unique_ptr<RecurrencePredictor> predictor_;
// Storage for target strings, which maps them to IDs.
std::unique_ptr<FrecencyStore> targets_;
// Storage for condition strings, which maps them to IDs.
std::unique_ptr<FrecencyStore> conditions_;
// Where to save the ranker.
const base::FilePath proto_filepath_;
// Hash of client-supplied config, used for associating a serialised ranker
......
......@@ -4,6 +4,7 @@
syntax = "proto2";
import "frecency_store.proto";
import "recurrence_predictor.proto";
option optimize_for = LITE_RUNTIME;
......@@ -17,4 +18,8 @@ message RecurrenceRankerProto {
optional uint32 config_hash = 1;
// Serialisation of the predictor used by the ranker.
optional RecurrencePredictorProto predictor = 2;
// Serialisation of stored targets.
optional FrecencyStoreProto targets = 3;
// Serialisation of stored conditions.
optional FrecencyStoreProto conditions = 4;
}
......@@ -20,13 +20,22 @@ message RecurrenceRankerConfigProto {
required uint32 min_seconds_between_saves = 3;
required uint32 target_limit = 4;
required float target_decay = 5;
required uint32 condition_limit = 6;
required float condition_decay = 7;
// Config for a fake predictor, used for testing.
message FakePredictorConfig {}
// Config for a default predictor, which uses the scores of the frecency store
// as its ranks. As a result, it has no configuration of its own.
message DefaultPredictorConfig {}
// Config for a frecency predictor.
message ZeroStateFrecencyPredictorConfig {
// The soft-maximum number of targets that are stored within the predictor.
required uint32 target_limit = 201;
// Field 201 (target_limit) has been deleted.
reserved 201;
// The frecency parameter used to control the frequency-recency tradeoff
// that determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency.
......@@ -37,9 +46,6 @@ message RecurrenceRankerConfigProto {
oneof predictor_config {
FakePredictorConfig fake_predictor = 10001;
ZeroStateFrecencyPredictorConfig zero_state_frecency_predictor = 10002;
DefaultPredictorConfig default_predictor = 10003;
}
// Configuration for a frecency predictor used as a fallback if the user is
// ephemeral.
required ZeroStateFrecencyPredictorConfig fallback_predictor = 11000;
}
......@@ -22,14 +22,25 @@
using testing::_;
using testing::ElementsAre;
using testing::FloatEq;
using testing::NiceMock;
using testing::Pair;
using testing::Return;
using testing::StrEq;
using testing::UnorderedElementsAre;
namespace app_list {
namespace {
// For convenience, sets all fields of a config proto except for the predictor.
void PartiallyPopulateConfig(RecurrenceRankerConfigProto* config) {
config->set_target_limit(100u);
config->set_target_decay(0.8f);
config->set_condition_limit(101u);
config->set_condition_decay(0.81f);
config->set_min_seconds_between_saves(5);
}
} // namespace
class RecurrenceRankerTest : public testing::Test {
protected:
void SetUp() override {
......@@ -37,14 +48,11 @@ class RecurrenceRankerTest : public testing::Test {
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
ranker_filepath_ = temp_dir_.GetPath().AppendASCII("recurrence_ranker");
auto* fallback = config_.mutable_fallback_predictor();
fallback->set_target_limit(0u);
fallback->set_decay_coeff(0.0f);
PartiallyPopulateConfig(&config_);
// Even if empty, the setting of the oneof |predictor_config| in
// |RecurrenceRankerConfigProto| is used to determine which predictor is
// RecurrenceRankerConfigProto is used to determine which predictor is
// constructed.
config_.mutable_fake_predictor();
config_.set_min_seconds_between_saves(5);
ranker_ =
std::make_unique<RecurrenceRanker>(ranker_filepath_, config_, false);
......@@ -57,11 +65,38 @@ class RecurrenceRankerTest : public testing::Test {
RecurrenceRankerProto proto;
proto.set_config_hash(base::PersistentHash(config_.SerializeAsString()));
// Make target frecency store.
auto* targets = proto.mutable_targets();
targets->set_value_limit(100u);
targets->set_decay_coeff(0.8f);
targets->set_num_updates(4);
targets->set_next_id(3);
auto* target_values = targets->mutable_values();
FrecencyStoreProto::ValueData value_data;
value_data.set_id(0u);
value_data.set_last_score(0.5f);
value_data.set_last_num_updates(1);
(*target_values)["A"] = value_data;
value_data = FrecencyStoreProto::ValueData();
value_data.set_id(1u);
value_data.set_last_score(0.5f);
value_data.set_last_num_updates(3);
(*target_values)["B"] = value_data;
value_data = FrecencyStoreProto::ValueData();
value_data.set_id(2u);
value_data.set_last_score(0.5f);
value_data.set_last_num_updates(4);
(*target_values)["C"] = value_data;
// Make FakePredictor counts.
auto* counts =
proto.mutable_predictor()->mutable_fake_predictor()->mutable_counts();
(*counts)["A"] = 1.0f;
(*counts)["B"] = 2.0f;
(*counts)["C"] = 1.0f;
(*counts)[0u] = 1.0f;
(*counts)[1u] = 2.0f;
(*counts)[2u] = 1.0f;
return proto;
}
......@@ -86,20 +121,20 @@ TEST_F(RecurrenceRankerTest, Record) {
Pair("B", FloatEq(2.0f))));
}
TEST_F(RecurrenceRankerTest, Rename) {
TEST_F(RecurrenceRankerTest, RenameTarget) {
ranker_->Record("A");
ranker_->Record("B");
ranker_->Record("B");
ranker_->Rename("B", "A");
ranker_->RenameTarget("B", "A");
EXPECT_THAT(ranker_->Rank(), ElementsAre(Pair("A", FloatEq(2.0f))));
}
TEST_F(RecurrenceRankerTest, Remove) {
TEST_F(RecurrenceRankerTest, RemoveTarget) {
ranker_->Record("A");
ranker_->Record("B");
ranker_->Record("B");
ranker_->Remove("A");
ranker_->RemoveTarget("A");
EXPECT_THAT(ranker_->Rank(), ElementsAre(Pair("B", FloatEq(2.0f))));
}
......@@ -109,11 +144,11 @@ TEST_F(RecurrenceRankerTest, ComplexRecordAndRank) {
ranker_->Record("B");
ranker_->Record("C");
ranker_->Record("B");
ranker_->Rename("D", "C");
ranker_->Remove("F");
ranker_->Rename("C", "F");
ranker_->Remove("A");
ranker_->Rename("C", "F");
ranker_->RenameTarget("D", "C");
ranker_->RemoveTarget("F");
ranker_->RenameTarget("C", "F");
ranker_->RemoveTarget("A");
ranker_->RenameTarget("C", "F");
ranker_->Record("A");
EXPECT_THAT(ranker_->Rank(), UnorderedElementsAre(Pair("A", FloatEq(1.0f)),
......@@ -209,11 +244,8 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) {
// Construct a second ranker with a slightly different config.
RecurrenceRankerConfigProto other_config;
auto* fallback = other_config.mutable_fallback_predictor();
fallback->set_target_limit(0u);
fallback->set_decay_coeff(0.0f);
PartiallyPopulateConfig(&other_config);
other_config.mutable_fake_predictor();
// This is different.
other_config.set_min_seconds_between_saves(
config_.min_seconds_between_saves() + 1);
......@@ -228,35 +260,50 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) {
EXPECT_THAT(ranker_->Rank(), UnorderedElementsAre(Pair("A", FloatEq(1.0f))));
}
TEST_F(RecurrenceRankerTest, EphemeralUsersUseFrecencyPredictor) {
TEST_F(RecurrenceRankerTest, EphemeralUsersUseDefaultPredictor) {
RecurrenceRanker ephemeral_ranker(ranker_filepath_, config_, true);
Wait();
EXPECT_THAT(ephemeral_ranker.GetPredictorNameForTesting(),
StrEq(ZeroStateFrecencyPredictor::kPredictorName));
StrEq(DefaultPredictor::kPredictorName));
}
TEST_F(RecurrenceRankerTest, IntegrationWithDefaultPredictor) {
RecurrenceRankerConfigProto config;
PartiallyPopulateConfig(&config);
config.mutable_default_predictor();
RecurrenceRanker ranker(ranker_filepath_, config, false);
Wait();
ranker.Record("A");
ranker.Record("A");
ranker.Record("B");
ranker.Record("C");
EXPECT_THAT(ranker.Rank(), UnorderedElementsAre(Pair("A", FloatEq(0.2304f)),
Pair("B", FloatEq(0.16f)),
Pair("C", FloatEq(0.2f))));
}
TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) {
RecurrenceRankerConfigProto config;
config.set_min_seconds_between_saves(5);
PartiallyPopulateConfig(&config);
auto* predictor = config.mutable_zero_state_frecency_predictor();
predictor->set_target_limit(100u);
predictor->set_decay_coeff(0.5f);
auto* fallback = config.mutable_fallback_predictor();
fallback->set_target_limit(100u);
fallback->set_decay_coeff(0.5f);
RecurrenceRanker ranker(ranker_filepath_, config, false);
Wait();
ranker.Record("A");
ranker.Record("A");
ranker.Record("D");
ranker.Record("C");
ranker.Record("E");
ranker.Rename("D", "B");
ranker.Remove("E");
ranker.Rename("E", "A");
ranker.RenameTarget("D", "B");
ranker.RemoveTarget("E");
ranker.RenameTarget("E", "A");
EXPECT_THAT(ranker.Rank(), UnorderedElementsAre(Pair("A", FloatEq(0.0625f)),
EXPECT_THAT(ranker.Rank(), UnorderedElementsAre(Pair("A", FloatEq(0.09375f)),
Pair("B", FloatEq(0.125f)),
Pair("C", FloatEq(0.25f))));
}
......
......@@ -53,18 +53,15 @@ SearchResultRanker::SearchResultRanker(Profile* profile) {
if (app_list_features::IsAdaptiveResultRankerEnabled()) {
RecurrenceRankerConfigProto config;
config.set_min_seconds_between_saves(240u);
auto* predictor = config.mutable_zero_state_frecency_predictor();
predictor->set_target_limit(base::GetFieldTrialParamByFeatureAsInt(
config.set_condition_limit(0u);
config.set_condition_decay(0.5f);
config.set_target_limit(base::GetFieldTrialParamByFeatureAsInt(
app_list_features::kEnableAdaptiveResultRanker, "target_limit", 200));
predictor->set_decay_coeff(base::GetFieldTrialParamByFeatureAsDouble(
app_list_features::kEnableAdaptiveResultRanker, "decay_coeff", 0.8f));
auto* fallback = config.mutable_fallback_predictor();
fallback->set_target_limit(base::GetFieldTrialParamByFeatureAsInt(
app_list_features::kEnableAdaptiveResultRanker, "fallback_target_limit",
200));
fallback->set_decay_coeff(base::GetFieldTrialParamByFeatureAsDouble(
app_list_features::kEnableAdaptiveResultRanker, "fallback_decay_coeff",
0.8f));
config.set_target_decay(base::GetFieldTrialParamByFeatureAsDouble(
app_list_features::kEnableAdaptiveResultRanker, "target_decay", 0.8f));
config.mutable_default_predictor();
results_list_group_ranker_ = std::make_unique<RecurrenceRanker>(
profile->GetPath().AppendASCII("adaptive_result_ranker.proto"), config,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment