Commit 28763b7d authored by Honglin Yu's avatar Honglin Yu Committed by Commit Bot

ml: add mojo API for language identification

Also roll the edits on mojom in CL:
https://chromium-review.googlesource.com/c/chromiumos/platform2/+/2291868

BUG=chromium:1086044
TEST=on device (eve), can call langid function
TEST=in the chrome://machine-learning-internal webpage.

Change-Id: I8abdeac7e318397069bf7961ad4ae43667df15eb
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2289573Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Reviewed-by: default avatarSam McNally <sammc@chromium.org>
Commit-Queue: Honglin Yu <honglinyu@chromium.org>
Cr-Commit-Position: refs/heads/master@{#791737}
parent 40c5adcc
......@@ -211,6 +211,16 @@ void FakeServiceConnectionImpl::HandleSuggestSelectionCall(
std::move(callback).Run(std::move(selection));
}
void FakeServiceConnectionImpl::HandleFindLanguagesCall(
std::string request,
mojom::TextClassifier::FindLanguagesCallback callback) {
std::vector<mojom::TextLanguagePtr> languages;
for (auto const& language : find_languages_result_) {
languages.emplace_back(language.Clone());
}
std::move(callback).Run(std::move(languages));
}
void FakeServiceConnectionImpl::SetOutputAnnotation(
const std::vector<mojom::TextAnnotationPtr>& annotations) {
annotate_result_.clear();
......@@ -224,6 +234,14 @@ void FakeServiceConnectionImpl::SetOutputSelection(
suggest_selection_result_ = selection.Clone();
}
void FakeServiceConnectionImpl::SetOutputLanguages(
const std::vector<mojom::TextLanguagePtr>& languages) {
find_languages_result_.clear();
for (auto const& language : languages) {
find_languages_result_.emplace_back(language.Clone());
}
}
void FakeServiceConnectionImpl::SetOutputHandwritingRecognizerResult(
const mojom::HandwritingRecognizerResultPtr& result) {
handwriting_result_ = result.Clone();
......@@ -245,6 +263,14 @@ void FakeServiceConnectionImpl::SuggestSelection(
base::Unretained(this), std::move(request), std::move(callback)));
}
void FakeServiceConnectionImpl::FindLanguages(
const std::string& text,
mojom::TextClassifier::FindLanguagesCallback callback) {
ScheduleCall(base::BindOnce(
&FakeServiceConnectionImpl::HandleFindLanguagesCall,
base::Unretained(this), text, std::move(callback)));
}
void FakeServiceConnectionImpl::Recognize(
mojom::HandwritingRecognitionQueryPtr query,
mojom::HandwritingRecognizer::RecognizeCallback callback) {
......
......@@ -120,6 +120,10 @@ class FakeServiceConnectionImpl : public ServiceConnection,
// selection.
void SetOutputSelection(const mojom::CodepointSpanPtr& selection);
// Call SetOutputLanguages() before FindLanguages() to set the output
// languages.
void SetOutputLanguages(const std::vector<mojom::TextLanguagePtr>& languages);
// Call SetOutputHandwritingRecognizerResult() before Recognize() to set the
// output of handwriting.
void SetOutputHandwritingRecognizerResult(
......@@ -134,6 +138,11 @@ class FakeServiceConnectionImpl : public ServiceConnection,
mojom::TextSuggestSelectionRequestPtr request,
mojom::TextClassifier::SuggestSelectionCallback callback) override;
// mojom::TextClassifier:
void FindLanguages(
const std::string& text,
mojom::TextClassifier::FindLanguagesCallback callback) override;
// mojom::HandwritingRecognizer:
void Recognize(
mojom::HandwritingRecognitionQueryPtr query,
......@@ -159,6 +168,9 @@ class FakeServiceConnectionImpl : public ServiceConnection,
void HandleSuggestSelectionCall(
mojom::TextSuggestSelectionRequestPtr request,
mojom::TextClassifier::SuggestSelectionCallback callback);
void HandleFindLanguagesCall(
std::string text,
mojom::TextClassifier::FindLanguagesCallback callback);
void HandleLoadHandwritingModel(
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
mojom::MachineLearningService::LoadHandwritingModelCallback callback);
......@@ -177,6 +189,7 @@ class FakeServiceConnectionImpl : public ServiceConnection,
mojom::ExecuteResult execute_result_;
std::vector<mojom::TextAnnotationPtr> annotate_result_;
mojom::CodepointSpanPtr suggest_selection_result_;
std::vector<mojom::TextLanguagePtr> find_languages_result_;
mojom::HandwritingRecognizerResultPtr handwriting_result_;
bool async_mode_;
......
......@@ -315,6 +315,51 @@ TEST_F(ServiceConnectionTest,
ASSERT_TRUE(infer_callback_done);
}
// Tests the fake ML service for text classifier language identification.
TEST_F(ServiceConnectionTest,
FakeServiceConnectionForTextClassifierFindLanguages) {
mojo::Remote<mojom::TextClassifier> text_classifier;
bool callback_done = false;
FakeServiceConnectionImpl fake_service_connection;
ServiceConnection::UseFakeServiceConnectionForTesting(
&fake_service_connection);
std::vector<mojom::TextLanguagePtr> languages;
languages.emplace_back(mojom::TextLanguage::New("en", 0.9));
languages.emplace_back(mojom::TextLanguage::New("fr", 0.1));
fake_service_connection.SetOutputLanguages(languages);
ServiceConnection::GetInstance()->LoadTextClassifier(
text_classifier.BindNewPipeAndPassReceiver(),
base::BindOnce(
[](bool* callback_done, mojom::LoadModelResult result) {
EXPECT_EQ(result, mojom::LoadModelResult::OK);
*callback_done = true;
},
&callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(callback_done);
ASSERT_TRUE(text_classifier.is_bound());
std::string input_text = "dummy input text";
bool infer_callback_done = false;
text_classifier->FindLanguages(
input_text, base::Bind(
[](bool* infer_callback_done,
std::vector<mojom::TextLanguagePtr> languages) {
*infer_callback_done = true;
// Check if the suggestion is correct.
ASSERT_EQ(languages.size(), 2ul);
EXPECT_EQ(languages[0]->locale, "en");
EXPECT_EQ(languages[0]->confidence, 0.9f);
EXPECT_EQ(languages[1]->locale, "fr");
EXPECT_EQ(languages[1]->confidence, 0.1f);
},
&infer_callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(infer_callback_done);
}
// Tests the fake ML service for handwriting.
TEST_F(ServiceConnectionTest, FakeHandWritingRecognizerWithSpec) {
mojo::Remote<mojom::HandwritingRecognizer> recognizer;
......
......@@ -33,24 +33,29 @@ enum LoadModelResult {
};
// Top-level interface between Chromium and the ML Service daemon.
// Next ordinal: 6
interface MachineLearningService {
// Binds another pipe to this instance.
Clone@5(pending_receiver<MachineLearningService> receiver);
// The BuiltinModelId inside BuiltinModelSpec is used to specify the model to
// be loaded.
LoadBuiltinModel(BuiltinModelSpec spec, pending_receiver<Model> receiver)
LoadBuiltinModel@0(BuiltinModelSpec spec, pending_receiver<Model> receiver)
=> (LoadModelResult result);
// The FlatbufferModelSpec contains both of the flatbuffer content and the
// metadata.
LoadFlatBufferModel(FlatBufferModelSpec spec,
LoadFlatBufferModel@1(FlatBufferModelSpec spec,
pending_receiver<Model> receiver)
=> (LoadModelResult result);
// Create a new TextClassifier.
LoadTextClassifier(pending_receiver<TextClassifier> receiver)
LoadTextClassifier@2(pending_receiver<TextClassifier> receiver)
=> (LoadModelResult result);
// Create and initialize a handwriting recognizer.
LoadHandwritingModel(pending_receiver<HandwritingRecognizer> receiver)
LoadHandwritingModel@3(pending_receiver<HandwritingRecognizer> receiver)
=> (LoadModelResult result);
// Create and initialize a handwriting recognizer with given |spec|.
LoadHandwritingModelWithSpec(HandwritingRecognizerSpec spec,
// Create and initialize a handwriting recognizer with given `spec`.
LoadHandwritingModelWithSpec@4(
HandwritingRecognizerSpec spec,
pending_receiver<HandwritingRecognizer> receiver)
=> (LoadModelResult result);
};
......@@ -22,20 +22,6 @@ module chromeos.machine_learning.mojom;
// under mojo folder, that is, "mojo/public/mojom/base/time.mojom".
import "mojo/public/mojom/base/time.mojom";
// These values are persisted to logs. Entries should not be renumbered and
// numeric values should never be reused.
enum TextAnnotationResult {
OK = 0,
ERROR = 1,
};
// These values are persisted to logs. Entries should not be renumbered and
// numeric values should never be reused.
enum SuggestSelectionResult {
OK = 0,
ERROR = 1,
};
// Enum for specifying the annotation usecase.
// Must be consistent with `AnnotationUsecase` in model.fb in libtextclassifier.
enum AnnotationUsecase {
......@@ -138,6 +124,14 @@ struct TextSuggestSelectionRequest {
AnnotationUsecase annotation_usecase@4 = ANNOTATION_USECASE_SMART;
};
// Represent a language detection result.
struct TextLanguage {
// The BCP-47 language code like "en", "fr", "zh" etc.
string locale;
// The confidence score of the language detected (range: 0~1).
float confidence;
};
// Used to annotate entities within text strings.
interface TextClassifier {
// Annotate a text string and returns the detected substrings and possible
......@@ -151,4 +145,10 @@ interface TextClassifier {
// UTF8 codepoints (not bytes).
SuggestSelection@1(TextSuggestSelectionRequest request) =>
(CodepointSpan outputs);
// Identify the languages the text is possibly written in.
// The returned results are sorted according to the confidence score, from the
// highest to the lowest.
// The maximum number of results returned is determined internally.
// Will return an empty array if the language can not be determined.
FindLanguages@2(string text) => (array<TextLanguage> outputs);
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment