Commit 3a724a44 authored by Michael Crouse's avatar Michael Crouse Committed by Chromium LUCI CQ

[LanguageDetection] Add model loading and return models to service.

This change loads the model from the file provided by the Opt Guide and
returns it via callback to any requestor of the model.

A future change will provide the model to the content translate driver
and the translate agent via mojo.

Bug: 1151406
Change-Id: If09871f3ad04a674d964a467586cfb0f24fd4708
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2593589Reviewed-by: default avatarTrevor  Perrier <perrier@chromium.org>
Reviewed-by: default avatarScott Little <sclittle@chromium.org>
Reviewed-by: default avatarSophie Chang <sophiechang@chromium.org>
Commit-Queue: Michael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#837751}
parent 51baf69c
......@@ -2,10 +2,19 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "base/base_paths.h"
#include "base/bind.h"
#include "base/files/file.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/path_service.h"
#include "base/run_loop.h"
#include "base/task/thread_pool/thread_pool_instance.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/threading/thread_restrictions.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/profiles/profile_key.h"
#include "chrome/browser/translate/translate_model_service_factory.h"
......@@ -14,6 +23,7 @@
#include "chrome/test/base/ui_test_utils.h"
#include "components/metrics/content/subprocess_metrics_provider.h"
#include "components/optimization_guide/optimization_guide_features.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/translate/core/common/translate_util.h"
#include "content/public/test/browser_test.h"
#include "content/public/test/browser_test_utils.h"
......@@ -111,14 +121,25 @@ class TranslateModelServiceBrowserTest
~TranslateModelServiceBrowserTest() override = default;
translate::TranslateModelService* translate_model_service() {
return TranslateModelServiceFactory::GetOrBuildForKey(
browser()->profile()->GetProfileKey());
}
private:
base::test::ScopedFeatureList scoped_feature_list_;
};
base::FilePath model_file_path() {
base::FilePath model_file_path;
EXPECT_TRUE(base::PathService::Get(base::DIR_SOURCE_ROOT, &model_file_path));
return model_file_path.AppendASCII(
"chrome/test/data/optimization_guide/unsignedmodel.crx3");
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
TranslateModelServiceEnabled) {
EXPECT_TRUE(TranslateModelServiceFactory::GetOrBuildForKey(
browser()->profile()->GetProfileKey()));
EXPECT_TRUE(translate_model_service());
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
......@@ -127,6 +148,77 @@ IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
browser()->profile()->GetPrimaryOTRProfile()->GetProfileKey()));
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelReadyOnRequest) {
base::ScopedAllowBlockingForTesting allow_io_for_test_setup;
base::HistogramTester histogram_tester;
ASSERT_TRUE(translate_model_service());
OptimizationGuideKeyedServiceFactory::GetForProfile(browser()->profile())
->OverrideTargetModelFileForTesting(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
model_file_path());
RetryForHistogramUntilCountReached(
&histogram_tester,
"TranslateModelService.LanguageDetectionModel.WasLoaded", 1);
histogram_tester.ExpectUniqueSample(
"TranslateModelService.LanguageDetectionModel.WasLoaded", true, 1);
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
translate_model_service()->GetLanguageDetectionModelFile(base::BindOnce(
[](base::RunLoop* run_loop, base::File model_file) {
EXPECT_TRUE(model_file.IsValid());
run_loop->Quit();
},
run_loop.get()));
run_loop->Run();
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelLoadedAfterRequest) {
base::ScopedAllowBlockingForTesting allow_io_for_test_setup;
base::HistogramTester histogram_tester;
ASSERT_TRUE(translate_model_service());
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
translate_model_service()->GetLanguageDetectionModelFile(base::BindOnce(
[](base::RunLoop* run_loop, base::File model_file) {
EXPECT_TRUE(model_file.IsValid());
run_loop->Quit();
},
run_loop.get()));
OptimizationGuideKeyedServiceFactory::GetForProfile(browser()->profile())
->OverrideTargetModelFileForTesting(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
model_file_path());
RetryForHistogramUntilCountReached(
&histogram_tester,
"TranslateModelService.LanguageDetectionModel.WasLoaded", 1);
histogram_tester.ExpectUniqueSample(
"TranslateModelService.LanguageDetectionModel.WasLoaded", true, 1);
run_loop->Run();
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
InvalidModelWhenLoading) {
base::ScopedAllowBlockingForTesting allow_io_for_test_setup;
base::HistogramTester histogram_tester;
ASSERT_TRUE(translate_model_service());
OptimizationGuideKeyedServiceFactory::GetForProfile(browser()->profile())
->OverrideTargetModelFileForTesting(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
base::FilePath());
RetryForHistogramUntilCountReached(
&histogram_tester,
"TranslateModelService.LanguageDetectionModel.WasValid", 1);
histogram_tester.ExpectUniqueSample(
"TranslateModelService.LanguageDetectionModel.WasValid", false, 1);
}
IN_PROC_BROWSER_TEST_F(TranslateModelServiceBrowserTest,
LanguageDetectionModelCreated) {
base::HistogramTester histogram_tester;
......
......@@ -4,6 +4,9 @@
#include "chrome/browser/translate/translate_model_service_factory.h"
#include "base/memory/scoped_refptr.h"
#include "base/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
......@@ -43,8 +46,13 @@ TranslateModelServiceFactory::BuildServiceInstanceFor(
auto* opt_guide = OptimizationGuideKeyedServiceFactory::GetForProfile(
ProfileManager::GetProfileFromProfileKey(
ProfileKey::FromSimpleFactoryKey(key)));
if (opt_guide)
return std::make_unique<translate::TranslateModelService>(opt_guide);
if (opt_guide) {
scoped_refptr<base::SequencedTaskRunner> background_task_runner =
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT});
return std::make_unique<translate::TranslateModelService>(
opt_guide, background_task_runner);
}
return nullptr;
}
......
......@@ -73,8 +73,6 @@ class OptimizationGuideDecider {
//
// It is assumed that any model retrieved this way will be passed to the
// Machine Learning Service for inference.
//
// Still being implemented - DO NOT USE YET.
virtual void AddObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target,
OptimizationTargetModelObserver* observer) = 0;
......@@ -84,8 +82,6 @@ class OptimizationGuideDecider {
// If |observer| is registered for multiple targets, |observer| must be
// removed for all targets that it is added for in order for it to be fully
// removed from receiving any calls.
//
// Still being implemented - DO NOT USE YET.
virtual void RemoveObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target,
OptimizationTargetModelObserver* observer) = 0;
......
......@@ -4,16 +4,36 @@
#include "components/translate/content/browser/translate_model_service.h"
#include "base/bind.h"
#include "base/files/file.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/metrics/histogram_macros_local.h"
#include "base/task/post_task.h"
#include "components/optimization_guide/optimization_guide_decider.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "content/public/browser/browser_thread.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
namespace {
// Load the model file at the provided file path.
base::File LoadModelFile(const base::FilePath& model_file_path) {
if (!base::PathExists(model_file_path))
return base::File();
return base::File(model_file_path,
base::File::FLAG_OPEN | base::File::FLAG_READ);
}
} // namespace
namespace translate {
TranslateModelService::TranslateModelService(
optimization_guide::OptimizationGuideDecider* opt_guide)
: opt_guide_(opt_guide) {
optimization_guide::OptimizationGuideDecider* opt_guide,
const scoped_refptr<base::SequencedTaskRunner>& background_task_runner)
: opt_guide_(opt_guide), background_task_runner_(background_task_runner) {
opt_guide_->AddObserverForOptimizationTargetModel(
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, this);
}
......@@ -29,19 +49,45 @@ void TranslateModelService::Shutdown() {
void TranslateModelService::OnModelFileUpdated(
optimization_guide::proto::OptimizationTarget optimization_target,
const base::FilePath& file_path) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
if (optimization_target !=
optimization_guide::proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION) {
return;
}
// TODO(crbug.com/1151406): Implement loading the model on a background thread
// and return it for use by translate.
background_task_runner_->PostTaskAndReplyWithResult(
FROM_HERE, base::BindOnce(&LoadModelFile, file_path),
base::BindOnce(&TranslateModelService::OnModelFileLoaded,
base::Unretained(this)));
}
base::Optional<base::File>
TranslateModelService::GetLanguageDetectionModelFile() {
// TODO(crbug.com/1151406): Implement loading the model on a background thread
// and return it for use by translate.
return base::nullopt;
void TranslateModelService::OnModelFileLoaded(base::File model_file) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
if (!model_file.IsValid()) {
// TODO(crbug.com/1157661): add histogram to log the model failed to load.
LOCAL_HISTOGRAM_BOOLEAN(
"TranslateModelService.LanguageDetectionModel.WasValid", false);
return;
}
language_detection_model_file_ = std::move(model_file);
LOCAL_HISTOGRAM_BOOLEAN(
"TranslateModelService.LanguageDetectionModel.WasLoaded", true);
for (auto& pending_request : pending_model_requests_) {
std::move(pending_request).Run(language_detection_model_file_->Duplicate());
}
}
void TranslateModelService::GetLanguageDetectionModelFile(
GetModelCallback callback) {
if (!language_detection_model_file_) {
// TODO(crbug.com/1157661): add histogram record the number of callbacks
// held.
pending_model_requests_.emplace_back(std::move(callback));
return;
}
// The model must be valid at this point.
DCHECK(language_detection_model_file_->IsValid());
std::move(callback).Run(language_detection_model_file_->Duplicate());
}
} // namespace translate
......@@ -5,14 +5,16 @@
#ifndef COMPONENTS_TRANSLATE_CONTENT_BROWSER_TRANSLATE_MODEL_SERVICE_H_
#define COMPONENTS_TRANSLATE_CONTENT_BROWSER_TRANSLATE_MODEL_SERVICE_H_
#include <memory>
#include <vector>
#include "base/callback.h"
#include "base/files/file.h"
#include "base/files/file_path.h"
#include "base/optional.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/optimization_guide/optimization_target_model_observer.h"
namespace base {
class File;
class FilePath;
} // namespace base
namespace optimization_guide {
class OptimizationGuideDecider;
......@@ -27,8 +29,11 @@ class TranslateModelService
: public KeyedService,
public optimization_guide::OptimizationTargetModelObserver {
public:
explicit TranslateModelService(
optimization_guide::OptimizationGuideDecider* opt_guide);
using GetModelCallback = base::OnceCallback<void(base::File)>;
TranslateModelService(
optimization_guide::OptimizationGuideDecider* opt_guide,
const scoped_refptr<base::SequencedTaskRunner>& background_task_runner);
~TranslateModelService() override;
// KeyedService implementation:
......@@ -39,16 +44,28 @@ class TranslateModelService
optimization_guide::proto::OptimizationTarget optimization_target,
const base::FilePath& file_path) override;
// Returns a loaded file containing the TFLite model capable of detecting the
// language of a web page's text.
base::Optional<base::File> GetLanguageDetectionModelFile();
// Invokes |callback| with a language detection model file when it is
// available.
void GetLanguageDetectionModelFile(GetModelCallback callback);
private:
// Optimization Guide Service that provides model files for this
// service. Optimization Guide Service is a
// BrowserContextKeyedServiceFactory and should not
// be used after ShutDown.
void OnModelFileLoaded(base::File model_file);
// Optimization Guide Service that provides model files for this service.
// Optimization Guide Service is a BrowserContextKeyedServiceFactory and
// should not be used after Shutdown.
optimization_guide::OptimizationGuideDecider* opt_guide_;
// The file that contains the language detection model. Available when the
// file path has been provided by the Optimization Guide and has been
// successfully loaded.
base::Optional<base::File> language_detection_model_file_;
// The set of callbacks associated with requests for the language detection
// model.
std::vector<GetModelCallback> pending_model_requests_;
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
};
} // namespace translate
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment