Commit a95db0a3 authored by alanlxl's avatar alanlxl Committed by Chromium LUCI CQ

Poke ServiceConnection::Initialize in PostBrowserStart

To make sure ServiceConnection is bound to UI sequence task_runner.

Bug: chromium:916760
Test: pass the unit_test
Change-Id: Ia03094fca3dcaa86774520f681b522f77508d4e6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2631805Reviewed-by: default avatarXiyuan Xia <xiyuan@chromium.org>
Reviewed-by: default avatarXinglong Luan <alanlxl@chromium.org>
Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Commit-Queue: Xinglong Luan <alanlxl@chromium.org>
Cr-Commit-Position: refs/heads/master@{#845344}
parent 10ece731
...@@ -176,6 +176,7 @@ ...@@ -176,6 +176,7 @@
#include "chromeos/network/network_handler.h" #include "chromeos/network/network_handler.h"
#include "chromeos/network/portal_detector/network_portal_detector_stub.h" #include "chromeos/network/portal_detector/network_portal_detector_stub.h"
#include "chromeos/services/cros_healthd/public/cpp/service_connection.h" #include "chromeos/services/cros_healthd/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/system/statistics_provider.h" #include "chromeos/system/statistics_provider.h"
#include "chromeos/tpm/install_attributes.h" #include "chromeos/tpm/install_attributes.h"
#include "chromeos/tpm/tpm_token_loader.h" #include "chromeos/tpm/tpm_token_loader.h"
...@@ -1120,6 +1121,8 @@ void ChromeBrowserMainPartsChromeos::PostBrowserStart() { ...@@ -1120,6 +1121,8 @@ void ChromeBrowserMainPartsChromeos::PostBrowserStart() {
dark_resume_controller_ = std::make_unique<system::DarkResumeController>( dark_resume_controller_ = std::make_unique<system::DarkResumeController>(
std::move(wake_lock_provider)); std::move(wake_lock_provider));
chromeos::machine_learning::ServiceConnection::GetInstance()->Initialize();
ChromeBrowserMainPartsLinux::PostBrowserStart(); ChromeBrowserMainPartsLinux::PostBrowserStart();
} }
......
...@@ -34,6 +34,8 @@ void FakeServiceConnectionImpl::Clone( ...@@ -34,6 +34,8 @@ void FakeServiceConnectionImpl::Clone(
clone_ml_service_receivers_.Add(this, std::move(receiver)); clone_ml_service_receivers_.Add(this, std::move(receiver));
} }
void FakeServiceConnectionImpl::Initialize() {}
void FakeServiceConnectionImpl::LoadBuiltinModel( void FakeServiceConnectionImpl::LoadBuiltinModel(
mojom::BuiltinModelSpecPtr spec, mojom::BuiltinModelSpecPtr spec,
mojo::PendingReceiver<mojom::Model> receiver, mojo::PendingReceiver<mojom::Model> receiver,
......
...@@ -49,6 +49,7 @@ class FakeServiceConnectionImpl : public ServiceConnection, ...@@ -49,6 +49,7 @@ class FakeServiceConnectionImpl : public ServiceConnection,
// ServiceConnection: // ServiceConnection:
void BindMachineLearningService( void BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) override; mojo::PendingReceiver<mojom::MachineLearningService> receiver) override;
void Initialize() override;
// mojom::MachineLearningService: // mojom::MachineLearningService:
void Clone( void Clone(
...@@ -228,8 +229,9 @@ class FakeServiceConnectionImpl : public ServiceConnection, ...@@ -228,8 +229,9 @@ class FakeServiceConnectionImpl : public ServiceConnection,
void HandleLoadGrammarCheckerCall( void HandleLoadGrammarCheckerCall(
mojo::PendingReceiver<mojom::GrammarChecker> receiver, mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback callback); mojom::MachineLearningService::LoadGrammarCheckerCallback callback);
void HandleGrammarCheckerQueryCall(mojom::GrammarCheckerQueryPtr query, void HandleGrammarCheckerQueryCall(
mojom::GrammarChecker::CheckCallback callback); mojom::GrammarCheckerQueryPtr query,
mojom::GrammarChecker::CheckCallback callback);
void HandleLoadSpeechRecognizerCall( void HandleLoadSpeechRecognizerCall(
mojo::PendingRemote<mojom::SodaClient> soda_client, mojo::PendingRemote<mojom::SodaClient> soda_client,
mojo::PendingReceiver<mojom::SodaRecognizer> soda_recognizer, mojo::PendingReceiver<mojom::SodaRecognizer> soda_recognizer,
......
...@@ -31,6 +31,8 @@ class ServiceConnectionImpl : public ServiceConnection { ...@@ -31,6 +31,8 @@ class ServiceConnectionImpl : public ServiceConnection {
void BindMachineLearningService( void BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) override; mojo::PendingReceiver<mojom::MachineLearningService> receiver) override;
void Initialize() override;
void LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec, void LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec,
mojo::PendingReceiver<mojom::Model> receiver, mojo::PendingReceiver<mojom::Model> receiver,
mojom::MachineLearningService::LoadBuiltinModelCallback mojom::MachineLearningService::LoadBuiltinModelCallback
...@@ -72,10 +74,10 @@ class ServiceConnectionImpl : public ServiceConnection { ...@@ -72,10 +74,10 @@ class ServiceConnectionImpl : public ServiceConnection {
override; override;
private: private:
// Binds the top level interface |machine_learning_service_| to an // Binds the primordial, top-level interface |machine_learning_service_| to an
// implementation in the ML Service daemon, if it is not already bound. The // implementation in the ML Service daemon, if it is not already bound. The
// binding is accomplished via D-Bus bootstrap. // binding is accomplished via D-Bus bootstrap.
void BindMachineLearningServiceIfNeeded(); void BindPrimordialMachineLearningServiceIfNeeded();
// Mojo disconnect handler. Resets |machine_learning_service_|, which // Mojo disconnect handler. Resets |machine_learning_service_|, which
// will be reconnected upon next use. // will be reconnected upon next use.
...@@ -94,6 +96,8 @@ class ServiceConnectionImpl : public ServiceConnection { ...@@ -94,6 +96,8 @@ class ServiceConnectionImpl : public ServiceConnection {
void ServiceConnectionImpl::BindMachineLearningService( void ServiceConnectionImpl::BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) { mojo::PendingReceiver<mojom::MachineLearningService> receiver) {
DCHECK(task_runner_)
<< "Call Initialize before first use of ServiceConnection.";
if (!task_runner_->RunsTasksInCurrentSequence()) { if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask( task_runner_->PostTask(
FROM_HERE, FROM_HERE,
...@@ -102,10 +106,17 @@ void ServiceConnectionImpl::BindMachineLearningService( ...@@ -102,10 +106,17 @@ void ServiceConnectionImpl::BindMachineLearningService(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->Clone(std::move(receiver)); machine_learning_service_->Clone(std::move(receiver));
} }
void ServiceConnectionImpl::Initialize() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(!task_runner_) << "Initialize must be called only once.";
task_runner_ = base::SequencedTaskRunnerHandle::Get();
}
void ServiceConnectionImpl::LoadBuiltinModel( void ServiceConnectionImpl::LoadBuiltinModel(
mojom::BuiltinModelSpecPtr spec, mojom::BuiltinModelSpecPtr spec,
mojo::PendingReceiver<mojom::Model> receiver, mojo::PendingReceiver<mojom::Model> receiver,
...@@ -119,7 +130,7 @@ void ServiceConnectionImpl::LoadBuiltinModel( ...@@ -119,7 +130,7 @@ void ServiceConnectionImpl::LoadBuiltinModel(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->LoadBuiltinModel( machine_learning_service_->LoadBuiltinModel(
std::move(spec), std::move(receiver), std::move(result_callback)); std::move(spec), std::move(receiver), std::move(result_callback));
} }
...@@ -138,7 +149,7 @@ void ServiceConnectionImpl::LoadFlatBufferModel( ...@@ -138,7 +149,7 @@ void ServiceConnectionImpl::LoadFlatBufferModel(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->LoadFlatBufferModel( machine_learning_service_->LoadFlatBufferModel(
std::move(spec), std::move(receiver), std::move(result_callback)); std::move(spec), std::move(receiver), std::move(result_callback));
} }
...@@ -154,7 +165,7 @@ void ServiceConnectionImpl::LoadTextClassifier( ...@@ -154,7 +165,7 @@ void ServiceConnectionImpl::LoadTextClassifier(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->LoadTextClassifier(std::move(receiver), machine_learning_service_->LoadTextClassifier(std::move(receiver),
std::move(result_callback)); std::move(result_callback));
} }
...@@ -173,7 +184,7 @@ void ServiceConnectionImpl::LoadHandwritingModel( ...@@ -173,7 +184,7 @@ void ServiceConnectionImpl::LoadHandwritingModel(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->LoadHandwritingModel( machine_learning_service_->LoadHandwritingModel(
std::move(spec), std::move(receiver), std::move(result_callback)); std::move(spec), std::move(receiver), std::move(result_callback));
} }
...@@ -192,7 +203,7 @@ void ServiceConnectionImpl::LoadHandwritingModelWithSpec( ...@@ -192,7 +203,7 @@ void ServiceConnectionImpl::LoadHandwritingModelWithSpec(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->LoadHandwritingModelWithSpec( machine_learning_service_->LoadHandwritingModelWithSpec(
std::move(spec), std::move(receiver), std::move(result_callback)); std::move(spec), std::move(receiver), std::move(result_callback));
} }
...@@ -208,7 +219,7 @@ void ServiceConnectionImpl::LoadGrammarChecker( ...@@ -208,7 +219,7 @@ void ServiceConnectionImpl::LoadGrammarChecker(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->LoadGrammarChecker(std::move(receiver), machine_learning_service_->LoadGrammarChecker(std::move(receiver),
std::move(result_callback)); std::move(result_callback));
} }
...@@ -228,13 +239,13 @@ void ServiceConnectionImpl::LoadSpeechRecognizer( ...@@ -228,13 +239,13 @@ void ServiceConnectionImpl::LoadSpeechRecognizer(
return; return;
} }
BindMachineLearningServiceIfNeeded(); BindPrimordialMachineLearningServiceIfNeeded();
machine_learning_service_->LoadSpeechRecognizer( machine_learning_service_->LoadSpeechRecognizer(
std::move(soda_config), std::move(soda_client), std::move(soda_config), std::move(soda_client),
std::move(soda_recognizer), std::move(callback)); std::move(soda_recognizer), std::move(callback));
} }
void ServiceConnectionImpl::BindMachineLearningServiceIfNeeded() { void ServiceConnectionImpl::BindPrimordialMachineLearningServiceIfNeeded() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (machine_learning_service_) { if (machine_learning_service_) {
return; return;
...@@ -267,8 +278,7 @@ void ServiceConnectionImpl::BindMachineLearningServiceIfNeeded() { ...@@ -267,8 +278,7 @@ void ServiceConnectionImpl::BindMachineLearningServiceIfNeeded() {
base::Unretained(this))); base::Unretained(this)));
} }
ServiceConnectionImpl::ServiceConnectionImpl() ServiceConnectionImpl::ServiceConnectionImpl() {
: task_runner_(base::SequencedTaskRunnerHandle::Get()) {
DETACH_FROM_SEQUENCE(sequence_checker_); DETACH_FROM_SEQUENCE(sequence_checker_);
} }
......
...@@ -53,10 +53,15 @@ class ServiceConnection { ...@@ -53,10 +53,15 @@ class ServiceConnection {
static void UseFakeServiceConnectionForTesting( static void UseFakeServiceConnectionForTesting(
ServiceConnection* fake_service_connection); ServiceConnection* fake_service_connection);
// Binds the receiver to the implementation in the ml_service daemon. // Binds the receiver to a Clone of the primordial top-level interface.
// May be called from any sequence.
virtual void BindMachineLearningService( virtual void BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) = 0; mojo::PendingReceiver<mojom::MachineLearningService> receiver) = 0;
// Call this once at startup (e.g. PostBrowserStart) on the sequence that
// should own the Mojo connection to MachineLearningService (e.g. UI thread).
virtual void Initialize() = 0;
// Instruct ML daemon to load the builtin model specified in |spec|, binding a // Instruct ML daemon to load the builtin model specified in |spec|, binding a
// Model implementation to |receiver|. Bootstraps the initial Mojo connection // Model implementation to |receiver|. Bootstraps the initial Mojo connection
// to the daemon if necessary. // to the daemon if necessary.
......
...@@ -57,6 +57,7 @@ class ServiceConnectionTest : public testing::Test { ...@@ -57,6 +57,7 @@ class ServiceConnectionTest : public testing::Test {
// Tests that LoadBuiltinModel runs OK (no crash) in a basic Mojo // Tests that LoadBuiltinModel runs OK (no crash) in a basic Mojo
// environment. // environment.
TEST_F(ServiceConnectionTest, LoadBuiltinModel) { TEST_F(ServiceConnectionTest, LoadBuiltinModel) {
ServiceConnection::GetInstance()->Initialize();
mojo::Remote<mojom::Model> model; mojo::Remote<mojom::Model> model;
mojom::BuiltinModelSpecPtr spec = mojom::BuiltinModelSpecPtr spec =
mojom::BuiltinModelSpec::New(mojom::BuiltinModelId::TEST_MODEL); mojom::BuiltinModelSpec::New(mojom::BuiltinModelId::TEST_MODEL);
...@@ -68,6 +69,7 @@ TEST_F(ServiceConnectionTest, LoadBuiltinModel) { ...@@ -68,6 +69,7 @@ TEST_F(ServiceConnectionTest, LoadBuiltinModel) {
// Tests that LoadFlatBufferModel runs OK (no crash) in a basic Mojo // Tests that LoadFlatBufferModel runs OK (no crash) in a basic Mojo
// environment. // environment.
TEST_F(ServiceConnectionTest, LoadFlatBufferModel) { TEST_F(ServiceConnectionTest, LoadFlatBufferModel) {
ServiceConnection::GetInstance()->Initialize();
mojo::Remote<mojom::Model> model; mojo::Remote<mojom::Model> model;
mojom::FlatBufferModelSpecPtr spec = mojom::FlatBufferModelSpec::New(); mojom::FlatBufferModelSpecPtr spec = mojom::FlatBufferModelSpec::New();
ServiceConnection::GetInstance()->LoadFlatBufferModel( ServiceConnection::GetInstance()->LoadFlatBufferModel(
...@@ -78,6 +80,7 @@ TEST_F(ServiceConnectionTest, LoadFlatBufferModel) { ...@@ -78,6 +80,7 @@ TEST_F(ServiceConnectionTest, LoadFlatBufferModel) {
// Tests that LoadTextClassifier runs OK (no crash) in a basic Mojo // Tests that LoadTextClassifier runs OK (no crash) in a basic Mojo
// environment. // environment.
TEST_F(ServiceConnectionTest, LoadTextClassifier) { TEST_F(ServiceConnectionTest, LoadTextClassifier) {
ServiceConnection::GetInstance()->Initialize();
mojo::Remote<mojom::TextClassifier> text_classifier; mojo::Remote<mojom::TextClassifier> text_classifier;
ServiceConnection::GetInstance()->LoadTextClassifier( ServiceConnection::GetInstance()->LoadTextClassifier(
text_classifier.BindNewPipeAndPassReceiver(), text_classifier.BindNewPipeAndPassReceiver(),
...@@ -87,6 +90,7 @@ TEST_F(ServiceConnectionTest, LoadTextClassifier) { ...@@ -87,6 +90,7 @@ TEST_F(ServiceConnectionTest, LoadTextClassifier) {
// Tests that LoadHandwritingModelWithSpec runs OK (no crash) in a basic Mojo // Tests that LoadHandwritingModelWithSpec runs OK (no crash) in a basic Mojo
// environment. // environment.
TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) { TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) {
ServiceConnection::GetInstance()->Initialize();
mojo::Remote<mojom::HandwritingRecognizer> handwriting_recognizer; mojo::Remote<mojom::HandwritingRecognizer> handwriting_recognizer;
ServiceConnection::GetInstance()->LoadHandwritingModelWithSpec( ServiceConnection::GetInstance()->LoadHandwritingModelWithSpec(
mojom::HandwritingRecognizerSpec::New("en"), mojom::HandwritingRecognizerSpec::New("en"),
...@@ -96,6 +100,7 @@ TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) { ...@@ -96,6 +100,7 @@ TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) {
// Tests that LoadGrammarChecker runs OK (no crash) in a basic Mojo environment. // Tests that LoadGrammarChecker runs OK (no crash) in a basic Mojo environment.
TEST_F(ServiceConnectionTest, LoadGrammarModel) { TEST_F(ServiceConnectionTest, LoadGrammarModel) {
ServiceConnection::GetInstance()->Initialize();
mojo::Remote<mojom::GrammarChecker> grammar_checker; mojo::Remote<mojom::GrammarChecker> grammar_checker;
ServiceConnection::GetInstance()->LoadGrammarChecker( ServiceConnection::GetInstance()->LoadGrammarChecker(
grammar_checker.BindNewPipeAndPassReceiver(), grammar_checker.BindNewPipeAndPassReceiver(),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment