Commit 9279abb7 authored by tby's avatar tby Committed by Commit Bot

[Dolphin] Remove the distinction between conditional and zero-state.

Previously, we had two copies of each Record, Train, and Rank method for
the ranker and each predictor: one that took a query and one that didn't.
The idea was that predictors (and therefore the ranker) was set up to
either work in a zero-state or condition-based environment, and clients
should only call the appropriate methods.

This has limitations, in particular, zero-state predictors can't be used
in an ensemble model along with condition-based ones. This CL removes
the distinction between these methods.

Specific changes are as follows:

1. Predictors now only have Train(target, condition) and
   Rank(condition) methods. The non-condition versions have been
   deleted.

2. The Ranker's zero-state Train and Rank methods are now just
   shortcuts for having an empty-string condition.

3. All predictors labelled ZeroStateX have been renamed to just X. It
   seems most of our predictors won't use the condition and, in order
   to incorporate a condition, will be wrapped in some kind of
   ConditionalPredictor. So it makes sense for the default naming to not
   specify zero-state-ness, as the current names are very verbose.

Bug: 921444
Change-Id: I5273715501319cda951b717d755e1abb559d4dd6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1670668
Commit-Queue: Tony Yeoman <tby@chromium.org>
Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Cr-Commit-Position: refs/heads/master@{#671894}
parent d851243b
...@@ -229,10 +229,9 @@ DEFINE_EQUIVTO_PROTO_LITE_3(FrecencyStoreProto_ValueData, ...@@ -229,10 +229,9 @@ DEFINE_EQUIVTO_PROTO_LITE_3(FrecencyStoreProto_ValueData,
last_score, last_score,
last_num_updates); last_num_updates);
DEFINE_EQUIVTO_PROTO_LITE_1(ZeroStateHourBinPredictorProto, DEFINE_EQUIVTO_PROTO_LITE_1(HourBinPredictorProto, binned_frequency_table);
binned_frequency_table);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateHourBinPredictorProto_FrequencyTable, DEFINE_EQUIVTO_PROTO_LITE_2(HourBinPredictorProto_FrequencyTable,
total_counts, total_counts,
frequency); frequency);
...@@ -245,7 +244,7 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable, ...@@ -245,7 +244,7 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable,
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto, DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto,
fake_predictor, fake_predictor,
zero_state_frecency_predictor); frecency_predictor);
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrenceRankerProto, config_hash, predictor); DEFINE_EQUIVTO_PROTO_LITE_2(RecurrenceRankerProto, config_hash, predictor);
...@@ -257,11 +256,9 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score, ...@@ -257,11 +256,9 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score,
num_of_trains_at_last_update, num_of_trains_at_last_update,
last_score); last_score);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto, DEFINE_EQUIVTO_PROTO_LITE_2(FrecencyPredictorProto, targets, num_updates);
targets,
num_updates);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto_TargetData, DEFINE_EQUIVTO_PROTO_LITE_2(FrecencyPredictorProto_TargetData,
last_score, last_score,
last_num_updates); last_num_updates);
......
...@@ -29,9 +29,9 @@ enum class SerializationError { ...@@ -29,9 +29,9 @@ enum class SerializationError {
kTargetsMissingError = 5, kTargetsMissingError = 5,
kConditionsMissingError = 6, kConditionsMissingError = 6,
kFakePredictorLoadingError = 7, kFakePredictorLoadingError = 7,
kZeroStateFrecencyPredictorLoadingError = 8, kFrecencyPredictorLoadingError = 8,
kZeroStateHourBinnedPredictorLoadingError = 9, kHourBinnedPredictorLoadingError = 9,
kMaxValue = kZeroStateHourBinnedPredictorLoadingError, kMaxValue = kHourBinnedPredictorLoadingError,
}; };
// Represents errors where a RecurrenceRanker is used in a way not supported by // Represents errors where a RecurrenceRanker is used in a way not supported by
......
...@@ -15,29 +15,6 @@ constexpr int kHoursADay = 24; ...@@ -15,29 +15,6 @@ constexpr int kHoursADay = 24;
} // namespace } // namespace
void RecurrencePredictor::Train(unsigned int target) {
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
}
void RecurrencePredictor::Train(unsigned int target, unsigned int condition) {
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank() {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank(
unsigned int condition) {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
FakePredictor::FakePredictor(const FakePredictorConfig& config) { FakePredictor::FakePredictor(const FakePredictorConfig& config) {
// The fake predictor should only be used for testing, not in production. // The fake predictor should only be used for testing, not in production.
// Record an error so we know if it is being used. // Record an error so we know if it is being used.
...@@ -51,11 +28,12 @@ const char* FakePredictor::GetPredictorName() const { ...@@ -51,11 +28,12 @@ const char* FakePredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
} }
void FakePredictor::Train(unsigned int target) { void FakePredictor::Train(unsigned int target, unsigned int condition) {
counts_[target] += 1.0f; counts_[target] += 1.0f;
} }
base::flat_map<unsigned int, float> FakePredictor::Rank() { base::flat_map<unsigned int, float> FakePredictor::Rank(
unsigned int condition) {
return counts_; return counts_;
} }
...@@ -70,15 +48,23 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) { ...@@ -70,15 +48,23 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
LogSerializationError(SerializationError::kFakePredictorLoadingError); LogSerializationError(SerializationError::kFakePredictorLoadingError);
return; return;
} }
auto predictor = proto.fake_predictor();
for (const auto& pair : predictor.counts()) for (const auto& pair : proto.fake_predictor().counts())
counts_[pair.first] = pair.second; counts_[pair.first] = pair.second;
} }
DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config) {} DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config) {}
DefaultPredictor::~DefaultPredictor() {} DefaultPredictor::~DefaultPredictor() {}
void DefaultPredictor::Train(unsigned int target, unsigned int condition) {}
base::flat_map<unsigned int, float> DefaultPredictor::Rank(
unsigned int condition) {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
const char DefaultPredictor::kPredictorName[] = "DefaultPredictor"; const char DefaultPredictor::kPredictorName[] = "DefaultPredictor";
const char* DefaultPredictor::GetPredictorName() const { const char* DefaultPredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
...@@ -88,25 +74,24 @@ void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {} ...@@ -88,25 +74,24 @@ void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {}
void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {} void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {}
ZeroStateFrecencyPredictor::ZeroStateFrecencyPredictor( FrecencyPredictor::FrecencyPredictor(const FrecencyPredictorConfig& config)
const ZeroStateFrecencyPredictorConfig& config)
: decay_coeff_(config.decay_coeff()) {} : decay_coeff_(config.decay_coeff()) {}
ZeroStateFrecencyPredictor::~ZeroStateFrecencyPredictor() = default; FrecencyPredictor::~FrecencyPredictor() = default;
const char ZeroStateFrecencyPredictor::kPredictorName[] = const char FrecencyPredictor::kPredictorName[] = "FrecencyPredictor";
"ZeroStateFrecencyPredictor"; const char* FrecencyPredictor::GetPredictorName() const {
const char* ZeroStateFrecencyPredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
} }
void ZeroStateFrecencyPredictor::Train(unsigned int target) { void FrecencyPredictor::Train(unsigned int target, unsigned int condition) {
++num_updates_; ++num_updates_;
TargetData& data = targets_[target]; TargetData& data = targets_[target];
DecayScore(&data); DecayScore(&data);
data.last_score += 1.0f - decay_coeff_; data.last_score += 1.0f - decay_coeff_;
} }
base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() { base::flat_map<unsigned int, float> FrecencyPredictor::Rank(
unsigned int condition) {
base::flat_map<unsigned int, float> result; base::flat_map<unsigned int, float> result;
for (auto& pair : targets_) { for (auto& pair : targets_) {
DecayScore(&pair.second); DecayScore(&pair.second);
...@@ -115,9 +100,8 @@ base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() { ...@@ -115,9 +100,8 @@ base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() {
return result; return result;
} }
void ZeroStateFrecencyPredictor::ToProto( void FrecencyPredictor::ToProto(RecurrencePredictorProto* proto) const {
RecurrencePredictorProto* proto) const { auto* predictor = proto->mutable_frecency_predictor();
auto* predictor = proto->mutable_zero_state_frecency_predictor();
predictor->set_num_updates(num_updates_); predictor->set_num_updates(num_updates_);
...@@ -129,14 +113,12 @@ void ZeroStateFrecencyPredictor::ToProto( ...@@ -129,14 +113,12 @@ void ZeroStateFrecencyPredictor::ToProto(
} }
} }
void ZeroStateFrecencyPredictor::FromProto( void FrecencyPredictor::FromProto(const RecurrencePredictorProto& proto) {
const RecurrencePredictorProto& proto) { if (!proto.has_frecency_predictor()) {
if (!proto.has_zero_state_frecency_predictor()) { LogSerializationError(SerializationError::kFrecencyPredictorLoadingError);
LogSerializationError(
SerializationError::kZeroStateFrecencyPredictorLoadingError);
return; return;
} }
const auto& predictor = proto.zero_state_frecency_predictor(); const auto& predictor = proto.frecency_predictor();
num_updates_ = predictor.num_updates(); num_updates_ = predictor.num_updates();
...@@ -148,7 +130,7 @@ void ZeroStateFrecencyPredictor::FromProto( ...@@ -148,7 +130,7 @@ void ZeroStateFrecencyPredictor::FromProto(
targets_.swap(targets); targets_.swap(targets);
} }
void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) { void FrecencyPredictor::DecayScore(TargetData* data) {
int time_since_update = num_updates_ - data->last_num_updates; int time_since_update = num_updates_ - data->last_num_updates;
if (time_since_update > 0) { if (time_since_update > 0) {
...@@ -157,25 +139,22 @@ void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) { ...@@ -157,25 +139,22 @@ void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) {
} }
} }
ZeroStateHourBinPredictor::ZeroStateHourBinPredictor( HourBinPredictor::HourBinPredictor(const HourBinPredictorConfig& config)
const ZeroStateHourBinPredictorConfig& config)
: config_(config) { : config_(config) {
if (!proto_.has_last_decay_timestamp()) if (!proto_.has_last_decay_timestamp())
SetLastDecayTimestamp( SetLastDecayTimestamp(
base::Time::Now().ToDeltaSinceWindowsEpoch().InDays()); base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
} }
ZeroStateHourBinPredictor::~ZeroStateHourBinPredictor() = default; HourBinPredictor::~HourBinPredictor() = default;
const char ZeroStateHourBinPredictor::kPredictorName[] = const char HourBinPredictor::kPredictorName[] = "HourBinPredictor";
"ZeroStateHourBinPredictor";
const char* ZeroStateHourBinPredictor::GetPredictorName() const { const char* HourBinPredictor::GetPredictorName() const {
return kPredictorName; return kPredictorName;
} }
int ZeroStateHourBinPredictor::GetBinFromHourDifference( int HourBinPredictor::GetBinFromHourDifference(int hour_difference) const {
int hour_difference) const {
base::Time shifted_time = base::Time shifted_time =
base::Time::Now() + base::TimeDelta::FromHours(hour_difference); base::Time::Now() + base::TimeDelta::FromHours(hour_difference);
base::Time::Exploded exploded_time; base::Time::Exploded exploded_time;
...@@ -193,18 +172,19 @@ int ZeroStateHourBinPredictor::GetBinFromHourDifference( ...@@ -193,18 +172,19 @@ int ZeroStateHourBinPredictor::GetBinFromHourDifference(
} }
} }
int ZeroStateHourBinPredictor::GetBin() const { int HourBinPredictor::GetBin() const {
return GetBinFromHourDifference(0); return GetBinFromHourDifference(0);
} }
void ZeroStateHourBinPredictor::Train(unsigned int target) { void HourBinPredictor::Train(unsigned int target, unsigned int condition) {
int hour = GetBin(); int hour = GetBin();
auto& frequency_table = (*proto_.mutable_binned_frequency_table())[hour]; auto& frequency_table = (*proto_.mutable_binned_frequency_table())[hour];
frequency_table.set_total_counts(frequency_table.total_counts() + 1); frequency_table.set_total_counts(frequency_table.total_counts() + 1);
(*frequency_table.mutable_frequency())[target] += 1; (*frequency_table.mutable_frequency())[target] += 1;
} }
base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() { base::flat_map<unsigned int, float> HourBinPredictor::Rank(
unsigned int condition) {
base::flat_map<unsigned int, float> ranks; base::flat_map<unsigned int, float> ranks;
const auto& frequency_table_map = proto_.binned_frequency_table(); const auto& frequency_table_map = proto_.binned_frequency_table();
for (const auto& hour_and_weight : config_.bin_weights_map()) { for (const auto& hour_and_weight : config_.bin_weights_map()) {
...@@ -229,27 +209,26 @@ base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() { ...@@ -229,27 +209,26 @@ base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() {
return ranks; return ranks;
} }
void ZeroStateHourBinPredictor::ToProto(RecurrencePredictorProto* proto) const { void HourBinPredictor::ToProto(RecurrencePredictorProto* proto) const {
*proto->mutable_zero_state_hour_bin_predictor() = proto_; *proto->mutable_hour_bin_predictor() = proto_;
} }
void ZeroStateHourBinPredictor::FromProto( void HourBinPredictor::FromProto(const RecurrencePredictorProto& proto) {
const RecurrencePredictorProto& proto) { if (!proto.has_hour_bin_predictor())
if (!proto.has_zero_state_hour_bin_predictor())
return; return;
proto_ = proto.zero_state_hour_bin_predictor(); proto_ = proto.hour_bin_predictor();
if (ShouldDecay()) if (ShouldDecay())
DecayAll(); DecayAll();
} }
bool ZeroStateHourBinPredictor::ShouldDecay() { bool HourBinPredictor::ShouldDecay() {
const int today = base::Time::Now().ToDeltaSinceWindowsEpoch().InDays(); const int today = base::Time::Now().ToDeltaSinceWindowsEpoch().InDays();
// Check if we should decay the frequency // Check if we should decay the frequency
return today - proto_.last_decay_timestamp() > 7; return today - proto_.last_decay_timestamp() > 7;
} }
void ZeroStateHourBinPredictor::DecayAll() { void HourBinPredictor::DecayAll() {
SetLastDecayTimestamp(base::Time::Now().ToDeltaSinceWindowsEpoch().InDays()); SetLastDecayTimestamp(base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
auto& frequency_table_map = *proto_.mutable_binned_frequency_table(); auto& frequency_table_map = *proto_.mutable_binned_frequency_table();
for (auto it_table = frequency_table_map.begin(); for (auto it_table = frequency_table_map.begin();
......
...@@ -19,10 +19,10 @@ namespace app_list { ...@@ -19,10 +19,10 @@ namespace app_list {
using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig; using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig;
using DefaultPredictorConfig = using DefaultPredictorConfig =
RecurrenceRankerConfigProto::DefaultPredictorConfig; RecurrenceRankerConfigProto::DefaultPredictorConfig;
using ZeroStateFrecencyPredictorConfig = using FrecencyPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateFrecencyPredictorConfig; RecurrenceRankerConfigProto::FrecencyPredictorConfig;
using ZeroStateHourBinPredictorConfig = using HourBinPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateHourBinPredictorConfig; RecurrenceRankerConfigProto::HourBinPredictorConfig;
// |RecurrencePredictor| is the interface for all predictors used by // |RecurrencePredictor| is the interface for all predictors used by
// |RecurrenceRanker| to drive rankings. If a predictor has some form of // |RecurrenceRanker| to drive rankings. If a predictor has some form of
...@@ -34,16 +34,12 @@ class RecurrencePredictor { ...@@ -34,16 +34,12 @@ class RecurrencePredictor {
// Train the predictor on an occurrence of |target| coinciding with // Train the predictor on an occurrence of |target| coinciding with
// |condition|. The predictor will collect its own contextual information, eg. // |condition|. The predictor will collect its own contextual information, eg.
// time of day, as part of training. Zero-state predictors should use the // time of day, as part of training.
// one-argument version. virtual void Train(unsigned int target, unsigned int condition) = 0;
virtual void Train(unsigned int target);
virtual void Train(unsigned int target, unsigned int condition);
// Return a map of all known targets to their scores for the given condition // Return a map of all known targets to their scores for the given condition
// under this predictor. Scores must be within the range [0,1]. Zero-state // under this predictor. Scores must be within the range [0,1].
// predictors should use the zero-argument version. virtual base::flat_map<unsigned int, float> Rank(unsigned int condition) = 0;
virtual base::flat_map<unsigned int, float> Rank();
virtual base::flat_map<unsigned int, float> Rank(unsigned int condition);
virtual void ToProto(RecurrencePredictorProto* proto) const = 0; virtual void ToProto(RecurrencePredictorProto* proto) const = 0;
virtual void FromProto(const RecurrencePredictorProto& proto) = 0; virtual void FromProto(const RecurrencePredictorProto& proto) = 0;
...@@ -62,8 +58,8 @@ class FakePredictor : public RecurrencePredictor { ...@@ -62,8 +58,8 @@ class FakePredictor : public RecurrencePredictor {
~FakePredictor() override; ~FakePredictor() override;
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target) override; void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank() override; base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -85,6 +81,8 @@ class DefaultPredictor : public RecurrencePredictor { ...@@ -85,6 +81,8 @@ class DefaultPredictor : public RecurrencePredictor {
~DefaultPredictor() override; ~DefaultPredictor() override;
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -95,16 +93,15 @@ class DefaultPredictor : public RecurrencePredictor { ...@@ -95,16 +93,15 @@ class DefaultPredictor : public RecurrencePredictor {
DISALLOW_COPY_AND_ASSIGN(DefaultPredictor); DISALLOW_COPY_AND_ASSIGN(DefaultPredictor);
}; };
// ZeroStateFrecencyPredictor ranks targets according to their frecency, and // FrecencyPredictor ranks targets according to their frecency, and
// can only be used for zero-state predictions. This predictor allows for // can only be used for zero-state predictions. This predictor allows for
// frecency-based rankings with different configuration to that of the ranker's // frecency-based rankings with different configuration to that of the ranker's
// FrecencyStore. If frecency-based rankings with the same configuration as the // FrecencyStore. If frecency-based rankings with the same configuration as the
// store are needed, the DefaultPredictor should be used instead. // store are needed, the DefaultPredictor should be used instead.
class ZeroStateFrecencyPredictor : public RecurrencePredictor { class FrecencyPredictor : public RecurrencePredictor {
public: public:
explicit ZeroStateFrecencyPredictor( explicit FrecencyPredictor(const FrecencyPredictorConfig& config);
const ZeroStateFrecencyPredictorConfig& config); ~FrecencyPredictor() override;
~ZeroStateFrecencyPredictor() override;
// Records all information about a target: its id and score, along with the // Records all information about a target: its id and score, along with the
// number of updates that had occurred when the score was last calculated. // number of updates that had occurred when the score was last calculated.
...@@ -115,8 +112,8 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor { ...@@ -115,8 +112,8 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor {
}; };
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target) override; void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank() override; base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -139,23 +136,22 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor { ...@@ -139,23 +136,22 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor {
// This stores all the data of the frecency predictor. // This stores all the data of the frecency predictor.
// TODO(tby): benchmark which map is best in practice for our use. // TODO(tby): benchmark which map is best in practice for our use.
base::flat_map<unsigned int, ZeroStateFrecencyPredictor::TargetData> targets_; base::flat_map<unsigned int, FrecencyPredictor::TargetData> targets_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateFrecencyPredictor); DISALLOW_COPY_AND_ASSIGN(FrecencyPredictor);
}; };
// |ZeroStateHourBinPredictor| ranks targets according to their frequency during // |HourBinPredictor| ranks targets according to their frequency during
// the current and neighbor hour bins. It can only be used for zero-state // the current and neighbor hour bins. It can only be used for zero-state
// predictions. // predictions.
class ZeroStateHourBinPredictor : public RecurrencePredictor { class HourBinPredictor : public RecurrencePredictor {
public: public:
explicit ZeroStateHourBinPredictor( explicit HourBinPredictor(const HourBinPredictorConfig& config);
const ZeroStateHourBinPredictorConfig& config); ~HourBinPredictor() override;
~ZeroStateHourBinPredictor() override;
// RecurrencePredictor: // RecurrencePredictor:
void Train(unsigned int target) override; void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank() override; base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override; void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override; void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override; const char* GetPredictorName() const override;
...@@ -163,13 +159,11 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor { ...@@ -163,13 +159,11 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor {
static const char kPredictorName[]; static const char kPredictorName[];
private: private:
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, GetTheRightBin); FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, GetTheRightBin);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, TrainAndRankSingleBin);
TrainAndRankSingleBin); FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, TrainAndRankMultipleBin);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, ToProto);
TrainAndRankMultipleBin); FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, FromProtoDecays);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, ToProto);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, FromProtoDecays);
// Return the bin index that is |hour_difference| away from the current bin // Return the bin index that is |hour_difference| away from the current bin
// index. // index.
int GetBinFromHourDifference(int hour_difference) const; int GetBinFromHourDifference(int hour_difference) const;
...@@ -182,9 +176,9 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor { ...@@ -182,9 +176,9 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor {
void SetLastDecayTimestamp(float value) { void SetLastDecayTimestamp(float value) {
proto_.set_last_decay_timestamp(value); proto_.set_last_decay_timestamp(value);
} }
ZeroStateHourBinPredictorProto proto_; HourBinPredictorProto proto_;
ZeroStateHourBinPredictorConfig config_; HourBinPredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateHourBinPredictor); DISALLOW_COPY_AND_ASSIGN(HourBinPredictor);
}; };
} // namespace app_list } // namespace app_list
......
...@@ -14,8 +14,7 @@ message FakePredictorProto { ...@@ -14,8 +14,7 @@ message FakePredictorProto {
map<uint32, float> counts = 1; map<uint32, float> counts = 1;
} }
// Zero-state frecency predictor. message FrecencyPredictorProto {
message ZeroStateFrecencyPredictorProto {
// Field 1 (targets) has been deleted. // Field 1 (targets) has been deleted.
reserved 1; reserved 1;
...@@ -34,8 +33,7 @@ message ZeroStateFrecencyPredictorProto { ...@@ -34,8 +33,7 @@ message ZeroStateFrecencyPredictorProto {
required uint32 num_updates = 5; required uint32 num_updates = 5;
} }
// Zero-state hour bin predictor message HourBinPredictorProto {
message ZeroStateHourBinPredictorProto {
// Records all data related to a single hour bin. // Records all data related to a single hour bin.
message FrequencyTable { message FrequencyTable {
// Total number of training counts in a bin. // Total number of training counts in a bin.
...@@ -53,7 +51,7 @@ message ZeroStateHourBinPredictorProto { ...@@ -53,7 +51,7 @@ message ZeroStateHourBinPredictorProto {
message RecurrencePredictorProto { message RecurrencePredictorProto {
oneof predictor { oneof predictor {
FakePredictorProto fake_predictor = 1; FakePredictorProto fake_predictor = 1;
ZeroStateFrecencyPredictorProto zero_state_frecency_predictor = 2; FrecencyPredictorProto frecency_predictor = 2;
ZeroStateHourBinPredictorProto zero_state_hour_bin_predictor = 3; HourBinPredictorProto hour_bin_predictor = 3;
} }
} }
...@@ -81,12 +81,10 @@ std::unique_ptr<RecurrencePredictor> MakePredictor( ...@@ -81,12 +81,10 @@ std::unique_ptr<RecurrencePredictor> MakePredictor(
return std::make_unique<FakePredictor>(config.fake_predictor()); return std::make_unique<FakePredictor>(config.fake_predictor());
if (config.has_default_predictor()) if (config.has_default_predictor())
return std::make_unique<DefaultPredictor>(config.default_predictor()); return std::make_unique<DefaultPredictor>(config.default_predictor());
if (config.has_zero_state_frecency_predictor()) if (config.has_frecency_predictor())
return std::make_unique<ZeroStateFrecencyPredictor>( return std::make_unique<FrecencyPredictor>(config.frecency_predictor());
config.zero_state_frecency_predictor()); if (config.has_hour_bin_predictor())
if (config.has_zero_state_hour_bin_predictor()) return std::make_unique<HourBinPredictor>(config.hour_bin_predictor());
return std::make_unique<ZeroStateHourBinPredictor>(
config.zero_state_hour_bin_predictor());
LogConfigurationError(ConfigurationError::kInvalidPredictor); LogConfigurationError(ConfigurationError::kInvalidPredictor);
NOTREACHED(); NOTREACHED();
...@@ -215,34 +213,12 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete( ...@@ -215,34 +213,12 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete(
LogSerializationError(SerializationError::kConditionsMissingError); LogSerializationError(SerializationError::kConditionsMissingError);
} }
void RecurrenceRanker::Record(const std::string& target) {
if (!load_from_disk_completed_)
return;
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
targets_->Update(target);
} else {
predictor_->Train(targets_->Update(target));
}
MaybeSave();
}
void RecurrenceRanker::Record(const std::string& target, void RecurrenceRanker::Record(const std::string& target,
const std::string& condition) { const std::string& condition) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return; return;
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) { predictor_->Train(targets_->Update(target), conditions_->Update(condition));
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
} else {
predictor_->Train(targets_->Update(target), conditions_->Update(condition));
}
MaybeSave(); MaybeSave();
} }
...@@ -282,43 +258,23 @@ void RecurrenceRanker::RemoveCondition(const std::string& condition) { ...@@ -282,43 +258,23 @@ void RecurrenceRanker::RemoveCondition(const std::string& condition) {
MaybeSave(); MaybeSave();
} }
base::flat_map<std::string, float> RecurrenceRanker::Rank() {
if (!load_from_disk_completed_)
return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName)
return GetScoresFromFrecencyStore(targets_->GetAll());
return ZipTargetsWithScores(targets_->GetAll(), predictor_->Rank());
}
base::flat_map<std::string, float> RecurrenceRanker::Rank( base::flat_map<std::string, float> RecurrenceRanker::Rank(
const std::string& condition) { const std::string& condition) {
if (!load_from_disk_completed_) if (!load_from_disk_completed_)
return {}; return {};
// Special case the default predictor, and return the scores from the target
// frecency store.
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName)
return GetScoresFromFrecencyStore(targets_->GetAll());
base::Optional<unsigned int> condition_id = conditions_->GetId(condition); base::Optional<unsigned int> condition_id = conditions_->GetId(condition);
if (condition_id == base::nullopt) if (condition_id == base::nullopt)
return {}; return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
return ZipTargetsWithScores(targets_->GetAll(), return ZipTargetsWithScores(targets_->GetAll(),
predictor_->Rank(condition_id.value())); predictor_->Rank(condition_id.value()));
} }
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) {
if (!load_from_disk_completed_)
return {};
return SortAndTruncateRanks(n, Rank());
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN( std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(
int n, int n,
const std::string& condition) { const std::string& condition) {
......
...@@ -33,12 +33,11 @@ class RecurrenceRanker { ...@@ -33,12 +33,11 @@ class RecurrenceRanker {
bool is_ephemeral_user); bool is_ephemeral_user);
~RecurrenceRanker(); ~RecurrenceRanker();
// Record the use of a given target, and train the predictor on it. The // Record the use of a given target, and train the predictor on it. This may
// one-argument version should be used for zero-state predictions, and the // save to disk, but is not guaranteed to. The user-supplied |condition| can
// two-argument version for condition-based predictions. This may save to // be ignored if it isn't needed.
// disk, but is not guaranteed to. void Record(const std::string& target,
void Record(const std::string& target); const std::string& condition = std::string());
void Record(const std::string& target, const std::string& condition);
// Rename a target, while keeping learned information on it. This may save to // Rename a target, while keeping learned information on it. This may save to
// disk, but is not guaranteed to. // disk, but is not guaranteed to.
...@@ -56,22 +55,19 @@ class RecurrenceRanker { ...@@ -56,22 +55,19 @@ class RecurrenceRanker {
// Returns a map of target to score. // Returns a map of target to score.
// - Higher scores are better. // - Higher scores are better.
// - Score are guaranteed to be in the range [0,1]. // - Score are guaranteed to be in the range [0,1].
// The zero-argument version should be used for zero-state predictions, and // The user-supplied |condition| can be ignored if it isn't needed.
// the one-argument version for condition-based predictions. base::flat_map<std::string, float> Rank(
base::flat_map<std::string, float> Rank(); const std::string& condition = std::string());
base::flat_map<std::string, float> Rank(const std::string& condition);
// Returns a sorted vector of <target, score> pairs. // Returns a sorted vector of <target, score> pairs.
// - Higher scores are better. // - Higher scores are better.
// - Score are guaranteed to be in the range [0,1]. // - Score are guaranteed to be in the range [0,1].
// - Pairs are sorted in descending order of score. // - Pairs are sorted in descending order of score.
// - At most n results will be returned. // - At most n results will be returned.
// The zero-argument version should be used for zero-state predictions, and // The user-supplied |condition| can be ignored if it isn't needed.
// the one-argument version for condition-based predictions.
std::vector<std::pair<std::string, float>> RankTopN(int n);
std::vector<std::pair<std::string, float>> RankTopN( std::vector<std::pair<std::string, float>> RankTopN(
int n, int n,
const std::string& condition); const std::string& condition = std::string());
// TODO(921444): Create a system for cleaning up internal predictor state that // TODO(921444): Create a system for cleaning up internal predictor state that
// is stored indepent of the target/condition frecency stores. // is stored indepent of the target/condition frecency stores.
......
...@@ -28,14 +28,14 @@ message RecurrenceRankerConfigProto { ...@@ -28,14 +28,14 @@ message RecurrenceRankerConfigProto {
message DefaultPredictorConfig {} message DefaultPredictorConfig {}
// Config for a frecency predictor. // Config for a frecency predictor.
message ZeroStateFrecencyPredictorConfig { message FrecencyPredictorConfig {
// The frecency parameter used to control the frequency-recency tradeoff // The frecency parameter used to control the frequency-recency tradeoff
// that determines when targets are removed. Must be in [0.5, 1.0], with 0.5 // that determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency. // meaning only-recency and 1.0 meaning only-frequency.
required float decay_coeff = 1; required float decay_coeff = 1;
} }
message ZeroStateHourBinPredictorConfig { message HourBinPredictorConfig {
// The decay coeffficient number that control the decay rate. The decay is // The decay coeffficient number that control the decay rate. The decay is
// once a week. // once a week.
required float weekly_decay_coeff = 1; required float weekly_decay_coeff = 1;
...@@ -45,8 +45,8 @@ message RecurrenceRankerConfigProto { ...@@ -45,8 +45,8 @@ message RecurrenceRankerConfigProto {
// The choice of which kind of predictor to use, and its configuration. // The choice of which kind of predictor to use, and its configuration.
oneof predictor_config { oneof predictor_config {
FakePredictorConfig fake_predictor = 10001; FakePredictorConfig fake_predictor = 10001;
ZeroStateFrecencyPredictorConfig zero_state_frecency_predictor = 10002; FrecencyPredictorConfig frecency_predictor = 10002;
DefaultPredictorConfig default_predictor = 10003; DefaultPredictorConfig default_predictor = 10003;
ZeroStateHourBinPredictorConfig zero_state_hour_bin_predictor = 10004; HourBinPredictorConfig hour_bin_predictor = 10004;
} }
} }
...@@ -104,13 +104,19 @@ class RecurrenceRankerTest : public testing::Test { ...@@ -104,13 +104,19 @@ class RecurrenceRankerTest : public testing::Test {
value_data.set_last_num_updates(4); value_data.set_last_num_updates(4);
(*target_values)["C"] = value_data; (*target_values)["C"] = value_data;
// Make empty conditions frecency store. // Make conditions frecency store.
auto* conditions = proto.mutable_conditions(); auto* conditions = proto.mutable_conditions();
conditions->set_value_limit(0u); conditions->set_value_limit(10u);
conditions->set_decay_coeff(0.0f); conditions->set_decay_coeff(0.5f);
conditions->set_num_updates(0); conditions->set_num_updates(0);
conditions->set_next_id(0); conditions->set_next_id(0);
conditions->mutable_values();
auto* condition_values = conditions->mutable_values();
value_data.set_id(0u);
value_data.set_last_score(0.5f);
value_data.set_last_num_updates(1);
(*condition_values)[""] = value_data;
// Make FakePredictor counts. // Make FakePredictor counts.
auto* counts = auto* counts =
...@@ -369,7 +375,7 @@ TEST_F(RecurrenceRankerTest, IntegrationWithDefaultPredictor) { ...@@ -369,7 +375,7 @@ TEST_F(RecurrenceRankerTest, IntegrationWithDefaultPredictor) {
TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) { TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) {
RecurrenceRankerConfigProto config; RecurrenceRankerConfigProto config;
PartiallyPopulateConfig(&config); PartiallyPopulateConfig(&config);
auto* predictor = config.mutable_zero_state_frecency_predictor(); auto* predictor = config.mutable_frecency_predictor();
predictor->set_decay_coeff(0.5f); predictor->set_decay_coeff(0.5f);
RecurrenceRanker ranker(ranker_filepath_, config, false); RecurrenceRanker ranker(ranker_filepath_, config, false);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment