Commit 3a724a44 authored by Michael Crouse's avatar Michael Crouse Committed by Chromium LUCI CQ

[LanguageDetection] Add model loading and return models to service.

This change loads the model from the file provided by the Opt Guide and
returns it via callback to any requestor of the model.

A future change will provide the model to the content translate driver
and the translate agent via mojo.

Bug: 1151406
Change-Id: If09871f3ad04a674d964a467586cfb0f24fd4708
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2593589Reviewed-by: default avatarTrevor  Perrier <perrier@chromium.org>
Reviewed-by: default avatarScott Little <sclittle@chromium.org>
Reviewed-by: default avatarSophie Chang <sophiechang@chromium.org>
Commit-Queue: Michael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#837751}
parent 51baf69c
...@@ -2,10 +2,19 @@ ...@@ -2,10 +2,19 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include "base/base_paths.h"
#include "base/bind.h"
#include "base/files/file.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/path_service.h"
#include "base/run_loop.h" #include "base/run_loop.h"
#include "base/task/thread_pool/thread_pool_instance.h" #include "base/task/thread_pool/thread_pool_instance.h"
#include "base/test/metrics/histogram_tester.h" #include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h" #include "base/test/scoped_feature_list.h"
#include "base/threading/thread_restrictions.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h" #include "chrome/browser/profiles/profile.h"
#include "chrome/browser/profiles/profile_key.h" #include "chrome/browser/profiles/profile_key.h"
#include "chrome/browser/translate/translate_model_service_factory.h" #include "chrome/browser/translate/translate_model_service_factory.h"
...@@ -14,6 +23,7 @@ ...@@ -14,6 +23,7 @@
#include "chrome/test/base/ui_test_utils.h" #include "chrome/test/base/ui_test_utils.h"
#include "components/metrics/content/subprocess_metrics_provider.h" #include "components/metrics/content/subprocess_metrics_provider.h"
#include "components/optimization_guide/optimization_guide_features.h" #include "components/optimization_guide/optimization_guide_features.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/translate/core/common/translate_util.h" #include "components/translate/core/common/translate_util.h"
#include "content/public/test/browser_test.h" #include "content/public/test/browser_test.h"
#include "content/public/test/browser_test_utils.h" #include "content/public/test/browser_test_utils.h"
...@@ -111,14 +121,25 @@ class TranslateModelServiceBrowserTest ...@@ -111,14 +121,25 @@ class TranslateModelServiceBrowserTest
~TranslateModelServiceBrowserTest() override = default; ~TranslateModelServiceBrowserTest() override = default;
translate::TranslateModelService* translate_model_service() {
return TranslateModelServiceFactory::GetOrBuildForKey(
browser()->profile()->GetProfileKey());
}
private: private:
base::test::ScopedFeatureList scoped_feature_list_; base::test::ScopedFeatureList scoped_feature_list_;
}; };
base::FilePath model_file_path() {
base::FilePath model_file_path;
EXPECT_TRUE(base::PathService::Get(base::DIR_SOURCE_ROOT, &model_file_path));
return model_file_path.AppendASCII(
"chrome/test/data/optimization_guide/unsignedmodel.crx3");
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest, IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
TranslateModelServiceEnabled) { TranslateModelServiceEnabled) {
EXPECT_TRUE(TranslateModelServiceFactory::GetOrBuildForKey( EXPECT_TRUE(translate_model_service());
browser()->profile()->GetProfileKey()));
} }
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest, IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
...@@ -127,6 +148,77 @@ IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest, ...@@ -127,6 +148,77 @@ IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
browser()->profile()->GetPrimaryOTRProfile()->GetProfileKey())); browser()->profile()->GetPrimaryOTRProfile()->GetProfileKey()));
} }
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelReadyOnRequest) {
base::ScopedAllowBlockingForTesting allow_io_for_test_setup;
base::HistogramTester histogram_tester;
ASSERT_TRUE(translate_model_service());
OptimizationGuideKeyedServiceFactory::GetForProfile(browser()->profile())
->OverrideTargetModelFileForTesting(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
model_file_path());
RetryForHistogramUntilCountReached(
&histogram_tester,
"TranslateModelService.LanguageDetectionModel.WasLoaded", 1);
histogram_tester.ExpectUniqueSample(
"TranslateModelService.LanguageDetectionModel.WasLoaded", true, 1);
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
translate_model_service()->GetLanguageDetectionModelFile(base::BindOnce(
[](base::RunLoop* run_loop, base::File model_file) {
EXPECT_TRUE(model_file.IsValid());
run_loop->Quit();
},
run_loop.get()));
run_loop->Run();
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelLoadedAfterRequest) {
base::ScopedAllowBlockingForTesting allow_io_for_test_setup;
base::HistogramTester histogram_tester;
ASSERT_TRUE(translate_model_service());
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
translate_model_service()->GetLanguageDetectionModelFile(base::BindOnce(
[](base::RunLoop* run_loop, base::File model_file) {
EXPECT_TRUE(model_file.IsValid());
run_loop->Quit();
},
run_loop.get()));
OptimizationGuideKeyedServiceFactory::GetForProfile(browser()->profile())
->OverrideTargetModelFileForTesting(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
model_file_path());
RetryForHistogramUntilCountReached(
&histogram_tester,
"TranslateModelService.LanguageDetectionModel.WasLoaded", 1);
histogram_tester.ExpectUniqueSample(
"TranslateModelService.LanguageDetectionModel.WasLoaded", true, 1);
run_loop->Run();
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
InvalidModelWhenLoading) {
base::ScopedAllowBlockingForTesting allow_io_for_test_setup;
base::HistogramTester histogram_tester;
ASSERT_TRUE(translate_model_service());
OptimizationGuideKeyedServiceFactory::GetForProfile(browser()->profile())
->OverrideTargetModelFileForTesting(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
base::FilePath());
RetryForHistogramUntilCountReached(
&histogram_tester,
"TranslateModelService.LanguageDetectionModel.WasValid", 1);
histogram_tester.ExpectUniqueSample(
"TranslateModelService.LanguageDetectionModel.WasValid", false, 1);
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest, IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelCreated) { LanguageDetectionModelCreated) {
base::HistogramTester histogram_tester; base::HistogramTester histogram_tester;
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
#include "chrome/browser/translate/translate_model_service_factory.h" #include "chrome/browser/translate/translate_model_service_factory.h"
#include "base/memory/scoped_refptr.h"
#include "base/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h" #include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h" #include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h" #include "chrome/browser/profiles/profile.h"
...@@ -43,8 +46,13 @@ TranslateModelServiceFactory::BuildServiceInstanceFor( ...@@ -43,8 +46,13 @@ TranslateModelServiceFactory::BuildServiceInstanceFor(
auto* opt_guide = OptimizationGuideKeyedServiceFactory::GetForProfile( auto* opt_guide = OptimizationGuideKeyedServiceFactory::GetForProfile(
ProfileManager::GetProfileFromProfileKey( ProfileManager::GetProfileFromProfileKey(
ProfileKey::FromSimpleFactoryKey(key))); ProfileKey::FromSimpleFactoryKey(key)));
if (opt_guide) if (opt_guide) {
return std::make_unique<translate::TranslateModelService>(opt_guide); scoped_refptr<base::SequencedTaskRunner> background_task_runner =
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT});
return std::make_unique<translate::TranslateModelService>(
opt_guide, background_task_runner);
}
return nullptr; return nullptr;
} }
......
...@@ -73,8 +73,6 @@ class OptimizationGuideDecider { ...@@ -73,8 +73,6 @@ class OptimizationGuideDecider {
// //
// It is assumed that any model retrieved this way will be passed to the // It is assumed that any model retrieved this way will be passed to the
// Machine Learning Service for inference. // Machine Learning Service for inference.
//
// Still being implemented - DO NOT USE YET.
virtual void AddObserverForOptimizationTargetModel( virtual void AddObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target, proto::OptimizationTarget optimization_target,
OptimizationTargetModelObserver* observer) = 0; OptimizationTargetModelObserver* observer) = 0;
...@@ -84,8 +82,6 @@ class OptimizationGuideDecider { ...@@ -84,8 +82,6 @@ class OptimizationGuideDecider {
// If |observer| is registered for multiple targets, |observer| must be // If |observer| is registered for multiple targets, |observer| must be
// removed for all targets that it is added for in order for it to be fully // removed for all targets that it is added for in order for it to be fully
// removed from receiving any calls. // removed from receiving any calls.
//
// Still being implemented - DO NOT USE YET.
virtual void RemoveObserverForOptimizationTargetModel( virtual void RemoveObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target, proto::OptimizationTarget optimization_target,
OptimizationTargetModelObserver* observer) = 0; OptimizationTargetModelObserver* observer) = 0;
......
...@@ -4,16 +4,36 @@ ...@@ -4,16 +4,36 @@
#include "components/translate/content/browser/translate_model_service.h" #include "components/translate/content/browser/translate_model_service.h"
#include "base/bind.h"
#include "base/files/file.h" #include "base/files/file.h"
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/metrics/histogram_macros_local.h"
#include "base/task/post_task.h"
#include "components/optimization_guide/optimization_guide_decider.h" #include "components/optimization_guide/optimization_guide_decider.h"
#include "components/optimization_guide/proto/models.pb.h" #include "components/optimization_guide/proto/models.pb.h"
#include "content/public/browser/browser_thread.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
namespace {
// Load the model file at the provided file path.
base::File LoadModelFile(const base::FilePath& model_file_path) {
if (!base::PathExists(model_file_path))
return base::File();
return base::File(model_file_path,
base::File::FLAG_OPEN | base::File::FLAG_READ);
}
} // namespace
namespace translate { namespace translate {
TranslateModelService::TranslateModelService( TranslateModelService::TranslateModelService(
optimization_guide::OptimizationGuideDecider* opt_guide) optimization_guide::OptimizationGuideDecider* opt_guide,
: opt_guide_(opt_guide) { const scoped_refptr<base::SequencedTaskRunner>& background_task_runner)
: opt_guide_(opt_guide), background_task_runner_(background_task_runner) {
opt_guide_->AddObserverForOptimizationTargetModel( opt_guide_->AddObserverForOptimizationTargetModel(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, this); optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, this);
} }
...@@ -29,19 +49,45 @@ void TranslateModelService::Shutdown() { ...@@ -29,19 +49,45 @@ void TranslateModelService::Shutdown() {
void TranslateModelService::OnModelFileUpdated( void TranslateModelService::OnModelFileUpdated(
optimization_guide::proto::OptimizationTarget optimization_target, optimization_guide::proto::OptimizationTarget optimization_target,
const base::FilePath& file_path) { const base::FilePath& file_path) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
if (optimization_target != if (optimization_target !=
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION) { optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION) {
return; return;
} }
// TODO(crbug.com/1151406): Implement loading the model on a background thread background_task_runner_->PostTaskAndReplyWithResult(
// and return it for use by translate. FROM_HERE, base::BindOnce(&LoadModelFile, file_path),
base::BindOnce(&TranslateModelService::OnModelFileLoaded,
base::Unretained(this)));
} }
base::Optional<base::File> void TranslateModelService::OnModelFileLoaded(base::File model_file) {
TranslateModelService::GetLanguageDetectionModelFile() { DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
// TODO(crbug.com/1151406): Implement loading the model on a background thread if (!model_file.IsValid()) {
// and return it for use by translate. // TODO(crbug.com/1157661): add histogram to log the model failed to load.
return base::nullopt; LOCAL_HISTOGRAM_BOOLEAN(
"TranslateModelService.LanguageDetectionModel.WasValid", false);
return;
}
language_detection_model_file_ = std::move(model_file);
LOCAL_HISTOGRAM_BOOLEAN(
"TranslateModelService.LanguageDetectionModel.WasLoaded", true);
for (auto& pending_request : pending_model_requests_) {
std::move(pending_request).Run(language_detection_model_file_->Duplicate());
}
}
void TranslateModelService::GetLanguageDetectionModelFile(
GetModelCallback callback) {
if (!language_detection_model_file_) {
// TODO(crbug.com/1157661): add histogram record the number of callbacks
// held.
pending_model_requests_.emplace_back(std::move(callback));
return;
}
// The model must be valid at this point.
DCHECK(language_detection_model_file_->IsValid());
std::move(callback).Run(language_detection_model_file_->Duplicate());
} }
} // namespace translate } // namespace translate
...@@ -5,14 +5,16 @@ ...@@ -5,14 +5,16 @@
#ifndef COMPONENTS_TRANSLATE_CONTENT_BROWSER_TRANSLATE_MODEL_SERVICE_H_ #ifndef COMPONENTS_TRANSLATE_CONTENT_BROWSER_TRANSLATE_MODEL_SERVICE_H_
#define COMPONENTS_TRANSLATE_CONTENT_BROWSER_TRANSLATE_MODEL_SERVICE_H_ #define COMPONENTS_TRANSLATE_CONTENT_BROWSER_TRANSLATE_MODEL_SERVICE_H_
#include <memory>
#include <vector>
#include "base/callback.h"
#include "base/files/file.h"
#include "base/files/file_path.h"
#include "base/optional.h" #include "base/optional.h"
#include "components/keyed_service/core/keyed_service.h" #include "components/keyed_service/core/keyed_service.h"
#include "components/optimization_guide/optimization_target_model_observer.h" #include "components/optimization_guide/optimization_target_model_observer.h"
namespace base {
class File;
class FilePath;
} // namespace base
namespace optimization_guide { namespace optimization_guide {
class OptimizationGuideDecider; class OptimizationGuideDecider;
...@@ -27,8 +29,11 @@ class TranslateModelService ...@@ -27,8 +29,11 @@ class TranslateModelService
: public KeyedService, : public KeyedService,
public optimization_guide::OptimizationTargetModelObserver { public optimization_guide::OptimizationTargetModelObserver {
public: public:
explicit TranslateModelService( using GetModelCallback = base::OnceCallback<void(base::File)>;
optimization_guide::OptimizationGuideDecider* opt_guide);
TranslateModelService(
optimization_guide::OptimizationGuideDecider* opt_guide,
const scoped_refptr<base::SequencedTaskRunner>& background_task_runner);
~TranslateModelService() override; ~TranslateModelService() override;
// KeyedService implementation: // KeyedService implementation:
...@@ -39,16 +44,28 @@ class TranslateModelService ...@@ -39,16 +44,28 @@ class TranslateModelService
optimization_guide::proto::OptimizationTarget optimization_target, optimization_guide::proto::OptimizationTarget optimization_target,
const base::FilePath& file_path) override; const base::FilePath& file_path) override;
// Returns a loaded file containing the TFLite model capable of detecting the // Invokes |callback| with a language detection model file when it is
// language of a web page's text. // available.
base::Optional<base::File> GetLanguageDetectionModelFile(); void GetLanguageDetectionModelFile(GetModelCallback callback);
private: private:
// Optimization Guide Service that provides model files for this void OnModelFileLoaded(base::File model_file);
// service. Optimization Guide Service is a
// BrowserContextKeyedServiceFactory and should not // Optimization Guide Service that provides model files for this service.
// be used after ShutDown. // Optimization Guide Service is a BrowserContextKeyedServiceFactory and
// should not be used after Shutdown.
optimization_guide::OptimizationGuideDecider* opt_guide_; optimization_guide::OptimizationGuideDecider* opt_guide_;
// The file that contains the language detection model. Available when the
// file path has been provided by the Optimization Guide and has been
// successfully loaded.
base::Optional<base::File> language_detection_model_file_;
// The set of callbacks associated with requests for the language detection
// model.
std::vector<GetModelCallback> pending_model_requests_;
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
}; };
} // namespace translate } // namespace translate
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment