Commit f101a923 authored by Thanh Nguyen's avatar Thanh Nguyen Committed by Commit Bot

[offline-logging] Add inference code for Search Ranker

This CL:
1. Adds necessary code for inference.
2. Put the inference part under a flag in SearchResultRanker.

Bug: 1006133

Change-Id: I4185cf77f55e6539b831dcebb9ec9b9da46667f3
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1831730Reviewed-by: default avatarcalamity <calamity@chromium.org>
Reviewed-by: default avatarDominick Ng <dominickn@chromium.org>
Reviewed-by: default avatarJia Meng <jiameng@chromium.org>
Commit-Queue: Thanh Nguyen <thanhdng@chromium.org>
Cr-Commit-Position: refs/heads/master@{#702757}
parent a6bb644a
...@@ -612,6 +612,7 @@ ...@@ -612,6 +612,7 @@
<include name="IDR_SMART_DIM_20181115_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20181115_example_preprocessor_config.pb" type="BINDATA" /> <include name="IDR_SMART_DIM_20181115_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20181115_example_preprocessor_config.pb" type="BINDATA" />
<include name="IDR_SMART_DIM_20190521_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20190521_example_preprocessor_config.pb" type="BINDATA" /> <include name="IDR_SMART_DIM_20190521_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20190521_example_preprocessor_config.pb" type="BINDATA" />
<include name="IDR_TOP_CAT_20190722_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="ui\app_list\search\search_result_ranker\20190722_example_preprocessor_config.pb" type="BINDATA" /> <include name="IDR_TOP_CAT_20190722_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="ui\app_list\search\search_result_ranker\20190722_example_preprocessor_config.pb" type="BINDATA" />
<include name="IDR_SEARCH_RANKER_20190923_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="ui\app_list\search\search_result_ranker\search_ranker_assets\20190923_example_preprocessor_config.pb" type="BINDATA" />
</if> </if>
<if expr="chromeos"> <if expr="chromeos">
<include name="IDR_ARC_GRAPHICS_TRACING_HTML" file="resources\chromeos\arc_graphics_tracing\arc_graphics_tracing.html" compress="gzip" type="BINDATA"/> <include name="IDR_ARC_GRAPHICS_TRACING_HTML" file="resources\chromeos\arc_graphics_tracing\arc_graphics_tracing.html" compress="gzip" type="BINDATA"/>
......
...@@ -11,16 +11,33 @@ ...@@ -11,16 +11,33 @@
#include "chrome/browser/ui/app_list/search/chrome_search_result.h" #include "chrome/browser/ui/app_list/search/chrome_search_result.h"
#include "chrome/browser/ui/app_list/search/omnibox_result.h" #include "chrome/browser/ui/app_list/search/omnibox_result.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event.pb.h" #include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event.pb.h"
#include "chrome/grit/browser_resources.h"
#include "chromeos/constants/devicetype.h" #include "chromeos/constants/devicetype.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "components/crx_file/id_util.h"
#include "components/omnibox/browser/autocomplete_match_type.h" #include "components/omnibox/browser/autocomplete_match_type.h"
#include "mojo/public/cpp/bindings/map.h"
#include "services/metrics/public/cpp/metrics_utils.h" #include "services/metrics/public/cpp/metrics_utils.h"
#include "services/metrics/public/cpp/ukm_builders.h" #include "services/metrics/public/cpp/ukm_builders.h"
#include "ui/base/resource/resource_bundle.h"
#include "url/gurl.h" #include "url/gurl.h"
#include "url/origin.h" #include "url/origin.h"
namespace app_list { namespace app_list {
namespace { namespace {
using chromeos::machine_learning::mojom::BuiltinModelId;
using chromeos::machine_learning::mojom::BuiltinModelSpec;
using chromeos::machine_learning::mojom::CreateGraphExecutorResult;
using chromeos::machine_learning::mojom::ExecuteResult;
using chromeos::machine_learning::mojom::FloatList;
using chromeos::machine_learning::mojom::Int64List;
using chromeos::machine_learning::mojom::LoadModelResult;
using chromeos::machine_learning::mojom::Tensor;
using chromeos::machine_learning::mojom::TensorPtr;
using chromeos::machine_learning::mojom::ValueList;
using ukm::GetExponentialBucketMinForCounts1000; using ukm::GetExponentialBucketMinForCounts1000;
// How long to wait for a URL to enter the history service before querying it // How long to wait for a URL to enter the history service before querying it
...@@ -119,6 +136,81 @@ Category CategoryFromResultType(ash::SearchResultType type, int subtype) { ...@@ -119,6 +136,81 @@ Category CategoryFromResultType(ash::SearchResultType type, int subtype) {
int GetExponentialBucketMinForSeconds(int64_t sample) { int GetExponentialBucketMinForSeconds(int64_t sample) {
return ukm::GetExponentialBucketMin(sample, kBucketExponentForSeconds); return ukm::GetExponentialBucketMin(sample, kBucketExponentForSeconds);
} }
void LoadModelCallback(LoadModelResult result) {
if (result != LoadModelResult::OK) {
LOG(ERROR) << "Failed to load Search Ranker model.";
// TODO(crbug.com/1006133): Add UMA metrics here.
}
}
void CreateGraphExecutorCallback(CreateGraphExecutorResult result) {
if (result != CreateGraphExecutorResult::OK) {
LOG(ERROR) << "Failed to create a Search Ranker Graph Executor.";
// TODO(crbug.com/1006133): Add UMA metrics here.
}
}
// Populates |example| using |features|.
void PopulateRankerExample(const SearchRankingItem::Features& features,
assist_ranker::RankerExample* example) {
CHECK(example);
auto& ranker_example_features = *example->mutable_features();
ranker_example_features["QueryLength"].set_int32_value(
features.query_length());
ranker_example_features["RelevanceScore"].set_int32_value(
features.relevance_score());
ranker_example_features["Category"].set_int32_value(features.category());
ranker_example_features["HourOfDay"].set_int32_value(features.hour_of_day());
ranker_example_features["DayOfWeek"].set_int32_value(features.day_of_week());
ranker_example_features["LaunchesThisSession"].set_int32_value(
features.launches_this_session());
if (features.has_file_extension()) {
ranker_example_features["FileExtension"].set_int32_value(
features.file_extension());
}
if (features.has_time_since_last_launch()) {
ranker_example_features["TimeSinceLastLaunch"].set_int32_value(
features.time_since_last_launch());
ranker_example_features["TimeOfLastLaunch"].set_int32_value(
features.time_of_last_launch());
}
const auto& launches = features.launches_at_hour();
for (int hour = 0; hour < launches.size(); hour++) {
ranker_example_features["LaunchesAtHour" + base::StringPrintf("%02d", hour)]
.set_int32_value(launches[hour]);
}
if (features.has_domain()) {
ranker_example_features["Domain"].set_string_value(features.domain());
ranker_example_features["HasDomain"].set_int32_value(1);
}
}
// Loads the preprocessor config protobuf, which will be used later to convert
// a RankerExample to a vectorized float for inactivity score calculation.
// Returns nullptr if cannot load or parse the config.
std::unique_ptr<assist_ranker::ExamplePreprocessorConfig>
LoadExamplePreprocessorConfig() {
auto config = std::make_unique<assist_ranker::ExamplePreprocessorConfig>();
const int res_id = IDR_SEARCH_RANKER_20190923_EXAMPLE_PREPROCESSOR_CONFIG_PB;
scoped_refptr<base::RefCountedMemory> raw_config =
ui::ResourceBundle::GetSharedInstance().LoadDataResourceBytes(res_id);
if (!raw_config || !raw_config->front()) {
LOG(ERROR) << "Failed to load SearchRanker example preprocessor config.";
// TODO(crbug.com/1006133): Add UMA metrics here.
return nullptr;
}
if (!config->ParseFromArray(raw_config->front(), raw_config->size())) {
LOG(ERROR) << "Failed to parse SearchRanker example preprocessor config.";
// TODO(crbug.com/1006133): Add UMA metrics here.
return nullptr;
}
return config;
}
} // namespace } // namespace
SearchRankingEventLogger::SearchRankingEventLogger( SearchRankingEventLogger::SearchRankingEventLogger(
...@@ -391,4 +483,130 @@ void SearchRankingEventLogger::LogEvent( ...@@ -391,4 +483,130 @@ void SearchRankingEventLogger::LogEvent(
std::move(event_recorded_for_testing_).Run(); std::move(event_recorded_for_testing_).Run();
} }
void SearchRankingEventLogger::CreateRankings(Mixer::SortedResults* results,
int query_length) {
for (const auto& result : *results) {
if (!result.result) {
continue;
}
SearchRankingItem proto;
std::vector<float> vectorized_features;
PopulateSearchRankingItem(&proto, result.result, query_length,
false /*use_for_logging*/);
if (!PreprocessInput(proto.features(), &vectorized_features)) {
return;
}
DoInference(vectorized_features, result.result->id());
}
}
std::map<std::string, float> SearchRankingEventLogger::RetrieveRankings() {
return prediction_;
}
void SearchRankingEventLogger::LazyInitialize() {
if (!preprocessor_config_) {
preprocessor_config_ = LoadExamplePreprocessorConfig();
}
}
bool SearchRankingEventLogger::PreprocessInput(
const SearchRankingItem::Features& features,
std::vector<float>* vectorized_features) {
DCHECK(vectorized_features);
LazyInitialize();
if (!preprocessor_config_) {
LOG(ERROR) << "Failed to create preprocessor config.";
// TODO(crbug.com/1006133): Add UMA metrics here.
return false;
}
assist_ranker::RankerExample ranker_example;
PopulateRankerExample(features, &ranker_example);
int preprocessor_error = assist_ranker::ExamplePreprocessor::Process(
*preprocessor_config_, &ranker_example, true);
// kNoFeatureIndexFound can occur normally (e.g., when the domain name
// isn't known to the model or a rarely seen enum value is used).
if (preprocessor_error != assist_ranker::ExamplePreprocessor::kSuccess &&
preprocessor_error !=
assist_ranker::ExamplePreprocessor::kNoFeatureIndexFound) {
LOG(ERROR) << "Failed to vectorize features using ExamplePreprocessor.";
// TODO(crbug.com/1006133): Add UMA metrics here.
return false;
}
const auto& extracted_features =
ranker_example.features()
.at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
.float_list()
.float_value();
vectorized_features->assign(extracted_features.begin(),
extracted_features.end());
return true;
}
void SearchRankingEventLogger::DoInference(const std::vector<float>& features,
const std::string& id) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
BindGraphExecutorIfNeeded();
// Prepare the input tensor.
std::map<std::string, TensorPtr> inputs;
auto tensor = Tensor::New();
tensor->shape = Int64List::New();
tensor->shape->value = std::vector<int64_t>({1, features.size()});
tensor->data = ValueList::New();
tensor->data->set_float_list(FloatList::New());
tensor->data->get_float_list()->value =
std::vector<double>(std::begin(features), std::end(features));
inputs.emplace(std::string("input"), std::move(tensor));
const std::vector<std::string> outputs({std::string("output")});
// Execute
executor_->Execute(mojo::MapToFlatMap(std::move(inputs)), std::move(outputs),
base::BindOnce(&SearchRankingEventLogger::ExecuteCallback,
weak_factory_.GetWeakPtr(), id));
}
void SearchRankingEventLogger::BindGraphExecutorIfNeeded() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!model_) {
// Load the model.
auto spec = BuiltinModelSpec::New(BuiltinModelId::SEARCH_RANKER_20190923);
chromeos::machine_learning::ServiceConnection::GetInstance()
->LoadBuiltinModel(std::move(spec), model_.BindNewPipeAndPassReceiver(),
base::BindOnce(&LoadModelCallback));
}
if (!executor_) {
// Get the graph executor.
model_->CreateGraphExecutor(executor_.BindNewPipeAndPassReceiver(),
base::BindOnce(&CreateGraphExecutorCallback));
executor_.set_disconnect_handler(base::BindOnce(
&SearchRankingEventLogger::OnConnectionError, base::Unretained(this)));
}
}
void SearchRankingEventLogger::OnConnectionError() {
LOG(WARNING) << "Mojo connection for ML service closed.";
executor_.reset();
model_.reset();
}
void SearchRankingEventLogger::ExecuteCallback(
const std::string& id,
ExecuteResult result,
const base::Optional<std::vector<TensorPtr>> outputs) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (result != ExecuteResult::OK) {
LOG(ERROR) << "Search Ranker inference execution failed.";
// TODO(crbug.com/1006133): Add UMA metrics here.
return;
}
prediction_[id] = outputs.value()[0]->data->get_float_list()->value[0];
}
} // namespace app_list } // namespace app_list
...@@ -19,13 +19,20 @@ ...@@ -19,13 +19,20 @@
#include "chrome/browser/profiles/profile.h" #include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/app_list/search/search_controller.h" #include "chrome/browser/ui/app_list/search/search_controller.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event.pb.h" #include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event.pb.h"
#include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
#include "components/assist_ranker/proto/example_preprocessor.pb.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/metrics/public/cpp/ukm_recorder.h" #include "services/metrics/public/cpp/ukm_recorder.h"
class ChromeSearchResult; class ChromeSearchResult;
namespace app_list { namespace app_list {
// Performs UKM logging of search ranking events in the launcher's results list. // TODO(crbug.com/1006133): This is class logs and does inference on a search
// ranking event. The class name doesn't reflect what it does. We should
// refactor it in future CLs.
class SearchRankingEventLogger { class SearchRankingEventLogger {
public: public:
SearchRankingEventLogger(Profile* profile, SearchRankingEventLogger(Profile* profile,
...@@ -46,6 +53,11 @@ class SearchRankingEventLogger { ...@@ -46,6 +53,11 @@ class SearchRankingEventLogger {
// recorded. // recorded.
void SetEventRecordedForTesting(base::OnceClosure closure); void SetEventRecordedForTesting(base::OnceClosure closure);
// Computes scores for a list of search result item using ML Service.
void CreateRankings(Mixer::SortedResults* results, int query_length);
// Retrieve the scores.
std::map<std::string, float> RetrieveRankings();
private: private:
// Stores state necessary for logging a given search result that is // Stores state necessary for logging a given search result that is
// accumulated throughout the session. // accumulated throughout the session.
...@@ -76,6 +88,39 @@ class SearchRankingEventLogger { ...@@ -76,6 +88,39 @@ class SearchRankingEventLogger {
void LogEvent(const SearchRankingItem& result, void LogEvent(const SearchRankingItem& result,
base::Optional<ukm::SourceId> source_id); base::Optional<ukm::SourceId> source_id);
// Create vectorized features from SearchRankingItem. Returns true if
// |vectorized_features| is successfully populated.
bool PreprocessInput(const SearchRankingItem::Features& features,
std::vector<float>* vectorized_features);
// Call ML Service to do the inference.
void DoInference(const std::vector<float>& features, const std::string& id);
// Stores the ranking score for an |app_id| in the |ranking_map_|.
// Executed by the ML Service when an Execute call is complete.
void ExecuteCallback(
const std::string& id,
::chromeos::machine_learning::mojom::ExecuteResult result,
base::Optional<
std::vector<::chromeos::machine_learning::mojom::TensorPtr>> outputs);
void LazyInitialize();
// Initializes the graph executor for the ML service if it's not already
// available.
void BindGraphExecutorIfNeeded();
void OnConnectionError();
std::map<std::string, float> prediction_;
// Remotes used to execute functions in the ML service server end.
mojo::Remote<::chromeos::machine_learning::mojom::Model> model_;
mojo::Remote<::chromeos::machine_learning::mojom::GraphExecutor> executor_;
std::unique_ptr<assist_ranker::ExamplePreprocessorConfig>
preprocessor_config_;
SearchController* search_controller_; SearchController* search_controller_;
// Some events do not have an associated URL and so are logged directly with // Some events do not have an associated URL and so are logged directly with
// |ukm_recorder_| using a blank source ID. Other events need to validate the // |ukm_recorder_| using a blank source ID. Other events need to validate the
......
...@@ -208,7 +208,10 @@ void SearchResultRanker::InitializeRankers( ...@@ -208,7 +208,10 @@ void SearchResultRanker::InitializeRankers(
"QueryBasedMixedTypesGroup", "QueryBasedMixedTypesGroup",
profile_->GetPath().AppendASCII("results_list_group_ranker.pb"), profile_->GetPath().AppendASCII("results_list_group_ranker.pb"),
config, chromeos::ProfileHelper::IsEphemeralUserProfile(profile_)); config, chromeos::ProfileHelper::IsEphemeralUserProfile(profile_));
} else if (GetFieldTrialParamByFeatureAsBool(
app_list_features::kEnableQueryBasedMixedTypesRanker,
"use_aggregated_model", false)) {
use_aggregated_search_ranking_inference_ = true;
} else { } else {
// Item ranker model. // Item ranker model.
const std::string config_json = GetFieldTrialParamValueByFeature( const std::string config_json = GetFieldTrialParamValueByFeature(
...@@ -356,6 +359,11 @@ void SearchResultRanker::Rank(Mixer::SortedResults* results) { ...@@ -356,6 +359,11 @@ void SearchResultRanker::Rank(Mixer::SortedResults* results) {
[](const Mixer::SortData& a, const Mixer::SortData& b) { [](const Mixer::SortData& a, const Mixer::SortData& b) {
return a.score > b.score; return a.score > b.score;
}); });
std::map<std::string, float> search_ranker_score_map;
if (!last_query_.empty() && use_aggregated_search_ranking_inference_) {
search_ranking_event_logger_->CreateRankings(results, last_query_.size());
search_ranker_score_map = search_ranking_event_logger_->RetrieveRankings();
}
std::map<std::string, float> ranking_map; std::map<std::string, float> ranking_map;
if (using_aggregated_app_inference_) if (using_aggregated_app_inference_)
...@@ -396,6 +404,9 @@ void SearchResultRanker::Rank(Mixer::SortedResults* results) { ...@@ -396,6 +404,9 @@ void SearchResultRanker::Rank(Mixer::SortedResults* results) {
result.score + rank_it->second * results_list_boost_coefficient_, result.score + rank_it->second * results_list_boost_coefficient_,
3.0); 3.0);
} }
} else if (!last_query_.empty() &&
use_aggregated_search_ranking_inference_) {
result.score = search_ranker_score_map[result.result->id()];
} }
} else if (model == Model::APPS) { } else if (model == Model::APPS) {
if (using_aggregated_app_inference_) { if (using_aggregated_app_inference_) {
......
...@@ -180,6 +180,7 @@ class SearchResultRanker : file_manager::file_tasks::FileTasksObserver, ...@@ -180,6 +180,7 @@ class SearchResultRanker : file_manager::file_tasks::FileTasksObserver,
// Logs impressions and stores feature data for aggregated model. // Logs impressions and stores feature data for aggregated model.
std::unique_ptr<app_list::SearchRankingEventLogger> std::unique_ptr<app_list::SearchRankingEventLogger>
search_ranking_event_logger_; search_ranking_event_logger_;
bool use_aggregated_search_ranking_inference_ = false;
// Stores the time of the last histogram logging event for each zero state // Stores the time of the last histogram logging event for each zero state
// search provider. Used to prevent scores from being logged multiple times // search provider. Used to prevent scores from being logged multiple times
......
...@@ -34,6 +34,8 @@ enum BuiltinModelId { ...@@ -34,6 +34,8 @@ enum BuiltinModelId {
TOP_CAT_20190722 = 4, TOP_CAT_20190722 = 4,
// The Smart Dim (20190521) ML model. // The Smart Dim (20190521) ML model.
SMART_DIM_20190521 = 5, SMART_DIM_20190521 = 5,
// The Search Ranker (20190923) ML model.
SEARCH_RANKER_20190923 = 6,
}; };
// These values are persisted to logs. Entries should not be renumbered and // These values are persisted to logs. Entries should not be renumbered and
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment