Commit 58e2cde2 authored by Ce Chen's avatar Ce Chen Committed by Commit Bot

[omnibox] Create a dedicated task runner for on device head provider to

run all model operations.

Tested on iOS simulator & Pixel 2 XL.

Bug: 925072
Change-Id: I56c2873c132bef83d6c8636aa290eeeeb2a771b0
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1927471
Commit-Queue: Ce Chen <cch@chromium.org>
Reviewed-by: default avatarTommy Li <tommycli@chromium.org>
Cr-Commit-Position: refs/heads/master@{#720101}
parent 46c30168
......@@ -281,7 +281,10 @@ AutocompleteController::AutocompleteController(
if (provider_types & AutocompleteProvider::TYPE_ON_DEVICE_HEAD) {
on_device_head_provider_ =
OnDeviceHeadProvider::Create(provider_client_.get(), this);
providers_.push_back(on_device_head_provider_);
if (on_device_head_provider_) {
providers_.push_back(on_device_head_provider_);
on_device_head_provider_->AddModelUpdateCallback();
}
}
if (provider_types & AutocompleteProvider::TYPE_CLIPBOARD) {
#if !defined(OS_IOS)
......
......@@ -14,6 +14,7 @@
#include "base/path_service.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/post_task.h"
#include "base/task_runner_util.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "base/trace_event/trace_event.h"
#include "components/omnibox/browser/autocomplete_provider_listener.h"
......@@ -82,9 +83,21 @@ OnDeviceHeadProvider::OnDeviceHeadProvider(
: AutocompleteProvider(AutocompleteProvider::TYPE_ON_DEVICE_HEAD),
client_(client),
listener_(listener),
serving_(nullptr),
task_runner_(base::SequencedTaskRunnerHandle::Get()),
on_device_search_request_id_(0) {
worker_task_runner_(base::CreateSequencedTaskRunner(
{base::ThreadPool(), base::TaskPriority::BEST_EFFORT,
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN, base::MayBlock()})),
on_device_search_request_id_(0) {}
OnDeviceHeadProvider::~OnDeviceHeadProvider() {
worker_task_runner_->DeleteSoon(FROM_HERE, std::move(model_));
}
void OnDeviceHeadProvider::AddModelUpdateCallback() {
// Bail out if we have already subscribed.
if (model_update_subscription_) {
return;
}
auto* model_update_listener = OnDeviceModelUpdateListener::GetInstance();
if (model_update_listener) {
model_update_subscription_ = model_update_listener->AddModelUpdateCallback(
......@@ -93,10 +106,6 @@ OnDeviceHeadProvider::OnDeviceHeadProvider(
}
}
OnDeviceHeadProvider::~OnDeviceHeadProvider() {
serving_.reset();
}
bool OnDeviceHeadProvider::IsOnDeviceHeadProviderAllowed(
const AutocompleteInput& input,
const std::string& incognito_serve_mode) {
......@@ -149,35 +158,45 @@ void OnDeviceHeadProvider::Start(const AutocompleteInput& input,
return;
matches_.clear();
if (!input.text().empty() && serving_) {
done_ = false;
// Note |on_device_search_request_id_| has already been changed in |Stop|
// so we don't need to change it again here to get a new id for this
// request.
std::unique_ptr<OnDeviceHeadProviderParams> params = base::WrapUnique(
new OnDeviceHeadProviderParams(on_device_search_request_id_, input));
// Since the On Device provider usually runs much faster than online
// providers, it will be very likely users will see on device suggestions
// first and then the Omnibox UI gets refreshed to show suggestions fetched
// from server, if we issue both requests simultaneously.
// Therefore, we might want to delay the On Device suggest requests (and
// also apply a timeout to search default loader) to mitigate this issue.
// Note this delay is not needed for incognito where server suggestion is
// not served.
int delay = 0;
if (!client()->IsOffTheRecord()) {
delay = base::GetFieldTrialParamByFeatureAsInt(
omnibox::kOnDeviceHeadProvider, "DelayOnDeviceHeadSuggestRequestMs",
0);
}
task_runner_->PostDelayedTask(
FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::DoSearch,
weak_ptr_factory_.GetWeakPtr(), std::move(params)),
delay > 0 ? base::TimeDelta::FromMilliseconds(delay)
: base::TimeDelta());
if (!input.text().empty()) {
base::PostTaskAndReplyWithResult(
worker_task_runner_.get(), FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::IsModelInstanceReady, this),
base::BindOnce(&OnDeviceHeadProvider::StartInternal,
weak_ptr_factory_.GetWeakPtr(), input));
}
}
void OnDeviceHeadProvider::StartInternal(const AutocompleteInput& input,
bool is_model_instance_ready) {
if (!is_model_instance_ready)
return;
done_ = false;
// Note |on_device_search_request_id_| has already been changed in |Stop|
// so we don't need to change it again here to get a new id for this
// request.
std::unique_ptr<OnDeviceHeadProviderParams> params = base::WrapUnique(
new OnDeviceHeadProviderParams(on_device_search_request_id_, input));
// Since the On Device provider usually runs much faster than online
// providers, it will be very likely users will see on device suggestions
// first and then the Omnibox UI gets refreshed to show suggestions fetched
// from server, if we issue both requests simultaneously.
// Therefore, we might want to delay the On Device suggest requests (and
// also apply a timeout to search default loader) to mitigate this issue.
// Note this delay is not needed for incognito where server suggestion is
// not served.
int delay = 0;
if (!client()->IsOffTheRecord()) {
delay = base::GetFieldTrialParamByFeatureAsInt(
omnibox::kOnDeviceHeadProvider, "DelayOnDeviceHeadSuggestRequestMs", 0);
}
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::DoSearch,
weak_ptr_factory_.GetWeakPtr(), std::move(params)),
delay > 0 ? base::TimeDelta::FromMilliseconds(delay) : base::TimeDelta());
}
void OnDeviceHeadProvider::Stop(bool clear_cached_results,
......@@ -195,22 +214,41 @@ void OnDeviceHeadProvider::Stop(bool clear_cached_results,
void OnDeviceHeadProvider::OnModelUpdate(
const std::string& new_model_filename) {
if (new_model_filename != current_model_filename_ &&
!new_model_filename.empty()) {
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::ResetServingInstanceFromNewModel,
weak_ptr_factory_.GetWeakPtr(), new_model_filename));
}
worker_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::ResetModelInstanceFromNewModel,
this, new_model_filename));
}
void OnDeviceHeadProvider::ResetServingInstanceFromNewModel(
void OnDeviceHeadProvider::ResetModelInstanceFromNewModel(
const std::string& new_model_filename) {
if (new_model_filename.empty())
return;
current_model_filename_ = new_model_filename;
serving_ = OnDeviceHeadServing::Create(current_model_filename_,
provider_max_matches_);
model_ =
OnDeviceHeadServing::Create(new_model_filename, provider_max_matches_);
}
std::unique_ptr<OnDeviceHeadProvider::OnDeviceHeadProviderParams>
OnDeviceHeadProvider::GetSuggestionsFromModel(
std::unique_ptr<OnDeviceHeadProviderParams> params) {
if (!IsModelInstanceReady() || !params) {
if (params) {
params->failed = true;
}
return params;
}
params->creation_time = base::TimeTicks::Now();
base::string16 trimmed_input;
base::TrimWhitespace(params->input.text(), base::TRIM_ALL, &trimmed_input);
auto results = model_->GetSuggestionsForPrefix(
base::UTF16ToUTF8(base::i18n::ToLower(trimmed_input)));
params->suggestions.clear();
for (const auto& item : results) {
// The second member is the score which is not useful for provider.
params->suggestions.push_back(item.first);
}
return params;
}
void OnDeviceHeadProvider::AddProviderInfo(ProvidersInfo* provider_info) const {
......@@ -222,25 +260,17 @@ void OnDeviceHeadProvider::AddProviderInfo(ProvidersInfo* provider_info) const {
void OnDeviceHeadProvider::DoSearch(
std::unique_ptr<OnDeviceHeadProviderParams> params) {
if (serving_ && params &&
params->request_id == on_device_search_request_id_) {
params->creation_time = base::TimeTicks::Now();
base::string16 trimmed_input;
base::TrimWhitespace(params->input.text(), base::TRIM_ALL, &trimmed_input);
auto results = serving_->GetSuggestionsForPrefix(
base::UTF16ToUTF8(base::i18n::ToLower(trimmed_input)));
params->suggestions.clear();
for (const auto& item : results) {
// The second member is the score which is not useful for provider.
params->suggestions.push_back(item.first);
}
} else {
params->failed = true;
if (!params || params->request_id != on_device_search_request_id_) {
SearchDone(std::move(params));
return;
}
task_runner_->PostTask(
FROM_HERE,
base::PostTaskAndReplyWithResult(
worker_task_runner_.get(), FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::GetSuggestionsFromModel, this,
std::move(params)),
base::BindOnce(&OnDeviceHeadProvider::SearchDone,
weak_ptr_factory_.GetWeakPtr(), std::move(params)));
weak_ptr_factory_.GetWeakPtr()));
}
void OnDeviceHeadProvider::SearchDone(
......@@ -251,10 +281,15 @@ void OnDeviceHeadProvider::SearchDone(
if (!params || params->request_id != on_device_search_request_id_)
return;
if (params->failed) {
done_ = true;
return;
}
const TemplateURLService* template_url_service =
client()->GetTemplateURLService();
if (IsDefaultSearchProviderGoogle(template_url_service) && !params->failed) {
if (IsDefaultSearchProviderGoogle(template_url_service)) {
UMA_HISTOGRAM_CUSTOM_COUNTS("Omnibox.OnDeviceHeadSuggest.ResultCount",
params->suggestions.size(), 1, 5, 6);
matches_.clear();
......
......@@ -28,6 +28,12 @@ class OnDeviceHeadProvider : public AutocompleteProvider {
static OnDeviceHeadProvider* Create(AutocompleteProviderClient* client,
AutocompleteProviderListener* listener);
// Adds a callback to on device head model updater listener which will create
// the model instance |model_| once the model is ready on disk.
// This function should not be called until the provider has at least one
// reference to avoid the binding error.
void AddModelUpdateCallback();
void Start(const AutocompleteInput& input, bool minimal_changes) override;
void Stop(bool clear_cached_results, bool due_to_user_inactivity) override;
void AddProviderInfo(ProvidersInfo* provider_info) const override;
......@@ -48,6 +54,9 @@ class OnDeviceHeadProvider : public AutocompleteProvider {
bool IsOnDeviceHeadProviderAllowed(const AutocompleteInput& input,
const std::string& incognito_serve_mode);
void StartInternal(const AutocompleteInput& input,
bool is_model_instance_ready);
// Helper functions used for asynchronous search to the on device head model.
// The Autocomplete input and output from the model will be passed from
// DoSearch to SearchDone via the OnDeviceHeadProviderParams object.
......@@ -62,30 +71,34 @@ class OnDeviceHeadProvider : public AutocompleteProvider {
// is available.
void OnModelUpdate(const std::string& new_model_filename);
// Resets |serving_| if new model is available and cleans up the old model if
// Resets |model_| if new model is available and cleans up the old model if
// it exists.
void ResetServingInstanceFromNewModel(const std::string& new_model_filename);
void ResetModelInstanceFromNewModel(const std::string& new_model_filename);
bool IsModelInstanceReady() const { return model_ != nullptr; }
// Fetches suggestions matching the params from the given on device head
// model instance.
std::unique_ptr<OnDeviceHeadProviderParams> GetSuggestionsFromModel(
std::unique_ptr<OnDeviceHeadProviderParams> params);
AutocompleteProviderClient* client_;
AutocompleteProviderListener* listener_;
// The instance which does the search in the head model and returns top
// suggestions matching the Autocomplete input.
std::unique_ptr<OnDeviceHeadServing> serving_;
// The task runner dedicated for on device head model operations (including
// model instance creation, deletion and suggestions fetching) which is added
// to offload expensive operations out of the UI sequence.
scoped_refptr<base::SequencedTaskRunner> worker_task_runner_;
// The task runner instance where asynchronous operations using |serving_|
// will be run. Note that SequencedTaskRunner guarantees that operations will
// be executed in sequence so we don't need to apply a lock to |serving_|.
scoped_refptr<base::SequencedTaskRunner> task_runner_;
// The model instance which serves top suggestions matching the Autocomplete
// input and is only accessed in |worker_task_runner_|.
std::unique_ptr<OnDeviceHeadServing> model_;
// The request id used to trace current request to the on device head model.
// The id will be increased whenever a new request is received from the
// AutocompleteController.
size_t on_device_search_request_id_;
// The filename for the on device model currently being used.
std::string current_model_filename_;
// Owns the subscription after adding the model update callback to the
// listener such that the callback can be removed automatically from the
// listener on provider's deconstruction.
......
......@@ -29,6 +29,7 @@ class OnDeviceHeadProviderTest : public testing::Test,
client_.reset(new FakeAutocompleteProviderClient());
SetTestOnDeviceHeadModel();
provider_ = OnDeviceHeadProvider::Create(client_.get(), this);
provider_->AddModelUpdateCallback();
task_environment_.RunUntilIdle();
}
......@@ -55,10 +56,9 @@ class OnDeviceHeadProviderTest : public testing::Test,
task_environment_.RunUntilIdle();
}
void ResetServingInstance() {
void ResetModelInstance() {
if (provider_) {
provider_->serving_.reset();
provider_->current_model_filename_.clear();
provider_->model_.reset();
}
}
......@@ -73,12 +73,12 @@ class OnDeviceHeadProviderTest : public testing::Test,
scoped_refptr<OnDeviceHeadProvider> provider_;
};
TEST_F(OnDeviceHeadProviderTest, ServingInstanceNotCreated) {
TEST_F(OnDeviceHeadProviderTest, ModelInstanceNotCreated) {
AutocompleteInput input(base::UTF8ToUTF16("M"),
metrics::OmniboxEventProto::OTHER,
TestSchemeClassifier());
input.set_want_asynchronous_matches(true);
ResetServingInstance();
ResetModelInstance();
EXPECT_CALL(*client_.get(), IsOffTheRecord()).WillRepeatedly(Return(false));
EXPECT_CALL(*client_.get(), SearchSuggestEnabled())
......@@ -87,8 +87,7 @@ TEST_F(OnDeviceHeadProviderTest, ServingInstanceNotCreated) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, ""));
provider_->Start(input, false);
if (!provider_->done())
task_environment_.RunUntilIdle();
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->matches().empty());
EXPECT_TRUE(provider_->done());
......@@ -149,8 +148,7 @@ TEST_F(OnDeviceHeadProviderTest, NoMatches) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, ""));
provider_->Start(input, false);
if (!provider_->done())
task_environment_.RunUntilIdle();
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->matches().empty());
EXPECT_TRUE(provider_->done());
......@@ -169,8 +167,7 @@ TEST_F(OnDeviceHeadProviderTest, HasMatches) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, ""));
provider_->Start(input, false);
if (!provider_->done())
task_environment_.RunUntilIdle();
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->done());
ASSERT_EQ(3U, provider_->matches().size());
......@@ -197,11 +194,8 @@ TEST_F(OnDeviceHeadProviderTest, CancelInProgressRequest) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input2, ""));
provider_->Start(input1, false);
EXPECT_FALSE(provider_->done());
provider_->Start(input2, false);
if (!provider_->done())
task_environment_.RunUntilIdle();
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->done());
ASSERT_EQ(3U, provider_->matches().size());
......
......@@ -89,12 +89,13 @@ OnDeviceHeadServing::GetSuggestionsForPrefix(const std::string& prefix) {
return suggestions;
}
OpenModelFileStream(kRootNodeOffset);
MatchCandidate start_match;
if (FindStartNode(prefix, &start_match)) {
suggestions = DoSearch(start_match);
if (OpenModelFileStream(kRootNodeOffset)) {
MatchCandidate start_match;
if (FindStartNode(prefix, &start_match)) {
suggestions = DoSearch(start_match);
}
MaybeCloseModelFileStream();
}
MaybeCloseModelFileStream();
return suggestions;
}
......
......@@ -66,6 +66,7 @@
// The size of score and address will be given in the first two bytes of the
// model file.
// TODO(crbug.com/925072): rename OnDeviceHeadServing to *Model.
class OnDeviceHeadServing {
public:
// Creates and returns an instance for serving on device head model.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment