Commit 84f0f36d authored by tby's avatar tby Committed by Commit Bot

[Cros SR] Implement query-based item ranking model.

This is the V2 model for query-based mixed-types ranking. For
experimentation purposes V1 has been kept also, and the instantiation of
one, the other, or neither, is controlled by finch parameters.

Bug: 931149
Change-Id: I4ecdbfc98dfed330a6e7406b33e89171838170c1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1661317Reviewed-by: default avatarJia Meng <jiameng@chromium.org>
Commit-Queue: Tony Yeoman <tby@chromium.org>
Cr-Commit-Position: refs/heads/master@{#669916}
parent 6acf1238
...@@ -36,7 +36,7 @@ constexpr TimeDelta kMinSecondsBetweenFetches = TimeDelta::FromSeconds(1); ...@@ -36,7 +36,7 @@ constexpr TimeDelta kMinSecondsBetweenFetches = TimeDelta::FromSeconds(1);
constexpr char kLogFileOpenType[] = "RecurrenceRanker.LogFileOpenType"; constexpr char kLogFileOpenType[] = "RecurrenceRanker.LogFileOpenType";
// Represents each model used within the SearchResultRanker. // Represents each model used within the SearchResultRanker.
enum class Model { NONE, RESULTS_LIST_GROUP_RANKER }; enum class Model { NONE, MIXED_TYPES };
// Returns the model relevant for predicting launches for results with the given // Returns the model relevant for predicting launches for results with the given
// |type|. // |type|.
...@@ -48,7 +48,7 @@ Model ModelForType(RankingItemType type) { ...@@ -48,7 +48,7 @@ Model ModelForType(RankingItemType type) {
case RankingItemType::kOmniboxDocument: case RankingItemType::kOmniboxDocument:
case RankingItemType::kOmniboxHistory: case RankingItemType::kOmniboxHistory:
case RankingItemType::kOmniboxSearch: case RankingItemType::kOmniboxSearch:
return Model::RESULTS_LIST_GROUP_RANKER; return Model::MIXED_TYPES;
default: default:
return Model::NONE; return Model::NONE;
} }
...@@ -87,35 +87,45 @@ SearchResultRanker::SearchResultRanker(Profile* profile) ...@@ -87,35 +87,45 @@ SearchResultRanker::SearchResultRanker(Profile* profile)
: enable_zero_state_mixed_types_( : enable_zero_state_mixed_types_(
app_list_features::IsZeroStateMixedTypesRankerEnabled()) { app_list_features::IsZeroStateMixedTypesRankerEnabled()) {
if (app_list_features::IsQueryBasedMixedTypesRankerEnabled()) { if (app_list_features::IsQueryBasedMixedTypesRankerEnabled()) {
results_list_boost_coefficient_ = base::GetFieldTrialParamByFeatureAsDouble(
app_list_features::kEnableQueryBasedMixedTypesRanker,
"boost_coefficient", 0.1);
RecurrenceRankerConfigProto config; RecurrenceRankerConfigProto config;
config.set_min_seconds_between_saves(240u); config.set_min_seconds_between_saves(240u);
config.set_condition_limit(0u); config.set_condition_limit(0u);
config.set_condition_decay(0.5f); config.set_condition_decay(0.5f);
config.set_target_limit(base::GetFieldTrialParamByFeatureAsInt( config.set_target_limit(base::GetFieldTrialParamByFeatureAsInt(
app_list_features::kEnableQueryBasedMixedTypesRanker, "target_limit", app_list_features::kEnableQueryBasedMixedTypesRanker, "target_limit",
200)); 200));
config.set_target_decay(base::GetFieldTrialParamByFeatureAsDouble( config.set_target_decay(base::GetFieldTrialParamByFeatureAsDouble(
app_list_features::kEnableQueryBasedMixedTypesRanker, "target_decay", app_list_features::kEnableQueryBasedMixedTypesRanker, "target_decay",
0.8f)); 0.8f));
// TODO(931149): Replace this with a more sophisticated model if the
// query-based mixed type model is being used.
config.mutable_default_predictor(); config.mutable_default_predictor();
results_list_group_ranker_ = std::make_unique<RecurrenceRanker>( if (GetFieldTrialParamByFeatureAsBool(
profile->GetPath().AppendASCII("adaptive_result_ranker.proto"), config, app_list_features::kEnableQueryBasedMixedTypesRanker,
chromeos::ProfileHelper::IsEphemeralUserProfile(profile)); "use_category_model", false)) {
results_list_group_ranker_ = std::make_unique<RecurrenceRanker>(
results_list_boost_coefficient_ = base::GetFieldTrialParamByFeatureAsDouble( profile->GetPath().AppendASCII("results_list_group_ranker.pb"),
app_list_features::kEnableQueryBasedMixedTypesRanker, config, chromeos::ProfileHelper::IsEphemeralUserProfile(profile));
"boost_coefficient", 0.1); } else {
query_based_mixed_types_ranker_ = std::make_unique<RecurrenceRanker>(
profile->GetPath().AppendASCII("query_based_mixed_types_ranker.pb"),
config, chromeos::ProfileHelper::IsEphemeralUserProfile(profile));
}
} }
profile_ = profile;
profile_ = profile;
if (auto* notifier = if (auto* notifier =
file_manager::file_tasks::FileTasksNotifier::GetForProfile( file_manager::file_tasks::FileTasksNotifier::GetForProfile(
profile_)) { profile_)) {
notifier->AddObserver(this); notifier->AddObserver(this);
} }
if (enable_zero_state_mixed_types_) { if (enable_zero_state_mixed_types_) {
RecurrenceRankerConfigProto config; RecurrenceRankerConfigProto config;
config.set_min_seconds_between_saves(240u); config.set_min_seconds_between_saves(240u);
...@@ -162,6 +172,9 @@ void SearchResultRanker::FetchRankings(const base::string16& query) { ...@@ -162,6 +172,9 @@ void SearchResultRanker::FetchRankings(const base::string16& query) {
if (results_list_group_ranker_) { if (results_list_group_ranker_) {
group_ranks_.clear(); group_ranks_.clear();
group_ranks_ = results_list_group_ranker_->Rank(); group_ranks_ = results_list_group_ranker_->Rank();
} else if (query_based_mixed_types_ranker_) {
query_mixed_ranks_.clear();
query_mixed_ranks_ = query_based_mixed_types_ranker_->Rank();
} }
} }
...@@ -170,35 +183,46 @@ void SearchResultRanker::Rank(Mixer::SortedResults* results) { ...@@ -170,35 +183,46 @@ void SearchResultRanker::Rank(Mixer::SortedResults* results) {
return; return;
for (auto& result : *results) { for (auto& result : *results) {
const RankingItemType& type = const auto& type = RankingItemTypeFromSearchResult(*result.result);
RankingItemTypeFromSearchResult(*result.result); const auto& model = ModelForType(type);
const Model& model = ModelForType(type);
if (model == Model::MIXED_TYPES) {
if (model == Model::RESULTS_LIST_GROUP_RANKER && if (results_list_group_ranker_) {
results_list_group_ranker_) { const auto& rank_it =
const auto& rank_it = group_ranks_.find(base::NumberToString(static_cast<int>(type)));
group_ranks_.find(base::NumberToString(static_cast<int>(type))); // The ranker only contains entries trained with types relating to files
// The ranker only contains entries trained with types relating to files // or the omnibox. This means scores for apps, app shortcuts, and answer
// or the omnibox. This means scores for apps, app shortcuts, and answer // cards will be unchanged.
// cards will be unchanged. if (rank_it != group_ranks_.end()) {
if (rank_it != group_ranks_.end()) { // Ranker scores are guaranteed to be in [0,1]. But, enforce that the
// Ranker scores are guaranteed to be in [0,1]. But, enforce that the // result of tweaking does not put the score above 3.0, as that may
// result of tweaking does not put the score above 3.0, as that may // interfere with apps or answer cards.
// interfere with apps or answer cards. result.score = std::min(
result.score = std::min( result.score + rank_it->second * results_list_boost_coefficient_,
result.score + rank_it->second * results_list_boost_coefficient_, 3.0);
3.0); }
} else if (query_based_mixed_types_ranker_) {
// TODO(931149): Add some normalization for URLs.
const auto& rank_it = query_mixed_ranks_.find(result.result->id());
if (rank_it != query_mixed_ranks_.end()) {
result.score = std::min(
result.score + rank_it->second * results_list_boost_coefficient_,
3.0);
}
} }
} }
} }
} }
void SearchResultRanker::Train(const std::string& id, RankingItemType type) { void SearchResultRanker::Train(const std::string& id, RankingItemType type) {
const Model& model = ModelForType(type); if (ModelForType(type) == Model::MIXED_TYPES) {
// TODO(931149): Add some normalization for URLs.
if (model == Model::RESULTS_LIST_GROUP_RANKER && results_list_group_ranker_) { if (results_list_group_ranker_) {
results_list_group_ranker_->Record( results_list_group_ranker_->Record(
base::NumberToString(static_cast<int>(type))); base::NumberToString(static_cast<int>(type)));
} else if (query_based_mixed_types_ranker_) {
query_based_mixed_types_ranker_->Record(id);
}
} }
} }
......
...@@ -69,11 +69,21 @@ class SearchResultRanker : file_manager::file_tasks::FileTasksObserver { ...@@ -69,11 +69,21 @@ class SearchResultRanker : file_manager::file_tasks::FileTasksObserver {
// Stores the scores produced by |results_list_group_ranker_|. // Stores the scores produced by |results_list_group_ranker_|.
base::flat_map<std::string, float> group_ranks_; base::flat_map<std::string, float> group_ranks_;
// Stores the scores produced by |query_based_mixed_types_ranker|.
base::flat_map<std::string, float> query_mixed_ranks_;
// The |results_list_group_ranker_| and |query_based_mixed_types_ranker_| are
// models for two different experiments. Only one will be constructed.
// A model that ranks groups (eg. 'file' and 'omnibox'), which is used to // A model that ranks groups (eg. 'file' and 'omnibox'), which is used to
// tweak the results shown in the search results list only. This does not // tweak the results shown in the search results list only. This does not
// affect apps. // affect apps.
std::unique_ptr<RecurrenceRanker> results_list_group_ranker_; std::unique_ptr<RecurrenceRanker> results_list_group_ranker_;
// Ranks items shown in the results list after a search query. Currently
// these are local files and omnibox results.
std::unique_ptr<RecurrenceRanker> query_based_mixed_types_ranker_;
// Ranks files and previous queries for launcher zero-state. // Ranks files and previous queries for launcher zero-state.
std::unique_ptr<RecurrenceRanker> zero_state_mixed_types_ranker_; std::unique_ptr<RecurrenceRanker> zero_state_mixed_types_ranker_;
......
...@@ -89,14 +89,14 @@ class SearchResultRankerTest : public testing::Test { ...@@ -89,14 +89,14 @@ class SearchResultRankerTest : public testing::Test {
} }
std::unique_ptr<SearchResultRanker> MakeRanker( std::unique_ptr<SearchResultRanker> MakeRanker(
bool use_group_ranker, bool query_based_mixed_types_enabled,
const std::map<std::string, std::string>& params = {}) { const std::map<std::string, std::string>& params = {}) {
if (use_group_ranker) { if (query_based_mixed_types_enabled) {
scoped_feature_list_.InitAndEnableFeatureWithParameters( scoped_feature_list_.InitAndEnableFeatureWithParameters(
app_list_features::kEnableQueryBasedMixedTypesRanker, params); app_list_features::kEnableQueryBasedMixedTypesRanker, params);
} else { } else {
scoped_feature_list_.InitWithFeatures( scoped_feature_list_.InitAndDisableFeature(
{}, {app_list_features::kEnableQueryBasedMixedTypesRanker}); app_list_features::kEnableQueryBasedMixedTypesRanker);
} }
auto ranker = std::make_unique<SearchResultRanker>(profile_.get()); auto ranker = std::make_unique<SearchResultRanker>(profile_.get());
...@@ -132,7 +132,7 @@ class SearchResultRankerTest : public testing::Test { ...@@ -132,7 +132,7 @@ class SearchResultRankerTest : public testing::Test {
DISALLOW_COPY_AND_ASSIGN(SearchResultRankerTest); DISALLOW_COPY_AND_ASSIGN(SearchResultRankerTest);
}; };
TEST_F(SearchResultRankerTest, GroupRankerIsDisabledWithFlag) { TEST_F(SearchResultRankerTest, MixedTypesRankersAreDisabledWithFlag) {
auto ranker = MakeRanker(false); auto ranker = MakeRanker(false);
for (int i = 0; i < 20; ++i) for (int i = 0; i < 20; ++i)
ranker->Train("unused", RankingItemType::kFile); ranker->Train("unused", RankingItemType::kFile);
...@@ -150,8 +150,9 @@ TEST_F(SearchResultRankerTest, GroupRankerIsDisabledWithFlag) { ...@@ -150,8 +150,9 @@ TEST_F(SearchResultRankerTest, GroupRankerIsDisabledWithFlag) {
HasId("C"), HasId("D")))); HasId("C"), HasId("D"))));
} }
TEST_F(SearchResultRankerTest, GroupRankerImprovesScores) { TEST_F(SearchResultRankerTest, CategoryModelImprovesScores) {
auto ranker = MakeRanker(true, {{"boost_coefficient", "1.0"}}); auto ranker = MakeRanker(
true, {{"use_category_model", "true"}, {"boost_coefficient", "1.0"}});
for (int i = 0; i < 20; ++i) for (int i = 0; i < 20; ++i)
ranker->Train("unused", RankingItemType::kFile); ranker->Train("unused", RankingItemType::kFile);
ranker->FetchRankings(base::string16()); ranker->FetchRankings(base::string16());
...@@ -167,5 +168,28 @@ TEST_F(SearchResultRankerTest, GroupRankerImprovesScores) { ...@@ -167,5 +168,28 @@ TEST_F(SearchResultRankerTest, GroupRankerImprovesScores) {
HasId("B"), HasId("A")))); HasId("B"), HasId("A"))));
} }
TEST_F(SearchResultRankerTest, ItemModelImprovesScores) {
// Without the |use_category_model| parameter, the ranker defaults to the item
// model.
auto ranker = MakeRanker(true, {{"boost_coefficient", "1.0"}});
for (int i = 0; i < 10; ++i) {
ranker->Train("C", RankingItemType::kFile);
ranker->Train("D", RankingItemType::kFile);
}
ranker->FetchRankings(base::string16());
// The types associated with these results don't match what was trained on,
// to check that the type is irrelevant to the item model.
auto results = MakeSearchResults({"A", "B", "C", "D"},
{ResultType::kOmnibox, ResultType::kOmnibox,
ResultType::kOmnibox, ResultType::kOmnibox},
{0.3f, 0.2f, 0.1f, 0.1f});
ranker->Rank(&results);
EXPECT_THAT(results, WhenSorted(ElementsAre(HasId("D"), HasId("C"),
HasId("A"), HasId("B"))));
}
} // namespace test } // namespace test
} // namespace app_list } // namespace app_list
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment