Commit ce8485a2 authored by Yue Li's avatar Yue Li Committed by Commit Bot

Quick Answers: use FindLanguage API from text classifier

Replace the current language detect API with the FindLanguage API from
text classifier.

This is the first change of the Translation backend for the Quick
Answers.

Bug: b/150034512
Test: Manual Test
Change-Id: Ie0a53e473b8a378c2692952f350f37b3b2534894
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2377582
Commit-Queue: Yue Li <updowndota@chromium.org>
Reviewed-by: default avatarXiyuan Xia <xiyuan@chromium.org>
Cr-Commit-Position: refs/heads/master@{#806902}
parent 67301963
......@@ -10,7 +10,6 @@
#include "base/no_destructor.h"
#include "base/strings/utf_string_conversions.h"
#include "chromeos/components/quick_answers/quick_answers_model.h"
#include "chromeos/components/quick_answers/utils/language_detector.h"
#include "chromeos/constants/chromeos_features.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
......@@ -85,7 +84,6 @@ IntentType RewriteIntent(const std::string& selected_text,
IntentGenerator::IntentGenerator(IntentGeneratorCallback complete_callback)
: complete_callback_(std::move(complete_callback)) {
language_detector_ = std::make_unique<LanguageDetector>();
}
IntentGenerator::~IntentGenerator() {
......@@ -155,9 +153,16 @@ void IntentGenerator::AnnotationCallback(
MaybeGenerateTranslationIntent(request);
}
void IntentGenerator::SetLanguageDetectorForTesting(
std::unique_ptr<LanguageDetector> language_detector) {
language_detector_ = std::move(language_detector);
void IntentGenerator::FindLanguagesCallback(
const QuickAnswersRequest& request,
std::vector<machine_learning::mojom::TextLanguagePtr> languages) {
auto intent_type = IntentType::kUnknown;
// TODO(b/b/150034512): Take confidence level into consideration.
if (!languages.empty() &&
languages.front()->locale != request.context.device_properties.language) {
intent_type = IntentType::kTranslation;
}
std::move(complete_callback_).Run(request.selected_text, intent_type);
}
void IntentGenerator::MaybeGenerateTranslationIntent(
......@@ -178,16 +183,15 @@ void IntentGenerator::MaybeGenerateTranslationIntent(
.Run(request.selected_text, IntentType::kUnknown);
return;
}
auto detected_language = language_detector_->DetectLanguage(
!request.context.surrounding_text.empty()
? request.context.surrounding_text
: request.selected_text);
auto intent_type = IntentType::kUnknown;
if (!detected_language.empty() &&
detected_language != request.context.device_properties.language) {
intent_type = IntentType::kTranslation;
if (text_classifier_) {
text_classifier_->FindLanguages(
!request.context.surrounding_text.empty()
? request.context.surrounding_text
: request.selected_text,
base::BindOnce(&IntentGenerator::FindLanguagesCallback,
weak_factory_.GetWeakPtr(), request));
}
std::move(complete_callback_).Run(request.selected_text, intent_type);
}
} // namespace quick_answers
......
......@@ -9,7 +9,6 @@
#include <string>
#include "base/callback.h"
#include "chromeos/components/quick_answers/utils/language_detector.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/text_classifier.mojom.h"
#include "mojo/public/cpp/bindings/remote.h"
......@@ -37,9 +36,6 @@ class IntentGenerator {
// Generate intent from the |request|. Virtual for testing.
virtual void GenerateIntent(const QuickAnswersRequest& request);
void SetLanguageDetectorForTesting(
std::unique_ptr<LanguageDetector> language_detector);
private:
FRIEND_TEST_ALL_PREFIXES(IntentGeneratorTest,
TextAnnotationIntentNoAnnotation);
......@@ -53,11 +49,13 @@ class IntentGenerator {
void AnnotationCallback(
const QuickAnswersRequest& request,
std::vector<machine_learning::mojom::TextAnnotationPtr> annotations);
void FindLanguagesCallback(
const QuickAnswersRequest& request,
std::vector<machine_learning::mojom::TextLanguagePtr> languages);
void MaybeGenerateTranslationIntent(const QuickAnswersRequest& request);
IntentGeneratorCallback complete_callback_;
std::unique_ptr<LanguageDetector> language_detector_;
mojo::Remote<::chromeos::machine_learning::mojom::TextClassifier>
text_classifier_;
......
......@@ -11,7 +11,6 @@
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "chromeos/components/quick_answers/quick_answers_model.h"
#include "chromeos/components/quick_answers/utils/language_detector.h"
#include "chromeos/constants/chromeos_features.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
......@@ -29,19 +28,12 @@ using machine_learning::mojom::TextAnnotationPtr;
using machine_learning::mojom::TextEntity;
using machine_learning::mojom::TextEntityData;
using machine_learning::mojom::TextEntityPtr;
using machine_learning::mojom::TextLanguage;
using machine_learning::mojom::TextLanguagePtr;
class MockLanguageDetector : public LanguageDetector {
public:
MockLanguageDetector() = default;
MockLanguageDetector(const MockLanguageDetector&) = delete;
MockLanguageDetector& operator=(const MockLanguageDetector&) = delete;
~MockLanguageDetector() override = default;
// TestResultLoader:
MOCK_METHOD1(DetectLanguage, std::string(const std::string&));
};
TextLanguagePtr DefaultLanguage() {
return TextLanguage::New("en", /* confidence */ 1);
}
} // namespace
......@@ -57,13 +49,6 @@ class IntentGeneratorTest : public testing::Test {
base::BindOnce(&IntentGeneratorTest::IntentGeneratorTestCallback,
base::Unretained(this)));
// Mock language detector.
mock_language_detector_ = std::make_unique<MockLanguageDetector>();
EXPECT_CALL(*mock_language_detector_, DetectLanguage(::testing::_))
.WillRepeatedly(::testing::Return("en"));
intent_generator_->SetLanguageDetectorForTesting(
std::move(mock_language_detector_));
scoped_feature_list_.InitWithFeatures(
{chromeos::features::kQuickAnswersTextAnnotator,
chromeos::features::kQuickAnswersTranslation},
......@@ -80,15 +65,17 @@ class IntentGeneratorTest : public testing::Test {
protected:
void UseFakeServiceConnection(
const std::vector<TextAnnotationPtr>& annotations =
std::vector<TextAnnotationPtr>()) {
std::vector<TextAnnotationPtr>(),
const std::vector<TextLanguagePtr>& languages =
std::vector<TextLanguagePtr>()) {
chromeos::machine_learning::ServiceConnection::
UseFakeServiceConnectionForTesting(&fake_service_connection_);
fake_service_connection_.SetOutputAnnotation(annotations);
fake_service_connection_.SetOutputLanguages(languages);
}
base::test::TaskEnvironment task_environment_;
std::unique_ptr<IntentGenerator> intent_generator_;
std::unique_ptr<MockLanguageDetector> mock_language_detector_;
std::string intent_text_;
IntentType intent_type_ = IntentType::kUnknown;
base::test::ScopedFeatureList scoped_feature_list_;
......@@ -97,7 +84,9 @@ class IntentGeneratorTest : public testing::Test {
};
TEST_F(IntentGeneratorTest, TranslationIntent) {
UseFakeServiceConnection();
std::vector<TextLanguagePtr> languages;
languages.push_back(DefaultLanguage());
UseFakeServiceConnection({}, languages);
QuickAnswersRequest request;
request.selected_text = "quick answers";
......@@ -111,7 +100,9 @@ TEST_F(IntentGeneratorTest, TranslationIntent) {
}
TEST_F(IntentGeneratorTest, TranslationIntentSameLanguage) {
UseFakeServiceConnection();
std::vector<TextLanguagePtr> languages;
languages.push_back(DefaultLanguage());
UseFakeServiceConnection({}, languages);
QuickAnswersRequest request;
request.selected_text = "quick answers";
......@@ -125,7 +116,9 @@ TEST_F(IntentGeneratorTest, TranslationIntentSameLanguage) {
}
TEST_F(IntentGeneratorTest, TranslationIntentTextLengthAboveThreshold) {
UseFakeServiceConnection();
std::vector<TextLanguagePtr> languages;
languages.push_back(DefaultLanguage());
UseFakeServiceConnection({}, languages);
QuickAnswersRequest request;
request.selected_text =
......@@ -148,7 +141,9 @@ TEST_F(IntentGeneratorTest, TranslationIntentNotEnabled) {
scoped_feature_list.InitWithFeatures(
{chromeos::features::kQuickAnswersTextAnnotator},
{chromeos::features::kQuickAnswersTranslation});
UseFakeServiceConnection();
std::vector<TextLanguagePtr> languages;
languages.push_back(DefaultLanguage());
UseFakeServiceConnection({}, languages);
QuickAnswersRequest request;
request.selected_text = "quick answers";
......@@ -162,7 +157,9 @@ TEST_F(IntentGeneratorTest, TranslationIntentNotEnabled) {
}
TEST_F(IntentGeneratorTest, TranslationIntentDeviceLanguageNotSet) {
UseFakeServiceConnection();
std::vector<TextLanguagePtr> languages;
languages.push_back(DefaultLanguage());
UseFakeServiceConnection({}, languages);
QuickAnswersRequest request;
request.selected_text = "quick answers";
......
......@@ -15,7 +15,9 @@ class NNetLanguageIdentifier;
namespace chromeos {
namespace quick_answers {
// Utility class for langugage detection.
// Utility class for language detection.
// TODO(b/168541952): Cleanup this class after the new language detection API
// becomes stable.
class LanguageDetector {
public:
LanguageDetector();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment