Commit 9279abb7 authored by tby's avatar tby Committed by Commit Bot

[Dolphin] Remove the distinction between conditional and zero-state.

Previously, we had two copies of each Record, Train, and Rank method for
the ranker and each predictor: one that took a query and one that didn't.
The idea was that predictors (and therefore the ranker) was set up to
either work in a zero-state or condition-based environment, and clients
should only call the appropriate methods.

This has limitations, in particular, zero-state predictors can't be used
in an ensemble model along with condition-based ones. This CL removes
the distinction between these methods.

Specific changes are as follows:

1. Predictors now only have Train(target, condition) and
   Rank(condition) methods. The non-condition versions have been
   deleted.

2. The Ranker's zero-state Train and Rank methods are now just
   shortcuts for having an empty-string condition.

3. All predictors labelled ZeroStateX have been renamed to just X. It
   seems most of our predictors won't use the condition and, in order
   to incorporate a condition, will be wrapped in some kind of
   ConditionalPredictor. So it makes sense for the default naming to not
   specify zero-state-ness, as the current names are very verbose.

Bug: 921444
Change-Id: I5273715501319cda951b717d755e1abb559d4dd6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1670668
Commit-Queue: Tony Yeoman <tby@chromium.org>
Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Cr-Commit-Position: refs/heads/master@{#671894}
parent d851243b
...@@ -229,10 +229,9 @@ DEFINE_EQUIVTO_PROTO_LITE_3(FrecencyStoreProto_ValueData, ...@@ -229,10 +229,9 @@ DEFINE_EQUIVTO_PROTO_LITE_3(FrecencyStoreProto_ValueData,
last_score, last_score,
last_num_updates); last_num_updates);
DEFINE_EQUIVTO_PROTO_LITE_1(ZeroStateHourBinPredictorProto, DEFINE_EQUIVTO_PROTO_LITE_1(HourBinPredictorProto, binned_frequency_table);
binned_frequency_table);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateHourBinPredictorProto_FrequencyTable, DEFINE_EQUIVTO_PROTO_LITE_2(HourBinPredictorProto_FrequencyTable,
total_counts, total_counts,
frequency); frequency);
...@@ -245,7 +244,7 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable, ...@@ -245,7 +244,7 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable,
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto, DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto,
fake_predictor, fake_predictor,
zero_state_frecency_predictor); frecency_predictor);
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrenceRankerProto, config_hash, predictor); DEFINE_EQUIVTO_PROTO_LITE_2(RecurrenceRankerProto, config_hash, predictor);
...@@ -257,11 +256,9 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score, ...@@ -257,11 +256,9 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score,
num_of_trains_at_last_update, num_of_trains_at_last_update,
last_score); last_score);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto, DEFINE_EQUIVTO_PROTO_LITE_2(FrecencyPredictorProto, targets, num_updates);
targets,
num_updates);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto_TargetData, DEFINE_EQUIVTO_PROTO_LITE_2(FrecencyPredictorProto_TargetData,
last_score, last_score,
last_num_updates); last_num_updates);
......
...@@ -29,9 +29,9 @@ enum class SerializationError { ...@@ -29,9 +29,9 @@ enum class SerializationError {
kTargetsMissingError = 5, kTargetsMissingError = 5,
kConditionsMissingError = 6, kConditionsMissingError = 6,
kFakePredictorLoadingError = 7, kFakePredictorLoadingError = 7,
kZeroStateFrecencyPredictorLoadingError = 8, kFrecencyPredictorLoadingError = 8,
kZeroStateHourBinnedPredictorLoadingError = 9, kHourBinnedPredictorLoadingError = 9,
kMaxValue = kZeroStateHourBinnedPredictorLoadingError, kMaxValue = kHourBinnedPredictorLoadingError,
}; };
// Represents errors where a RecurrenceRanker is used in a way not supported by // Represents errors where a RecurrenceRanker is used in a way not supported by
......
...@@ -15,29 +15,6 @@ constexpr int kHoursADay = 24; ...@@ -15,29 +15,6 @@ constexpr int kHoursADay = 24;
} // namespace } // namespace
void RecurrencePredictor::Train(unsigned int target) {
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
}
void RecurrencePredictor::Train(unsigned int target, unsigned int condition) {
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank() {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank(
unsigned int condition) {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
FakePredictor::FakePredictor(const FakePredictorConfig& config) { FakePredictor::FakePredictor(const FakePredictorConfig& config) {
// The fake predictor should only be used for testing, not in production. // The fake predictor should only be used for testing, not in production.
// Record an error so we know if it is being used. // Record an error so we know if it is being used.
...@@ -51,11 +28,12 @@ const char* FakePredictor::GetPredictorName() const { ...@@ -51,11 +28,12 @@ const char* FakePredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
} }
void FakePredictor::Train(unsigned int target) { void FakePredictor::Train(unsigned int target, unsigned int condition) {
counts_[target] += 1.0f; counts_[target] += 1.0f;
} }
base::flat_map<unsigned int, float> FakePredictor::Rank() { base::flat_map<unsigned int, float> FakePredictor::Rank(
unsigned int condition) {
return counts_; return counts_;
} }
...@@ -70,15 +48,23 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) { ...@@ -70,15 +48,23 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
LogSerializationError(SerializationError::kFakePredictorLoadingError); LogSerializationError(SerializationError::kFakePredictorLoadingError);
return; return;
} }
auto predictor = proto.fake_predictor();
for (const auto& pair : predictor.counts()) for (const auto& pair : proto.fake_predictor().counts())
counts_[pair.first] = pair.second; counts_[pair.first] = pair.second;
} }
DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config) {} DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config) {}
DefaultPredictor::~DefaultPredictor() {} DefaultPredictor::~DefaultPredictor() {}
void DefaultPredictor::Train(unsigned int target, unsigned int condition) {}
base::flat_map<unsigned int, float> DefaultPredictor::Rank(
unsigned int condition) {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
const char DefaultPredictor::kPredictorName[] = "DefaultPredictor"; const char DefaultPredictor::kPredictorName[] = "DefaultPredictor";
const char* DefaultPredictor::GetPredictorName() const { const char* DefaultPredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
...@@ -88,25 +74,24 @@ void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {} ...@@ -88,25 +74,24 @@ void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {}
void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {} void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {}
ZeroStateFrecencyPredictor::ZeroStateFrecencyPredictor( FrecencyPredictor::FrecencyPredictor(const FrecencyPredictorConfig& config)
const ZeroStateFrecencyPredictorConfig& config)
: decay_coeff_(config.decay_coeff()) {} : decay_coeff_(config.decay_coeff()) {}
ZeroStateFrecencyPredictor::~ZeroStateFrecencyPredictor() = default; FrecencyPredictor::~FrecencyPredictor() = default;
const char ZeroStateFrecencyPredictor::kPredictorName[] = const char FrecencyPredictor::kPredictorName[] = "FrecencyPredictor";
"ZeroStateFrecencyPredictor"; const char* FrecencyPredictor::GetPredictorName() const {
const char* ZeroStateFrecencyPredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
} }
void ZeroStateFrecencyPredictor::Train(unsigned int target) { void FrecencyPredictor::Train(unsigned int target, unsigned int condition) {
++num_updates_; ++num_updates_;
TargetData& data = targets_[target]; TargetData& data = targets_[target];
DecayScore(&data); DecayScore(&data);
data.last_score += 1.0f - decay_coeff_; data.last_score += 1.0f - decay_coeff_;
} }
base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() { base::flat_map<unsigned int, float> FrecencyPredictor::Rank(
unsigned int condition) {
base::flat_map<unsigned int, float> result; base::flat_map<unsigned int, float> result;
for (auto& pair : targets_) { for (auto& pair : targets_) {
DecayScore(&pair.second); DecayScore(&pair.second);
...@@ -115,9 +100,8 @@ base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() { ...@@ -115,9 +100,8 @@ base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() {
return result; return result;
} }
void ZeroStateFrecencyPredictor::ToProto( void FrecencyPredictor::ToProto(RecurrencePredictorProto* proto) const {
RecurrencePredictorProto* proto) const { auto* predictor = proto->mutable_frecency_predictor();
auto* predictor = proto->mutable_zero_state_frecency_predictor();
predictor->set_num_updates(num_updates_); predictor->set_num_updates(num_updates_);
...@@ -129,14 +113,12 @@ void ZeroStateFrecencyPredictor::ToProto( ...@@ -129,14 +113,12 @@ void ZeroStateFrecencyPredictor::ToProto(
} }
} }
void ZeroStateFrecencyPredictor::FromProto( void FrecencyPredictor::FromProto(const RecurrencePredictorProto& proto) {
const RecurrencePredictorProto& proto) { if (!proto.has_frecency_predictor()) {
if (!proto.has_zero_state_frecency_predictor()) { LogSerializationError(SerializationError::kFrecencyPredictorLoadingError);
LogSerializationError(
SerializationError::kZeroStateFrecencyPredictorLoadingError);
return; return;
} }
const auto& predictor = proto.zero_state_frecency_predictor(); const auto& predictor = proto.frecency_predictor();
num_updates_ = predictor.num_updates(); num_updates_ = predictor.num_updates();
...@@ -148,7 +130,7 @@ void ZeroStateFrecencyPredictor::FromProto( ...@@ -148,7 +130,7 @@ void ZeroStateFrecencyPredictor::FromProto(
targets_.swap(targets); targets_.swap(targets);
} }
void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) { void FrecencyPredictor::DecayScore(TargetData* data) {
int time_since_update = num_updates_ - data->last_num_updates; int time_since_update = num_updates_ - data->last_num_updates;
if (time_since_update > 0) { if (time_since_update > 0) {
...@@ -157,25 +139,22 @@ void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) { ...@@ -157,25 +139,22 @@ void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) {
} }
} }
ZeroStateHourBinPredictor::ZeroStateHourBinPredictor( HourBinPredictor::HourBinPredictor(const HourBinPredictorConfig& config)
const ZeroStateHourBinPredictorConfig& config)
: config_(config) { : config_(config) {
if (!proto_.has_last_decay_timestamp()) if (!proto_.has_last_decay_timestamp())
SetLastDecayTimestamp( SetLastDecayTimestamp(
base::Time::Now().ToDeltaSinceWindowsEpoch().InDays()); base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
} }
ZeroStateHourBinPredictor::~ZeroStateHourBinPredictor() = default; HourBinPredictor::~HourBinPredictor() = default;
const char ZeroStateHourBinPredictor::kPredictorName[] = const char HourBinPredictor::kPredictorName[] = "HourBinPredictor";
"ZeroStateHourBinPredictor";
const char* ZeroStateHourBinPredictor::GetPredictorName() const { const char* HourBinPredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
} }
int ZeroStateHourBinPredictor::GetBinFromHourDifference( int HourBinPredictor::GetBinFromHourDifference(int hour_difference) const {
int hour_difference) const {
base::Time shifted_time = base::Time shifted_time =
base::Time::Now() + base::TimeDelta::FromHours(hour_difference); base::Time::Now() + base::TimeDelta::FromHours(hour_difference);
base::Time::Exploded exploded_time; base::Time::Exploded exploded_time;
...@@ -193,18 +172,19 @@ int ZeroStateHourBinPredictor::GetBinFromHourDifference( ...@@ -193,18 +172,19 @@ int ZeroStateHourBinPredictor::GetBinFromHourDifference(
} }
} }
int ZeroStateHourBinPredictor::GetBin() const { int HourBinPredictor::GetBin() const {
return GetBinFromHourDifference(0); return GetBinFromHourDifference(0);
} }
void ZeroStateHourBinPredictor::Train(unsigned int target) { void HourBinPredictor::Train(unsigned int target, unsigned int condition) {
int hour = GetBin(); int hour = GetBin();
auto& frequency_table = (*proto_.mutable_binned_frequency_table())[hour]; auto& frequency_table = (*proto_.mutable_binned_frequency_table())[hour];
frequency_table.set_total_counts(frequency_table.total_counts() + 1); frequency_table.set_total_counts(frequency_table.total_counts() + 1);
(*frequency_table.mutable_frequency())[target] += 1; (*frequency_table.mutable_frequency())[target] += 1;
} }
base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() { base::flat_map<unsigned int, float> HourBinPredictor::Rank(
unsigned int condition) {
base::flat_map<unsigned int, float> ranks; base::flat_map<unsigned int, float> ranks;
const auto& frequency_table_map = proto_.binned_frequency_table(); const auto& frequency_table_map = proto_.binned_frequency_table();
for (const auto& hour_and_weight : config_.bin_weights_map()) { for (const auto& hour_and_weight : config_.bin_weights_map()) {
...@@ -229,27 +209,26 @@ base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() { ...@@ -229,27 +209,26 @@ base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() {
return ranks; return ranks;
} }
void ZeroStateHourBinPredictor::ToProto(RecurrencePredictorProto* proto) const { void HourBinPredictor::ToProto(RecurrencePredictorProto* proto) const {
*proto->mutable_zero_state_hour_bin_predictor() = proto_; *proto->mutable_hour_bin_predictor() = proto_;
} }
void ZeroStateHourBinPredictor::FromProto( void HourBinPredictor::FromProto(const RecurrencePredictorProto& proto) {
const RecurrencePredictorProto& proto) { if (!proto.has_hour_bin_predictor())
if (!proto.has_zero_state_hour_bin_predictor())
return; return;
proto_ = proto.zero_state_hour_bin_predictor(); proto_ = proto.hour_bin_predictor();
if (ShouldDecay()) if (ShouldDecay())
DecayAll(); DecayAll();
} }
bool ZeroStateHourBinPredictor::ShouldDecay() { bool HourBinPredictor::ShouldDecay() {
const int today = base::Time::Now().ToDeltaSinceWindowsEpoch().InDays(); const int today = base::Time::Now().ToDeltaSinceWindowsEpoch().InDays();
// Check if we should decay the frequency // Check if we should decay the frequency
return today - proto_.last_decay_timestamp() > 7; return today - proto_.last_decay_timestamp() > 7;
} }
void ZeroStateHourBinPredictor::DecayAll() { void HourBinPredictor::DecayAll() {
SetLastDecayTimestamp(base::Time::Now().ToDeltaSinceWindowsEpoch().InDays()); SetLastDecayTimestamp(base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
auto& frequency_table_map = *proto_.mutable_binned_frequency_table(); auto& frequency_table_map = *proto_.mutable_binned_frequency_table();
for (auto it_table = frequency_table_map.begin(); for (auto it_table = frequency_table_map.begin();
......
...@@ -19,10 +19,10 @@ namespace app_list { ...@@ -19,10 +19,10 @@ namespace app_list {
using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig; using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig;
using DefaultPredictorConfig = using DefaultPredictorConfig =
RecurrenceRankerConfigProto::DefaultPredictorConfig; RecurrenceRankerConfigProto::DefaultPredictorConfig;
using ZeroStateFrecencyPredictorConfig = using FrecencyPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateFrecencyPredictorConfig; RecurrenceRankerConfigProto::FrecencyPredictorConfig;
using ZeroStateHourBinPredictorConfig = using HourBinPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateHourBinPredictorConfig; RecurrenceRankerConfigProto::HourBinPredictorConfig;
// |RecurrencePredictor| is the interface for all predictors used by // |RecurrencePredictor| is the interface for all predictors used by
// |RecurrenceRanker| to drive rankings. If a predictor has some form of // |RecurrenceRanker| to drive rankings. If a predictor has some form of
...@@ -34,16 +34,12 @@ class RecurrencePredictor { ...@@ -34,16 +34,12 @@ class RecurrencePredictor {
// Train the predictor on an occurrence of |target| coinciding with // Train the predictor on an occurrence of |target| coinciding with
// |condition|. The predictor will collect its own contextual information, eg. // |condition|. The predictor will collect its own contextual information, eg.
// time of day, as part of training. Zero-state predictors should use the // time of day, as part of training.
// one-argument version. virtual void Train(unsigned int target, unsigned int condition) = 0;
virtual void Train(unsigned int target);
virtual void Train(unsigned int target, unsigned int condition);
// Return a map of all known targets to their scores for the given condition // Return a map of all known targets to their scores for the given condition
// under this predictor. Scores must be within the range [0,1]. Zero-state // under this predictor. Scores must be within the range [0,1].
// predictors should use the zero-argument version. virtual base::flat_map<unsigned int, float> Rank(unsigned int condition) = 0;
virtual base::flat_map<unsigned int, float> Rank();
virtual base::flat_map<unsigned int, float> Rank(unsigned int condition);
virtual void ToProto(RecurrencePredictorProto* proto) const = 0; virtual void ToProto(RecurrencePredictorProto* proto) const = 0;
virtual void FromProto(const RecurrencePredictorProto& proto) = 0; virtual void FromProto(const RecurrencePredictorProto& proto) = 0;
...@@ -62,8 +58,8 @@ class FakePredictor : public RecurrencePredictor { ...@@ -62,8 +58,8 @@ class FakePredictor : public RecurrencePredictor {
~FakePredictor() override; ~FakePredictor() override;
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target) override; void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank() override; base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -85,6 +81,8 @@ class DefaultPredictor : public RecurrencePredictor { ...@@ -85,6 +81,8 @@ class DefaultPredictor : public RecurrencePredictor {
~DefaultPredictor() override; ~DefaultPredictor() override;
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -95,16 +93,15 @@ class DefaultPredictor : public RecurrencePredictor { ...@@ -95,16 +93,15 @@ class DefaultPredictor : public RecurrencePredictor {
DISALLOW_COPY_AND_ASSIGN(DefaultPredictor); DISALLOW_COPY_AND_ASSIGN(DefaultPredictor);
}; };
// ZeroStateFrecencyPredictor ranks targets according to their frecency, and // FrecencyPredictor ranks targets according to their frecency, and
// can only be used for zero-state predictions. This predictor allows for // can only be used for zero-state predictions. This predictor allows for
// frecency-based rankings with different configuration to that of the ranker's // frecency-based rankings with different configuration to that of the ranker's
// FrecencyStore. If frecency-based rankings with the same configuration as the // FrecencyStore. If frecency-based rankings with the same configuration as the
// store are needed, the DefaultPredictor should be used instead. // store are needed, the DefaultPredictor should be used instead.
class ZeroStateFrecencyPredictor : public RecurrencePredictor { class FrecencyPredictor : public RecurrencePredictor {
public: public:
explicit ZeroStateFrecencyPredictor( explicit FrecencyPredictor(const FrecencyPredictorConfig& config);
const ZeroStateFrecencyPredictorConfig& config); ~FrecencyPredictor() override;
~ZeroStateFrecencyPredictor() override;
// Records all information about a target: its id and score, along with the // Records all information about a target: its id and score, along with the
// number of updates that had occurred when the score was last calculated. // number of updates that had occurred when the score was last calculated.
...@@ -115,8 +112,8 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor { ...@@ -115,8 +112,8 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor {
}; };
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target) override; void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank() override; base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -139,23 +136,22 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor { ...@@ -139,23 +136,22 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor {
// This stores all the data of the frecency predictor. // This stores all the data of the frecency predictor.
// TODO(tby): benchmark which map is best in practice for our use. // TODO(tby): benchmark which map is best in practice for our use.
base::flat_map<unsigned int, ZeroStateFrecencyPredictor::TargetData> targets_; base::flat_map<unsigned int, FrecencyPredictor::TargetData> targets_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateFrecencyPredictor); DISALLOW_COPY_AND_ASSIGN(FrecencyPredictor);
}; };
// |ZeroStateHourBinPredictor| ranks targets according to their frequency during // |HourBinPredictor| ranks targets according to their frequency during
// the current and neighbor hour bins. It can only be used for zero-state // the current and neighbor hour bins. It can only be used for zero-state
// predictions. // predictions.
class ZeroStateHourBinPredictor : public RecurrencePredictor { class HourBinPredictor : public RecurrencePredictor {
public: public:
explicit ZeroStateHourBinPredictor( explicit HourBinPredictor(const HourBinPredictorConfig& config);
const ZeroStateHourBinPredictorConfig& config); ~HourBinPredictor() override;
~ZeroStateHourBinPredictor() override;
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target) override; void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank() override; base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -163,13 +159,11 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor { ...@@ -163,13 +159,11 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor {
static const char kPredictorName[]; static const char kPredictorName[];
private: private:
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, GetTheRightBin); FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, GetTheRightBin);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, TrainAndRankSingleBin);
TrainAndRankSingleBin); FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, TrainAndRankMultipleBin);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, ToProto);
TrainAndRankMultipleBin); FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, FromProtoDecays);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, ToProto);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, FromProtoDecays);
// Return the bin index that is |hour_difference| away from the current bin // Return the bin index that is |hour_difference| away from the current bin
// index. // index.
int GetBinFromHourDifference(int hour_difference) const; int GetBinFromHourDifference(int hour_difference) const;
...@@ -182,9 +176,9 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor { ...@@ -182,9 +176,9 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor {
void SetLastDecayTimestamp(float value) { void SetLastDecayTimestamp(float value) {
proto_.set_last_decay_timestamp(value); proto_.set_last_decay_timestamp(value);
} }
ZeroStateHourBinPredictorProto proto_; HourBinPredictorProto proto_;
ZeroStateHourBinPredictorConfig config_; HourBinPredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateHourBinPredictor); DISALLOW_COPY_AND_ASSIGN(HourBinPredictor);
}; };
} // namespace app_list } // namespace app_list
......
...@@ -14,8 +14,7 @@ message FakePredictorProto { ...@@ -14,8 +14,7 @@ message FakePredictorProto {
map<uint32, float> counts = 1; map<uint32, float> counts = 1;
} }
// Zero-state frecency predictor. message FrecencyPredictorProto {
message ZeroStateFrecencyPredictorProto {
// Field 1 (targets) has been deleted. // Field 1 (targets) has been deleted.
reserved 1; reserved 1;
...@@ -34,8 +33,7 @@ message ZeroStateFrecencyPredictorProto { ...@@ -34,8 +33,7 @@ message ZeroStateFrecencyPredictorProto {
required uint32 num_updates = 5; required uint32 num_updates = 5;
} }
// Zero-state hour bin predictor message HourBinPredictorProto {
message ZeroStateHourBinPredictorProto {
// Records all data related to a single hour bin. // Records all data related to a single hour bin.
message FrequencyTable { message FrequencyTable {
// Total number of training counts in a bin. // Total number of training counts in a bin.
...@@ -53,7 +51,7 @@ message ZeroStateHourBinPredictorProto { ...@@ -53,7 +51,7 @@ message ZeroStateHourBinPredictorProto {
message RecurrencePredictorProto { message RecurrencePredictorProto {
oneof predictor { oneof predictor {
FakePredictorProto fake_predictor = 1; FakePredictorProto fake_predictor = 1;
ZeroStateFrecencyPredictorProto zero_state_frecency_predictor = 2; FrecencyPredictorProto frecency_predictor = 2;
ZeroStateHourBinPredictorProto zero_state_hour_bin_predictor = 3; HourBinPredictorProto hour_bin_predictor = 3;
} }
} }
...@@ -29,41 +29,38 @@ namespace app_list { ...@@ -29,41 +29,38 @@ namespace app_list {
RecurrencePredictorProto MakeTestingProto() { RecurrencePredictorProto MakeTestingProto() {
RecurrencePredictorProto proto; RecurrencePredictorProto proto;
auto* zero_state_hour_bin_proto = auto* hour_bin_proto = proto.mutable_hour_bin_predictor();
proto.mutable_zero_state_hour_bin_predictor(); hour_bin_proto->set_last_decay_timestamp(365);
zero_state_hour_bin_proto->set_last_decay_timestamp(365);
ZeroStateHourBinPredictorProto::FrequencyTable frequency_table; HourBinPredictorProto::FrequencyTable frequency_table;
(*frequency_table.mutable_frequency())[1u] = 3; (*frequency_table.mutable_frequency())[1u] = 3;
(*frequency_table.mutable_frequency())[2u] = 1; (*frequency_table.mutable_frequency())[2u] = 1;
frequency_table.set_total_counts(4); frequency_table.set_total_counts(4);
(*zero_state_hour_bin_proto->mutable_binned_frequency_table())[10] = (*hour_bin_proto->mutable_binned_frequency_table())[10] = frequency_table;
frequency_table;
frequency_table = ZeroStateHourBinPredictorProto::FrequencyTable(); frequency_table = HourBinPredictorProto::FrequencyTable();
(*frequency_table.mutable_frequency())[1u] = 1; (*frequency_table.mutable_frequency())[1u] = 1;
(*frequency_table.mutable_frequency())[3u] = 1; (*frequency_table.mutable_frequency())[3u] = 1;
frequency_table.set_total_counts(2); frequency_table.set_total_counts(2);
(*zero_state_hour_bin_proto->mutable_binned_frequency_table())[11] = (*hour_bin_proto->mutable_binned_frequency_table())[11] = frequency_table;
frequency_table;
return proto; return proto;
} }
class ZeroStateFrecencyPredictorTest : public testing::Test { class FrecencyPredictorTest : public testing::Test {
protected: protected:
void SetUp() override { void SetUp() override {
Test::SetUp(); Test::SetUp();
config_.set_decay_coeff(0.5f); config_.set_decay_coeff(0.5f);
predictor_ = std::make_unique<ZeroStateFrecencyPredictor>(config_); predictor_ = std::make_unique<FrecencyPredictor>(config_);
} }
ZeroStateFrecencyPredictorConfig config_; FrecencyPredictorConfig config_;
std::unique_ptr<ZeroStateFrecencyPredictor> predictor_; std::unique_ptr<FrecencyPredictor> predictor_;
}; };
class ZeroStateHourBinPredictorTest : public testing::Test { class HourBinPredictorTest : public testing::Test {
protected: protected:
void SetUp() override { void SetUp() override {
Test::SetUp(); Test::SetUp();
...@@ -74,7 +71,7 @@ class ZeroStateHourBinPredictorTest : public testing::Test { ...@@ -74,7 +71,7 @@ class ZeroStateHourBinPredictorTest : public testing::Test {
(*config_.mutable_bin_weights_map())[2] = 0.05; (*config_.mutable_bin_weights_map())[2] = 0.05;
(*config_.mutable_bin_weights_map())[-1] = 0.15; (*config_.mutable_bin_weights_map())[-1] = 0.15;
(*config_.mutable_bin_weights_map())[-2] = 0.05; (*config_.mutable_bin_weights_map())[-2] = 0.05;
predictor_ = std::make_unique<ZeroStateHourBinPredictor>(config_); predictor_ = std::make_unique<HourBinPredictor>(config_);
} }
// Sets local time according to |day_of_week| and |hour_of_day|. // Sets local time according to |day_of_week| and |hour_of_day|.
...@@ -88,8 +85,8 @@ class ZeroStateHourBinPredictorTest : public testing::Test { ...@@ -88,8 +85,8 @@ class ZeroStateHourBinPredictorTest : public testing::Test {
} }
base::ScopedMockClockOverride time_; base::ScopedMockClockOverride time_;
ZeroStateHourBinPredictorConfig config_; HourBinPredictorConfig config_;
std::unique_ptr<ZeroStateHourBinPredictor> predictor_; std::unique_ptr<HourBinPredictor> predictor_;
private: private:
// Advances time to be 0am next Sunday. // Advances time to be 0am next Sunday.
...@@ -107,58 +104,58 @@ class ZeroStateHourBinPredictorTest : public testing::Test { ...@@ -107,58 +104,58 @@ class ZeroStateHourBinPredictorTest : public testing::Test {
} }
}; };
TEST_F(ZeroStateFrecencyPredictorTest, RankWithNoTargets) { TEST_F(FrecencyPredictorTest, RankWithNoTargets) {
EXPECT_TRUE(predictor_->Rank().empty()); EXPECT_TRUE(predictor_->Rank(0u).empty());
} }
TEST_F(ZeroStateFrecencyPredictorTest, RecordAndRankSimple) { TEST_F(FrecencyPredictorTest, RecordAndRankSimple) {
predictor_->Train(2u); predictor_->Train(2u, 0u);
predictor_->Train(4u); predictor_->Train(4u, 0u);
predictor_->Train(6u); predictor_->Train(6u, 0u);
EXPECT_THAT( EXPECT_THAT(
predictor_->Rank(), predictor_->Rank(0u),
UnorderedElementsAre(Pair(2u, FloatEq(0.125f)), Pair(4u, FloatEq(0.25f)), UnorderedElementsAre(Pair(2u, FloatEq(0.125f)), Pair(4u, FloatEq(0.25f)),
Pair(6u, FloatEq(0.5f)))); Pair(6u, FloatEq(0.5f))));
} }
TEST_F(ZeroStateFrecencyPredictorTest, RecordAndRankComplex) { TEST_F(FrecencyPredictorTest, RecordAndRankComplex) {
predictor_->Train(2u); predictor_->Train(2u, 0u);
predictor_->Train(4u); predictor_->Train(4u, 0u);
predictor_->Train(6u); predictor_->Train(6u, 0u);
predictor_->Train(4u); predictor_->Train(4u, 0u);
predictor_->Train(2u); predictor_->Train(2u, 0u);
// Ranks should be deterministic. // Ranks should be deterministic.
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
EXPECT_THAT(predictor_->Rank(), EXPECT_THAT(predictor_->Rank(0u),
UnorderedElementsAre(Pair(2u, FloatEq(0.53125f)), UnorderedElementsAre(Pair(2u, FloatEq(0.53125f)),
Pair(4u, FloatEq(0.3125f)), Pair(4u, FloatEq(0.3125f)),
Pair(6u, FloatEq(0.125f)))); Pair(6u, FloatEq(0.125f))));
} }
} }
TEST_F(ZeroStateFrecencyPredictorTest, ToAndFromProto) { TEST_F(FrecencyPredictorTest, ToAndFromProto) {
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(3u); predictor_->Train(3u, 0u);
predictor_->Train(5u); predictor_->Train(5u, 0u);
RecurrencePredictorProto proto; RecurrencePredictorProto proto;
predictor_->ToProto(&proto); predictor_->ToProto(&proto);
ZeroStateFrecencyPredictor new_predictor(config_); FrecencyPredictor new_predictor(config_);
new_predictor.FromProto(proto); new_predictor.FromProto(proto);
EXPECT_TRUE(proto.has_zero_state_frecency_predictor()); EXPECT_TRUE(proto.has_frecency_predictor());
EXPECT_EQ(proto.zero_state_frecency_predictor().num_updates(), 3u); EXPECT_EQ(proto.frecency_predictor().num_updates(), 3u);
EXPECT_EQ(predictor_->Rank(), new_predictor.Rank()); EXPECT_EQ(predictor_->Rank(0u), new_predictor.Rank(0u));
} }
TEST_F(ZeroStateHourBinPredictorTest, RankWithNoTargets) { TEST_F(HourBinPredictorTest, RankWithNoTargets) {
EXPECT_TRUE(predictor_->Rank().empty()); EXPECT_TRUE(predictor_->Rank(0u).empty());
} }
TEST_F(ZeroStateHourBinPredictorTest, GetTheRightBin) { TEST_F(HourBinPredictorTest, GetTheRightBin) {
// Monday. // Monday.
for (int i = 0; i <= 23; ++i) { for (int i = 0; i <= 23; ++i) {
SetLocalTime(1, i); SetLocalTime(1, i);
...@@ -200,64 +197,64 @@ TEST_F(ZeroStateHourBinPredictorTest, GetTheRightBin) { ...@@ -200,64 +197,64 @@ TEST_F(ZeroStateHourBinPredictorTest, GetTheRightBin) {
EXPECT_EQ(predictor_->GetBinFromHourDifference(-5), 22); EXPECT_EQ(predictor_->GetBinFromHourDifference(-5), 22);
} }
TEST_F(ZeroStateHourBinPredictorTest, TrainAndRankSingleBin) { TEST_F(HourBinPredictorTest, TrainAndRankSingleBin) {
base::flat_map<int, float> weights( base::flat_map<int, float> weights(
predictor_->config_.bin_weights_map().begin(), predictor_->config_.bin_weights_map().begin(),
predictor_->config_.bin_weights_map().end()); predictor_->config_.bin_weights_map().end());
SetLocalTime(1, 10); SetLocalTime(1, 10);
predictor_->Train(1u); predictor_->Train(1u, 0u);
SetLocalTime(2, 10); SetLocalTime(2, 10);
predictor_->Train(1u); predictor_->Train(1u, 0u);
SetLocalTime(3, 10); SetLocalTime(3, 10);
predictor_->Train(2u); predictor_->Train(2u, 0u);
SetLocalTime(4, 10); SetLocalTime(4, 10);
predictor_->Train(1u); predictor_->Train(1u, 0u);
SetLocalTime(5, 10); SetLocalTime(5, 10);
predictor_->Train(2u); predictor_->Train(2u, 0u);
// Train on weekend doesn't affect the result during the week // Train on weekend doesn't affect the result during the week
SetLocalTime(0, 10); SetLocalTime(0, 10);
predictor_->Train(1u); predictor_->Train(1u, 0u);
SetLocalTime(0, 10); SetLocalTime(0, 10);
predictor_->Train(2u); predictor_->Train(2u, 0u);
SetLocalTime(1, 10); SetLocalTime(1, 10);
EXPECT_THAT(predictor_->Rank(), EXPECT_THAT(predictor_->Rank(0u),
UnorderedElementsAre(Pair(1u, FloatEq(weights[0] * 0.6)), UnorderedElementsAre(Pair(1u, FloatEq(weights[0] * 0.6)),
Pair(2u, FloatEq((weights)[0] * 0.4)))); Pair(2u, FloatEq((weights)[0] * 0.4))));
} }
TEST_F(ZeroStateHourBinPredictorTest, TrainAndRankMultipleBin) { TEST_F(HourBinPredictorTest, TrainAndRankMultipleBin) {
base::flat_map<int, float> weights( base::flat_map<int, float> weights(
predictor_->config_.bin_weights_map().begin(), predictor_->config_.bin_weights_map().begin(),
predictor_->config_.bin_weights_map().end()); predictor_->config_.bin_weights_map().end());
// For bin 10 // For bin 10
SetLocalTime(1, 10); SetLocalTime(1, 10);
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(1u); predictor_->Train(1u, 0u);
SetLocalTime(2, 10); SetLocalTime(2, 10);
predictor_->Train(2u); predictor_->Train(2u, 0u);
// For bin 11 // For bin 11
SetLocalTime(3, 11); SetLocalTime(3, 11);
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(2u); predictor_->Train(2u, 0u);
// For bin 12 // For bin 12
SetLocalTime(5, 12); SetLocalTime(5, 12);
predictor_->Train(2u); predictor_->Train(2u, 0u);
// Train on weekend. // Train on weekend.
SetLocalTime(6, 10); SetLocalTime(6, 10);
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(2u); predictor_->Train(2u, 0u);
SetLocalTime(0, 11); SetLocalTime(0, 11);
predictor_->Train(2u); predictor_->Train(2u, 0u);
// Check workdays. // Check workdays.
SetLocalTime(1, 10); SetLocalTime(1, 10);
EXPECT_THAT( EXPECT_THAT(
predictor_->Rank(), predictor_->Rank(0u),
UnorderedElementsAre( UnorderedElementsAre(
Pair(1u, FloatEq((weights)[0] * 2.0 / 3.0 + weights[1] * 0.5)), Pair(1u, FloatEq((weights)[0] * 2.0 / 3.0 + weights[1] * 0.5)),
Pair(2u, FloatEq(weights[0] * 1.0 / 3.0 + weights[1] * 0.5 + Pair(2u, FloatEq(weights[0] * 1.0 / 3.0 + weights[1] * 0.5 +
...@@ -265,28 +262,28 @@ TEST_F(ZeroStateHourBinPredictorTest, TrainAndRankMultipleBin) { ...@@ -265,28 +262,28 @@ TEST_F(ZeroStateHourBinPredictorTest, TrainAndRankMultipleBin) {
// Check weekends. // Check weekends.
SetLocalTime(0, 9); SetLocalTime(0, 9);
EXPECT_THAT(predictor_->Rank(), EXPECT_THAT(predictor_->Rank(0u),
UnorderedElementsAre( UnorderedElementsAre(
Pair(1u, FloatEq(weights[1] * 1.0 / 2.0)), Pair(1u, FloatEq(weights[1] * 1.0 / 2.0)),
Pair(2u, FloatEq(weights[1] * 1.0 / 2.0 + weights[2])))); Pair(2u, FloatEq(weights[1] * 1.0 / 2.0 + weights[2]))));
} }
TEST_F(ZeroStateHourBinPredictorTest, FromProto) { TEST_F(HourBinPredictorTest, FromProto) {
RecurrencePredictorProto proto = MakeTestingProto(); RecurrencePredictorProto proto = MakeTestingProto();
predictor_->FromProto(proto); predictor_->FromProto(proto);
SetLocalTime(1, 11); SetLocalTime(1, 11);
EXPECT_THAT( EXPECT_THAT(
predictor_->Rank(), predictor_->Rank(0u),
UnorderedElementsAre(Pair(1u, FloatEq(0.4125)), Pair(2u, FloatEq(0.0375)), UnorderedElementsAre(Pair(1u, FloatEq(0.4125)), Pair(2u, FloatEq(0.0375)),
Pair(3u, FloatEq(0.3)))); Pair(3u, FloatEq(0.3))));
} }
TEST_F(ZeroStateHourBinPredictorTest, FromProtoDecays) { TEST_F(HourBinPredictorTest, FromProtoDecays) {
RecurrencePredictorProto proto = MakeTestingProto(); RecurrencePredictorProto proto = MakeTestingProto();
proto.mutable_zero_state_hour_bin_predictor()->set_last_decay_timestamp(350); proto.mutable_hour_bin_predictor()->set_last_decay_timestamp(350);
predictor_->FromProto(proto); predictor_->FromProto(proto);
SetLocalTime(1, 11); SetLocalTime(1, 11);
EXPECT_THAT(predictor_->Rank(), EXPECT_THAT(predictor_->Rank(0u),
UnorderedElementsAre(Pair(1u, FloatEq(0.15)))); UnorderedElementsAre(Pair(1u, FloatEq(0.15))));
// Check if empty items got deleted during decay. // Check if empty items got deleted during decay.
...@@ -299,25 +296,25 @@ TEST_F(ZeroStateHourBinPredictorTest, FromProtoDecays) { ...@@ -299,25 +296,25 @@ TEST_F(ZeroStateHourBinPredictorTest, FromProtoDecays) {
1); 1);
} }
TEST_F(ZeroStateHourBinPredictorTest, ToProto) { TEST_F(HourBinPredictorTest, ToProto) {
RecurrencePredictorProto proto; RecurrencePredictorProto proto;
SetLocalTime(1, 10); SetLocalTime(1, 10);
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(2u); predictor_->Train(2u, 0u);
SetLocalTime(1, 11); SetLocalTime(1, 11);
predictor_->Train(1u); predictor_->Train(1u, 0u);
predictor_->Train(3u); predictor_->Train(3u, 0u);
predictor_->SetLastDecayTimestamp(365); predictor_->SetLastDecayTimestamp(365);
predictor_->ToProto(&proto); predictor_->ToProto(&proto);
RecurrencePredictorProto target_proto = MakeTestingProto(); RecurrencePredictorProto target_proto = MakeTestingProto();
EXPECT_TRUE(proto.has_zero_state_hour_bin_predictor()); EXPECT_TRUE(proto.has_hour_bin_predictor());
EXPECT_TRUE(EquivToProtoLite(proto.zero_state_hour_bin_predictor(), EXPECT_TRUE(EquivToProtoLite(proto.hour_bin_predictor(),
target_proto.zero_state_hour_bin_predictor())); target_proto.hour_bin_predictor()));
} }
} // namespace app_list } // namespace app_list
...@@ -81,12 +81,10 @@ std::unique_ptr<RecurrencePredictor> MakePredictor( ...@@ -81,12 +81,10 @@ std::unique_ptr<RecurrencePredictor> MakePredictor(
return std::make_unique<FakePredictor>(config.fake_predictor()); return std::make_unique<FakePredictor>(config.fake_predictor());
if (config.has_default_predictor()) if (config.has_default_predictor())
return std::make_unique<DefaultPredictor>(config.default_predictor()); return std::make_unique<DefaultPredictor>(config.default_predictor());
if (config.has_zero_state_frecency_predictor()) if (config.has_frecency_predictor())
return std::make_unique<ZeroStateFrecencyPredictor>( return std::make_unique<FrecencyPredictor>(config.frecency_predictor());
config.zero_state_frecency_predictor()); if (config.has_hour_bin_predictor())
if (config.has_zero_state_hour_bin_predictor()) return std::make_unique<HourBinPredictor>(config.hour_bin_predictor());
return std::make_unique<ZeroStateHourBinPredictor>(
config.zero_state_hour_bin_predictor());
LogConfigurationError(ConfigurationError::kInvalidPredictor); LogConfigurationError(ConfigurationError::kInvalidPredictor);
NOTREACHED(); NOTREACHED();
...@@ -215,34 +213,12 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete( ...@@ -215,34 +213,12 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete(
LogSerializationError(SerializationError::kConditionsMissingError); LogSerializationError(SerializationError::kConditionsMissingError);
} }
void RecurrenceRanker::Record(const std::string& target) {
if (!load_from_disk_completed_)
return;
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
targets_->Update(target);
} else {
predictor_->Train(targets_->Update(target));
}
MaybeSave();
}
void RecurrenceRanker::Record(const std::string& target, void RecurrenceRanker::Record(const std::string& target,
const std::string& condition) { const std::string& condition) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return; return;
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) { predictor_->Train(targets_->Update(target), conditions_->Update(condition));
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
} else {
predictor_->Train(targets_->Update(target), conditions_->Update(condition));
}
MaybeSave(); MaybeSave();
} }
...@@ -282,43 +258,23 @@ void RecurrenceRanker::RemoveCondition(const std::string& condition) { ...@@ -282,43 +258,23 @@ void RecurrenceRanker::RemoveCondition(const std::string& condition) {
MaybeSave(); MaybeSave();
} }
base::flat_map<std::string, float> RecurrenceRanker::Rank() {
if (!load_from_disk_completed_)
return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName)
return GetScoresFromFrecencyStore(targets_->GetAll());
return ZipTargetsWithScores(targets_->GetAll(), predictor_->Rank());
}
base::flat_map<std::string, float> RecurrenceRanker::Rank( base::flat_map<std::string, float> RecurrenceRanker::Rank(
const std::string& condition) { const std::string& condition) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return {}; return {};
// Special case the default predictor, and return the scores from the target
// frecency store.
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName)
return GetScoresFromFrecencyStore(targets_->GetAll());
base::Optional<unsigned int> condition_id = conditions_->GetId(condition); base::Optional<unsigned int> condition_id = conditions_->GetId(condition);
if (condition_id == base::nullopt) if (condition_id == base::nullopt)
return {}; return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
return ZipTargetsWithScores(targets_->GetAll(), return ZipTargetsWithScores(targets_->GetAll(),
predictor_->Rank(condition_id.value())); predictor_->Rank(condition_id.value()));
} }
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) {
if (!load_from_disk_completed_)
return {};
return SortAndTruncateRanks(n, Rank());
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN( std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(
int n, int n,
const std::string& condition) { const std::string& condition) {
......
...@@ -33,12 +33,11 @@ class RecurrenceRanker { ...@@ -33,12 +33,11 @@ class RecurrenceRanker {
bool is_ephemeral_user); bool is_ephemeral_user);
~RecurrenceRanker(); ~RecurrenceRanker();
// Record the use of a given target, and train the predictor on it. The // Record the use of a given target, and train the predictor on it. This may
// one-argument version should be used for zero-state predictions, and the // save to disk, but is not guaranteed to. The user-supplied |condition| can
// two-argument version for condition-based predictions. This may save to // be ignored if it isn't needed.
// disk, but is not guaranteed to. void Record(const std::string& target,
void Record(const std::string& target); const std::string& condition = std::string());
void Record(const std::string& target, const std::string& condition);
// Rename a target, while keeping learned information on it. This may save to // Rename a target, while keeping learned information on it. This may save to
// disk, but is not guaranteed to. // disk, but is not guaranteed to.
...@@ -56,22 +55,19 @@ class RecurrenceRanker { ...@@ -56,22 +55,19 @@ class RecurrenceRanker {
// Returns a map of target to score. // Returns a map of target to score.
// - Higher scores are better. // - Higher scores are better.
// - Score are guaranteed to be in the range [0,1]. // - Score are guaranteed to be in the range [0,1].
// The zero-argument version should be used for zero-state predictions, and // The user-supplied |condition| can be ignored if it isn't needed.
// the one-argument version for condition-based predictions. base::flat_map<std::string, float> Rank(
base::flat_map<std::string, float> Rank(); const std::string& condition = std::string());
base::flat_map<std::string, float> Rank(const std::string& condition);
// Returns a sorted vector of <target, score> pairs. // Returns a sorted vector of <target, score> pairs.
// - Higher scores are better. // - Higher scores are better.
// - Score are guaranteed to be in the range [0,1]. // - Score are guaranteed to be in the range [0,1].
// - Pairs are sorted in descending order of score. // - Pairs are sorted in descending order of score.
// - At most n results will be returned. // - At most n results will be returned.
// The zero-argument version should be used for zero-state predictions, and // The user-supplied |condition| can be ignored if it isn't needed.
// the one-argument version for condition-based predictions.
std::vector<std::pair<std::string, float>> RankTopN(int n);
std::vector<std::pair<std::string, float>> RankTopN( std::vector<std::pair<std::string, float>> RankTopN(
int n, int n,
const std::string& condition); const std::string& condition = std::string());
// TODO(921444): Create a system for cleaning up internal predictor state that // TODO(921444): Create a system for cleaning up internal predictor state that
// is stored indepent of the target/condition frecency stores. // is stored indepent of the target/condition frecency stores.
......
...@@ -28,14 +28,14 @@ message RecurrenceRankerConfigProto { ...@@ -28,14 +28,14 @@ message RecurrenceRankerConfigProto {
message DefaultPredictorConfig {} message DefaultPredictorConfig {}
// Config for a frecency predictor. // Config for a frecency predictor.
message ZeroStateFrecencyPredictorConfig { message FrecencyPredictorConfig {
// The frecency parameter used to control the frequency-recency tradeoff // The frecency parameter used to control the frequency-recency tradeoff
// that determines when targets are removed. Must be in [0.5, 1.0], with 0.5 // that determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency. // meaning only-recency and 1.0 meaning only-frequency.
required float decay_coeff = 1; required float decay_coeff = 1;
} }
message ZeroStateHourBinPredictorConfig { message HourBinPredictorConfig {
// The decay coeffficient number that control the decay rate. The decay is // The decay coeffficient number that control the decay rate. The decay is
// once a week. // once a week.
required float weekly_decay_coeff = 1; required float weekly_decay_coeff = 1;
...@@ -45,8 +45,8 @@ message RecurrenceRankerConfigProto { ...@@ -45,8 +45,8 @@ message RecurrenceRankerConfigProto {
// The choice of which kind of predictor to use, and its configuration. // The choice of which kind of predictor to use, and its configuration.
oneof predictor_config { oneof predictor_config {
FakePredictorConfig fake_predictor = 10001; FakePredictorConfig fake_predictor = 10001;
ZeroStateFrecencyPredictorConfig zero_state_frecency_predictor = 10002; FrecencyPredictorConfig frecency_predictor = 10002;
DefaultPredictorConfig default_predictor = 10003; DefaultPredictorConfig default_predictor = 10003;
ZeroStateHourBinPredictorConfig zero_state_hour_bin_predictor = 10004; HourBinPredictorConfig hour_bin_predictor = 10004;
} }
} }
...@@ -104,13 +104,19 @@ class RecurrenceRankerTest : public testing::Test { ...@@ -104,13 +104,19 @@ class RecurrenceRankerTest : public testing::Test {
value_data.set_last_num_updates(4); value_data.set_last_num_updates(4);
(*target_values)["C"] = value_data; (*target_values)["C"] = value_data;
// Make empty conditions frecency store. // Make conditions frecency store.
auto* conditions = proto.mutable_conditions(); auto* conditions = proto.mutable_conditions();
conditions->set_value_limit(0u); conditions->set_value_limit(10u);
conditions->set_decay_coeff(0.0f); conditions->set_decay_coeff(0.5f);
conditions->set_num_updates(0); conditions->set_num_updates(0);
conditions->set_next_id(0); conditions->set_next_id(0);
conditions->mutable_values();
auto* condition_values = conditions->mutable_values();
value_data.set_id(0u);
value_data.set_last_score(0.5f);
value_data.set_last_num_updates(1);
(*condition_values)[""] = value_data;
// Make FakePredictor counts. // Make FakePredictor counts.
auto* counts = auto* counts =
...@@ -369,7 +375,7 @@ TEST_F(RecurrenceRankerTest, IntegrationWithDefaultPredictor) { ...@@ -369,7 +375,7 @@ TEST_F(RecurrenceRankerTest, IntegrationWithDefaultPredictor) {
TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) { TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) {
RecurrenceRankerConfigProto config; RecurrenceRankerConfigProto config;
PartiallyPopulateConfig(&config); PartiallyPopulateConfig(&config);
auto* predictor = config.mutable_zero_state_frecency_predictor(); auto* predictor = config.mutable_frecency_predictor();
predictor->set_decay_coeff(0.5f); predictor->set_decay_coeff(0.5f);
RecurrenceRanker ranker(ranker_filepath_, config, false); RecurrenceRanker ranker(ranker_filepath_, config, false);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment