Commit 5a849c90 authored by Michael Crouse's avatar Michael Crouse Committed by Chromium LUCI CQ

[LanguageDetection] Pass language model to TranslateAgent.

This change provides the file for the language detection model needed
in the TranslateAgent.

Bug: 1151422
Change-Id: I24870758adf8b00ee378ca7ccc55bd45edb164d1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2599580Reviewed-by: default avatarJohn Abd-El-Malek <jam@chromium.org>
Reviewed-by: default avatarScott Little <sclittle@chromium.org>
Reviewed-by: default avatarDaniel Cheng <dcheng@chromium.org>
Reviewed-by: default avatarTrevor  Perrier <perrier@chromium.org>
Reviewed-by: default avatarSophie Chang <sophiechang@chromium.org>
Commit-Queue: Michael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#841678}
parent 2bbfae08
...@@ -20,7 +20,9 @@ ...@@ -20,7 +20,9 @@
#include "chrome/browser/language/language_model_manager_factory.h" #include "chrome/browser/language/language_model_manager_factory.h"
#include "chrome/browser/language/url_language_histogram_factory.h" #include "chrome/browser/language/url_language_histogram_factory.h"
#include "chrome/browser/profiles/profile.h" #include "chrome/browser/profiles/profile.h"
#include "chrome/browser/profiles/profile_key.h"
#include "chrome/browser/translate/translate_accept_languages_factory.h" #include "chrome/browser/translate/translate_accept_languages_factory.h"
#include "chrome/browser/translate/translate_model_service_factory.h"
#include "chrome/browser/translate/translate_ranker_factory.h" #include "chrome/browser/translate/translate_ranker_factory.h"
#include "chrome/browser/translate/translate_service.h" #include "chrome/browser/translate/translate_service.h"
#include "chrome/browser/ui/translate/translate_bubble_factory.h" #include "chrome/browser/ui/translate/translate_bubble_factory.h"
...@@ -31,6 +33,7 @@ ...@@ -31,6 +33,7 @@
#include "components/language/core/browser/language_model_manager.h" #include "components/language/core/browser/language_model_manager.h"
#include "components/language/core/browser/pref_names.h" #include "components/language/core/browser/pref_names.h"
#include "components/prefs/pref_service.h" #include "components/prefs/pref_service.h"
#include "components/translate/content/browser/translate_model_service.h"
#include "components/translate/core/browser/language_state.h" #include "components/translate/core/browser/language_state.h"
#include "components/translate/core/browser/page_translated_details.h" #include "components/translate/core/browser/page_translated_details.h"
#include "components/translate/core/browser/translate_accept_languages.h" #include "components/translate/core/browser/translate_accept_languages.h"
...@@ -99,7 +102,10 @@ ChromeTranslateClient::ChromeTranslateClient(content::WebContents* web_contents) ...@@ -99,7 +102,10 @@ ChromeTranslateClient::ChromeTranslateClient(content::WebContents* web_contents)
translate_driver_ = std::make_unique<translate::ContentTranslateDriver>( translate_driver_ = std::make_unique<translate::ContentTranslateDriver>(
&web_contents->GetController(), &web_contents->GetController(),
UrlLanguageHistogramFactory::GetForBrowserContext( UrlLanguageHistogramFactory::GetForBrowserContext(
web_contents->GetBrowserContext())); web_contents->GetBrowserContext()),
TranslateModelServiceFactory::GetOrBuildForKey(
Profile::FromBrowserContext(web_contents->GetBrowserContext())
->GetProfileKey()));
} }
translate_manager_ = std::make_unique<translate::TranslateManager>( translate_manager_ = std::make_unique<translate::TranslateManager>(
this, this,
......
...@@ -25,8 +25,10 @@ ...@@ -25,8 +25,10 @@
#include "components/optimization_guide/core/optimization_guide_features.h" #include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/proto/models.pb.h" #include "components/optimization_guide/proto/models.pb.h"
#include "components/translate/core/common/translate_util.h" #include "components/translate/core/common/translate_util.h"
#include "components/translate/core/language_detection/language_detection_model.h"
#include "content/public/test/browser_test.h" #include "content/public/test/browser_test.h"
#include "content/public/test/browser_test_utils.h" #include "content/public/test/browser_test_utils.h"
#include "net/test/embedded_test_server/embedded_test_server.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace { namespace {
...@@ -119,6 +121,15 @@ class TranslateModelServiceBrowserTest ...@@ -119,6 +121,15 @@ class TranslateModelServiceBrowserTest
{}); {});
} }
void SetUp() override {
origin_server_ = std::make_unique<net::EmbeddedTestServer>(
net::EmbeddedTestServer::TYPE_HTTPS);
origin_server_->ServeFilesFromSourceDirectory("chrome/test/data/translate");
ASSERT_TRUE(origin_server_->Start());
english_url_ = origin_server_->GetURL("/english_page.html");
InProcessBrowserTest::SetUp();
}
~TranslateModelServiceBrowserTest() override = default; ~TranslateModelServiceBrowserTest() override = default;
translate::TranslateModelService* translate_model_service() { translate::TranslateModelService* translate_model_service() {
...@@ -126,8 +137,12 @@ class TranslateModelServiceBrowserTest ...@@ -126,8 +137,12 @@ class TranslateModelServiceBrowserTest
browser()->profile()->GetProfileKey()); browser()->profile()->GetProfileKey());
} }
const GURL& english_url() const { return english_url_; }
private: private:
base::test::ScopedFeatureList scoped_feature_list_; base::test::ScopedFeatureList scoped_feature_list_;
GURL english_url_;
std::unique_ptr<net::EmbeddedTestServer> origin_server_;
}; };
base::FilePath model_file_path() { base::FilePath model_file_path() {
...@@ -222,10 +237,33 @@ IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest, ...@@ -222,10 +237,33 @@ IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest, IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelCreated) { LanguageDetectionModelCreated) {
base::HistogramTester histogram_tester; base::HistogramTester histogram_tester;
ui_test_utils::NavigateToURL(browser(), GURL("https://test.com")); ui_test_utils::NavigateToURL(browser(), english_url());
RetryForHistogramUntilCountReached( RetryForHistogramUntilCountReached(
&histogram_tester, &histogram_tester,
"LanguageDetection.TFLiteModel.WasModelAvailableForDetection", 1); "LanguageDetection.TFLiteModel.WasModelAvailableForDetection", 1);
histogram_tester.ExpectUniqueSample( histogram_tester.ExpectUniqueSample(
"LanguageDetection.TFLiteModel.WasModelAvailableForDetection", false, 1); "LanguageDetection.TFLiteModel.WasModelAvailableForDetection", false, 1);
} }
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelAvailableForDetection) {
base::HistogramTester histogram_tester;
OptimizationGuideKeyedServiceFactory::GetForProfile(browser()->profile())
->OverrideTargetModelFileForTesting(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
model_file_path());
RetryForHistogramUntilCountReached(
&histogram_tester,
"LanguageDetection.TFLiteModel.LanguageDetectionModelState", 1);
histogram_tester.ExpectUniqueSample(
"LanguageDetection.TFLiteModel.LanguageDetectionModelState",
translate::LanguageDetectionModelState::kModelFileValidAndMemoryMapped,
1);
ui_test_utils::NavigateToURL(browser(), english_url());
RetryForHistogramUntilCountReached(
&histogram_tester,
"LanguageDetection.TFLiteModel.WasModelAvailableForDetection", 1);
histogram_tester.ExpectBucketCount(
"LanguageDetection.TFLiteModel.WasModelAvailableForDetection", true, 1);
}
...@@ -44,6 +44,8 @@ class FakeContentTranslateDriver ...@@ -44,6 +44,8 @@ class FakeContentTranslateDriver
called_new_page_ = true; called_new_page_ = true;
page_level_translation_critiera_met_ = page_level_translation_critiera_met; page_level_translation_critiera_met_ = page_level_translation_critiera_met;
} }
void GetLanguageDetectionModel(
GetLanguageDetectionModelCallback callback) override {}
bool called_new_page_ = false; bool called_new_page_ = false;
bool page_level_translation_critiera_met_ = false; bool page_level_translation_critiera_met_ = false;
......
...@@ -51,6 +51,8 @@ class FakeContentTranslateDriver ...@@ -51,6 +51,8 @@ class FakeContentTranslateDriver
details_ = details; details_ = details;
page_level_translation_critiera_met_ = page_level_translation_critiera_met; page_level_translation_critiera_met_ = page_level_translation_critiera_met;
} }
void GetLanguageDetectionModel(
GetLanguageDetectionModelCallback callback) override {}
void ResetNewPageValues() { void ResetNewPageValues() {
called_new_page_ = false; called_new_page_ = false;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "components/google/core/common/google_util.h" #include "components/google/core/common/google_util.h"
#include "components/language/core/browser/url_language_histogram.h" #include "components/language/core/browser/url_language_histogram.h"
#include "components/translate/content/browser/content_record_page_language.h" #include "components/translate/content/browser/content_record_page_language.h"
#include "components/translate/content/browser/translate_model_service.h"
#include "components/translate/core/browser/translate_download_manager.h" #include "components/translate/core/browser/translate_download_manager.h"
#include "components/translate/core/browser/translate_manager.h" #include "components/translate/core/browser/translate_manager.h"
#include "components/translate/core/browser/translate_metrics_logger.h" #include "components/translate/core/browser/translate_metrics_logger.h"
...@@ -57,13 +58,15 @@ const base::Feature kAutoHrefTranslateAllOrigins{ ...@@ -57,13 +58,15 @@ const base::Feature kAutoHrefTranslateAllOrigins{
ContentTranslateDriver::ContentTranslateDriver( ContentTranslateDriver::ContentTranslateDriver(
content::NavigationController* nav_controller, content::NavigationController* nav_controller,
language::UrlLanguageHistogram* url_language_histogram) language::UrlLanguageHistogram* url_language_histogram,
translate::TranslateModelService* translate_model_service)
: content::WebContentsObserver(nav_controller->GetWebContents()), : content::WebContentsObserver(nav_controller->GetWebContents()),
navigation_controller_(nav_controller), navigation_controller_(nav_controller),
translate_manager_(nullptr), translate_manager_(nullptr),
max_reload_check_attempts_(kMaxTranslateLoadCheckAttempts), max_reload_check_attempts_(kMaxTranslateLoadCheckAttempts),
next_page_seq_no_(0), next_page_seq_no_(0),
language_histogram_(url_language_histogram) { language_histogram_(url_language_histogram),
translate_model_service_(translate_model_service) {
DCHECK(navigation_controller_); DCHECK(navigation_controller_);
} }
...@@ -339,4 +342,22 @@ void ContentTranslateDriver::OnPageTranslated( ...@@ -339,4 +342,22 @@ void ContentTranslateDriver::OnPageTranslated(
observer.OnPageTranslated(original_lang, translated_lang, error_type); observer.OnPageTranslated(original_lang, translated_lang, error_type);
} }
void ContentTranslateDriver::GetLanguageDetectionModel(
GetLanguageDetectionModelCallback callback) {
if (!translate_model_service_) {
std::move(callback).Run(base::File());
return;
}
translate_model_service_->GetLanguageDetectionModelFile(
base::BindOnce(&ContentTranslateDriver::OnLanguageDetectionModelFile,
weak_pointer_factory_.GetWeakPtr(), std::move(callback)));
}
void ContentTranslateDriver::OnLanguageDetectionModelFile(
GetLanguageDetectionModelCallback callback,
base::File model_file) {
DCHECK(model_file.IsValid());
std::move(callback).Run(std::move(model_file));
}
} // namespace translate } // namespace translate
...@@ -34,6 +34,7 @@ namespace translate { ...@@ -34,6 +34,7 @@ namespace translate {
struct LanguageDetectionDetails; struct LanguageDetectionDetails;
class TranslateManager; class TranslateManager;
class TranslateModelService;
// Content implementation of TranslateDriver. // Content implementation of TranslateDriver.
class ContentTranslateDriver : public TranslateDriver, class ContentTranslateDriver : public TranslateDriver,
...@@ -55,9 +56,9 @@ class ContentTranslateDriver : public TranslateDriver, ...@@ -55,9 +56,9 @@ class ContentTranslateDriver : public TranslateDriver,
} }
}; };
ContentTranslateDriver( ContentTranslateDriver(content::NavigationController* nav_controller,
content::NavigationController* nav_controller, language::UrlLanguageHistogram* url_language_histogram,
language::UrlLanguageHistogram* url_language_histogram); TranslateModelService* translate_model_service);
~ContentTranslateDriver() override; ~ContentTranslateDriver() override;
// Adds or removes observers. // Adds or removes observers.
...@@ -108,12 +109,17 @@ class ContentTranslateDriver : public TranslateDriver, ...@@ -108,12 +109,17 @@ class ContentTranslateDriver : public TranslateDriver,
// Adds a receiver in |receivers_| for the passed |receiver|. // Adds a receiver in |receivers_| for the passed |receiver|.
void AddReceiver( void AddReceiver(
mojo::PendingReceiver<translate::mojom::ContentTranslateDriver> receiver); mojo::PendingReceiver<translate::mojom::ContentTranslateDriver> receiver);
// Called when a page has been loaded and can be potentially translated. // Called when a page has been loaded and can be potentially translated.
void RegisterPage( void RegisterPage(
mojo::PendingRemote<translate::mojom::TranslateAgent> translate_agent, mojo::PendingRemote<translate::mojom::TranslateAgent> translate_agent,
const translate::LanguageDetectionDetails& details, const translate::LanguageDetectionDetails& details,
bool page_level_translation_critiera_met) override; bool page_level_translation_critiera_met) override;
// translate::mojom::ContentTranslateDriver implementation:
void GetLanguageDetectionModel(
GetLanguageDetectionModelCallback callback) override;
protected: protected:
const base::ObserverList<TranslationObserver, true>& translation_observers() const base::ObserverList<TranslationObserver, true>& translation_observers()
const { const {
...@@ -131,6 +137,11 @@ class ContentTranslateDriver : public TranslateDriver, ...@@ -131,6 +137,11 @@ class ContentTranslateDriver : public TranslateDriver,
private: private:
void OnPageAway(int page_seq_no); void OnPageAway(int page_seq_no);
// Runs the provided callback with the loaded model file
// to pass it to the connected translate agent.
void OnLanguageDetectionModelFile(GetLanguageDetectionModelCallback callback,
base::File model_file);
// The navigation controller of the tab we are associated with. // The navigation controller of the tab we are associated with.
content::NavigationController* navigation_controller_; content::NavigationController* navigation_controller_;
...@@ -162,6 +173,10 @@ class ContentTranslateDriver : public TranslateDriver, ...@@ -162,6 +173,10 @@ class ContentTranslateDriver : public TranslateDriver,
// page language is determined. // page language is determined.
base::TimeTicks finish_navigation_time_; base::TimeTicks finish_navigation_time_;
// The service that provides the model files needed for translate. Not owned
// but guaranteed to outlive |this|.
TranslateModelService* const translate_model_service_;
base::WeakPtrFactory<ContentTranslateDriver> weak_pointer_factory_{this}; base::WeakPtrFactory<ContentTranslateDriver> weak_pointer_factory_{this};
DISALLOW_COPY_AND_ASSIGN(ContentTranslateDriver); DISALLOW_COPY_AND_ASSIGN(ContentTranslateDriver);
......
...@@ -130,7 +130,9 @@ void PerFrameContentTranslateDriver::PendingRequestStats::Report() { ...@@ -130,7 +130,9 @@ void PerFrameContentTranslateDriver::PendingRequestStats::Report() {
PerFrameContentTranslateDriver::PerFrameContentTranslateDriver( PerFrameContentTranslateDriver::PerFrameContentTranslateDriver(
content::NavigationController* nav_controller, content::NavigationController* nav_controller,
language::UrlLanguageHistogram* url_language_histogram) language::UrlLanguageHistogram* url_language_histogram)
: ContentTranslateDriver(nav_controller, url_language_histogram) {} : ContentTranslateDriver(nav_controller,
url_language_histogram,
/*translate_model_service=*/nullptr) {}
PerFrameContentTranslateDriver::~PerFrameContentTranslateDriver() = default; PerFrameContentTranslateDriver::~PerFrameContentTranslateDriver() = default;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
module translate.mojom; module translate.mojom;
import "mojo/public/mojom/base/file.mojom";
import "mojo/public/mojom/base/time.mojom"; import "mojo/public/mojom/base/time.mojom";
import "mojo/public/mojom/base/string16.mojom"; import "mojo/public/mojom/base/string16.mojom";
import "url/mojom/url.mojom"; import "url/mojom/url.mojom";
...@@ -78,4 +79,10 @@ interface ContentTranslateDriver { ...@@ -78,4 +79,10 @@ interface ContentTranslateDriver {
// and the language for it has been determined. // and the language for it has been determined.
RegisterPage(pending_remote<TranslateAgent> translate_agent, RegisterPage(pending_remote<TranslateAgent> translate_agent,
LanguageDetectionDetails details, bool translation_critiera_met); LanguageDetectionDetails details, bool translation_critiera_met);
// Request that the language detection model being loaded and returned
// for use by the TranslateAgent.
GetLanguageDetectionModel()
=> (mojo_base.mojom.File? model_file);
}; };
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "third_party/blink/public/web/web_local_frame.h" #include "third_party/blink/public/web/web_local_frame.h"
#include "third_party/blink/public/web/web_script_source.h" #include "third_party/blink/public/web/web_script_source.h"
#include "url/gurl.h" #include "url/gurl.h"
#include "url/url_constants.h"
#include "v8/include/v8.h" #include "v8/include/v8.h"
using blink::WebDocument; using blink::WebDocument;
...@@ -82,6 +83,19 @@ TranslateAgent::TranslateAgent(content::RenderFrame* render_frame, ...@@ -82,6 +83,19 @@ TranslateAgent::TranslateAgent(content::RenderFrame* render_frame,
extension_scheme_(extension_scheme) { extension_scheme_(extension_scheme) {
translate_task_runner_ = this->render_frame()->GetTaskRunner( translate_task_runner_ = this->render_frame()->GetTaskRunner(
blink::TaskType::kInternalTranslation); blink::TaskType::kInternalTranslation);
if (translate::IsTFLiteLanguageDetectionEnabled()) {
translate::LanguageDetectionModel& language_detection_model =
GetLanguageDetectionModel();
if (!language_detection_model.IsAvailable()) {
// TODO(crbug.com/1160948): Consider tracking if another agent associated
// with the same LanguageDetectionModel has already requested a model be
// provided by the translate host.
GetTranslateHandler()->GetLanguageDetectionModel(
base::BindOnce(&TranslateAgent::UpdateLanguageDetectionModel,
weak_pointer_factory_.GetWeakPtr()));
}
}
} }
TranslateAgent::~TranslateAgent() {} TranslateAgent::~TranslateAgent() {}
...@@ -99,7 +113,7 @@ void TranslateAgent::PageCaptured(const base::string16& contents) { ...@@ -99,7 +113,7 @@ void TranslateAgent::PageCaptured(const base::string16& contents) {
// original intent of http-equiv to be an equivalent) with the former // original intent of http-equiv to be an equivalent) with the former
// being the language of the document and the latter being the // being the language of the document and the latter being the
// language of the intended audience (a distinction really only // language of the intended audience (a distinction really only
// relevant for things like langauge textbooks). This distinction // relevant for things like language textbooks). This distinction
// shouldn't affect translation. // shouldn't affect translation.
WebLocalFrame* main_frame = render_frame()->GetWebFrame(); WebLocalFrame* main_frame = render_frame()->GetWebFrame();
if (!main_frame) if (!main_frame)
...@@ -115,6 +129,12 @@ void TranslateAgent::PageCaptured(const base::string16& contents) { ...@@ -115,6 +129,12 @@ void TranslateAgent::PageCaptured(const base::string16& contents) {
std::string language; std::string language;
if (translate::IsTFLiteLanguageDetectionEnabled()) { if (translate::IsTFLiteLanguageDetectionEnabled()) {
if (!document.Url().ProtocolIs(url::kHttpsScheme) &&
!document.Url().ProtocolIs(url::kHttpScheme)) {
// TFLite-based language detection only supports HTTP/HTTPS pages.
// Others should be ignored, for example the New Tab Page.
return;
}
translate::LanguageDetectionModel& language_detection_model = translate::LanguageDetectionModel& language_detection_model =
GetLanguageDetectionModel(); GetLanguageDetectionModel();
bool is_available = language_detection_model.IsAvailable(); bool is_available = language_detection_model.IsAvailable();
...@@ -502,4 +522,10 @@ std::string TranslateAgent::BuildTranslationScript( ...@@ -502,4 +522,10 @@ std::string TranslateAgent::BuildTranslationScript(
base::GetQuotedJSONString(target_lang) + ")"; base::GetQuotedJSONString(target_lang) + ")";
} }
void TranslateAgent::UpdateLanguageDetectionModel(base::File model_file) {
translate::LanguageDetectionModel& language_detection_model =
GetLanguageDetectionModel();
language_detection_model.UpdateWithFile(std::move(model_file));
}
} // namespace translate } // namespace translate
...@@ -154,6 +154,10 @@ class TranslateAgent : public content::RenderFrameObserver, ...@@ -154,6 +154,10 @@ class TranslateAgent : public content::RenderFrameObserver,
// if the page is being closed. // if the page is being closed.
blink::WebLocalFrame* GetMainFrame(); blink::WebLocalFrame* GetMainFrame();
// Called by the translate host when a new language detection model file
// has been loaded and is available.
void UpdateLanguageDetectionModel(base::File model_file);
// The states associated with the current translation. // The states associated with the current translation.
TranslateFrameCallback translate_callback_pending_; TranslateFrameCallback translate_callback_pending_;
std::string source_lang_; std::string source_lang_;
...@@ -184,6 +188,9 @@ class TranslateAgent : public content::RenderFrameObserver, ...@@ -184,6 +188,9 @@ class TranslateAgent : public content::RenderFrameObserver,
// Method factory used to make calls to TranslatePageImpl. // Method factory used to make calls to TranslatePageImpl.
base::WeakPtrFactory<TranslateAgent> weak_method_factory_{this}; base::WeakPtrFactory<TranslateAgent> weak_method_factory_{this};
// Weak pointer factory used to provide references to the translate host.
base::WeakPtrFactory<TranslateAgent> weak_pointer_factory_{this};
DISALLOW_COPY_AND_ASSIGN(TranslateAgent); DISALLOW_COPY_AND_ASSIGN(TranslateAgent);
}; };
......
...@@ -53,7 +53,8 @@ std::unique_ptr<translate::TranslatePrefs> CreateTranslatePrefs( ...@@ -53,7 +53,8 @@ std::unique_ptr<translate::TranslatePrefs> CreateTranslatePrefs(
TranslateClientImpl::TranslateClientImpl(content::WebContents* web_contents) TranslateClientImpl::TranslateClientImpl(content::WebContents* web_contents)
: content::WebContentsObserver(web_contents), : content::WebContentsObserver(web_contents),
translate_driver_(&web_contents->GetController(), translate_driver_(&web_contents->GetController(),
/*url_language_histogram=*/nullptr), /*url_language_histogram=*/nullptr,
/*translate_model_service=*/nullptr),
translate_manager_(new translate::TranslateManager( translate_manager_(new translate::TranslateManager(
this, this,
TranslateRankerFactory::GetForBrowserContext( TranslateRankerFactory::GetForBrowserContext(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment