Commit 17ad51e3 authored by Jing Wang's avatar Jing Wang Committed by Commit Bot

Uprev mojo API of grammar check from chromeos

(1) The mojom interface is generated with the script
"chromeos/services/machine_learning/public/mojom/roll_mojoms.sh"
from https://https://chromium-review.googlesource.com/c/chromiumos/platform2/+/2497369

(2) An implementation and fake implementation are also added for
unit_tests.

Bug: 1132699
TEST: unit tests passed.
Change-Id: Ib80675eace10b0abf331518fe61b9d63506ade17
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2521973
Commit-Queue: Jing Wang <jiwan@chromium.org>
Reviewed-by: default avatarSam McNally <sammc@chromium.org>
Reviewed-by: default avatarHonglin Yu <honglinyu@chromium.org>
Cr-Commit-Position: refs/heads/master@{#825232}
parent 97692bbb
......@@ -84,6 +84,14 @@ void FakeServiceConnectionImpl::LoadHandwritingModelWithSpec(
base::Unretained(this), std::move(receiver), std::move(callback)));
}
void FakeServiceConnectionImpl::LoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback callback) {
ScheduleCall(base::BindOnce(
&FakeServiceConnectionImpl::HandleLoadGrammarChecker,
base::Unretained(this), std::move(receiver), std::move(callback)));
}
void FakeServiceConnectionImpl::Execute(
base::flat_map<std::string, mojom::TensorPtr> inputs,
const std::vector<std::string>& output_names,
......@@ -252,6 +260,11 @@ void FakeServiceConnectionImpl::SetOutputHandwritingRecognizerResult(
handwriting_result_ = result.Clone();
}
void FakeServiceConnectionImpl::SetOutputGrammarCheckerResult(
const mojom::GrammarCheckerResultPtr& result) {
grammar_checker_result_ = result.Clone();
}
void FakeServiceConnectionImpl::Annotate(
mojom::TextAnnotationRequestPtr request,
mojom::TextClassifier::AnnotateCallback callback) {
......@@ -284,6 +297,14 @@ void FakeServiceConnectionImpl::Recognize(
std::move(callback)));
}
void FakeServiceConnectionImpl::Check(
mojom::GrammarCheckerQueryPtr query,
mojom::GrammarChecker::CheckCallback callback) {
ScheduleCall(base::BindOnce(
&FakeServiceConnectionImpl::HandleGrammarCheckerQuery,
base::Unretained(this), std::move(query), std::move(callback)));
}
void FakeServiceConnectionImpl::HandleLoadHandwritingModel(
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
mojom::MachineLearningService::LoadHandwritingModelCallback callback) {
......@@ -309,5 +330,20 @@ void FakeServiceConnectionImpl::HandleRecognize(
std::move(callback).Run(handwriting_result_.Clone());
}
void FakeServiceConnectionImpl::HandleLoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback callback) {
if (load_model_result_ == mojom::LoadModelResult::OK)
grammar_checker_receivers_.Add(this, std::move(receiver));
std::move(callback).Run(load_model_result_);
}
void FakeServiceConnectionImpl::HandleGrammarCheckerQuery(
mojom::GrammarCheckerQueryPtr query,
mojom::GrammarChecker::CheckCallback callback) {
std::move(callback).Run(grammar_checker_result_.Clone());
}
} // namespace machine_learning
} // namespace chromeos
......@@ -11,6 +11,7 @@
#include "base/callback_forward.h"
#include "base/macros.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/grammar_checker.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
......@@ -35,6 +36,7 @@ class FakeServiceConnectionImpl : public ServiceConnection,
public mojom::Model,
public mojom::TextClassifier,
public mojom::HandwritingRecognizer,
public mojom::GrammarChecker,
public mojom::GraphExecutor {
public:
FakeServiceConnectionImpl();
......@@ -71,6 +73,11 @@ class FakeServiceConnectionImpl : public ServiceConnection,
mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
result_callback) override;
void LoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback callback)
override;
// mojom::Model:
void CreateGraphExecutor(
mojo::PendingReceiver<mojom::GraphExecutor> receiver,
......@@ -126,6 +133,11 @@ class FakeServiceConnectionImpl : public ServiceConnection,
// languages.
void SetOutputLanguages(const std::vector<mojom::TextLanguagePtr>& languages);
// Call SetOutputGrammarCheckerResult() before Check() to set the output of
// grammar checker.
void SetOutputGrammarCheckerResult(
const mojom::GrammarCheckerResultPtr& result);
// Call SetOutputHandwritingRecognizerResult() before Recognize() to set the
// output of handwriting.
void SetOutputHandwritingRecognizerResult(
......@@ -150,6 +162,10 @@ class FakeServiceConnectionImpl : public ServiceConnection,
mojom::HandwritingRecognitionQueryPtr query,
mojom::HandwritingRecognizer::RecognizeCallback callback) override;
// mojom::GrammarChecker:
void Check(mojom::GrammarCheckerQueryPtr query,
mojom::GrammarChecker::CheckCallback callback) override;
private:
void ScheduleCall(base::OnceClosure call);
void HandleLoadBuiltinModelCall(
......@@ -183,11 +199,17 @@ class FakeServiceConnectionImpl : public ServiceConnection,
void HandleRecognize(
mojom::HandwritingRecognitionQueryPtr query,
mojom::HandwritingRecognizer::RecognizeCallback callback);
void HandleLoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback callback);
void HandleGrammarCheckerQuery(mojom::GrammarCheckerQueryPtr query,
mojom::GrammarChecker::CheckCallback callback);
mojo::ReceiverSet<mojom::Model> model_receivers_;
mojo::ReceiverSet<mojom::GraphExecutor> graph_receivers_;
mojo::ReceiverSet<mojom::TextClassifier> text_classifier_receivers_;
mojo::ReceiverSet<mojom::HandwritingRecognizer> handwriting_receivers_;
mojo::ReceiverSet<mojom::GrammarChecker> grammar_checker_receivers_;
mojom::TensorPtr output_tensor_;
mojom::LoadHandwritingModelResult load_handwriting_model_result_;
mojom::LoadModelResult load_model_result_;
......@@ -198,6 +220,7 @@ class FakeServiceConnectionImpl : public ServiceConnection,
mojom::CodepointSpanPtr suggest_selection_result_;
std::vector<mojom::TextLanguagePtr> find_languages_result_;
mojom::HandwritingRecognizerResultPtr handwriting_result_;
mojom::GrammarCheckerResultPtr grammar_checker_result_;
bool async_mode_;
std::vector<base::OnceClosure> pending_calls_;
......
......@@ -56,6 +56,11 @@ class ServiceConnectionImpl : public ServiceConnection {
mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
result_callback) override;
void LoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback)
override;
private:
// Binds the top level interface |machine_learning_service_| to an
// implementation in the ML Service daemon, if it is not already bound. The
......@@ -128,6 +133,15 @@ void ServiceConnectionImpl::LoadHandwritingModelWithSpec(
std::move(spec), std::move(receiver), std::move(result_callback));
}
void ServiceConnectionImpl::LoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
BindMachineLearningServiceIfNeeded();
machine_learning_service_->LoadGrammarChecker(std::move(receiver),
std::move(result_callback));
}
void ServiceConnectionImpl::BindMachineLearningServiceIfNeeded() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (machine_learning_service_) {
......
......@@ -86,6 +86,14 @@ class ServiceConnection {
mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
result_callback) = 0;
// Instruct ML daemon to load the Grammar model, binding a GrammarChecker
// implementation to |receiver|. Bootstraps the initial Mojo connection to the
// daemon if necessary.
virtual void LoadGrammarChecker(
mojo::PendingReceiver<mojom::GrammarChecker> receiver,
mojom::MachineLearningService::LoadGrammarCheckerCallback
result_callback) = 0;
protected:
ServiceConnection() = default;
virtual ~ServiceConnection() {}
......
......@@ -93,6 +93,14 @@ TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) {
base::BindOnce([](mojom::LoadModelResult result) {}));
}
// Tests that LoadGrammarChecker runs OK (no crash) in a basic Mojo environment.
TEST_F(ServiceConnectionTest, LoadGrammarModel) {
mojo::Remote<mojom::GrammarChecker> grammar_checker;
ServiceConnection::GetInstance()->LoadGrammarChecker(
grammar_checker.BindNewPipeAndPassReceiver(),
base::BindOnce([](mojom::LoadModelResult result) {}));
}
// Tests the fake ML service for builtin model.
TEST_F(ServiceConnectionTest, FakeServiceConnectionForBuiltinModel) {
mojo::Remote<mojom::Model> model;
......@@ -462,6 +470,52 @@ TEST_F(ServiceConnectionTest, FakeHandWritingRecognizerWithSpec) {
ASSERT_TRUE(infer_callback_done);
}
TEST_F(ServiceConnectionTest, FakeGrammarChecker) {
mojo::Remote<mojom::GrammarChecker> checker;
bool callback_done = false;
FakeServiceConnectionImpl fake_service_connection;
ServiceConnection::UseFakeServiceConnectionForTesting(
&fake_service_connection);
ServiceConnection::GetInstance()->LoadGrammarChecker(
checker.BindNewPipeAndPassReceiver(),
base::BindOnce(
[](bool* callback_done, mojom::LoadModelResult result) {
EXPECT_EQ(result, mojom::LoadModelResult::OK);
*callback_done = true;
},
&callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(callback_done);
ASSERT_TRUE(checker.is_bound());
// Construct fake output
mojom::GrammarCheckerResultPtr result = mojom::GrammarCheckerResult::New();
result->status = mojom::GrammarCheckerResult::Status::OK;
mojom::GrammarCheckerCandidatePtr candidate =
mojom::GrammarCheckerCandidate::New();
candidate->text = "cat";
candidate->score = 0.5f;
result->candidates.emplace_back(std::move(candidate));
fake_service_connection.SetOutputGrammarCheckerResult(result);
auto query = mojom::GrammarCheckerQuery::New();
bool infer_callback_done = false;
checker->Check(
std::move(query),
base::BindOnce(
[](bool* infer_callback_done, mojom::GrammarCheckerResultPtr result) {
*infer_callback_done = true;
// Check if the annotation is correct.
ASSERT_EQ(result->status, mojom::GrammarCheckerResult::Status::OK);
EXPECT_EQ(result->candidates.at(0)->text, "cat");
EXPECT_EQ(result->candidates.at(0)->score, 0.5f);
},
&infer_callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(infer_callback_done);
}
} // namespace
} // namespace machine_learning
} // namespace chromeos
......@@ -6,6 +6,7 @@ import("//mojo/public/tools/bindings/mojom.gni")
mojom("mojom") {
sources = [
"grammar_checker.mojom",
"graph_executor.mojom",
"handwriting_recognizer.mojom",
"handwriting_recognizer_requestor.mojom",
......
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Datatypes and interfaces of grammar checker API.
// NOTE: This mojom exists in two places and must be kept in sync:
// Chromium: //chromeos/services/machine_learning/public/mojom/
// Chrome OS: src/platform2/ml/mojom/
// Note: Other repos downstream of Chromium might also use this mojom.
// Example: A backwards-compatible mojom change (and corresponding
// implementation change) can be made in Chrome OS first, then replicated to the
// clients (Chromium, other downstream repos) later.
// Use //chromeos/services/machine_learning/public/mojom/roll_mojom.sh to help
// replicate Chrome OS-side changes over to Chromium.
module chromeos.machine_learning.mojom;
// Defines a grammar check query.
struct GrammarCheckerQuery {
// Required: Text to be checked. This is expected to be a full sentence.
string text;
// Required: Language of the text to be checked, in BCP-47 format.
string language;
};
// One possible candidate returned from the grammar checker model.
struct GrammarCheckerCandidate {
// Corrected text.
string text;
// Score of the text. Log of conditional probability.
float score;
};
// The grammar check response.
struct GrammarCheckerResult {
// Status of the response.
enum Status {
// Grammar check succeeded.
OK = 0,
// Grammar check failed. In this case, candidates will be empty.
ERROR = 1,
};
Status status;
// Candidates of corrected text and their scores, sorted by higher score
// first.
array<GrammarCheckerCandidate> candidates;
};
// The mojom interface for performing the grammar check.
interface GrammarChecker {
// Performs grammar check on a piece of text, and returns a set of
// candidates of corrected text and their scores.
Check(GrammarCheckerQuery query) => (GrammarCheckerResult result);
};
......@@ -18,6 +18,7 @@ module chromeos.machine_learning.mojom;
// NOTE: The base directory for 'import' statements is expected to differ
// between Chromium and Chrome OS versions of this file.
import "chromeos/services/machine_learning/public/mojom/grammar_checker.mojom";
import "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom";
import "chromeos/services/machine_learning/public/mojom/model.mojom";
import "chromeos/services/machine_learning/public/mojom/text_classifier.mojom";
......@@ -33,7 +34,7 @@ enum LoadModelResult {
};
// Top-level interface between Chromium and the ML Service daemon.
// Next ordinal: 6
// Next ordinal: 8
interface MachineLearningService {
// Binds another pipe to this instance.
Clone@5(pending_receiver<MachineLearningService> receiver);
......@@ -60,4 +61,7 @@ interface MachineLearningService {
HandwritingRecognizerSpec spec,
pending_receiver<HandwritingRecognizer> receiver)
=> (LoadModelResult result);
// Create and initialize a grammar checker.
LoadGrammarChecker@7(pending_receiver<GrammarChecker> receiver)
=> (LoadModelResult result);
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment