Commit 9279abb7 authored by tby's avatar tby Committed by Commit Bot

[Dolphin] Remove the distinction between conditional and zero-state.

Previously, we had two copies of each Record, Train, and Rank method for
the ranker and each predictor: one that took a query and one that didn't.
The idea was that predictors (and therefore the ranker) was set up to
either work in a zero-state or condition-based environment, and clients
should only call the appropriate methods.

This has limitations, in particular, zero-state predictors can't be used
in an ensemble model along with condition-based ones. This CL removes
the distinction between these methods.

Specific changes are as follows:

1. Predictors now only have Train(target, condition) and
   Rank(condition) methods. The non-condition versions have been
   deleted.

2. The Ranker's zero-state Train and Rank methods are now just
   shortcuts for having an empty-string condition.

3. All predictors labelled ZeroStateX have been renamed to just X. It
   seems most of our predictors won't use the condition and, in order
   to incorporate a condition, will be wrapped in some kind of
   ConditionalPredictor. So it makes sense for the default naming to not
   specify zero-state-ness, as the current names are very verbose.

Bug: 921444
Change-Id: I5273715501319cda951b717d755e1abb559d4dd6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1670668
Commit-Queue: Tony Yeoman <tby@chromium.org>
Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Cr-Commit-Position: refs/heads/master@{#671894}
parent d851243b
......@@ -229,10 +229,9 @@ DEFINE_EQUIVTO_PROTO_LITE_3(FrecencyStoreProto_ValueData,
last_score,
last_num_updates);
DEFINE_EQUIVTO_PROTO_LITE_1(ZeroStateHourBinPredictorProto,
binned_frequency_table);
DEFINE_EQUIVTO_PROTO_LITE_1(HourBinPredictorProto, binned_frequency_table);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateHourBinPredictorProto_FrequencyTable,
DEFINE_EQUIVTO_PROTO_LITE_2(HourBinPredictorProto_FrequencyTable,
total_counts,
frequency);
......@@ -245,7 +244,7 @@ DEFINE_EQUIVTO_PROTO_LITE_2(HourAppLaunchPredictorProto_FrequencyTable,
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrencePredictorProto,
fake_predictor,
zero_state_frecency_predictor);
frecency_predictor);
DEFINE_EQUIVTO_PROTO_LITE_2(RecurrenceRankerProto, config_hash, predictor);
......@@ -257,11 +256,9 @@ DEFINE_EQUIVTO_PROTO_LITE_2(SerializedMrfuAppLaunchPredictorProto_Score,
num_of_trains_at_last_update,
last_score);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto,
targets,
num_updates);
DEFINE_EQUIVTO_PROTO_LITE_2(FrecencyPredictorProto, targets, num_updates);
DEFINE_EQUIVTO_PROTO_LITE_2(ZeroStateFrecencyPredictorProto_TargetData,
DEFINE_EQUIVTO_PROTO_LITE_2(FrecencyPredictorProto_TargetData,
last_score,
last_num_updates);
......
......@@ -29,9 +29,9 @@ enum class SerializationError {
kTargetsMissingError = 5,
kConditionsMissingError = 6,
kFakePredictorLoadingError = 7,
kZeroStateFrecencyPredictorLoadingError = 8,
kZeroStateHourBinnedPredictorLoadingError = 9,
kMaxValue = kZeroStateHourBinnedPredictorLoadingError,
kFrecencyPredictorLoadingError = 8,
kHourBinnedPredictorLoadingError = 9,
kMaxValue = kHourBinnedPredictorLoadingError,
};
// Represents errors where a RecurrenceRanker is used in a way not supported by
......
......@@ -15,29 +15,6 @@ constexpr int kHoursADay = 24;
} // namespace
void RecurrencePredictor::Train(unsigned int target) {
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
}
void RecurrencePredictor::Train(unsigned int target, unsigned int condition) {
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank() {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
base::flat_map<unsigned int, float> RecurrencePredictor::Rank(
unsigned int condition) {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
FakePredictor::FakePredictor(const FakePredictorConfig& config) {
// The fake predictor should only be used for testing, not in production.
// Record an error so we know if it is being used.
......@@ -51,11 +28,12 @@ const char* FakePredictor::GetPredictorName() const {
return kPredictorName;
}
void FakePredictor::Train(unsigned int target) {
void FakePredictor::Train(unsigned int target, unsigned int condition) {
counts_[target] += 1.0f;
}
base::flat_map<unsigned int, float> FakePredictor::Rank() {
base::flat_map<unsigned int, float> FakePredictor::Rank(
unsigned int condition) {
return counts_;
}
......@@ -70,15 +48,23 @@ void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
LogSerializationError(SerializationError::kFakePredictorLoadingError);
return;
}
auto predictor = proto.fake_predictor();
for (const auto& pair : predictor.counts())
for (const auto& pair : proto.fake_predictor().counts())
counts_[pair.first] = pair.second;
}
DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config) {}
DefaultPredictor::~DefaultPredictor() {}
void DefaultPredictor::Train(unsigned int target, unsigned int condition) {}
base::flat_map<unsigned int, float> DefaultPredictor::Rank(
unsigned int condition) {
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
const char DefaultPredictor::kPredictorName[] = "DefaultPredictor";
const char* DefaultPredictor::GetPredictorName() const {
return kPredictorName;
......@@ -88,25 +74,24 @@ void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {}
void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {}
ZeroStateFrecencyPredictor::ZeroStateFrecencyPredictor(
const ZeroStateFrecencyPredictorConfig& config)
FrecencyPredictor::FrecencyPredictor(const FrecencyPredictorConfig& config)
: decay_coeff_(config.decay_coeff()) {}
ZeroStateFrecencyPredictor::~ZeroStateFrecencyPredictor() = default;
FrecencyPredictor::~FrecencyPredictor() = default;
const char ZeroStateFrecencyPredictor::kPredictorName[] =
"ZeroStateFrecencyPredictor";
const char* ZeroStateFrecencyPredictor::GetPredictorName() const {
const char FrecencyPredictor::kPredictorName[] = "FrecencyPredictor";
const char* FrecencyPredictor::GetPredictorName() const {
return kPredictorName;
}
void ZeroStateFrecencyPredictor::Train(unsigned int target) {
void FrecencyPredictor::Train(unsigned int target, unsigned int condition) {
++num_updates_;
TargetData& data = targets_[target];
DecayScore(&data);
data.last_score += 1.0f - decay_coeff_;
}
base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() {
base::flat_map<unsigned int, float> FrecencyPredictor::Rank(
unsigned int condition) {
base::flat_map<unsigned int, float> result;
for (auto& pair : targets_) {
DecayScore(&pair.second);
......@@ -115,9 +100,8 @@ base::flat_map<unsigned int, float> ZeroStateFrecencyPredictor::Rank() {
return result;
}
void ZeroStateFrecencyPredictor::ToProto(
RecurrencePredictorProto* proto) const {
auto* predictor = proto->mutable_zero_state_frecency_predictor();
void FrecencyPredictor::ToProto(RecurrencePredictorProto* proto) const {
auto* predictor = proto->mutable_frecency_predictor();
predictor->set_num_updates(num_updates_);
......@@ -129,14 +113,12 @@ void ZeroStateFrecencyPredictor::ToProto(
}
}
void ZeroStateFrecencyPredictor::FromProto(
const RecurrencePredictorProto& proto) {
if (!proto.has_zero_state_frecency_predictor()) {
LogSerializationError(
SerializationError::kZeroStateFrecencyPredictorLoadingError);
void FrecencyPredictor::FromProto(const RecurrencePredictorProto& proto) {
if (!proto.has_frecency_predictor()) {
LogSerializationError(SerializationError::kFrecencyPredictorLoadingError);
return;
}
const auto& predictor = proto.zero_state_frecency_predictor();
const auto& predictor = proto.frecency_predictor();
num_updates_ = predictor.num_updates();
......@@ -148,7 +130,7 @@ void ZeroStateFrecencyPredictor::FromProto(
targets_.swap(targets);
}
void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) {
void FrecencyPredictor::DecayScore(TargetData* data) {
int time_since_update = num_updates_ - data->last_num_updates;
if (time_since_update > 0) {
......@@ -157,25 +139,22 @@ void ZeroStateFrecencyPredictor::DecayScore(TargetData* data) {
}
}
ZeroStateHourBinPredictor::ZeroStateHourBinPredictor(
const ZeroStateHourBinPredictorConfig& config)
HourBinPredictor::HourBinPredictor(const HourBinPredictorConfig& config)
: config_(config) {
if (!proto_.has_last_decay_timestamp())
SetLastDecayTimestamp(
base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
}
ZeroStateHourBinPredictor::~ZeroStateHourBinPredictor() = default;
HourBinPredictor::~HourBinPredictor() = default;
const char ZeroStateHourBinPredictor::kPredictorName[] =
"ZeroStateHourBinPredictor";
const char HourBinPredictor::kPredictorName[] = "HourBinPredictor";
const char* ZeroStateHourBinPredictor::GetPredictorName() const {
const char* HourBinPredictor::GetPredictorName() const {
return kPredictorName;
}
int ZeroStateHourBinPredictor::GetBinFromHourDifference(
int hour_difference) const {
int HourBinPredictor::GetBinFromHourDifference(int hour_difference) const {
base::Time shifted_time =
base::Time::Now() + base::TimeDelta::FromHours(hour_difference);
base::Time::Exploded exploded_time;
......@@ -193,18 +172,19 @@ int ZeroStateHourBinPredictor::GetBinFromHourDifference(
}
}
int ZeroStateHourBinPredictor::GetBin() const {
int HourBinPredictor::GetBin() const {
return GetBinFromHourDifference(0);
}
void ZeroStateHourBinPredictor::Train(unsigned int target) {
void HourBinPredictor::Train(unsigned int target, unsigned int condition) {
int hour = GetBin();
auto& frequency_table = (*proto_.mutable_binned_frequency_table())[hour];
frequency_table.set_total_counts(frequency_table.total_counts() + 1);
(*frequency_table.mutable_frequency())[target] += 1;
}
base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() {
base::flat_map<unsigned int, float> HourBinPredictor::Rank(
unsigned int condition) {
base::flat_map<unsigned int, float> ranks;
const auto& frequency_table_map = proto_.binned_frequency_table();
for (const auto& hour_and_weight : config_.bin_weights_map()) {
......@@ -229,27 +209,26 @@ base::flat_map<unsigned int, float> ZeroStateHourBinPredictor::Rank() {
return ranks;
}
void ZeroStateHourBinPredictor::ToProto(RecurrencePredictorProto* proto) const {
*proto->mutable_zero_state_hour_bin_predictor() = proto_;
void HourBinPredictor::ToProto(RecurrencePredictorProto* proto) const {
*proto->mutable_hour_bin_predictor() = proto_;
}
void ZeroStateHourBinPredictor::FromProto(
const RecurrencePredictorProto& proto) {
if (!proto.has_zero_state_hour_bin_predictor())
void HourBinPredictor::FromProto(const RecurrencePredictorProto& proto) {
if (!proto.has_hour_bin_predictor())
return;
proto_ = proto.zero_state_hour_bin_predictor();
proto_ = proto.hour_bin_predictor();
if (ShouldDecay())
DecayAll();
}
bool ZeroStateHourBinPredictor::ShouldDecay() {
bool HourBinPredictor::ShouldDecay() {
const int today = base::Time::Now().ToDeltaSinceWindowsEpoch().InDays();
// Check if we should decay the frequency
return today - proto_.last_decay_timestamp() > 7;
}
void ZeroStateHourBinPredictor::DecayAll() {
void HourBinPredictor::DecayAll() {
SetLastDecayTimestamp(base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
auto& frequency_table_map = *proto_.mutable_binned_frequency_table();
for (auto it_table = frequency_table_map.begin();
......
......@@ -19,10 +19,10 @@ namespace app_list {
using FakePredictorConfig = RecurrenceRankerConfigProto::FakePredictorConfig;
using DefaultPredictorConfig =
RecurrenceRankerConfigProto::DefaultPredictorConfig;
using ZeroStateFrecencyPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateFrecencyPredictorConfig;
using ZeroStateHourBinPredictorConfig =
RecurrenceRankerConfigProto::ZeroStateHourBinPredictorConfig;
using FrecencyPredictorConfig =
RecurrenceRankerConfigProto::FrecencyPredictorConfig;
using HourBinPredictorConfig =
RecurrenceRankerConfigProto::HourBinPredictorConfig;
// |RecurrencePredictor| is the interface for all predictors used by
// |RecurrenceRanker| to drive rankings. If a predictor has some form of
......@@ -34,16 +34,12 @@ class RecurrencePredictor {
// Train the predictor on an occurrence of |target| coinciding with
// |condition|. The predictor will collect its own contextual information, eg.
// time of day, as part of training. Zero-state predictors should use the
// one-argument version.
virtual void Train(unsigned int target);
virtual void Train(unsigned int target, unsigned int condition);
// time of day, as part of training.
virtual void Train(unsigned int target, unsigned int condition) = 0;
// Return a map of all known targets to their scores for the given condition
// under this predictor. Scores must be within the range [0,1]. Zero-state
// predictors should use the zero-argument version.
virtual base::flat_map<unsigned int, float> Rank();
virtual base::flat_map<unsigned int, float> Rank(unsigned int condition);
// under this predictor. Scores must be within the range [0,1].
virtual base::flat_map<unsigned int, float> Rank(unsigned int condition) = 0;
virtual void ToProto(RecurrencePredictorProto* proto) const = 0;
virtual void FromProto(const RecurrencePredictorProto& proto) = 0;
......@@ -62,8 +58,8 @@ class FakePredictor : public RecurrencePredictor {
~FakePredictor() override;
// RecurrencePredictor:
void Train(unsigned int target) override;
base::flat_map<unsigned int, float> Rank() override;
void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
......@@ -85,6 +81,8 @@ class DefaultPredictor : public RecurrencePredictor {
~DefaultPredictor() override;
// RecurrencePredictor:
void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
......@@ -95,16 +93,15 @@ class DefaultPredictor : public RecurrencePredictor {
DISALLOW_COPY_AND_ASSIGN(DefaultPredictor);
};
// ZeroStateFrecencyPredictor ranks targets according to their frecency, and
// FrecencyPredictor ranks targets according to their frecency, and
// can only be used for zero-state predictions. This predictor allows for
// frecency-based rankings with different configuration to that of the ranker's
// FrecencyStore. If frecency-based rankings with the same configuration as the
// store are needed, the DefaultPredictor should be used instead.
class ZeroStateFrecencyPredictor : public RecurrencePredictor {
class FrecencyPredictor : public RecurrencePredictor {
public:
explicit ZeroStateFrecencyPredictor(
const ZeroStateFrecencyPredictorConfig& config);
~ZeroStateFrecencyPredictor() override;
explicit FrecencyPredictor(const FrecencyPredictorConfig& config);
~FrecencyPredictor() override;
// Records all information about a target: its id and score, along with the
// number of updates that had occurred when the score was last calculated.
......@@ -115,8 +112,8 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor {
};
// RecurrencePredictor:
void Train(unsigned int target) override;
base::flat_map<unsigned int, float> Rank() override;
void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
......@@ -139,23 +136,22 @@ class ZeroStateFrecencyPredictor : public RecurrencePredictor {
// This stores all the data of the frecency predictor.
// TODO(tby): benchmark which map is best in practice for our use.
base::flat_map<unsigned int, ZeroStateFrecencyPredictor::TargetData> targets_;
base::flat_map<unsigned int, FrecencyPredictor::TargetData> targets_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateFrecencyPredictor);
DISALLOW_COPY_AND_ASSIGN(FrecencyPredictor);
};
// |ZeroStateHourBinPredictor| ranks targets according to their frequency during
// |HourBinPredictor| ranks targets according to their frequency during
// the current and neighbor hour bins. It can only be used for zero-state
// predictions.
class ZeroStateHourBinPredictor : public RecurrencePredictor {
class HourBinPredictor : public RecurrencePredictor {
public:
explicit ZeroStateHourBinPredictor(
const ZeroStateHourBinPredictorConfig& config);
~ZeroStateHourBinPredictor() override;
explicit HourBinPredictor(const HourBinPredictorConfig& config);
~HourBinPredictor() override;
// RecurrencePredictor:
void Train(unsigned int target) override;
base::flat_map<unsigned int, float> Rank() override;
void Train(unsigned int target, unsigned int condition) override;
base::flat_map<unsigned int, float> Rank(unsigned int condition) override;
void ToProto(RecurrencePredictorProto* proto) const override;
void FromProto(const RecurrencePredictorProto& proto) override;
const char* GetPredictorName() const override;
......@@ -163,13 +159,11 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor {
static const char kPredictorName[];
private:
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, GetTheRightBin);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest,
TrainAndRankSingleBin);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest,
TrainAndRankMultipleBin);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, ToProto);
FRIEND_TEST_ALL_PREFIXES(ZeroStateHourBinPredictorTest, FromProtoDecays);
FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, GetTheRightBin);
FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, TrainAndRankSingleBin);
FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, TrainAndRankMultipleBin);
FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, ToProto);
FRIEND_TEST_ALL_PREFIXES(HourBinPredictorTest, FromProtoDecays);
// Return the bin index that is |hour_difference| away from the current bin
// index.
int GetBinFromHourDifference(int hour_difference) const;
......@@ -182,9 +176,9 @@ class ZeroStateHourBinPredictor : public RecurrencePredictor {
void SetLastDecayTimestamp(float value) {
proto_.set_last_decay_timestamp(value);
}
ZeroStateHourBinPredictorProto proto_;
ZeroStateHourBinPredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(ZeroStateHourBinPredictor);
HourBinPredictorProto proto_;
HourBinPredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(HourBinPredictor);
};
} // namespace app_list
......
......@@ -14,8 +14,7 @@ message FakePredictorProto {
map<uint32, float> counts = 1;
}
// Zero-state frecency predictor.
message ZeroStateFrecencyPredictorProto {
message FrecencyPredictorProto {
// Field 1 (targets) has been deleted.
reserved 1;
......@@ -34,8 +33,7 @@ message ZeroStateFrecencyPredictorProto {
required uint32 num_updates = 5;
}
// Zero-state hour bin predictor
message ZeroStateHourBinPredictorProto {
message HourBinPredictorProto {
// Records all data related to a single hour bin.
message FrequencyTable {
// Total number of training counts in a bin.
......@@ -53,7 +51,7 @@ message ZeroStateHourBinPredictorProto {
message RecurrencePredictorProto {
oneof predictor {
FakePredictorProto fake_predictor = 1;
ZeroStateFrecencyPredictorProto zero_state_frecency_predictor = 2;
ZeroStateHourBinPredictorProto zero_state_hour_bin_predictor = 3;
FrecencyPredictorProto frecency_predictor = 2;
HourBinPredictorProto hour_bin_predictor = 3;
}
}
......@@ -81,12 +81,10 @@ std::unique_ptr<RecurrencePredictor> MakePredictor(
return std::make_unique<FakePredictor>(config.fake_predictor());
if (config.has_default_predictor())
return std::make_unique<DefaultPredictor>(config.default_predictor());
if (config.has_zero_state_frecency_predictor())
return std::make_unique<ZeroStateFrecencyPredictor>(
config.zero_state_frecency_predictor());
if (config.has_zero_state_hour_bin_predictor())
return std::make_unique<ZeroStateHourBinPredictor>(
config.zero_state_hour_bin_predictor());
if (config.has_frecency_predictor())
return std::make_unique<FrecencyPredictor>(config.frecency_predictor());
if (config.has_hour_bin_predictor())
return std::make_unique<HourBinPredictor>(config.hour_bin_predictor());
LogConfigurationError(ConfigurationError::kInvalidPredictor);
NOTREACHED();
......@@ -215,34 +213,12 @@ void RecurrenceRanker::OnLoadProtoFromDiskComplete(
LogSerializationError(SerializationError::kConditionsMissingError);
}
void RecurrenceRanker::Record(const std::string& target) {
if (!load_from_disk_completed_)
return;
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
targets_->Update(target);
} else {
predictor_->Train(targets_->Update(target));
}
MaybeSave();
}
void RecurrenceRanker::Record(const std::string& target,
const std::string& condition) {
if (!load_from_disk_completed_)
return;
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
LogUsageError(UsageError::kInvalidTrainCall);
NOTREACHED();
} else {
predictor_->Train(targets_->Update(target), conditions_->Update(condition));
}
predictor_->Train(targets_->Update(target), conditions_->Update(condition));
MaybeSave();
}
......@@ -282,43 +258,23 @@ void RecurrenceRanker::RemoveCondition(const std::string& condition) {
MaybeSave();
}
base::flat_map<std::string, float> RecurrenceRanker::Rank() {
if (!load_from_disk_completed_)
return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName)
return GetScoresFromFrecencyStore(targets_->GetAll());
return ZipTargetsWithScores(targets_->GetAll(), predictor_->Rank());
}
base::flat_map<std::string, float> RecurrenceRanker::Rank(
const std::string& condition) {
if (!load_from_disk_completed_)
return {};
// Special case the default predictor, and return the scores from the target
// frecency store.
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName)
return GetScoresFromFrecencyStore(targets_->GetAll());
base::Optional<unsigned int> condition_id = conditions_->GetId(condition);
if (condition_id == base::nullopt)
return {};
if (predictor_->GetPredictorName() == DefaultPredictor::kPredictorName) {
// TODO(921444): The default predictor does not support queries, so we fail
// here. Once we have a suitable query-based default predictor implemented,
// change this.
LogUsageError(UsageError::kInvalidRankCall);
NOTREACHED();
return {};
}
return ZipTargetsWithScores(targets_->GetAll(),
predictor_->Rank(condition_id.value()));
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(int n) {
if (!load_from_disk_completed_)
return {};
return SortAndTruncateRanks(n, Rank());
}
std::vector<std::pair<std::string, float>> RecurrenceRanker::RankTopN(
int n,
const std::string& condition) {
......
......@@ -33,12 +33,11 @@ class RecurrenceRanker {
bool is_ephemeral_user);
~RecurrenceRanker();
// Record the use of a given target, and train the predictor on it. The
// one-argument version should be used for zero-state predictions, and the
// two-argument version for condition-based predictions. This may save to
// disk, but is not guaranteed to.
void Record(const std::string& target);
void Record(const std::string& target, const std::string& condition);
// Record the use of a given target, and train the predictor on it. This may
// save to disk, but is not guaranteed to. The user-supplied |condition| can
// be ignored if it isn't needed.
void Record(const std::string& target,
const std::string& condition = std::string());
// Rename a target, while keeping learned information on it. This may save to
// disk, but is not guaranteed to.
......@@ -56,22 +55,19 @@ class RecurrenceRanker {
// Returns a map of target to score.
// - Higher scores are better.
// - Score are guaranteed to be in the range [0,1].
// The zero-argument version should be used for zero-state predictions, and
// the one-argument version for condition-based predictions.
base::flat_map<std::string, float> Rank();
base::flat_map<std::string, float> Rank(const std::string& condition);
// The user-supplied |condition| can be ignored if it isn't needed.
base::flat_map<std::string, float> Rank(
const std::string& condition = std::string());
// Returns a sorted vector of <target, score> pairs.
// - Higher scores are better.
// - Score are guaranteed to be in the range [0,1].
// - Pairs are sorted in descending order of score.
// - At most n results will be returned.
// The zero-argument version should be used for zero-state predictions, and
// the one-argument version for condition-based predictions.
std::vector<std::pair<std::string, float>> RankTopN(int n);
// The user-supplied |condition| can be ignored if it isn't needed.
std::vector<std::pair<std::string, float>> RankTopN(
int n,
const std::string& condition);
const std::string& condition = std::string());
// TODO(921444): Create a system for cleaning up internal predictor state that
// is stored indepent of the target/condition frecency stores.
......
......@@ -28,14 +28,14 @@ message RecurrenceRankerConfigProto {
message DefaultPredictorConfig {}
// Config for a frecency predictor.
message ZeroStateFrecencyPredictorConfig {
message FrecencyPredictorConfig {
// The frecency parameter used to control the frequency-recency tradeoff
// that determines when targets are removed. Must be in [0.5, 1.0], with 0.5
// meaning only-recency and 1.0 meaning only-frequency.
required float decay_coeff = 1;
}
message ZeroStateHourBinPredictorConfig {
message HourBinPredictorConfig {
// The decay coeffficient number that control the decay rate. The decay is
// once a week.
required float weekly_decay_coeff = 1;
......@@ -45,8 +45,8 @@ message RecurrenceRankerConfigProto {
// The choice of which kind of predictor to use, and its configuration.
oneof predictor_config {
FakePredictorConfig fake_predictor = 10001;
ZeroStateFrecencyPredictorConfig zero_state_frecency_predictor = 10002;
FrecencyPredictorConfig frecency_predictor = 10002;
DefaultPredictorConfig default_predictor = 10003;
ZeroStateHourBinPredictorConfig zero_state_hour_bin_predictor = 10004;
HourBinPredictorConfig hour_bin_predictor = 10004;
}
}
......@@ -104,13 +104,19 @@ class RecurrenceRankerTest : public testing::Test {
value_data.set_last_num_updates(4);
(*target_values)["C"] = value_data;
// Make empty conditions frecency store.
// Make conditions frecency store.
auto* conditions = proto.mutable_conditions();
conditions->set_value_limit(0u);
conditions->set_decay_coeff(0.0f);
conditions->set_value_limit(10u);
conditions->set_decay_coeff(0.5f);
conditions->set_num_updates(0);
conditions->set_next_id(0);
conditions->mutable_values();
auto* condition_values = conditions->mutable_values();
value_data.set_id(0u);
value_data.set_last_score(0.5f);
value_data.set_last_num_updates(1);
(*condition_values)[""] = value_data;
// Make FakePredictor counts.
auto* counts =
......@@ -369,7 +375,7 @@ TEST_F(RecurrenceRankerTest, IntegrationWithDefaultPredictor) {
TEST_F(RecurrenceRankerTest, IntegrationWithZeroStateFrecencyPredictor) {
RecurrenceRankerConfigProto config;
PartiallyPopulateConfig(&config);
auto* predictor = config.mutable_zero_state_frecency_predictor();
auto* predictor = config.mutable_frecency_predictor();
predictor->set_decay_coeff(0.5f);
RecurrenceRanker ranker(ranker_filepath_, config, false);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment