Commit 4aeed825 authored by Amanda Deacon's avatar Amanda Deacon Committed by Chromium LUCI CQ

Remove Top Cat inference code and references to tflite model.

Tests: unit tests

Bug: 1098112
Change-Id: Ia17ebf3407b4f9429b6b2ab006210b38a0d7fab1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2570878Reviewed-by: default avatarTony Yeoman <tby@chromium.org>
Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Reviewed-by: default avatarSam McNally <sammc@chromium.org>
Reviewed-by: default avatarcalamity <calamity@chromium.org>
Commit-Queue: Amanda Deacon <amandadeacon@chromium.org>
Cr-Commit-Position: refs/heads/master@{#833595}
parent 790d98e7
...@@ -505,7 +505,6 @@ ...@@ -505,7 +505,6 @@
<if expr="chromeos"> <if expr="chromeos">
<include name="IDR_SMART_DIM_20181115_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20181115_example_preprocessor_config.pb" type="BINDATA" /> <include name="IDR_SMART_DIM_20181115_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20181115_example_preprocessor_config.pb" type="BINDATA" />
<include name="IDR_SMART_DIM_20190521_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20190521_example_preprocessor_config.pb" type="BINDATA" /> <include name="IDR_SMART_DIM_20190521_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="chromeos\power\ml\smart_dim\20190521_example_preprocessor_config.pb" type="BINDATA" />
<include name="IDR_TOP_CAT_20190722_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="ui\app_list\search\search_result_ranker\20190722_example_preprocessor_config.pb" type="BINDATA" />
<include name="IDR_SEARCH_RANKER_20190923_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="ui\app_list\search\search_result_ranker\search_ranker_assets\20190923_example_preprocessor_config.pb" type="BINDATA" /> <include name="IDR_SEARCH_RANKER_20190923_EXAMPLE_PREPROCESSOR_CONFIG_PB" file="ui\app_list\search\search_result_ranker\search_ranker_assets\20190923_example_preprocessor_config.pb" type="BINDATA" />
</if> </if>
<if expr="chromeos"> <if expr="chromeos">
......
...@@ -1811,8 +1811,6 @@ static_library("ui") { ...@@ -1811,8 +1811,6 @@ static_library("ui") {
"app_list/search/search_result_ranker/frecency_store.h", "app_list/search/search_result_ranker/frecency_store.h",
"app_list/search/search_result_ranker/histogram_util.cc", "app_list/search/search_result_ranker/histogram_util.cc",
"app_list/search/search_result_ranker/histogram_util.h", "app_list/search/search_result_ranker/histogram_util.h",
"app_list/search/search_result_ranker/ml_app_rank_provider.cc",
"app_list/search/search_result_ranker/ml_app_rank_provider.h",
"app_list/search/search_result_ranker/ranking_item_util.cc", "app_list/search/search_result_ranker/ranking_item_util.cc",
"app_list/search/search_result_ranker/ranking_item_util.h", "app_list/search/search_result_ranker/ranking_item_util.h",
"app_list/search/search_result_ranker/recurrence_predictor.cc", "app_list/search/search_result_ranker/recurrence_predictor.cc",
......
...@@ -134,26 +134,6 @@ void AppLaunchEventLogger::OnGridClicked(const std::string& id) { ...@@ -134,26 +134,6 @@ void AppLaunchEventLogger::OnGridClicked(const std::string& id) {
weak_factory_.GetWeakPtr(), event)); weak_factory_.GetWeakPtr(), event));
} }
void AppLaunchEventLogger::CreateRankings() {
const base::TimeDelta duration = base::Time::Now() - start_time_;
if (!ml_app_rank_provider_) {
ml_app_rank_provider_ = std::make_unique<MlAppRankProvider>();
}
ml_app_rank_provider_->CreateRankings(
app_features_map_,
ExponentialBucket(duration.InHours(), kTotalHoursBucketSizeMultiplier),
Bucketize(all_clicks_last_hour_->GetTotal(duration), kClickBuckets),
Bucketize(all_clicks_last_24_hours_->GetTotal(duration), kClickBuckets));
}
std::map<std::string, float> AppLaunchEventLogger::RetrieveRankings() {
if (!ml_app_rank_provider_) {
return {};
}
return ml_app_rank_provider_->RetrieveRankings();
}
std::string AppLaunchEventLogger::RemoveScheme(const std::string& id) { std::string AppLaunchEventLogger::RemoveScheme(const std::string& id) {
std::string app_id(id); std::string app_id(id);
if (!app_id.compare(0, strlen(kExtensionSchemeWithDelimiter), if (!app_id.compare(0, strlen(kExtensionSchemeWithDelimiter),
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "base/sequenced_task_runner.h" #include "base/sequenced_task_runner.h"
#include "base/values.h" #include "base/values.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_event_logger.pb.h" #include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_event_logger.pb.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/ml_app_rank_provider.h"
#include "extensions/browser/extension_registry.h" #include "extensions/browser/extension_registry.h"
#include "services/metrics/public/cpp/ukm_source_id.h" #include "services/metrics/public/cpp/ukm_source_id.h"
...@@ -154,9 +153,6 @@ class AppLaunchEventLogger { ...@@ -154,9 +153,6 @@ class AppLaunchEventLogger {
const std::unique_ptr<chromeos::power::ml::RecentEventsCounter> const std::unique_ptr<chromeos::power::ml::RecentEventsCounter>
all_clicks_last_24_hours_; all_clicks_last_24_hours_;
// Empty until/unless CreateRankings is called.
std::unique_ptr<MlAppRankProvider> ml_app_rank_provider_;
scoped_refptr<base::SequencedTaskRunner> task_runner_; scoped_refptr<base::SequencedTaskRunner> task_runner_;
base::WeakPtrFactory<AppLaunchEventLogger> weak_factory_; base::WeakPtrFactory<AppLaunchEventLogger> weak_factory_;
......
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/app_list/search/search_result_ranker/ml_app_rank_provider.h"
#include <utility>
#include "base/bind.h"
#include "base/callback.h"
#include "base/location.h"
#include "base/memory/ref_counted_memory.h"
#include "base/strings/stringprintf.h"
#include "base/task/post_task.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/time/time.h"
#include "chrome/browser/chromeos/power/ml/user_activity_ukm_logger_helpers.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_event_logger_helper.h"
#include "chrome/grit/browser_resources.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "components/crx_file/id_util.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
#include "ui/base/resource/resource_bundle.h"
using ::chromeos::machine_learning::mojom::BuiltinModelId;
using ::chromeos::machine_learning::mojom::BuiltinModelSpec;
using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
using ::chromeos::machine_learning::mojom::ExecuteResult;
using ::chromeos::machine_learning::mojom::FloatList;
using ::chromeos::machine_learning::mojom::Int64List;
using ::chromeos::machine_learning::mojom::LoadModelResult;
using ::chromeos::machine_learning::mojom::Tensor;
using ::chromeos::machine_learning::mojom::TensorPtr;
using ::chromeos::machine_learning::mojom::ValueList;
namespace app_list {
namespace {
void LoadModelCallback(LoadModelResult result) {
if (result != LoadModelResult::OK) {
LOG(ERROR) << "Failed to load Top Cat model.";
}
}
void CreateGraphExecutorCallback(CreateGraphExecutorResult result) {
if (result != CreateGraphExecutorResult::OK) {
LOG(ERROR) << "Failed to create a Top Cat Graph Executor.";
}
}
// Returns: true if preprocessor config loaded, false if it could not be loaded.
bool LoadExamplePreprocessorConfig(
assist_ranker::ExamplePreprocessorConfig* preprocessor_config) {
DCHECK(preprocessor_config);
const int resource_id = IDR_TOP_CAT_20190722_EXAMPLE_PREPROCESSOR_CONFIG_PB;
const scoped_refptr<base::RefCountedMemory> raw_config =
ui::ResourceBundle::GetSharedInstance().LoadDataResourceBytes(
resource_id);
if (!raw_config || !raw_config->front()) {
LOG(ERROR) << "Failed to load TopCatModel example preprocessor config.";
return false;
}
if (!preprocessor_config->ParseFromArray(raw_config->front(),
raw_config->size())) {
LOG(ERROR) << "Failed to parse TopCatModel example preprocessor config.";
return false;
}
return true;
}
// Perform the inference given the |features| and |app_id| of an app.
// Posts |callback| to |task_runner| to perform the actual inference.
void DoInference(const std::string& app_id,
const std::vector<float>& features,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const base::RepeatingCallback<
void(base::flat_map<std::string, TensorPtr> inputs,
const std::vector<std::string> outputs,
const std::string app_id)> callback) {
// Prepare the input tensor.
base::flat_map<std::string, TensorPtr> inputs;
auto tensor = Tensor::New();
tensor->shape = Int64List::New();
tensor->shape->value = std::vector<int64_t>({1, features.size()});
tensor->data = ValueList::New();
tensor->data->set_float_list(FloatList::New());
tensor->data->get_float_list()->value =
std::vector<double>(std::begin(features), std::end(features));
inputs.emplace(std::string("input"), std::move(tensor));
const std::vector<std::string> outputs({std::string("output")});
DCHECK(task_runner);
task_runner->PostTask(FROM_HERE, base::BindOnce(callback, std::move(inputs),
std::move(outputs), app_id));
}
// Process the RankerExample to vectorize the feature list for inference.
// Returns true on success.
bool RankerExampleToVectorizedFeatures(
const assist_ranker::ExamplePreprocessorConfig& preprocessor_config,
assist_ranker::RankerExample* example,
std::vector<float>* vectorized_features) {
int preprocessor_error = assist_ranker::ExamplePreprocessor::Process(
preprocessor_config, example, true);
// kNoFeatureIndexFound can occur normally (e.g., when the app URL
// isn't known to the model or a rarely seen enum value is used).
if (preprocessor_error != assist_ranker::ExamplePreprocessor::kSuccess &&
preprocessor_error !=
assist_ranker::ExamplePreprocessor::kNoFeatureIndexFound) {
// TODO: Log to UMA.
return false;
}
const auto& extracted_features =
example->features()
.at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
.float_list()
.float_value();
vectorized_features->assign(extracted_features.begin(),
extracted_features.end());
return true;
}
// Does the CPU-intensive part of CreateRankings (preparing the Tensor inputs
// from |app_features_map|, intended to be called on a low-priority
// background thread. Invokes |callback| on |task_runner| once for each app in
// |app_features_map|.
void CreateRankingsImpl(
base::flat_map<std::string, AppLaunchFeatures> app_features_map,
int total_hours,
int all_clicks_last_hour,
int all_clicks_last_24_hours,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const base::RepeatingCallback<
void(base::flat_map<std::string, TensorPtr> inputs,
const std::vector<std::string> outputs,
const std::string app_id)>& callback) {
const base::Time now(base::Time::Now());
const int hour = HourOfDay(now);
const int day = DayOfWeek(now);
assist_ranker::ExamplePreprocessorConfig preprocessor_config;
if (!LoadExamplePreprocessorConfig(&preprocessor_config)) {
return;
}
for (auto& app : app_features_map) {
assist_ranker::RankerExample example(
CreateRankerExample(app.second,
now.ToDeltaSinceWindowsEpoch().InSeconds() -
app.second.time_of_last_click_sec(),
total_hours, day, hour, all_clicks_last_hour,
all_clicks_last_24_hours));
std::vector<float> vectorized_features;
if (RankerExampleToVectorizedFeatures(preprocessor_config, &example,
&vectorized_features)) {
DoInference(app.first, vectorized_features, task_runner, callback);
}
}
}
} // namespace
assist_ranker::RankerExample CreateRankerExample(
const AppLaunchFeatures& features,
int time_since_last_click,
int total_hours,
int day_of_week,
int hour_of_day,
int all_clicks_last_hour,
int all_clicks_last_24_hours) {
assist_ranker::RankerExample example;
auto& ranker_example_features = *example.mutable_features();
ranker_example_features["DayOfWeek"].set_int32_value(day_of_week);
ranker_example_features["HourOfDay"].set_int32_value(hour_of_day);
ranker_example_features["AllClicksLastHour"].set_int32_value(
all_clicks_last_hour);
ranker_example_features["AllClicksLast24Hours"].set_int32_value(
all_clicks_last_24_hours);
ranker_example_features["AppType"].set_int32_value(features.app_type());
ranker_example_features["ClickRank"].set_int32_value(features.click_rank());
ranker_example_features["ClicksLastHour"].set_int32_value(
features.clicks_last_hour());
ranker_example_features["ClicksLast24Hours"].set_int32_value(
features.clicks_last_24_hours());
ranker_example_features["LastLaunchedFrom"].set_int32_value(
features.last_launched_from());
ranker_example_features["HasClick"].set_bool_value(
features.has_most_recently_used_index());
ranker_example_features["MostRecentlyUsedIndex"].set_int32_value(
features.most_recently_used_index());
ranker_example_features["TimeSinceLastClick"].set_int32_value(
Bucketize(time_since_last_click, kTimeSinceLastClickBuckets));
ranker_example_features["TotalClicks"].set_int32_value(
features.total_clicks());
ranker_example_features["TotalClicksPerHour"].set_float_value(
static_cast<float>(features.total_clicks()) / (total_hours + 1));
ranker_example_features["TotalHours"].set_int32_value(total_hours);
// Calculate FourHourClicksN and SixHourClicksN, which sum clicks for four
// and six hour periods respectively.
int four_hour_count = 0;
int six_hour_count = 0;
// Apps that have been clicked will have 24 clicks_each_hour values. Apps that
// have not been clicked will have no clicks_each_hour values, so can skip
// the FourHourClicksN and SixHourClicksN calculations.
if (features.clicks_each_hour_size() == 24) {
for (int hour = 0; hour < 24; hour++) {
int clicks = Bucketize(features.clicks_each_hour(hour), kClickBuckets);
ranker_example_features["ClicksEachHour" +
base::StringPrintf("%02d", hour)]
.set_int32_value(clicks);
ranker_example_features["ClicksPerHour" +
base::StringPrintf("%02d", hour)]
.set_float_value(static_cast<float>(clicks) / (total_hours + 1));
four_hour_count += clicks;
six_hour_count += clicks;
// Divide day into periods of 4 hours each.
if (hour % 4 == 3 && four_hour_count != 0) {
ranker_example_features["FourHourClicks" +
base::StringPrintf("%01d", hour / 4)]
.set_int32_value(four_hour_count);
four_hour_count = 0;
}
// Divide day into periods of 6 hours each.
if (hour % 6 == 5 && six_hour_count != 0) {
ranker_example_features["SixHourClicks" +
base::StringPrintf("%01d", hour / 6)]
.set_int32_value(six_hour_count);
six_hour_count = 0;
}
}
}
if (features.app_type() == AppLaunchEvent_AppType_CHROME) {
ranker_example_features["URL"].set_string_value(
kExtensionSchemeWithDelimiter + features.app_id());
} else if (features.app_type() == AppLaunchEvent_AppType_PWA) {
ranker_example_features["URL"].set_string_value(features.pwa_url());
} else if (features.app_type() == AppLaunchEvent_AppType_PLAY) {
ranker_example_features["URL"].set_string_value(
kAppScheme +
crx_file::id_util::GenerateId(features.arc_package_name()));
} else {
// TODO(crbug.com/1027782): Add DCHECK that this branch is not reached.
}
return example;
}
MlAppRankProvider::MlAppRankProvider()
: creation_task_runner_(base::SequencedTaskRunnerHandle::Get()),
background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner(
{base::TaskPriority::BEST_EFFORT,
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})) {}
MlAppRankProvider::~MlAppRankProvider() = default;
void MlAppRankProvider::CreateRankings(
const base::flat_map<std::string, AppLaunchFeatures>& app_features_map,
int total_hours,
int all_clicks_last_hour,
int all_clicks_last_24_hours) {
DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
// TODO(jennyz): Add start-to-end latency metrics for the work on each
// sequence.
background_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&CreateRankingsImpl, app_features_map, total_hours,
all_clicks_last_hour, all_clicks_last_24_hours,
creation_task_runner_,
base::BindRepeating(&MlAppRankProvider::RunExecutor,
weak_factory_.GetWeakPtr())));
}
std::map<std::string, float> MlAppRankProvider::RetrieveRankings() {
DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
return ranking_map_;
}
void MlAppRankProvider::RunExecutor(
base::flat_map<std::string, TensorPtr> inputs,
const std::vector<std::string> outputs,
const std::string app_id) {
DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
BindGraphExecutorIfNeeded();
executor_->Execute(std::move(inputs), std::move(outputs),
base::BindOnce(&MlAppRankProvider::ExecuteCallback,
base::Unretained(this), app_id));
}
void MlAppRankProvider::ExecuteCallback(
std::string app_id,
ExecuteResult result,
const base::Optional<std::vector<TensorPtr>> outputs) {
DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
if (result != ExecuteResult::OK) {
LOG(ERROR) << "Top Cat inference execution failed.";
return;
}
ranking_map_[app_id] = outputs.value()[0]->data->get_float_list()->value[0];
}
void MlAppRankProvider::BindGraphExecutorIfNeeded() {
DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
if (!model_) {
// Load the model.
BuiltinModelSpecPtr spec =
BuiltinModelSpec::New(BuiltinModelId::TOP_CAT_20190722);
chromeos::machine_learning::ServiceConnection::GetInstance()
->LoadBuiltinModel(std::move(spec), model_.BindNewPipeAndPassReceiver(),
base::BindOnce(&LoadModelCallback));
}
if (!executor_) {
// Get the graph executor.
model_->CreateGraphExecutor(executor_.BindNewPipeAndPassReceiver(),
base::BindOnce(&CreateGraphExecutorCallback));
executor_.set_disconnect_handler(base::BindOnce(
&MlAppRankProvider::OnConnectionError, base::Unretained(this)));
}
}
void MlAppRankProvider::OnConnectionError() {
LOG(WARNING) << "Mojo connection for ML service closed.";
executor_.reset();
model_.reset();
}
} // namespace app_list
// Copyright (c) 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_ML_APP_RANK_PROVIDER_H_
#define CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_ML_APP_RANK_PROVIDER_H_
#include <map>
#include <string>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/optional.h"
#include "base/sequence_checker.h"
#include "base/sequenced_task_runner.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_event_logger.pb.h"
#include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
#include "components/assist_ranker/proto/example_preprocessor.pb.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
#include "mojo/public/cpp/bindings/remote.h"
namespace app_list {
// Creates a RankerExample with the given |features| and provided parameters.
// Calculates other features (ClicksEachHour, ClickPerHour, FourHourClicks,
// SixHourClicks). Converts the app id into the URL format used in the ML model.
assist_ranker::RankerExample CreateRankerExample(
const AppLaunchFeatures& features,
int time_since_last_click,
int total_hours,
int day_of_week,
int hour_of_day,
int all_clicks_last_hour,
int all_clicks_last_24_hours);
// Provide the app ranking using an ML model.
// Rankings are created asynchronously using the ML Service and retrieved
// synchronously at any time.
// Sequencing: Must be created and used on the same sequence (typically the UI
// thread).
class MlAppRankProvider {
public:
MlAppRankProvider();
~MlAppRankProvider();
// Asynchronously generates ranking scores for the apps in |app_features_map|.
void CreateRankings(
const base::flat_map<std::string, AppLaunchFeatures>& app_features_map,
int total_hours,
int all_clicks_last_hour,
int all_clicks_last_24_hours);
// Returns a map of the ranking scores keyed by app id.
// This will return an empty map until some time after the first call to
// CreateRankings().
std::map<std::string, float> RetrieveRankings();
private:
// Execute the |executor_| on the creation thread.
void RunExecutor(
base::flat_map<std::string,
::chromeos::machine_learning::mojom::TensorPtr> inputs,
std::vector<std::string> outputs,
std::string app_id);
// Stores the ranking score for an |app_id| in the |ranking_map_|.
// Executed by the ML Service when an Execute call is complete.
void ExecuteCallback(
std::string app_id,
::chromeos::machine_learning::mojom::ExecuteResult result,
base::Optional<
std::vector<::chromeos::machine_learning::mojom::TensorPtr>> outputs);
// Initializes the graph executor for the ML service if it's not already
// available.
void BindGraphExecutorIfNeeded();
void OnConnectionError();
// Remotes used to execute functions in the ML service server end.
mojo::Remote<::chromeos::machine_learning::mojom::Model> model_;
mojo::Remote<::chromeos::machine_learning::mojom::GraphExecutor> executor_;
// Map from app id to ranking score.
std::map<std::string, float> ranking_map_;
// Runner for tasks that should run on the creation sequence.
scoped_refptr<base::SequencedTaskRunner> creation_task_runner_;
// Runner for low priority background tasks.
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
// Sequence checker for methods that must run on the creation sequence.
SEQUENCE_CHECKER(creation_sequence_checker_);
base::WeakPtrFactory<MlAppRankProvider> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(MlAppRankProvider);
};
} // namespace app_list
#endif // CHROME_BROWSER_UI_APP_LIST_SEARCH_SEARCH_RESULT_RANKER_ML_APP_RANK_PROVIDER_H_
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/app_list/search/search_result_ranker/ml_app_rank_provider.h"
#include <string>
#include "base/test/task_environment.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_event_logger.pb.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace app_list {
const char kAppId[] = "app_id";
TEST(MlAppRankProviderTest, MlInferenceTest) {
base::test::TaskEnvironment task_environment_;
chromeos::machine_learning::FakeServiceConnectionImpl fake_service_connection;
const double expected_value = 1.234;
fake_service_connection.SetOutputValue(std::vector<int64_t>{1L},
std::vector<double>{expected_value});
chromeos::machine_learning::ServiceConnection::
UseFakeServiceConnectionForTesting(&fake_service_connection);
MlAppRankProvider ml_app_rank_provider;
base::flat_map<std::string, AppLaunchFeatures> app_features_map;
AppLaunchFeatures features;
features.set_app_id(kAppId);
features.set_app_type(AppLaunchEvent_AppType_CHROME);
features.set_click_rank(1);
for (int hour = 0; hour < 24; ++hour) {
features.add_clicks_each_hour(1);
}
app_features_map[kAppId] = features;
ml_app_rank_provider.CreateRankings(app_features_map, 3, 1, 7);
EXPECT_EQ(0UL, ml_app_rank_provider.RetrieveRankings().size());
task_environment_.RunUntilIdle();
const std::map<std::string, float> ranking_map =
ml_app_rank_provider.RetrieveRankings();
task_environment_.RunUntilIdle();
ASSERT_EQ(1UL, ranking_map.size());
const auto it = ranking_map.find(kAppId);
ASSERT_NE(ranking_map.end(), it);
EXPECT_NEAR(expected_value, it->second, 0.001);
}
TEST(MlAppRankProviderTest, ExecutionAfterDestructorTest) {
base::test::TaskEnvironment task_environment_;
chromeos::machine_learning::FakeServiceConnectionImpl fake_service_connection;
const double expected_value = 1.234;
fake_service_connection.SetOutputValue(std::vector<int64_t>{1L},
std::vector<double>{expected_value});
chromeos::machine_learning::ServiceConnection::
UseFakeServiceConnectionForTesting(&fake_service_connection);
{
MlAppRankProvider ml_app_rank_provider;
base::flat_map<std::string, AppLaunchFeatures> app_features_map;
AppLaunchFeatures features;
features.set_app_id(kAppId);
features.set_app_type(AppLaunchEvent_AppType_CHROME);
app_features_map[kAppId] = features;
ml_app_rank_provider.CreateRankings(app_features_map, 3, 1, 7);
}
// Run the background tasks after ml_app_rank_provider has been destroyed.
// If this does not crash it is a success.
task_environment_.RunUntilIdle();
}
TEST(MlAppRankProviderTest, CreateRankerExampleTest) {
base::test::TaskEnvironment task_environment_;
MlAppRankProvider ml_app_rank_provider;
base::flat_map<std::string, AppLaunchFeatures> app_features_map;
AppLaunchFeatures features;
features.set_app_id(kAppId);
features.set_app_type(AppLaunchEvent_AppType_CHROME);
features.set_click_rank(1);
features.set_clicks_last_hour(3);
features.set_clicks_last_24_hours(4);
features.set_last_launched_from(AppLaunchEvent_LaunchedFrom_GRID);
features.set_most_recently_used_index(2);
features.set_total_clicks(100);
for (int hour = 0; hour < 24; ++hour) {
features.add_clicks_each_hour(hour + 10);
}
app_features_map[kAppId] = features;
assist_ranker::RankerExample actual =
CreateRankerExample(features, 120, 4, 3, 19, 7, 17);
auto* actual_feature_map(actual.mutable_features());
EXPECT_EQ(3, (*actual_feature_map)["DayOfWeek"].int32_value());
EXPECT_EQ(19, (*actual_feature_map)["HourOfDay"].int32_value());
EXPECT_EQ(7, (*actual_feature_map)["AllClicksLastHour"].int32_value());
EXPECT_EQ(17, (*actual_feature_map)["AllClicksLast24Hours"].int32_value());
EXPECT_EQ(1, (*actual_feature_map)["AppType"].int32_value());
EXPECT_EQ(1, (*actual_feature_map)["ClickRank"].int32_value());
EXPECT_EQ(3, (*actual_feature_map)["ClicksLastHour"].int32_value());
EXPECT_EQ(4, (*actual_feature_map)["ClicksLast24Hours"].int32_value());
EXPECT_EQ(1, (*actual_feature_map)["LastLaunchedFrom"].int32_value());
EXPECT_EQ(true, (*actual_feature_map)["HasClick"].bool_value());
EXPECT_EQ(2, (*actual_feature_map)["MostRecentlyUsedIndex"].int32_value());
EXPECT_EQ(120, (*actual_feature_map)["TimeSinceLastClick"].int32_value());
EXPECT_EQ(100, (*actual_feature_map)["TotalClicks"].int32_value());
EXPECT_NEAR(20.0, (*actual_feature_map)["TotalClicksPerHour"].float_value(),
0.1);
EXPECT_EQ(4, (*actual_feature_map)["TotalHours"].int32_value());
EXPECT_EQ(std::string("chrome-extension://") + kAppId,
(*actual_feature_map)["URL"].string_value());
EXPECT_EQ(10, (*actual_feature_map)["ClicksEachHour00"].int32_value());
EXPECT_EQ(11, (*actual_feature_map)["ClicksEachHour01"].int32_value());
EXPECT_EQ(19, (*actual_feature_map)["ClicksEachHour09"].int32_value());
EXPECT_EQ(20, (*actual_feature_map)["ClicksEachHour10"].int32_value());
// Bucketizing rounds 21-29 down to 20, 31-39 down to 30.
EXPECT_EQ(20, (*actual_feature_map)["ClicksEachHour11"].int32_value());
EXPECT_EQ(20, (*actual_feature_map)["ClicksEachHour19"].int32_value());
EXPECT_EQ(30, (*actual_feature_map)["ClicksEachHour20"].int32_value());
EXPECT_EQ(30, (*actual_feature_map)["ClicksEachHour23"].int32_value());
EXPECT_NEAR(2.0, (*actual_feature_map)["ClicksPerHour00"].float_value(), 0.1);
EXPECT_NEAR(2.2, (*actual_feature_map)["ClicksPerHour01"].float_value(), 0.1);
EXPECT_NEAR(6.0, (*actual_feature_map)["ClicksPerHour23"].float_value(), 0.1);
EXPECT_EQ(10 + 11 + 12 + 13,
(*actual_feature_map)["FourHourClicks0"].int32_value());
EXPECT_EQ(30 + 30 + 30 + 30,
(*actual_feature_map)["FourHourClicks5"].int32_value());
EXPECT_EQ(10 + 11 + 12 + 13 + 14 + 15,
(*actual_feature_map)["SixHourClicks0"].int32_value());
EXPECT_EQ(20 + 20 + 30 + 30 + 30 + 30,
(*actual_feature_map)["SixHourClicks3"].int32_value());
}
} // namespace app_list
...@@ -4996,7 +4996,6 @@ test("unit_tests") { ...@@ -4996,7 +4996,6 @@ test("unit_tests") {
"../browser/ui/app_list/search/search_result_ranker/app_search_result_ranker_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/app_search_result_ranker_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/chip_ranker_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/chip_ranker_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/frecency_store_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/frecency_store_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/ml_app_rank_provider_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/ranking_item_util_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/ranking_item_util_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/recurrence_predictor_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/recurrence_predictor_unittest.cc",
"../browser/ui/app_list/search/search_result_ranker/recurrence_ranker_unittest.cc", "../browser/ui/app_list/search/search_result_ranker/recurrence_ranker_unittest.cc",
......
...@@ -34,7 +34,7 @@ enum BuiltinModelId { ...@@ -34,7 +34,7 @@ enum BuiltinModelId {
// The Smart Dim (20190221) ML model. // The Smart Dim (20190221) ML model.
SMART_DIM_20190221 = 3, SMART_DIM_20190221 = 3,
// The Top Cat (20190722) ML model. // The Top Cat (20190722) ML model.
TOP_CAT_20190722 = 4, UNSUPPORTED_TOP_CAT_20190722 = 4,
// The Smart Dim (20190521) ML model. // The Smart Dim (20190521) ML model.
SMART_DIM_20190521 = 5, SMART_DIM_20190521 = 5,
// The Search Ranker (20190923) ML model. // The Search Ranker (20190923) ML model.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment