Commit 659976a2 authored by Findit's avatar Findit

Revert "[Dolphin] Refactor in preparation for query-based predictions."

This reverts commit 6d39702e.

Reason for revert:

Findit (https://goo.gl/kROfz5) identified CL at revision 630682 as the
culprit for failures in the build cycles as shown on:
https://findit-for-me.appspot.com/waterfall/culprit?key=ag9zfmZpbmRpdC1mb3ItbWVyRAsSDVdmU3VzcGVjdGVkQ0wiMWNocm9taXVtLzZkMzk3MDJlNzQ4YWUzMzg0MmU1MzlhZmNmNTVlYjI2MjM1ODJkODAM

Sample Failed Build: https://ci.chromium.org/buildbot/chromium.chromiumos/linux-chromeos-rel/19798

Sample Failed Step: non_single_process_mash_unit_tests

Original change's description:
> [Dolphin] Refactor in preparation for query-based predictions.
> 
>  - Extra methods added to RecurrenceRanker and RecurrencePredictor to accept queries.
> 
>  - Storage of targets has been moved from the RecurrenceRanker to individual
>    RecurrencePredictors. As a result, RecurrencePredictor must now handle targets
>    renames and removes themselves. Methods for this have been added to their API.
> 
>    This is necessary because query-based predictions will require a different data
>    structure to store target + query than zero-state predictions, which only require
>    a target. In order to keep the RecurrenceRanker usable for both tasks, choice of
>    data structure needs to moved into the RecurrencePredictors, where individual
>    predictors can use what's suitable.
> 
>  - The original FrecencyPredictor has been renamed ZeroStateFrecencyPredictor.
> 
>  - Some extra tests added.
> 
> Bug: 921444
> Change-Id: I95a6dca135928726c779261ec6b1663b1023bf1f
> Reviewed-on: https://chromium-review.googlesource.com/c/1459856
> Reviewed-by: Jia Meng <jiameng@chromium.org>
> Commit-Queue: Tony Yeoman <tby@chromium.org>
> Cr-Commit-Position: refs/heads/master@{#630682}

Change-Id: I9ac4e9efc7abf31d64e30cd9e3f5a1fa55bd9778
No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Bug: 921444
Reviewed-on: https://chromium-review.googlesource.com/c/1462426
Cr-Commit-Position: refs/heads/master@{#630687}
parent f56337f3
...@@ -217,6 +217,8 @@ DEFINE_EQUIVTO_PROTO_LITE_1(FakeAppLaunchPredictorProto, rank_result); ...@@ -217,6 +217,8 @@ DEFINE_EQUIVTO_PROTO_LITE_1(FakeAppLaunchPredictorProto, rank_result);
DEFINE_EQUIVTO_PROTO_LITE_1(FakePredictorProto, counts); DEFINE_EQUIVTO_PROTO_LITE_1(FakePredictorProto, counts);
DEFINE_EQUIVTO_PROTO_LITE_DEFAULT(FrecencyPredictorProto);
DEFINE_EQUIVTO_PROTO_LITE_5(FrecencyStoreProto, DEFINE_EQUIVTO_PROTO_LITE_5(FrecencyStoreProto,
values, values,
value_limit, value_limit,
...@@ -238,9 +240,12 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable, ...@@ -238,9 +240,12 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable,
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto, DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto,
fake_predictor, fake_predictor,
zero_state_frecency_predictor); frecency_predictor);
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrenceRankerProto, config_hash, predictor); DEFINE_EQUIVTO_PROTO_LITE_3(RecurrenceRankerProto,
config_hash,
predictor,
targets);
DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto, DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto,
num_of_trains, num_of_trains,
...@@ -249,9 +254,6 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto, ...@@ -249,9 +254,6 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto,
DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score, DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score,
num_of_trains_at_last_update, num_of_trains_at_last_update,
last_score); last_score);
DEFINE_EQUIVTO_PROTO_LITE_1(ZeroStateFrecencyPredictorProto, targets);
} // namespace internal } // namespace internal
template <typename Proto> template <typename Proto>
......
...@@ -22,13 +22,7 @@ static constexpr float kMinScoreBeforeDelete = ...@@ -22,13 +22,7 @@ static constexpr float kMinScoreBeforeDelete =
} // namespace } // namespace
FrecencyStore::FrecencyStore(int value_limit, float decay_coeff) FrecencyStore::FrecencyStore(int value_limit, float decay_coeff)
: value_limit_(value_limit), decay_coeff_(decay_coeff) { : value_limit_(value_limit), decay_coeff_(decay_coeff) {}
if (decay_coeff <= 0.0f || decay_coeff >= 1.0f) {
LOG(ERROR) << "FrecencyStore decay_coeff has invalid value: " << decay_coeff
<< ", resetting to default.";
decay_coeff_ = 0.75f;
}
}
FrecencyStore::~FrecencyStore() {} FrecencyStore::~FrecencyStore() {}
......
...@@ -7,42 +7,24 @@ ...@@ -7,42 +7,24 @@
#include <cmath> #include <cmath>
#include "base/logging.h" #include "base/logging.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.pb.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h" #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
namespace app_list { namespace app_list {
FakePredictor::FakePredictor(FakePredictorConfig config) {} FakePredictor::FakePredictor() = default;
FakePredictor::~FakePredictor() = default; FakePredictor::~FakePredictor() = default;
const char FakePredictor::kPredictorName[] = "FakePredictor"; void FakePredictor::Train(unsigned int target) {
const char* FakePredictor::GetPredictorName() const {
return kPredictorName;
}
void FakePredictor::Train(const std::string& target, const std::string& query) {
counts_[target] += 1.0f; counts_[target] += 1.0f;
} }
base::flat_map<std::string, float> FakePredictor::Rank( base::flat_map<unsigned int, float> FakePredictor::Rank() {
const std::string& query) {
return counts_; return counts_;
} }
void FakePredictor::Rename(const std::string& target, const char FakePredictor::kPredictorName[] = "FakePredictor";
const std::string& new_target) { const char* FakePredictor::GetPredictorName() const {
auto it = counts_.find(target); return kPredictorName;
if (it != counts_.end()) {
counts_[new_target] = it->second;
counts_.erase(it);
}
}
void FakePredictor::Remove(const std::string& target) {
auto it = counts_.find(target);
if (it != counts_.end())
counts_.erase(it);
} }
void FakePredictor::ToProto(RecurrencePredictorProto* proto) const { void FakePredictor::ToProto(RecurrencePredictorProto* proto) const {
...@@ -55,72 +37,36 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) { ...@@ -55,72 +37,36 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
if (!proto.has_fake_predictor()) if (!proto.has_fake_predictor())
return; return;
auto predictor = proto.fake_predictor(); auto fake_predictor_proto = proto.fake_predictor();
for (const auto& pair : predictor.counts()) for (const auto& pair : fake_predictor_proto.counts()) {
counts_[pair.first] = pair.second; counts_[pair.first] = pair.second;
}
ZeroStateFrecencyPredictor::ZeroStateFrecencyPredictor(
ZeroStateFrecencyPredictorConfig config)
: targets_(std::make_unique<FrecencyStore>(config.target_limit(),
config.decay_coeff())) {}
ZeroStateFrecencyPredictor::~ZeroStateFrecencyPredictor() = default;
const char ZeroStateFrecencyPredictor::kPredictorName[] =
"ZeroStateFrecencyPredictor";
const char* ZeroStateFrecencyPredictor::GetPredictorName() const {
return kPredictorName;
}
void ZeroStateFrecencyPredictor::Train(const std::string& target,
const std::string& query) {
if (!query.empty()) {
NOTREACHED();
LOG(ERROR) << "ZeroStateFrecencyPredictor was passed a query.";
return;
} }
targets_->Update(target);
} }
base::flat_map<std::string, float> ZeroStateFrecencyPredictor::Rank( // |FrecencyPredictor| uses the |RecurrenceRanker|'s store of targets for
const std::string& query) { // ranking, which is updated on calls to |RecurrenceRanker::Record|. So, there
if (!query.empty()) { // is no training work for the predictor itself to do.
void FrecencyPredictor::Train(unsigned int target) {}
// |FrecencyPredictor::Rank| is special-cased inside |RecurrenceRanker::Record|
// for efficiency reasons, so as not to do uneccessary conversion of targets to
// ids and back.
base::flat_map<unsigned int, float> FrecencyPredictor::Rank() {
NOTREACHED(); NOTREACHED();
LOG(ERROR) << "ZeroStateFrecencyPredictor was passed a query.";
return {}; return {};
}
base::flat_map<std::string, float> ranks;
for (const auto& target : targets_->GetAll())
ranks[target.first] = target.second.last_score;
return ranks;
} }
void ZeroStateFrecencyPredictor::Rename(const std::string& target, const char FrecencyPredictor::kPredictorName[] = "FrecencyPredictor";
const std::string& new_target) { const char* FrecencyPredictor::GetPredictorName() const {
targets_->Rename(target, new_target); return kPredictorName;
}
void ZeroStateFrecencyPredictor::Remove(const std::string& target) {
targets_->Remove(target);
}
void ZeroStateFrecencyPredictor::ToProto(
RecurrencePredictorProto* proto) const {
auto* targets =
proto->mutable_zero_state_frecency_predictor()->mutable_targets();
targets_->ToProto(targets);
} }
void ZeroStateFrecencyPredictor::FromProto( // Empty as all data used by the frecency predictor is serialised with
const RecurrencePredictorProto& proto) { // |RecurrenceRanker|.
if (!proto.has_zero_state_frecency_predictor()) void FrecencyPredictor::ToProto(RecurrencePredictorProto* proto) const {}
return;
const auto& predictor = proto.zero_state_frecency_predictor(); // Empty as all data used by the frecency predictor is serialised with
if (predictor.has_targets()) // |RecurrenceRanker|.
targets_->FromProto(predictor.targets()); void FrecencyPredictor::FromProto(const RecurrencePredictorProto& proto) {}
}
} // namespace app_list } // namespace app_list
...@@ -5,23 +5,15 @@ ...@@ -5,23 +5,15 @@
#ifndef CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_RECURRENCE_PREDICTOR_H_ #ifndef CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_RECURRENCE_PREDICTOR_H_
#define CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_RECURRENCE_PREDICTOR_H_ #define CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_RECURRENCE_PREDICTOR_H_
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "base/containers/flat_map.h" #include "base/containers/flat_map.h"
#include "base/macros.h" #include "base/macros.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h" #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_config.pb.h"
namespace app_list { namespace app_list {
using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig;
using ZeroStateFrecencyPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateFrecencyPredictorConfig;
class FrecencyStore;
// |RecurrencePredictor| is the interface for all predictors used by // |RecurrencePredictor| is the interface for all predictors used by
// |RecurrenceRanker| to drive rankings. If a predictor has some form of // |RecurrenceRanker| to drive rankings. If a predictor has some form of
// serialisation, it should have a corresponding proto in // serialisation, it should have a corresponding proto in
...@@ -30,21 +22,13 @@ class RecurrencePredictor { ...@@ -30,21 +22,13 @@ class RecurrencePredictor {
public: public:
virtual ~RecurrencePredictor() = default; virtual ~RecurrencePredictor() = default;
// Train the predictor on an occurrence of |target| coinciding with |query|. // Train the predictor on an occurrence of |target|. The predictor will
// The predictor will collect its own contextual information, eg. time of day, // collect its own contextual information, eg. time of day, as part of
// as part of training. Zero-state scenarios should use an empty string for // training.
// |query|. virtual void Train(unsigned int target) = 0;
virtual void Train(const std::string& target, const std::string& query) = 0; // Return a map of all known targets to their scores under this predictor.
// Return a map of all known targets to their scores for the given query // Scores must be within the range [0,1].
// under this predictor. Scores must be within the range [0,1]. Zero-state virtual base::flat_map<unsigned int, float> Rank() = 0;
// scenarios should use an empty string for |query|.
virtual base::flat_map<std::string, float> Rank(const std::string& query) = 0;
// Rename a target, while keeping learned information on it.
virtual void Rename(const std::string& target,
const std::string& new_target) = 0;
// Remove a target entirely.
virtual void Remove(const std::string& target) = 0;
virtual void ToProto(RecurrencePredictorProto* proto) const = 0; virtual void ToProto(RecurrencePredictorProto* proto) const = 0;
virtual void FromProto(const RecurrencePredictorProto& proto) = 0; virtual void FromProto(const RecurrencePredictorProto& proto) = 0;
...@@ -52,57 +36,51 @@ class RecurrencePredictor { ...@@ -52,57 +36,51 @@ class RecurrencePredictor {
}; };
// |FakePredictor| is a simple 'predictor' used for testing. |Rank| returns the // |FakePredictor| is a simple 'predictor' used for testing. |Rank| returns the
// numbers of times each target has been trained on, and ignores the query // numbers of times each target has been trained on.
// altogether.
// //
// WARNING: this breaks the guarantees on the range of values a score can take, // WARNING: this breaks the guarantees on the range of values a score can take,
// so should not be used for anything except testing. // so should not be used for anything except testing.
class FakePredictor : public RecurrencePredictor { class FakePredictor : public RecurrencePredictor {
public: public:
explicit FakePredictor(FakePredictorConfig config); FakePredictor();
~FakePredictor() override; ~FakePredictor() override;
// RecurrencePredictor: // RecurrencePredictor overrides:
void Train(const std::string& target, const std::string& query) override; void Train(unsigned int target) override;
base::flat_map<std::string, float> Rank(const std::string& query) override; base::flat_map<unsigned int, float> Rank() override;
void Rename(const std::string& target, const char* GetPredictorName() const override;
const std::string& new_target) override;
void Remove(const std::string& target) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
static const char kPredictorName[]; static const char kPredictorName[];
private: private:
base::flat_map<std::string, float> counts_; base::flat_map<unsigned int, float> counts_;
DISALLOW_COPY_AND_ASSIGN(FakePredictor); DISALLOW_COPY_AND_ASSIGN(FakePredictor);
}; };
// |ZeroStateFrecencyPredictor| ranks targets according to their frecency, and // |FrecencyPredictor| simply returns targets in frecency order. To do this, it
// can only be used for zero-state predictions, that is, an empty query string. // piggybacks off the existing targets FrecencyStore used in |RecurrenceRanker|.
class ZeroStateFrecencyPredictor : public RecurrencePredictor { // For efficiency reasons the ranker itself has a special case that handles the
// logic of |FrecencyPredictor::Rank|. This leaves |FrecencyPredictor| as a
// mostly empty class.
class FrecencyPredictor : public RecurrencePredictor {
public: public:
explicit ZeroStateFrecencyPredictor(ZeroStateFrecencyPredictorConfig config); FrecencyPredictor() = default;
~ZeroStateFrecencyPredictor() override; ~FrecencyPredictor() override = default;
// RecurrencePredictor: // RecurrencePredictor overrides:
void Train(const std::string& target, const std::string& query) override; void Train(unsigned int target) override;
base::flat_map<std::string, float> Rank(const std::string& query) override; base::flat_map<unsigned int, float> Rank() override;
void Rename(const std::string& target, const char* GetPredictorName() const override;
const std::string& new_target) override;
void Remove(const std::string& target) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
static const char kPredictorName[]; static const char kPredictorName[];
private: private:
std::unique_ptr<FrecencyStore> targets_; DISALLOW_COPY_AND_ASSIGN(FrecencyPredictor);
DISALLOW_COPY_AND_ASSIGN(ZeroStateFrecencyPredictor);
}; };
} // namespace app_list } // namespace app_list
......
...@@ -4,27 +4,23 @@ ...@@ -4,27 +4,23 @@
syntax = "proto2"; syntax = "proto2";
import "frecency_store.proto";
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
package app_list; package app_list;
// Fake predictor used for testing. // Fake predictor used for testing.
message FakePredictorProto { message FakePredictorProto {
// Maps targets to their score. map<uint32, float> counts = 1;
map<string, float> counts = 1;
} }
// Zero-state frecency predictor. // Frecency predictor. Uses the targets stored in the ranker, so serialises
message ZeroStateFrecencyPredictorProto { // nothing of its own.
optional FrecencyStoreProto targets = 1; message FrecencyPredictorProto {}
}
// Represents the serialisation of one particular predictor. // Represents the serialisation of one particular predictor.
message RecurrencePredictorProto { message RecurrencePredictorProto {
oneof predictor { oneof predictor {
FakePredictorProto fake_predictor = 1; FakePredictorProto fake_predictor = 1;
ZeroStateFrecencyPredictorProto zero_state_frecency_predictor = 2; FrecencyPredictorProto frecency_predictor = 2;
} }
} }
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.h"
#include <memory>
#include <vector>
#include "ash/public/cpp/app_list/app_list_features.h"
#include "base/files/scoped_temp_dir.h"
#include "base/hash.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/scoped_task_environment.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_predictor_test_util.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_config.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::ElementsAre;
using testing::FloatEq;
using testing::IsSupersetOf;
using testing::NiceMock;
using testing::Pair;
using testing::Return;
using testing::StrEq;
using testing::UnorderedElementsAre;
namespace app_list {
class ZeroStateFrecencyPredictorTest : public testing::Test {
protected:
void SetUp() override {
Test::SetUp();
config_.set_target_limit(100u);
config_.set_decay_coeff(0.5f);
predictor_ = std::make_unique<ZeroStateFrecencyPredictor>(config_);
}
ZeroStateFrecencyPredictorConfig config_;
std::unique_ptr<ZeroStateFrecencyPredictor> predictor_;
};
TEST_F(ZeroStateFrecencyPredictorTest, RankWithNoTargets) {
EXPECT_TRUE(predictor_->Rank("").empty());
}
TEST_F(ZeroStateFrecencyPredictorTest, RecordAndRank) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("C", "");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.125f)),
Pair("B", FloatEq(0.25f)),
Pair("C", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, Rename) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("B", "");
predictor_->Rename("B", "A");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.75f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, RenameNonexistentTarget) {
predictor_->Train("A", "");
predictor_->Rename("B", "C");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, Remove) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Remove("B");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.25f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, RemoveNonexistentTarget) {
predictor_->Train("A", "");
predictor_->Remove("B");
EXPECT_THAT(predictor_->Rank(""),
UnorderedElementsAre(Pair("A", FloatEq(0.5f))));
}
TEST_F(ZeroStateFrecencyPredictorTest, TargetLimitExceeded) {
ZeroStateFrecencyPredictorConfig config;
config.set_target_limit(5u);
config.set_decay_coeff(0.9999f);
ZeroStateFrecencyPredictor predictor(config);
// Insert many more targets than the target limit.
for (int i = 0; i < 50; i++) {
predictor.Train(std::to_string(i), "");
}
// Check that some of the values have been deleted, and the most recent ones
// remain. We check loose bounds on these requirements, to prevent the test
// from being tied to implementation details of the |FrecencyStore| cleanup
// logic. See |FrecencyStoreTest::CleanupOnOverflow| for a corresponding, more
// precise test.
auto ranks = predictor.Rank("");
EXPECT_LE(ranks.size(), 10ul);
EXPECT_THAT(
ranks, testing::IsSupersetOf({Pair("45", _), Pair("46", _), Pair("47", _),
Pair("48", _), Pair("49", _)}));
}
TEST_F(ZeroStateFrecencyPredictorTest, ToAndFromProto) {
predictor_->Train("A", "");
predictor_->Train("B", "");
predictor_->Train("C", "");
const auto ranks = predictor_->Rank("");
RecurrencePredictorProto proto;
predictor_->ToProto(&proto);
predictor_->FromProto(proto);
EXPECT_EQ(ranks, predictor_->Rank(""));
}
TEST_F(ZeroStateFrecencyPredictorTest, FailsWithQuery) {
ASSERT_DEATH(predictor_->Rank("query"), "");
}
} // namespace app_list
...@@ -74,12 +74,11 @@ std::unique_ptr<RecurrenceRankerProto> LoadProtoFromDisk( ...@@ -74,12 +74,11 @@ std::unique_ptr<RecurrenceRankerProto> LoadProtoFromDisk(
// Returns a new, configured instance of the predictor defined in |config|. // Returns a new, configured instance of the predictor defined in |config|.
std::unique_ptr<RecurrencePredictor> MakePredictor( std::unique_ptr<RecurrencePredictor> MakePredictor(
const RecurrenceRankerConfigProto& config) { RecurrenceRankerConfigProto config) {
if (config.has_frecency_predictor())
return std::make_unique<FrecencyPredictor>();
if (config.has_fake_predictor()) if (config.has_fake_predictor())
return std::make_unique<FakePredictor>(config.fake_predictor()); return std::make_unique<FakePredictor>();
if (config.has_zero_state_frecency_predictor())
return std::make_unique<ZeroStateFrecencyPredictor>(
config.zero_state_frecency_predictor());
NOTREACHED(); NOTREACHED();
return nullptr; return nullptr;
...@@ -101,12 +100,13 @@ RecurrenceRanker::RecurrenceRanker(const base::FilePath& filepath, ...@@ -101,12 +100,13 @@ RecurrenceRanker::RecurrenceRanker(const base::FilePath& filepath,
{base::TaskPriority::BEST_EFFORT, base::MayBlock(), {base::TaskPriority::BEST_EFFORT, base::MayBlock(),
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN}); base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN});
targets_ = std::make_unique<FrecencyStore>(config.target_limit(),
config.target_decay_coeff());
if (is_ephemeral_user_) { if (is_ephemeral_user_) {
// Ephemeral users have no persistent storage, so we don't try and load the // Ephemeral users have no persistent storage, so we don't try and load the
// proto from disk. Instead, we fall back on using a frecency predictor, // proto from disk. Instead, we fall back on using a frecency predictor,
// which is still useful with only data from the current session. // which is still useful with only data from the current session.
predictor_ = std::make_unique<ZeroStateFrecencyPredictor>( predictor_ = std::make_unique<FrecencyPredictor>();
config.fallback_predictor());
} else { } else {
predictor_ = MakePredictor(config); predictor_ = MakePredictor(config);
...@@ -132,21 +132,25 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete( ...@@ -132,21 +132,25 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete(
return; return;
} }
if (proto->has_targets())
targets_->FromProto(proto->targets());
if (proto->has_predictor()) if (proto->has_predictor())
predictor_->FromProto(proto->predictor()); predictor_->FromProto(proto->predictor());
load_from_disk_completed_ = true; load_from_disk_completed_ = true;
} }
void RecurrenceRanker::Record(const std::string& target) { void RecurrenceRanker::Record(const std::string& target) {
Record(target, "");
}
void RecurrenceRanker::Record(const std::string& target,
const std::string& query) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return; return;
predictor_->Train(target, query); targets_->Update(target);
// It might be possible that, despite just being updated, the target was
// removed from the store. Only train if the target is still valid.
Optional<unsigned int> id = targets_->GetId(target);
if (id.has_value())
predictor_->Train(id.value());
MaybeSave(); MaybeSave();
} }
...@@ -155,7 +159,7 @@ void RecurrenceRanker::Rename(const std::string& target, ...@@ -155,7 +159,7 @@ void RecurrenceRanker::Rename(const std::string& target,
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return; return;
predictor_->Rename(target, new_target); targets_->Rename(target, new_target);
MaybeSave(); MaybeSave();
} }
...@@ -163,33 +167,44 @@ void RecurrenceRanker::Remove(const std::string& target) { ...@@ -163,33 +167,44 @@ void RecurrenceRanker::Remove(const std::string& target) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return; return;
predictor_->Remove(target); targets_->Remove(target);
MaybeSave(); MaybeSave();
} }
base::flat_map<std::string, float> RecurrenceRanker::Rank() { base::flat_map<std::string, float> RecurrenceRanker::Rank() {
return Rank("");
}
base::flat_map<std::string, float> RecurrenceRanker::Rank(
const std::string& query) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return {}; return {};
return predictor_->Rank(query); // Special case for a frecency predictor. Because this is simply a wrapper
} // around the |RecurrenceRanker|'s targets store, we can directly return the
// contents of the store and avoid an uneccessary iteration through targets.
if (predictor_->GetPredictorName() == FrecencyPredictor::kPredictorName) {
base::flat_map<std::string, float> ranks;
for (const auto& pair : targets_->GetAll())
ranks[pair.first] = pair.second.last_score;
return ranks;
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) { const base::flat_map<unsigned int, float> id_ranks = predictor_->Rank();
return RankTopN(n, ""); const base::flat_map<std::string, FrecencyStore::ValueData>& targets =
targets_->GetAll();
base::flat_map<std::string, float> ranks;
for (const auto& pair : targets) {
const auto& data = pair.second;
const auto it = id_ranks.find(data.id);
if (it == id_ranks.end())
continue;
ranks[pair.first] = it->second;
}
return ranks;
} }
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN( std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) {
int n,
const std::string& query) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return {}; return {};
base::flat_map<std::string, float> ranks = Rank(query); base::flat_map<std::string, float> ranks = Rank();
std::vector<std::pair<std::string, float>> sorted_ranks(ranks.begin(), std::vector<std::pair<std::string, float>> sorted_ranks(ranks.begin(),
ranks.end()); ranks.end());
std::sort(sorted_ranks.begin(), sorted_ranks.end(), std::sort(sorted_ranks.begin(), sorted_ranks.end(),
...@@ -221,6 +236,7 @@ void RecurrenceRanker::MaybeSave() { ...@@ -221,6 +236,7 @@ void RecurrenceRanker::MaybeSave() {
void RecurrenceRanker::ToProto(RecurrenceRankerProto* proto) { void RecurrenceRanker::ToProto(RecurrenceRankerProto* proto) {
proto->set_config_hash(config_hash_); proto->set_config_hash(config_hash_);
predictor_->ToProto(proto->mutable_predictor()); predictor_->ToProto(proto->mutable_predictor());
targets_->ToProto(proto->mutable_targets());
} }
void RecurrenceRanker::ForceSaveOnNextUpdateForTesting() { void RecurrenceRanker::ForceSaveOnNextUpdateForTesting() {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace app_list { namespace app_list {
class FrecencyStore;
class RecurrencePredictor; class RecurrencePredictor;
class RecurrenceRankerProto; class RecurrenceRankerProto;
...@@ -39,11 +40,8 @@ class RecurrenceRanker { ...@@ -39,11 +40,8 @@ class RecurrenceRanker {
~RecurrenceRanker(); ~RecurrenceRanker();
// Record the use of a given target, and train the predictor on it. The one // Record the use of a given target, and train the predictor on it.
// argument version is a shortcut for an empty query string, useful for
// zero-state scenarios.
void Record(const std::string& target); void Record(const std::string& target);
void Record(const std::string& target, const std::string& query);
// Rename a target, while keeping learned information on it. // Rename a target, while keeping learned information on it.
void Rename(const std::string& target, const std::string& new_target); void Rename(const std::string& target, const std::string& new_target);
// Remove a target entirely. // Remove a target entirely.
...@@ -52,19 +50,14 @@ class RecurrenceRanker { ...@@ -52,19 +50,14 @@ class RecurrenceRanker {
// Returns a map of target to score. // Returns a map of target to score.
// - Higher scores are better. // - Higher scores are better.
// - Score are guaranteed to be in the range [0,1]. // - Score are guaranteed to be in the range [0,1].
// The zero-argument version is a shortcut for an empty query string.
base::flat_map<std::string, float> Rank(); base::flat_map<std::string, float> Rank();
base::flat_map<std::string, float> Rank(const std::string& query);
// Returns a sorted vector of <target, score> pairs. // Returns a sorted vector of <target, score> pairs.
// - Higher scores are better. // - Higher scores are better.
// - Score are guaranteed to be in the range [0,1]. // - Score are guaranteed to be in the range [0,1].
// - Pairs are sorted in descending order of score. // - Pairs are sorted in descending order of score.
// - At most n results will be returned. // - At most n results will be returned.
// The one-argument version is a shortcut for an empty query string.
std::vector<std::pair<std::string, float>> RankTopN(int n); std::vector<std::pair<std::string, float>> RankTopN(int n);
std::vector<std::pair<std::string, float>> RankTopN(int n,
const std::string& query);
private: private:
FRIEND_TEST_ALL_PREFIXES(RecurrenceRankerTest, FRIEND_TEST_ALL_PREFIXES(RecurrenceRankerTest,
...@@ -89,6 +82,8 @@ class RecurrenceRanker { ...@@ -89,6 +82,8 @@ class RecurrenceRanker {
// Internal predictor that drives ranking. // Internal predictor that drives ranking.
std::unique_ptr<RecurrencePredictor> predictor_; std::unique_ptr<RecurrencePredictor> predictor_;
// A store of all possible targets to return.
std::unique_ptr<FrecencyStore> targets_;
// Where to save the ranker. // Where to save the ranker.
const base::FilePath proto_filepath_; const base::FilePath proto_filepath_;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
syntax = "proto2"; syntax = "proto2";
import "frecency_store.proto";
import "recurrence_predictor.proto"; import "recurrence_predictor.proto";
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
...@@ -17,4 +18,6 @@ message RecurrenceRankerProto { ...@@ -17,4 +18,6 @@ message RecurrenceRankerProto {
optional uint32 config_hash = 1; optional uint32 config_hash = 1;
// Serialisation of the predictor used by the ranker. // Serialisation of the predictor used by the ranker.
optional RecurrencePredictorProto predictor = 2; optional RecurrencePredictorProto predictor = 2;
// Serialisation of the store of targets.
optional FrecencyStoreProto targets = 3;
} }
...@@ -13,33 +13,23 @@ package app_list; ...@@ -13,33 +13,23 @@ package app_list;
// Warning: this cannot contain any map fields, as they cannot be relied upon // Warning: this cannot contain any map fields, as they cannot be relied upon
// for a consistent hash. // for a consistent hash.
message RecurrenceRankerConfigProto { message RecurrenceRankerConfigProto {
// Fields with IDs 1 (target_limit) and 2 (target_decay_coeff) have been // The soft-maximum number of targets that are stored within the ranker.
// deleted. required uint32 target_limit = 1;
reserved 1; // The frecency parameter used to control the frequency-recency tradeoff that
reserved 2; // determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency.
required float target_decay_coeff = 2;
required uint32 min_seconds_between_saves = 3; required uint32 min_seconds_between_saves = 3;
// Config for a fake predictor, used for testing. // Config for a fake predictor, used for testing.
message FakePredictorConfig {} message FakePredictorConfig {}
// Config for a frecency predictor. // Config for a frecency predictor.
message ZeroStateFrecencyPredictorConfig { message FrecencyPredictorConfig {}
// The soft-maximum number of targets that are stored within the predictor.
required uint32 target_limit = 201;
// The frecency parameter used to control the frequency-recency tradeoff
// that determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency.
required float decay_coeff = 202;
}
// The choice of which kind of predictor to use, and its configuration. // The choice of which kind of predictor to use, and its configuration.
oneof predictor_config { oneof predictor_config {
FakePredictorConfig fake_predictor = 10001; FakePredictorConfig fake_predictor = 4;
ZeroStateFrecencyPredictorConfig zero_state_frecency_predictor = 10002; FrecencyPredictorConfig frecency_predictor = 5;
} }
// Configuration for a frecency predictor used as a fallback if the user is
// ephemeral.
required ZeroStateFrecencyPredictorConfig fallback_predictor = 11000;
} }
...@@ -36,14 +36,13 @@ class RecurrenceRankerTest : public testing::Test { ...@@ -36,14 +36,13 @@ class RecurrenceRankerTest : public testing::Test {
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
ranker_filepath_ = temp_dir_.GetPath().AppendASCII("recurrence_ranker"); ranker_filepath_ = temp_dir_.GetPath().AppendASCII("recurrence_ranker");
auto* fallback = config_.mutable_fallback_predictor(); config_.set_target_limit(1000u);
fallback->set_target_limit(0u); config_.set_target_decay_coeff(0.75f);
fallback->set_decay_coeff(0.0f); config_.set_min_seconds_between_saves(120);
// Even if empty, the setting of the oneof |predictor_config| in // Even if empty, the setting of the oneof |predictor_config| in
// |RecurrenceRankerConfigProto| is used to determine which predictor is // |RecurrenceRankerConfigProto| is used to determine which predictor is
// constructed. // constructed.
config_.mutable_fake_predictor(); config_.mutable_fake_predictor();
config_.set_min_seconds_between_saves(5);
ranker_ = ranker_ =
std::make_unique<RecurrenceRanker>(ranker_filepath_, config_, false); std::make_unique<RecurrenceRanker>(ranker_filepath_, config_, false);
...@@ -58,9 +57,17 @@ class RecurrenceRankerTest : public testing::Test { ...@@ -58,9 +57,17 @@ class RecurrenceRankerTest : public testing::Test {
auto* counts = auto* counts =
proto.mutable_predictor()->mutable_fake_predictor()->mutable_counts(); proto.mutable_predictor()->mutable_fake_predictor()->mutable_counts();
(*counts)["A"] = 1.0f; (*counts)[0u] = 1.0f;
(*counts)["B"] = 2.0f; (*counts)[1u] = 2.0f;
(*counts)["C"] = 1.0f; (*counts)[2u] = 1.0f;
FrecencyStore store(1000, 0.75f);
store.Update("A");
store.Update("B");
store.Update("B");
store.Update("C");
FrecencyStoreProto targets_proto;
store.ToProto(proto.mutable_targets());
return proto; return proto;
} }
...@@ -74,7 +81,7 @@ class RecurrenceRankerTest : public testing::Test { ...@@ -74,7 +81,7 @@ class RecurrenceRankerTest : public testing::Test {
base::FilePath ranker_filepath_; base::FilePath ranker_filepath_;
}; };
TEST_F(RecurrenceRankerTest, Record) { TEST_F(RecurrenceRankerTest, CheckRecord) {
ranker_->Record("A"); ranker_->Record("A");
ranker_->Record("B"); ranker_->Record("B");
ranker_->Record("B"); ranker_->Record("B");
...@@ -83,7 +90,7 @@ TEST_F(RecurrenceRankerTest, Record) { ...@@ -83,7 +90,7 @@ TEST_F(RecurrenceRankerTest, Record) {
Pair("B", FloatEq(2.0f)))); Pair("B", FloatEq(2.0f))));
} }
TEST_F(RecurrenceRankerTest, Rename) { TEST_F(RecurrenceRankerTest, CheckRename) {
ranker_->Record("A"); ranker_->Record("A");
ranker_->Record("B"); ranker_->Record("B");
ranker_->Record("B"); ranker_->Record("B");
...@@ -92,7 +99,7 @@ TEST_F(RecurrenceRankerTest, Rename) { ...@@ -92,7 +99,7 @@ TEST_F(RecurrenceRankerTest, Rename) {
EXPECT_THAT(ranker_->Rank(), ElementsAre(Pair("A", FloatEq(2.0f)))); EXPECT_THAT(ranker_->Rank(), ElementsAre(Pair("A", FloatEq(2.0f))));
} }
TEST_F(RecurrenceRankerTest, Remove) { TEST_F(RecurrenceRankerTest, CheckRemove) {
ranker_->Record("A"); ranker_->Record("A");
ranker_->Record("B"); ranker_->Record("B");
ranker_->Record("B"); ranker_->Record("B");
...@@ -206,14 +213,10 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) { ...@@ -206,14 +213,10 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) {
// Construct a second ranker with a slightly different config. // Construct a second ranker with a slightly different config.
RecurrenceRankerConfigProto other_config; RecurrenceRankerConfigProto other_config;
auto* fallback = other_config.mutable_fallback_predictor(); other_config.set_target_limit(1000u);
fallback->set_target_limit(0u); other_config.set_target_decay_coeff(0.76f);
fallback->set_decay_coeff(0.0f); other_config.set_min_seconds_between_saves(120);
other_config.mutable_fake_predictor(); other_config.mutable_fake_predictor();
// This is different.
other_config.set_min_seconds_between_saves(
config_.min_seconds_between_saves() + 1);
RecurrenceRanker other_ranker(ranker_filepath_, other_config, false); RecurrenceRanker other_ranker(ranker_filepath_, other_config, false);
Wait(); Wait();
...@@ -226,36 +229,31 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) { ...@@ -226,36 +229,31 @@ TEST_F(RecurrenceRankerTest, SavedRankerRejectedIfConfigMismatched) {
} }
TEST_F(RecurrenceRankerTest, EphemeralUsersUseFrecencyPredictor) { TEST_F(RecurrenceRankerTest, EphemeralUsersUseFrecencyPredictor) {
RecurrenceRanker ephemeral_ranker(ranker_filepath_, config_, true); auto ephemeral_ranker =
std::make_unique<RecurrenceRanker>(ranker_filepath_, config_, true);
Wait(); Wait();
EXPECT_THAT(ephemeral_ranker.GetPredictorNameForTesting(), EXPECT_THAT(ephemeral_ranker->GetPredictorNameForTesting(),
StrEq(ZeroStateFrecencyPredictor::kPredictorName)); StrEq(FrecencyPredictor().GetPredictorName()));
} }
TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) { TEST_F(RecurrenceRankerTest, CheckFrecencyPredictor) {
RecurrenceRankerConfigProto config; RecurrenceRankerConfigProto config;
config.set_min_seconds_between_saves(5); config.set_target_limit(1000u);
auto* predictor = config.mutable_zero_state_frecency_predictor(); config.set_target_decay_coeff(0.5f);
predictor->set_target_limit(100u); config.set_min_seconds_between_saves(120);
predictor->set_decay_coeff(0.5f); config.mutable_frecency_predictor();
auto* fallback = config.mutable_fallback_predictor();
fallback->set_target_limit(100u); RecurrenceRanker frecency_ranker(ranker_filepath_, config, false);
fallback->set_decay_coeff(0.5f);
RecurrenceRanker ranker(ranker_filepath_, config, false);
Wait(); Wait();
ranker.Record("A"); frecency_ranker.Record("A");
ranker.Record("D"); frecency_ranker.Record("B");
ranker.Record("C"); frecency_ranker.Record("C");
ranker.Record("E");
ranker.Rename("D", "B"); EXPECT_THAT(frecency_ranker.Rank(),
ranker.Remove("E"); UnorderedElementsAre(Pair("A", FloatEq(0.125f)),
ranker.Rename("E", "A"); Pair("B", FloatEq(0.25f)),
Pair("C", FloatEq(0.5f))));
EXPECT_THAT(ranker.Rank(), UnorderedElementsAre(Pair("A", FloatEq(0.0625f)),
Pair("B", FloatEq(0.125f)),
Pair("C", FloatEq(0.25f))));
} }
} // namespace app_list } // namespace app_list
...@@ -4632,7 +4632,6 @@ test("unit_tests") { ...@@ -4632,7 +4632,6 @@ test("unit_tests") {
"../browser/ui/app_list/search/search_result_ranker/app_launch_predictor_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/app_launch_predictor_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/app_search_result_ranker_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/app_search_result_ranker_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/frecency_store_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/frecency_store_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/recurrence_predictor_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/recurrence_ranker_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/recurrence_ranker_unittest.cc",
"../browser/ui/app_list/search/settings_shortcut/settings_shortcut_provider_unittest.cc", "../browser/ui/app_list/search/settings_shortcut/settings_shortcut_provider_unittest.cc",
"../browser/ui/app_list/search/settings_shortcut/settings_shortcut_result_unittest.cc", "../browser/ui/app_list/search/settings_shortcut/settings_shortcut_result_unittest.cc",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment