Commit e9be6c5d authored by alanlxl's avatar alanlxl Committed by Chromium LUCI CQ

Make ServiceConnection thread_safe

1. Add a task_runner_ to ServiceConnectionImpl, it's initialized when
   ServiceConnection::GetInstance() is first called. All the calls to
   top-level ml service interfaces will run on this task_runner_. Now
   they can be called from any sequence.
2. Add BindMachineLearningServiceReceiver, customers can use it to bind
   their own remote and call ml service interfaces via it. Eventually
   clients should all use this method, rather than any of existing
   public methods, which will become deprecated.

Bug: chromium:916760
Test: pass the unit_test
Change-Id: I7e6e693289cd8d29bbb4bde8a949af7374855139
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2626950Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Reviewed-by: default avatarHonglin Yu <honglinyu@chromium.org>
Commit-Queue: Xinglong Luan <alanlxl@chromium.org>
Cr-Commit-Position: refs/heads/master@{#844616}
parent 2fb17d86
......@@ -24,6 +24,16 @@ FakeServiceConnectionImpl::FakeServiceConnectionImpl()
FakeServiceConnectionImpl::~FakeServiceConnectionImpl() {}
void FakeServiceConnectionImpl::BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) {
Clone(std::move(receiver));
}
void FakeServiceConnectionImpl::Clone(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) {
clone_ml_service_receivers_.Add(this, std::move(receiver));
}
void FakeServiceConnectionImpl::LoadBuiltinModel(
mojom::BuiltinModelSpecPtr spec,
mojo::PendingReceiver<mojom::Model> receiver,
......@@ -279,9 +289,9 @@ void FakeServiceConnectionImpl::SetOutputGrammarCheckerResult(
void FakeServiceConnectionImpl::Annotate(
mojom::TextAnnotationRequestPtr request,
mojom::TextClassifier::AnnotateCallback callback) {
ScheduleCall(base::BindOnce(
&FakeServiceConnectionImpl::HandleAnnotateCall,
base::Unretained(this), std::move(request), std::move(callback)));
ScheduleCall(base::BindOnce(&FakeServiceConnectionImpl::HandleAnnotateCall,
base::Unretained(this), std::move(request),
std::move(callback)));
}
void FakeServiceConnectionImpl::SuggestSelection(
......@@ -295,9 +305,9 @@ void FakeServiceConnectionImpl::SuggestSelection(
void FakeServiceConnectionImpl::FindLanguages(
const std::string& text,
mojom::TextClassifier::FindLanguagesCallback callback) {
ScheduleCall(base::BindOnce(
&FakeServiceConnectionImpl::HandleFindLanguagesCall,
base::Unretained(this), text, std::move(callback)));
ScheduleCall(
base::BindOnce(&FakeServiceConnectionImpl::HandleFindLanguagesCall,
base::Unretained(this), text, std::move(callback)));
}
void FakeServiceConnectionImpl::Recognize(
......
......@@ -14,6 +14,7 @@
#include "chromeos/services/machine_learning/public/mojom/grammar_checker.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/text_classifier.mojom.h"
......@@ -34,6 +35,7 @@ namespace machine_learning {
// specified by a previous call to SetOutputSelection.
// For use with ServiceConnection::UseFakeServiceConnectionForTesting().
class FakeServiceConnectionImpl : public ServiceConnection,
public mojom::MachineLearningService,
public mojom::Model,
public mojom::TextClassifier,
public mojom::HandwritingRecognizer,
......@@ -44,6 +46,15 @@ class FakeServiceConnectionImpl : public ServiceConnection,
FakeServiceConnectionImpl();
~FakeServiceConnectionImpl() override;
// ServiceConnection:
void BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) override;
// mojom::MachineLearningService:
void Clone(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) override;
// mojom::MachineLearningService and ServiceConnection:
// It's safe to execute LoadBuiltinModel, LoadFlatBufferModel and
// LoadTextClassifier for multi times, but all the receivers will be bound to
// the same instance.
......@@ -228,6 +239,9 @@ class FakeServiceConnectionImpl : public ServiceConnection,
void HandleStartCall();
void HandleMarkDoneCall();
// Additional receivers bound via `Clone`.
mojo::ReceiverSet<mojom::MachineLearningService> clone_ml_service_receivers_;
mojo::ReceiverSet<mojom::Model> model_receivers_;
mojo::ReceiverSet<mojom::GraphExecutor> graph_receivers_;
mojo::ReceiverSet<mojom::TextClassifier> text_classifier_receivers_;
......
......@@ -28,6 +28,9 @@ class ServiceConnectionImpl : public ServiceConnection {
ServiceConnectionImpl();
~ServiceConnectionImpl() override = default;
void BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) override;
void LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec,
mojo::PendingReceiver<mojom::Model> receiver,
mojom::MachineLearningService::LoadBuiltinModelCallback
......@@ -41,8 +44,8 @@ class ServiceConnectionImpl : public ServiceConnection {
void LoadTextClassifier(
mojo::PendingReceiver<mojom::TextClassifier> receiver,
mojom::MachineLearningService::LoadTextClassifierCallback
result_callback) override;
mojom::MachineLearningService::LoadTextClassifierCallback result_callback)
override;
void LoadHandwritingModel(
mojom::HandwritingRecognizerSpecPtr spec,
......@@ -82,17 +85,40 @@ class ServiceConnectionImpl : public ServiceConnection {
void OnBootstrapMojoConnectionResponse(bool success);
mojo::Remote<mojom::MachineLearningService> machine_learning_service_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
SEQUENCE_CHECKER(sequence_checker_);
DISALLOW_COPY_AND_ASSIGN(ServiceConnectionImpl);
};
void ServiceConnectionImpl::BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) {
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ServiceConnectionImpl::BindMachineLearningService,
base::Unretained(this), std::move(receiver)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->Clone(std::move(receiver));
}
void ServiceConnectionImpl::LoadBuiltinModel(
mojom::BuiltinModelSpecPtr spec,
mojo::PendingReceiver<mojom::Model> receiver,
mojom::MachineLearningService::LoadBuiltinModelCallback result_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ServiceConnectionImpl::LoadBuiltinModel,
base::Unretained(this), std::move(spec),
std::move(receiver), std::move(result_callback)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadBuiltinModel(
std::move(spec), std::move(receiver), std::move(result_callback));
......@@ -103,7 +129,15 @@ void ServiceConnectionImpl::LoadFlatBufferModel(
mojo::PendingReceiver<mojom::Model> receiver,
mojom::MachineLearningService::LoadFlatBufferModelCallback
result_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ServiceConnectionImpl::LoadFlatBufferModel,
base::Unretained(this), std::move(spec),
std::move(receiver), std::move(result_callback)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadFlatBufferModel(
std::move(spec), std::move(receiver), std::move(result_callback));
......@@ -112,7 +146,14 @@ void ServiceConnectionImpl::LoadFlatBufferModel(
void ServiceConnectionImpl::LoadTextClassifier(
mojo::PendingReceiver<mojom::TextClassifier> receiver,
mojom::MachineLearningService::LoadTextClassifierCallback result_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ServiceConnectionImpl::LoadTextClassifier,
base::Unretained(this), std::move(receiver),
std::move(result_callback)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadTextClassifier(std::move(receiver),
std::move(result_callback));
......@@ -123,7 +164,15 @@ void ServiceConnectionImpl::LoadHandwritingModel(
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
mojom::MachineLearningService::LoadHandwritingModelCallback
result_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ServiceConnectionImpl::LoadHandwritingModel,
base::Unretained(this), std::move(spec),
std::move(receiver), std::move(result_callback)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadHandwritingModel(
std::move(spec), std::move(receiver), std::move(result_callback));
......@@ -134,7 +183,15 @@ void ServiceConnectionImpl::LoadHandwritingModelWithSpec(
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
result_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ServiceConnectionImpl::LoadHandwritingModelWithSpec,
base::Unretained(this), std::move(spec),
std::move(receiver), std::move(result_callback)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadHandwritingModelWithSpec(
std::move(spec), std::move(receiver), std::move(result_callback));
......@@ -143,7 +200,14 @@ void ServiceConnectionImpl::LoadHandwritingModelWithSpec(
void ServiceConnectionImpl::LoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ServiceConnectionImpl::LoadGrammarChecker,
base::Unretained(this), std::move(receiver),
std::move(result_callback)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadGrammarChecker(std::move(receiver),
std::move(result_callback));
......@@ -154,7 +218,16 @@ void ServiceConnectionImpl::LoadSpeechRecognizer(
mojo::PendingRemote<mojom::SodaClient> soda_client,
mojo::PendingReceiver<mojom::SodaRecognizer> soda_recognizer,
mojom::MachineLearningService::LoadSpeechRecognizerCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!task_runner_->RunsTasksInCurrentSequence()) {
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ServiceConnectionImpl::LoadSpeechRecognizer,
base::Unretained(this), std::move(soda_config),
std::move(soda_client), std::move(soda_recognizer),
std::move(callback)));
return;
}
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadSpeechRecognizer(
std::move(soda_config), std::move(soda_client),
......@@ -194,7 +267,8 @@ void ServiceConnectionImpl::BindMachineLearningServiceIfNeeded() {
base::Unretained(this)));
}
ServiceConnectionImpl::ServiceConnectionImpl() {
ServiceConnectionImpl::ServiceConnectionImpl()
: task_runner_(base::SequencedTaskRunnerHandle::Get()) {
DETACH_FROM_SEQUENCE(sequence_checker_);
}
......
......@@ -13,7 +13,16 @@ namespace machine_learning {
// Encapsulates a connection to the Chrome OS ML Service daemon via its Mojo
// interface.
// Usage for Built-in models:
//
// Usage for BindMachineLearningService:
// mojo::Remote<mojom::MachineLearningService> ml_service;
// chromeos::machine_learning::ServiceConnection::GetInstance()
// ->BindMachineLearningService(
// ml_service.BindNewPipeAndPassReceiver());
// // Use ml_service to LoadBuiltinModel(), LoadFlatBufferModel() etc. e.g
// ml_service->LoadBuiltinModel(...);
//
// Usage for Built-in models (will be deprecated soon):
// mojo::Remote<chromeos::machine_learning::mojom::Model> model;
// chromeos::machine_learning::mojom::BuiltinModelSpecPtr spec =
// chromeos::machine_learning::mojom::BuiltinModelSpec::New();
......@@ -22,7 +31,7 @@ namespace machine_learning {
// ->LoadBuiltinModel(std::move(spec), model.BindNewPipeAndPassReceiver(),
// base::BindOnce(&MyCallBack));
// // Use |model| or wait for |MyCallBack|.
// Usage for Flatbuffer models:
// Usage for Flatbuffer models (will be deprecated soon):
// mojo::Remote<chromeos::machine_learning::mojom::Model> model;
// chromeos::machine_learning::mojom::FlatBufferModelSpecPtr spec =
// chromeos::machine_learning::mojom::FlatBufferModelSpec::New();
......@@ -35,7 +44,7 @@ namespace machine_learning {
// model.BindNewPipeAndPassReceiver(),
// base::BindOnce(&MyCallBack));
//
// Sequencing: Must be used on a single sequence (may be created on another).
// Sequencing: can be called from any sequence.
class ServiceConnection {
public:
static ServiceConnection* GetInstance();
......@@ -44,6 +53,10 @@ class ServiceConnection {
static void UseFakeServiceConnectionForTesting(
ServiceConnection* fake_service_connection);
// Binds the receiver to the implementation in the ml_service daemon.
virtual void BindMachineLearningService(
mojo::PendingReceiver<mojom::MachineLearningService> receiver) = 0;
// Instruct ML daemon to load the builtin model specified in |spec|, binding a
// Model implementation to |receiver|. Bootstraps the initial Mojo connection
// to the daemon if necessary.
......
......@@ -94,6 +94,14 @@ TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) {
base::BindOnce([](mojom::LoadModelResult result) {}));
}
// Tests that LoadGrammarChecker runs OK (no crash) in a basic Mojo environment.
TEST_F(ServiceConnectionTest, LoadGrammarModel) {
mojo::Remote<mojom::GrammarChecker> grammar_checker;
ServiceConnection::GetInstance()->LoadGrammarChecker(
grammar_checker.BindNewPipeAndPassReceiver(),
base::BindOnce([](mojom::LoadModelResult result) {}));
}
class TestSodaClient : public mojom::SodaClient {};
// Tests that LoadSpeechRecognizer runs OK without a crash in a basic Mojo
......@@ -120,14 +128,6 @@ TEST_F(ServiceConnectionTest, LoadSpeechRecognizerAndCallback) {
ASSERT_TRUE(callback_done);
}
// Tests that LoadGrammarChecker runs OK (no crash) in a basic Mojo environment.
TEST_F(ServiceConnectionTest, LoadGrammarModel) {
mojo::Remote<mojom::GrammarChecker> grammar_checker;
ServiceConnection::GetInstance()->LoadGrammarChecker(
grammar_checker.BindNewPipeAndPassReceiver(),
base::BindOnce([](mojom::LoadModelResult result) {}));
}
// Tests the fake ML service for builtin model.
TEST_F(ServiceConnectionTest, FakeServiceConnectionForBuiltinModel) {
mojo::Remote<mojom::Model> model;
......@@ -543,6 +543,37 @@ TEST_F(ServiceConnectionTest, FakeGrammarChecker) {
ASSERT_TRUE(infer_callback_done);
}
// Tests the fake ML service for binding ml_service receiver.
TEST_F(ServiceConnectionTest, BindMachineLearningService) {
FakeServiceConnectionImpl fake_service_connection;
ServiceConnection::UseFakeServiceConnectionForTesting(
&fake_service_connection);
mojo::Remote<mojom::MachineLearningService> ml_service;
ServiceConnection::GetInstance()->BindMachineLearningService(
ml_service.BindNewPipeAndPassReceiver());
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(ml_service.is_bound());
// Check the bound ml_service remote can be used to call
// MachineLearningService methods.
mojo::Remote<mojom::Model> model;
bool callback_done = false;
ml_service->LoadBuiltinModel(
mojom::BuiltinModelSpec::New(mojom::BuiltinModelId::TEST_MODEL),
model.BindNewPipeAndPassReceiver(),
base::BindOnce(
[](bool* callback_done, mojom::LoadModelResult result) {
EXPECT_EQ(result, mojom::LoadModelResult::OK);
*callback_done = true;
},
&callback_done));
base::RunLoop().RunUntilIdle();
EXPECT_TRUE(callback_done);
EXPECT_TRUE(model.is_bound());
}
} // namespace
} // namespace machine_learning
} // namespace chromeos
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment