Commit 58e2cde2 authored by Ce Chen's avatar Ce Chen Committed by Commit Bot

[omnibox] Create a dedicated task runner for on device head provider to

run all model operations.

Tested on iOS simulator & Pixel 2 XL.

Bug: 925072
Change-Id: I56c2873c132bef83d6c8636aa290eeeeb2a771b0
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1927471
Commit-Queue: Ce Chen <cch@chromium.org>
Reviewed-by: default avatarTommy Li <tommycli@chromium.org>
Cr-Commit-Position: refs/heads/master@{#720101}
parent 46c30168
...@@ -281,7 +281,10 @@ AutocompleteController::AutocompleteController( ...@@ -281,7 +281,10 @@ AutocompleteController::AutocompleteController(
if (provider_types & AutocompleteProvider::TYPE_ON_DEVICE_HEAD) { if (provider_types & AutocompleteProvider::TYPE_ON_DEVICE_HEAD) {
on_device_head_provider_ = on_device_head_provider_ =
OnDeviceHeadProvider::Create(provider_client_.get(), this); OnDeviceHeadProvider::Create(provider_client_.get(), this);
providers_.push_back(on_device_head_provider_); if (on_device_head_provider_) {
providers_.push_back(on_device_head_provider_);
on_device_head_provider_->AddModelUpdateCallback();
}
} }
if (provider_types & AutocompleteProvider::TYPE_CLIPBOARD) { if (provider_types & AutocompleteProvider::TYPE_CLIPBOARD) {
#if !defined(OS_IOS) #if !defined(OS_IOS)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "base/path_service.h" #include "base/path_service.h"
#include "base/strings/utf_string_conversions.h" #include "base/strings/utf_string_conversions.h"
#include "base/task/post_task.h" #include "base/task/post_task.h"
#include "base/task_runner_util.h"
#include "base/threading/sequenced_task_runner_handle.h" #include "base/threading/sequenced_task_runner_handle.h"
#include "base/trace_event/trace_event.h" #include "base/trace_event/trace_event.h"
#include "components/omnibox/browser/autocomplete_provider_listener.h" #include "components/omnibox/browser/autocomplete_provider_listener.h"
...@@ -82,9 +83,21 @@ OnDeviceHeadProvider::OnDeviceHeadProvider( ...@@ -82,9 +83,21 @@ OnDeviceHeadProvider::OnDeviceHeadProvider(
: AutocompleteProvider(AutocompleteProvider::TYPE_ON_DEVICE_HEAD), : AutocompleteProvider(AutocompleteProvider::TYPE_ON_DEVICE_HEAD),
client_(client), client_(client),
listener_(listener), listener_(listener),
serving_(nullptr), worker_task_runner_(base::CreateSequencedTaskRunner(
task_runner_(base::SequencedTaskRunnerHandle::Get()), {base::ThreadPool(), base::TaskPriority::BEST_EFFORT,
on_device_search_request_id_(0) { base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN, base::MayBlock()})),
on_device_search_request_id_(0) {}
OnDeviceHeadProvider::~OnDeviceHeadProvider() {
worker_task_runner_->DeleteSoon(FROM_HERE, std::move(model_));
}
void OnDeviceHeadProvider::AddModelUpdateCallback() {
// Bail out if we have already subscribed.
if (model_update_subscription_) {
return;
}
auto* model_update_listener = OnDeviceModelUpdateListener::GetInstance(); auto* model_update_listener = OnDeviceModelUpdateListener::GetInstance();
if (model_update_listener) { if (model_update_listener) {
model_update_subscription_ = model_update_listener->AddModelUpdateCallback( model_update_subscription_ = model_update_listener->AddModelUpdateCallback(
...@@ -93,10 +106,6 @@ OnDeviceHeadProvider::OnDeviceHeadProvider( ...@@ -93,10 +106,6 @@ OnDeviceHeadProvider::OnDeviceHeadProvider(
} }
} }
OnDeviceHeadProvider::~OnDeviceHeadProvider() {
serving_.reset();
}
bool OnDeviceHeadProvider::IsOnDeviceHeadProviderAllowed( bool OnDeviceHeadProvider::IsOnDeviceHeadProviderAllowed(
const AutocompleteInput& input, const AutocompleteInput& input,
const std::string& incognito_serve_mode) { const std::string& incognito_serve_mode) {
...@@ -149,35 +158,45 @@ void OnDeviceHeadProvider::Start(const AutocompleteInput& input, ...@@ -149,35 +158,45 @@ void OnDeviceHeadProvider::Start(const AutocompleteInput& input,
return; return;
matches_.clear(); matches_.clear();
if (!input.text().empty() && serving_) { if (!input.text().empty()) {
done_ = false; base::PostTaskAndReplyWithResult(
// Note |on_device_search_request_id_| has already been changed in |Stop| worker_task_runner_.get(), FROM_HERE,
// so we don't need to change it again here to get a new id for this base::BindOnce(&OnDeviceHeadProvider::IsModelInstanceReady, this),
// request. base::BindOnce(&OnDeviceHeadProvider::StartInternal,
std::unique_ptr<OnDeviceHeadProviderParams> params = base::WrapUnique( weak_ptr_factory_.GetWeakPtr(), input));
new OnDeviceHeadProviderParams(on_device_search_request_id_, input)); }
}
// Since the On Device provider usually runs much faster than online
// providers, it will be very likely users will see on device suggestions void OnDeviceHeadProvider::StartInternal(const AutocompleteInput& input,
// first and then the Omnibox UI gets refreshed to show suggestions fetched bool is_model_instance_ready) {
// from server, if we issue both requests simultaneously. if (!is_model_instance_ready)
// Therefore, we might want to delay the On Device suggest requests (and return;
// also apply a timeout to search default loader) to mitigate this issue.
// Note this delay is not needed for incognito where server suggestion is done_ = false;
// not served. // Note |on_device_search_request_id_| has already been changed in |Stop|
int delay = 0; // so we don't need to change it again here to get a new id for this
if (!client()->IsOffTheRecord()) { // request.
delay = base::GetFieldTrialParamByFeatureAsInt( std::unique_ptr<OnDeviceHeadProviderParams> params = base::WrapUnique(
omnibox::kOnDeviceHeadProvider, "DelayOnDeviceHeadSuggestRequestMs", new OnDeviceHeadProviderParams(on_device_search_request_id_, input));
0);
} // Since the On Device provider usually runs much faster than online
task_runner_->PostDelayedTask( // providers, it will be very likely users will see on device suggestions
FROM_HERE, // first and then the Omnibox UI gets refreshed to show suggestions fetched
base::BindOnce(&OnDeviceHeadProvider::DoSearch, // from server, if we issue both requests simultaneously.
weak_ptr_factory_.GetWeakPtr(), std::move(params)), // Therefore, we might want to delay the On Device suggest requests (and
delay > 0 ? base::TimeDelta::FromMilliseconds(delay) // also apply a timeout to search default loader) to mitigate this issue.
: base::TimeDelta()); // Note this delay is not needed for incognito where server suggestion is
// not served.
int delay = 0;
if (!client()->IsOffTheRecord()) {
delay = base::GetFieldTrialParamByFeatureAsInt(
omnibox::kOnDeviceHeadProvider, "DelayOnDeviceHeadSuggestRequestMs", 0);
} }
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::DoSearch,
weak_ptr_factory_.GetWeakPtr(), std::move(params)),
delay > 0 ? base::TimeDelta::FromMilliseconds(delay) : base::TimeDelta());
} }
void OnDeviceHeadProvider::Stop(bool clear_cached_results, void OnDeviceHeadProvider::Stop(bool clear_cached_results,
...@@ -195,22 +214,41 @@ void OnDeviceHeadProvider::Stop(bool clear_cached_results, ...@@ -195,22 +214,41 @@ void OnDeviceHeadProvider::Stop(bool clear_cached_results,
void OnDeviceHeadProvider::OnModelUpdate( void OnDeviceHeadProvider::OnModelUpdate(
const std::string& new_model_filename) { const std::string& new_model_filename) {
if (new_model_filename != current_model_filename_ && worker_task_runner_->PostTask(
!new_model_filename.empty()) { FROM_HERE,
task_runner_->PostTask( base::BindOnce(&OnDeviceHeadProvider::ResetModelInstanceFromNewModel,
FROM_HERE, this, new_model_filename));
base::BindOnce(&OnDeviceHeadProvider::ResetServingInstanceFromNewModel,
weak_ptr_factory_.GetWeakPtr(), new_model_filename));
}
} }
void OnDeviceHeadProvider::ResetServingInstanceFromNewModel( void OnDeviceHeadProvider::ResetModelInstanceFromNewModel(
const std::string& new_model_filename) { const std::string& new_model_filename) {
if (new_model_filename.empty()) if (new_model_filename.empty())
return; return;
current_model_filename_ = new_model_filename; model_ =
serving_ = OnDeviceHeadServing::Create(current_model_filename_, OnDeviceHeadServing::Create(new_model_filename, provider_max_matches_);
provider_max_matches_); }
std::unique_ptr<OnDeviceHeadProvider::OnDeviceHeadProviderParams>
OnDeviceHeadProvider::GetSuggestionsFromModel(
std::unique_ptr<OnDeviceHeadProviderParams> params) {
if (!IsModelInstanceReady() || !params) {
if (params) {
params->failed = true;
}
return params;
}
params->creation_time = base::TimeTicks::Now();
base::string16 trimmed_input;
base::TrimWhitespace(params->input.text(), base::TRIM_ALL, &trimmed_input);
auto results = model_->GetSuggestionsForPrefix(
base::UTF16ToUTF8(base::i18n::ToLower(trimmed_input)));
params->suggestions.clear();
for (const auto& item : results) {
// The second member is the score which is not useful for provider.
params->suggestions.push_back(item.first);
}
return params;
} }
void OnDeviceHeadProvider::AddProviderInfo(ProvidersInfo* provider_info) const { void OnDeviceHeadProvider::AddProviderInfo(ProvidersInfo* provider_info) const {
...@@ -222,25 +260,17 @@ void OnDeviceHeadProvider::AddProviderInfo(ProvidersInfo* provider_info) const { ...@@ -222,25 +260,17 @@ void OnDeviceHeadProvider::AddProviderInfo(ProvidersInfo* provider_info) const {
void OnDeviceHeadProvider::DoSearch( void OnDeviceHeadProvider::DoSearch(
std::unique_ptr<OnDeviceHeadProviderParams> params) { std::unique_ptr<OnDeviceHeadProviderParams> params) {
if (serving_ && params && if (!params || params->request_id != on_device_search_request_id_) {
params->request_id == on_device_search_request_id_) { SearchDone(std::move(params));
params->creation_time = base::TimeTicks::Now(); return;
base::string16 trimmed_input;
base::TrimWhitespace(params->input.text(), base::TRIM_ALL, &trimmed_input);
auto results = serving_->GetSuggestionsForPrefix(
base::UTF16ToUTF8(base::i18n::ToLower(trimmed_input)));
params->suggestions.clear();
for (const auto& item : results) {
// The second member is the score which is not useful for provider.
params->suggestions.push_back(item.first);
}
} else {
params->failed = true;
} }
task_runner_->PostTask(
FROM_HERE, base::PostTaskAndReplyWithResult(
worker_task_runner_.get(), FROM_HERE,
base::BindOnce(&OnDeviceHeadProvider::GetSuggestionsFromModel, this,
std::move(params)),
base::BindOnce(&OnDeviceHeadProvider::SearchDone, base::BindOnce(&OnDeviceHeadProvider::SearchDone,
weak_ptr_factory_.GetWeakPtr(), std::move(params))); weak_ptr_factory_.GetWeakPtr()));
} }
void OnDeviceHeadProvider::SearchDone( void OnDeviceHeadProvider::SearchDone(
...@@ -251,10 +281,15 @@ void OnDeviceHeadProvider::SearchDone( ...@@ -251,10 +281,15 @@ void OnDeviceHeadProvider::SearchDone(
if (!params || params->request_id != on_device_search_request_id_) if (!params || params->request_id != on_device_search_request_id_)
return; return;
if (params->failed) {
done_ = true;
return;
}
const TemplateURLService* template_url_service = const TemplateURLService* template_url_service =
client()->GetTemplateURLService(); client()->GetTemplateURLService();
if (IsDefaultSearchProviderGoogle(template_url_service) && !params->failed) { if (IsDefaultSearchProviderGoogle(template_url_service)) {
UMA_HISTOGRAM_CUSTOM_COUNTS("Omnibox.OnDeviceHeadSuggest.ResultCount", UMA_HISTOGRAM_CUSTOM_COUNTS("Omnibox.OnDeviceHeadSuggest.ResultCount",
params->suggestions.size(), 1, 5, 6); params->suggestions.size(), 1, 5, 6);
matches_.clear(); matches_.clear();
......
...@@ -28,6 +28,12 @@ class OnDeviceHeadProvider : public AutocompleteProvider { ...@@ -28,6 +28,12 @@ class OnDeviceHeadProvider : public AutocompleteProvider {
static OnDeviceHeadProvider* Create(AutocompleteProviderClient* client, static OnDeviceHeadProvider* Create(AutocompleteProviderClient* client,
AutocompleteProviderListener* listener); AutocompleteProviderListener* listener);
// Adds a callback to on device head model updater listener which will create
// the model instance |model_| once the model is ready on disk.
// This function should not be called until the provider has at least one
// reference to avoid the binding error.
void AddModelUpdateCallback();
void Start(const AutocompleteInput& input, bool minimal_changes) override; void Start(const AutocompleteInput& input, bool minimal_changes) override;
void Stop(bool clear_cached_results, bool due_to_user_inactivity) override; void Stop(bool clear_cached_results, bool due_to_user_inactivity) override;
void AddProviderInfo(ProvidersInfo* provider_info) const override; void AddProviderInfo(ProvidersInfo* provider_info) const override;
...@@ -48,6 +54,9 @@ class OnDeviceHeadProvider : public AutocompleteProvider { ...@@ -48,6 +54,9 @@ class OnDeviceHeadProvider : public AutocompleteProvider {
bool IsOnDeviceHeadProviderAllowed(const AutocompleteInput& input, bool IsOnDeviceHeadProviderAllowed(const AutocompleteInput& input,
const std::string& incognito_serve_mode); const std::string& incognito_serve_mode);
void StartInternal(const AutocompleteInput& input,
bool is_model_instance_ready);
// Helper functions used for asynchronous search to the on device head model. // Helper functions used for asynchronous search to the on device head model.
// The Autocomplete input and output from the model will be passed from // The Autocomplete input and output from the model will be passed from
// DoSearch to SearchDone via the OnDeviceHeadProviderParams object. // DoSearch to SearchDone via the OnDeviceHeadProviderParams object.
...@@ -62,30 +71,34 @@ class OnDeviceHeadProvider : public AutocompleteProvider { ...@@ -62,30 +71,34 @@ class OnDeviceHeadProvider : public AutocompleteProvider {
// is available. // is available.
void OnModelUpdate(const std::string& new_model_filename); void OnModelUpdate(const std::string& new_model_filename);
// Resets |serving_| if new model is available and cleans up the old model if // Resets |model_| if new model is available and cleans up the old model if
// it exists. // it exists.
void ResetServingInstanceFromNewModel(const std::string& new_model_filename); void ResetModelInstanceFromNewModel(const std::string& new_model_filename);
bool IsModelInstanceReady() const { return model_ != nullptr; }
// Fetches suggestions matching the params from the given on device head
// model instance.
std::unique_ptr<OnDeviceHeadProviderParams> GetSuggestionsFromModel(
std::unique_ptr<OnDeviceHeadProviderParams> params);
AutocompleteProviderClient* client_; AutocompleteProviderClient* client_;
AutocompleteProviderListener* listener_; AutocompleteProviderListener* listener_;
// The instance which does the search in the head model and returns top // The task runner dedicated for on device head model operations (including
// suggestions matching the Autocomplete input. // model instance creation, deletion and suggestions fetching) which is added
std::unique_ptr<OnDeviceHeadServing> serving_; // to offload expensive operations out of the UI sequence.
scoped_refptr<base::SequencedTaskRunner> worker_task_runner_;
// The task runner instance where asynchronous operations using |serving_| // The model instance which serves top suggestions matching the Autocomplete
// will be run. Note that SequencedTaskRunner guarantees that operations will // input and is only accessed in |worker_task_runner_|.
// be executed in sequence so we don't need to apply a lock to |serving_|. std::unique_ptr<OnDeviceHeadServing> model_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
// The request id used to trace current request to the on device head model. // The request id used to trace current request to the on device head model.
// The id will be increased whenever a new request is received from the // The id will be increased whenever a new request is received from the
// AutocompleteController. // AutocompleteController.
size_t on_device_search_request_id_; size_t on_device_search_request_id_;
// The filename for the on device model currently being used.
std::string current_model_filename_;
// Owns the subscription after adding the model update callback to the // Owns the subscription after adding the model update callback to the
// listener such that the callback can be removed automatically from the // listener such that the callback can be removed automatically from the
// listener on provider's deconstruction. // listener on provider's deconstruction.
......
...@@ -29,6 +29,7 @@ class OnDeviceHeadProviderTest : public testing::Test, ...@@ -29,6 +29,7 @@ class OnDeviceHeadProviderTest : public testing::Test,
client_.reset(new FakeAutocompleteProviderClient()); client_.reset(new FakeAutocompleteProviderClient());
SetTestOnDeviceHeadModel(); SetTestOnDeviceHeadModel();
provider_ = OnDeviceHeadProvider::Create(client_.get(), this); provider_ = OnDeviceHeadProvider::Create(client_.get(), this);
provider_->AddModelUpdateCallback();
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
} }
...@@ -55,10 +56,9 @@ class OnDeviceHeadProviderTest : public testing::Test, ...@@ -55,10 +56,9 @@ class OnDeviceHeadProviderTest : public testing::Test,
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
} }
void ResetServingInstance() { void ResetModelInstance() {
if (provider_) { if (provider_) {
provider_->serving_.reset(); provider_->model_.reset();
provider_->current_model_filename_.clear();
} }
} }
...@@ -73,12 +73,12 @@ class OnDeviceHeadProviderTest : public testing::Test, ...@@ -73,12 +73,12 @@ class OnDeviceHeadProviderTest : public testing::Test,
scoped_refptr<OnDeviceHeadProvider> provider_; scoped_refptr<OnDeviceHeadProvider> provider_;
}; };
TEST_F(OnDeviceHeadProviderTest, ServingInstanceNotCreated) { TEST_F(OnDeviceHeadProviderTest, ModelInstanceNotCreated) {
AutocompleteInput input(base::UTF8ToUTF16("M"), AutocompleteInput input(base::UTF8ToUTF16("M"),
metrics::OmniboxEventProto::OTHER, metrics::OmniboxEventProto::OTHER,
TestSchemeClassifier()); TestSchemeClassifier());
input.set_want_asynchronous_matches(true); input.set_want_asynchronous_matches(true);
ResetServingInstance(); ResetModelInstance();
EXPECT_CALL(*client_.get(), IsOffTheRecord()).WillRepeatedly(Return(false)); EXPECT_CALL(*client_.get(), IsOffTheRecord()).WillRepeatedly(Return(false));
EXPECT_CALL(*client_.get(), SearchSuggestEnabled()) EXPECT_CALL(*client_.get(), SearchSuggestEnabled())
...@@ -87,8 +87,7 @@ TEST_F(OnDeviceHeadProviderTest, ServingInstanceNotCreated) { ...@@ -87,8 +87,7 @@ TEST_F(OnDeviceHeadProviderTest, ServingInstanceNotCreated) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, "")); ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, ""));
provider_->Start(input, false); provider_->Start(input, false);
if (!provider_->done()) task_environment_.RunUntilIdle();
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->matches().empty()); EXPECT_TRUE(provider_->matches().empty());
EXPECT_TRUE(provider_->done()); EXPECT_TRUE(provider_->done());
...@@ -149,8 +148,7 @@ TEST_F(OnDeviceHeadProviderTest, NoMatches) { ...@@ -149,8 +148,7 @@ TEST_F(OnDeviceHeadProviderTest, NoMatches) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, "")); ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, ""));
provider_->Start(input, false); provider_->Start(input, false);
if (!provider_->done()) task_environment_.RunUntilIdle();
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->matches().empty()); EXPECT_TRUE(provider_->matches().empty());
EXPECT_TRUE(provider_->done()); EXPECT_TRUE(provider_->done());
...@@ -169,8 +167,7 @@ TEST_F(OnDeviceHeadProviderTest, HasMatches) { ...@@ -169,8 +167,7 @@ TEST_F(OnDeviceHeadProviderTest, HasMatches) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, "")); ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input, ""));
provider_->Start(input, false); provider_->Start(input, false);
if (!provider_->done()) task_environment_.RunUntilIdle();
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->done()); EXPECT_TRUE(provider_->done());
ASSERT_EQ(3U, provider_->matches().size()); ASSERT_EQ(3U, provider_->matches().size());
...@@ -197,11 +194,8 @@ TEST_F(OnDeviceHeadProviderTest, CancelInProgressRequest) { ...@@ -197,11 +194,8 @@ TEST_F(OnDeviceHeadProviderTest, CancelInProgressRequest) {
ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input2, "")); ASSERT_TRUE(IsOnDeviceHeadProviderAllowed(input2, ""));
provider_->Start(input1, false); provider_->Start(input1, false);
EXPECT_FALSE(provider_->done());
provider_->Start(input2, false); provider_->Start(input2, false);
task_environment_.RunUntilIdle();
if (!provider_->done())
task_environment_.RunUntilIdle();
EXPECT_TRUE(provider_->done()); EXPECT_TRUE(provider_->done());
ASSERT_EQ(3U, provider_->matches().size()); ASSERT_EQ(3U, provider_->matches().size());
......
...@@ -89,12 +89,13 @@ OnDeviceHeadServing::GetSuggestionsForPrefix(const std::string& prefix) { ...@@ -89,12 +89,13 @@ OnDeviceHeadServing::GetSuggestionsForPrefix(const std::string& prefix) {
return suggestions; return suggestions;
} }
OpenModelFileStream(kRootNodeOffset); if (OpenModelFileStream(kRootNodeOffset)) {
MatchCandidate start_match; MatchCandidate start_match;
if (FindStartNode(prefix, &start_match)) { if (FindStartNode(prefix, &start_match)) {
suggestions = DoSearch(start_match); suggestions = DoSearch(start_match);
}
MaybeCloseModelFileStream();
} }
MaybeCloseModelFileStream();
return suggestions; return suggestions;
} }
......
...@@ -66,6 +66,7 @@ ...@@ -66,6 +66,7 @@
// The size of score and address will be given in the first two bytes of the // The size of score and address will be given in the first two bytes of the
// model file. // model file.
// TODO(crbug.com/925072): rename OnDeviceHeadServing to *Model.
class OnDeviceHeadServing { class OnDeviceHeadServing {
public: public:
// Creates and returns an instance for serving on device head model. // Creates and returns an instance for serving on device head model.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment