Commit f51abccb authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Add download service integration for optimization guide models

This doesn't actually do anything with the downloads (i.e. verification, storage, etc.) but just requests for the Download Service to download models based on the URLs that the Opt Guide server sends us

Bug: 1146151
Change-Id: Ifc25e7266bd2a5388036fd2f325d891a66754690
No-Try: true
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2530760
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: default avatarNicolas Ouellet-Payeur <nicolaso@chromium.org>
Reviewed-by: default avatarMin Qin <qinmin@chromium.org>
Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#827823}
parent 255f1eef
......@@ -975,6 +975,10 @@ static_library("browser") {
"optimization_guide/optimization_guide_web_contents_observer.h",
"optimization_guide/prediction/prediction_manager.cc",
"optimization_guide/prediction/prediction_manager.h",
"optimization_guide/prediction/prediction_model_download_client.cc",
"optimization_guide/prediction/prediction_model_download_client.h",
"optimization_guide/prediction/prediction_model_download_manager.cc",
"optimization_guide/prediction/prediction_model_download_manager.h",
"optimization_guide/prediction/prediction_model_fetcher.cc",
"optimization_guide/prediction/prediction_model_fetcher.h",
"optimization_guide/prediction/remote_decision_tree_predictor.cc",
......
......@@ -22,6 +22,7 @@
#include "chrome/browser/download/download_task_scheduler_impl.h"
#include "chrome/browser/download/simple_download_manager_coordinator_factory.h"
#include "chrome/browser/net/system_network_context_manager.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_client.h"
#include "chrome/browser/profiles/incognito_helpers.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/profiles/profile_key.h"
......@@ -38,6 +39,7 @@
#include "components/keyed_service/core/simple_dependency_manager.h"
#include "components/leveldb_proto/public/proto_database_provider.h"
#include "components/offline_pages/buildflags/buildflags.h"
#include "components/optimization_guide/optimization_guide_features.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
......@@ -67,6 +69,12 @@ std::unique_ptr<download::Client> CreatePluginVmImageDownloadClient(
}
#endif // BUILDFLAG(IS_CHROMEOS_ASH)
std::unique_ptr<download::Client>
CreateOptimizationGuidePredictionModelDownloadClient(Profile* profile) {
return std::make_unique<optimization_guide::PredictionModelDownloadClient>(
profile);
}
// Called on profile created to retrieve the BlobStorageContextGetter.
void DownloadOnProfileCreated(download::BlobContextGetterCallback callback,
Profile* profile) {
......@@ -148,6 +156,16 @@ std::unique_ptr<KeyedService> DownloadServiceFactory::BuildServiceInstanceFor(
}
#endif // BUILDFLAG(IS_CHROMEOS_ASH)
if (optimization_guide::features::IsModelDownloadingEnabled() &&
!key->IsOffTheRecord()) {
clients->insert(std::make_pair(
download::DownloadClient::OPTIMIZATION_GUIDE_PREDICTION_MODELS,
std::make_unique<download::DeferredClientWrapper>(
base::BindOnce(
&CreateOptimizationGuidePredictionModelDownloadClient),
key)));
}
// Build in memory download service for incognito profile.
if (key->IsOffTheRecord() &&
base::FeatureList::IsEnabled(download::kDownloadServiceIncognito)) {
......
......@@ -37,6 +37,7 @@ class OptimizationGuideService;
class TopHostProvider;
class PredictionManager;
class PredictionManagerBrowserTestBase;
class PredictionModelDownloadClient;
} // namespace optimization_guide
class GURL;
......@@ -95,6 +96,7 @@ class OptimizationGuideKeyedService
friend class OptimizationGuideKeyedServiceBrowserTest;
friend class OptimizationGuideWebContentsObserver;
friend class ProfileManager;
friend class optimization_guide::PredictionModelDownloadClient;
friend class optimization_guide::PredictionManagerBrowserTestBase;
friend class optimization_guide::android::OptimizationGuideBridge;
......
......@@ -24,6 +24,7 @@
#include "chrome/browser/optimization_guide/optimization_guide_navigation_data.h"
#include "chrome/browser/optimization_guide/optimization_guide_permissions_util.h"
#include "chrome/browser/optimization_guide/optimization_guide_session_statistic.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_manager.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_fetcher.h"
#include "chrome/browser/optimization_guide/prediction/remote_decision_tree_predictor.h"
#include "chrome/browser/profiles/profile.h"
......@@ -230,7 +231,10 @@ PredictionManager::PredictionManager(
Profile* profile)
: host_model_features_cache_(
std::max(features::MaxHostModelFeaturesCacheSize(), size_t(1))),
session_fcp_(),
prediction_model_download_manager_(
features::IsModelDownloadingEnabled()
? std::make_unique<PredictionModelDownloadManager>(profile)
: nullptr),
top_host_provider_(top_host_provider),
model_and_features_store_(std::move(model_and_features_store)),
url_loader_factory_(url_loader_factory),
......@@ -693,6 +697,13 @@ void PredictionManager::SetPredictionModelFetcherForTesting(
prediction_model_fetcher_ = std::move(prediction_model_fetcher);
}
void PredictionManager::SetPredictionModelDownloadManagerForTesting(
std::unique_ptr<PredictionModelDownloadManager>
prediction_model_download_manager) {
prediction_model_download_manager_ =
std::move(prediction_model_download_manager);
}
void PredictionManager::FetchModelsAndHostModelFeatures() {
SEQUENCE_CHECKER(sequence_checker_);
if (!IsUserPermittedToFetchFromRemoteOptimizationGuide(profile_))
......@@ -700,11 +711,23 @@ void PredictionManager::FetchModelsAndHostModelFeatures() {
ScheduleModelsAndHostModelFeaturesFetch();
// We cannot download any models from the server, so don't refresh them.
if (prediction_model_download_manager_ &&
!prediction_model_download_manager_->IsAvailableForDownloads()) {
// TODO(crbug/1146151): Add histogram for how often this happens.
return;
}
// Models and host model features should not be fetched if there are no
// optimization targets registered.
if (registered_optimization_targets_.size() == 0)
return;
// Cancel all pending downloads since the server will probably give us new
// ones to fetch.
if (prediction_model_download_manager_)
prediction_model_download_manager_->CancelAllPendingDownloads();
std::vector<std::string> top_hosts;
// If the top host provider is not available, the user has likely not seen the
// Lite mode infobar, so top hosts cannot be provided. However, prediction
......@@ -827,10 +850,17 @@ void PredictionManager::UpdatePredictionModels(
bool has_models_to_update = false;
for (const auto& model : prediction_models) {
if (model.has_model() && !model.model().download_url().empty()) {
// Skip over models that have a download URL since they will be updated
// out-of-band.
if (prediction_model_download_manager_) {
GURL download_url(model.model().download_url());
if (download_url.is_valid()) {
prediction_model_download_manager_->StartDownload(download_url);
} else {
// TODO(crbug/1146151): Add histogram for invalid download URL.
}
}
// TODO(crbug/1146151): Download model from URL.
// Skip over models that have a download URL since they will be updated
// once the download has completed successfully.
continue;
}
......
......@@ -50,6 +50,7 @@ namespace optimization_guide {
enum class OptimizationGuideDecision;
class OptimizationGuideStore;
class PredictionModel;
class PredictionModelDownloadManager;
class PredictionModelFetcher;
class TopHostProvider;
class RemoteDecisionTreePredictor;
......@@ -152,6 +153,15 @@ class PredictionManager
return prediction_model_fetcher_.get();
}
// Set the prediction model download manager for testing.
void SetPredictionModelDownloadManagerForTesting(
std::unique_ptr<PredictionModelDownloadManager>
prediction_model_download_manager);
PredictionModelDownloadManager* prediction_model_download_manager() const {
return prediction_model_download_manager_.get();
}
OptimizationGuideStore* model_and_features_store() const {
return model_and_features_store_.get();
}
......@@ -384,10 +394,15 @@ class PredictionManager
// load of a session).
base::Optional<float> previous_load_fcp_ms_;
// The fetcher than handles making requests to update the models and host
// The fetcher that handles making requests to update the models and host
// model features from the remote Optimization Guide Service.
std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher_;
// The downloader that handles making requests to download the prediction
// models. Can be null if model downloading is disabled.
std::unique_ptr<PredictionModelDownloadManager>
prediction_model_download_manager_;
// The top host provider that can be queried. Not owned.
TopHostProvider* top_host_provider_ = nullptr;
......
......@@ -16,6 +16,7 @@
#include "build/build_config.h"
#include "chrome/browser/optimization_guide/optimization_guide_navigation_data.h"
#include "chrome/browser/optimization_guide/optimization_guide_web_contents_observer.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_manager.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_fetcher.h"
#include "chrome/browser/optimization_guide/prediction/remote_decision_tree_predictor.h"
#include "chrome/services/machine_learning/public/cpp/test_support/fake_service_connection.h"
......@@ -73,7 +74,8 @@ std::unique_ptr<proto::PredictionModel> CreatePredictionModel(
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
if (output_model_as_download_url)
prediction_model->mutable_model()->set_download_url("someurl");
prediction_model->mutable_model()->set_download_url(
"https://example.com/model");
else
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
return prediction_model;
......@@ -166,6 +168,33 @@ class FakeTopHostProvider : public TopHostProvider {
int num_top_hosts_called_ = 0;
};
class FakePredictionModelDownloadManager
: public PredictionModelDownloadManager {
public:
explicit FakePredictionModelDownloadManager(Profile* profile)
: PredictionModelDownloadManager(profile) {}
~FakePredictionModelDownloadManager() override = default;
void StartDownload(const GURL& url) override {
last_requested_download_ = url;
}
GURL last_requested_download() const { return last_requested_download_; }
void CancelAllPendingDownloads() override { cancel_downloads_called_ = true; }
bool cancel_downloads_called() const { return cancel_downloads_called_; }
bool IsAvailableForDownloads() const override { return is_available_; }
void SetAvailableForDownloads(bool is_available) {
is_available_ = is_available;
}
private:
GURL last_requested_download_;
bool cancel_downloads_called_ = false;
bool is_available_ = true;
};
enum class PredictionModelFetcherEndState {
kFetchFailed = 0,
kFetchSuccessWithModelsAndHostsModelFeatures = 1,
......@@ -519,6 +548,12 @@ class PredictionManagerTest
prediction_manager()->prediction_model_fetcher());
}
FakePredictionModelDownloadManager* prediction_model_download_manager()
const {
return static_cast<FakePredictionModelDownloadManager*>(
prediction_manager()->prediction_model_download_manager());
}
TestOptimizationGuideStore* models_and_features_store() const {
return static_cast<TestOptimizationGuideStore*>(
prediction_manager()->model_and_features_store());
......@@ -528,6 +563,8 @@ class PredictionManagerTest
TestingPrefServiceSimple* pref_service() const { return pref_service_.get(); }
TestingProfile* profile() { return &testing_profile_; }
void RunUntilIdle() {
task_environment_.RunUntilIdle();
base::RunLoop().RunUntilIdle();
......@@ -1142,6 +1179,38 @@ TEST_P(PredictionManagerMLServiceTest,
}
}
TEST_P(PredictionManagerMLServiceTest,
DownloadManagerUnavailableShouldNotFetch) {
base::test::ScopedFeatureList scoped_feature_list;
if (UsingMLService()) {
scoped_feature_list.InitWithFeatures(
{optimization_guide::features::
kOptimizationTargetPredictionUsingMLService},
{});
SetLoadModelResult(machine_learning::mojom::LoadModelResult::kOk);
}
base::HistogramTester histogram_tester;
std::unique_ptr<content::MockNavigationHandle> navigation_handle =
CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver(
GURL("https://foo.com"));
CreatePredictionManager({});
prediction_manager()->SetPredictionModelFetcherForTesting(
BuildTestPredictionModelFetcher(
PredictionModelFetcherEndState::kFetchSuccessWithModelDownloadUrls));
prediction_manager()->SetPredictionModelDownloadManagerForTesting(
std::make_unique<FakePredictionModelDownloadManager>(profile()));
prediction_model_download_manager()->SetAvailableForDownloads(false);
prediction_manager()->RegisterOptimizationTargets(
{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD});
SetStoreInitialized();
EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
}
TEST_P(PredictionManagerMLServiceTest, UpdateModelWithDownloadUrl) {
base::test::ScopedFeatureList scoped_feature_list;
if (UsingMLService()) {
......@@ -1162,12 +1231,15 @@ TEST_P(PredictionManagerMLServiceTest, UpdateModelWithDownloadUrl) {
prediction_manager()->SetPredictionModelFetcherForTesting(
BuildTestPredictionModelFetcher(
PredictionModelFetcherEndState::kFetchSuccessWithModelDownloadUrls));
prediction_manager()->SetPredictionModelDownloadManagerForTesting(
std::make_unique<FakePredictionModelDownloadManager>(profile()));
prediction_manager()->RegisterOptimizationTargets(
{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD});
SetStoreInitialized();
EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
EXPECT_TRUE(prediction_model_download_manager()->cancel_downloads_called());
models_and_features_store()->RunUpdateHostModelFeaturesCallback();
histogram_tester.ExpectUniqueSample(
......@@ -1175,7 +1247,8 @@ TEST_P(PredictionManagerMLServiceTest, UpdateModelWithDownloadUrl) {
histogram_tester.ExpectTotalCount(
"OptimizationGuide.PredictionManager.PredictionModelsStored", 0);
// TODO(crbug/1146151): Update test to incorporate downloading of model.
EXPECT_EQ(prediction_model_download_manager()->last_requested_download(),
GURL("https://example.com/model"));
}
TEST_P(PredictionManagerMLServiceTest,
......
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_client.h"
#include "base/bind.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/optimization_guide/prediction/prediction_manager.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_manager.h"
#include "components/download/public/background_service/download_metadata.h"
namespace optimization_guide {
PredictionModelDownloadClient::PredictionModelDownloadClient(Profile* profile)
: profile_(profile) {}
PredictionModelDownloadClient::~PredictionModelDownloadClient() = default;
PredictionModelDownloadManager*
PredictionModelDownloadClient::GetPredictionModelDownloadManager() {
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile_);
if (!optimization_guide_keyed_service)
return nullptr;
PredictionManager* prediction_manager =
optimization_guide_keyed_service->GetPredictionManager();
if (!prediction_manager)
return nullptr;
return prediction_manager->prediction_model_download_manager();
}
void PredictionModelDownloadClient::OnServiceInitialized(
bool state_lost,
const std::vector<download::DownloadMetaData>& downloads) {
PredictionModelDownloadManager* download_manager =
GetPredictionModelDownloadManager();
if (!download_manager)
return;
std::set<std::string> outstanding_download_guids;
std::map<std::string, base::FilePath> successful_downloads;
for (const auto& download : downloads) {
if (!download.completion_info) {
outstanding_download_guids.emplace(download.guid);
continue;
}
successful_downloads.emplace(download.guid, download.completion_info->path);
}
download_manager->OnDownloadServiceReady(outstanding_download_guids,
successful_downloads);
}
void PredictionModelDownloadClient::OnServiceUnavailable() {
PredictionModelDownloadManager* download_manager =
GetPredictionModelDownloadManager();
if (download_manager)
download_manager->OnDownloadServiceUnavailable();
}
void PredictionModelDownloadClient::OnDownloadFailed(
const std::string& guid,
const download::CompletionInfo& completion_info,
download::Client::FailureReason reason) {
PredictionModelDownloadManager* download_manager =
GetPredictionModelDownloadManager();
if (download_manager)
download_manager->OnDownloadFailed(guid);
}
void PredictionModelDownloadClient::OnDownloadSucceeded(
const std::string& guid,
const download::CompletionInfo& completion_info) {
PredictionModelDownloadManager* download_manager =
GetPredictionModelDownloadManager();
if (download_manager)
download_manager->OnDownloadSucceeded(guid, completion_info.path);
}
bool PredictionModelDownloadClient::CanServiceRemoveDownloadedFile(
const std::string& guid,
bool force_delete) {
// Always return true. We immediately postprocess successful downloads and the
// file downloaded by the Download Service should already be deleted and this
// hypothetically should never be called with anything that matters.
return true;
}
void PredictionModelDownloadClient::GetUploadData(
const std::string& guid,
download::GetUploadDataCallback callback) {
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), nullptr));
}
} // namespace optimization_guide
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MODEL_DOWNLOAD_CLIENT_H_
#define CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MODEL_DOWNLOAD_CLIENT_H_
#include "components/download/public/background_service/client.h"
class Profile;
namespace download {
struct CompletionInfo;
struct DownloadMetaData;
} // namespace download
namespace optimization_guide {
class PredictionModelDownloadManager;
class PredictionModelDownloadClient : public download::Client {
public:
explicit PredictionModelDownloadClient(Profile* profile);
~PredictionModelDownloadClient() override;
PredictionModelDownloadClient(const PredictionModelDownloadClient&) = delete;
PredictionModelDownloadClient& operator=(
const PredictionModelDownloadClient&) = delete;
// download::Client:
void OnServiceInitialized(
bool state_lost,
const std::vector<download::DownloadMetaData>& downloads) override;
void OnServiceUnavailable() override;
void OnDownloadFailed(const std::string& guid,
const download::CompletionInfo& completion_info,
download::Client::FailureReason reason) override;
void OnDownloadSucceeded(
const std::string& guid,
const download::CompletionInfo& completion_info) override;
bool CanServiceRemoveDownloadedFile(const std::string& guid,
bool force_delete) override;
void GetUploadData(const std::string& guid,
download::GetUploadDataCallback callback) override;
private:
// Returns the PredictionModelDownloadManager for the profile.
PredictionModelDownloadManager* GetPredictionModelDownloadManager();
Profile* profile_;
};
} // namespace optimization_guide
#endif // CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MODEL_DOWNLOAD_CLIENT_H_
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_manager.h"
#include "base/bind.h"
#include "base/guid.h"
#include "chrome/browser/download/download_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/profiles/profile_key.h"
#include "components/download/public/background_service/download_service.h"
#include "components/optimization_guide/optimization_guide_features.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
namespace optimization_guide {
namespace {
// Header for API key.
constexpr char kGoogApiKey[] = "X-Goog-Api-Key";
const net::NetworkTrafficAnnotationTag
kOptimizationGuidePredictionModelsTrafficAnnotation =
net::DefineNetworkTrafficAnnotation("optimization_guide_model_download",
R"(
semantics {
sender: "Optimization Guide"
description:
"Chromium interacts with Optimization Guide Service to download "
"non-personalized models used to improve browser behavior around "
"page load performance and features such as Translate."
trigger:
"When there are new models to download based on response from "
"Optimization Guide Service that is triggered daily."
data: "The URL provided by the Optimization Guide Service to fetch "
"an updated model. No user information is sent."
destination: GOOGLE_OWNED_SERVICE
}
policy {
cookies_allowed: NO
setting:
"This request cannot be disabled in settings. However it will "
"never be made if the "
"'OptimizationGuideModelDownloading' feature is disabled."
policy_exception_justification: "Not yet implemented."
})");
} // namespace
PredictionModelDownloadManager::PredictionModelDownloadManager(Profile* profile)
: download_service_(
DownloadServiceFactory::GetForKey(profile->GetProfileKey())),
is_available_for_downloads_(true),
api_key_(features::GetOptimizationGuideServiceAPIKey()) {}
PredictionModelDownloadManager::~PredictionModelDownloadManager() = default;
void PredictionModelDownloadManager::StartDownload(const GURL& download_url) {
download::DownloadParams download_params;
download_params.client =
download::DownloadClient::OPTIMIZATION_GUIDE_PREDICTION_MODELS;
download_params.guid = base::GenerateGUID();
download_params.callback =
base::BindRepeating(&PredictionModelDownloadManager::OnDownloadStarted,
weak_ptr_factory_.GetWeakPtr());
download_params.traffic_annotation = net::MutableNetworkTrafficAnnotationTag(
kOptimizationGuidePredictionModelsTrafficAnnotation);
download_params.request_params.url = download_url;
download_params.request_params.method = "GET";
download_params.request_params.request_headers.SetHeader(kGoogApiKey,
api_key_);
// TODO(crbug/1146151): Add feature params to control the scheduling params.
download_params.scheduling_params.priority =
download::SchedulingParams::Priority::NORMAL;
download_params.scheduling_params.battery_requirements =
download::SchedulingParams::BatteryRequirements::BATTERY_INSENSITIVE;
download_params.scheduling_params.network_requirements =
download::SchedulingParams::NetworkRequirements::OPTIMISTIC;
download_service_->StartDownload(download_params);
}
void PredictionModelDownloadManager::CancelAllPendingDownloads() {
for (const std::string& pending_download_guid : pending_download_guids_)
download_service_->CancelDownload(pending_download_guid);
}
bool PredictionModelDownloadManager::IsAvailableForDownloads() const {
return is_available_for_downloads_;
}
void PredictionModelDownloadManager::OnDownloadServiceReady(
const std::set<std::string>& pending_download_guids,
const std::map<std::string, base::FilePath>& successful_downloads) {
for (const std::string& pending_download_guid : pending_download_guids)
pending_download_guids_.insert(pending_download_guid);
for (const auto& successful_download : successful_downloads)
OnDownloadSucceeded(successful_download.first, successful_download.second);
}
void PredictionModelDownloadManager::OnDownloadServiceUnavailable() {
is_available_for_downloads_ = false;
// TODO(crbug/1146151): Log histogram.
}
void PredictionModelDownloadManager::OnDownloadStarted(
const std::string& guid,
download::DownloadParams::StartResult start_result) {
if (start_result == download::DownloadParams::StartResult::ACCEPTED)
pending_download_guids_.insert(guid);
}
void PredictionModelDownloadManager::OnDownloadSucceeded(
const std::string& guid,
const base::FilePath& file_path) {
pending_download_guids_.erase(guid);
// TODO(crbug/1146151): Verify download.
}
void PredictionModelDownloadManager::OnDownloadFailed(const std::string& guid) {
pending_download_guids_.erase(guid);
// TODO(crbug/1146151): Log histogram.
}
} // namespace optimization_guide
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MODEL_DOWNLOAD_MANAGER_H_
#define CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MODEL_DOWNLOAD_MANAGER_H_
#include <map>
#include <set>
#include <string>
#include "base/memory/weak_ptr.h"
#include "components/download/public/background_service/download_params.h"
class Profile;
namespace download {
class DownloadService;
} // namespace download
namespace optimization_guide {
class PredictionModelDownloadClient;
// Manages the downloads of prediction models.
class PredictionModelDownloadManager {
public:
explicit PredictionModelDownloadManager(Profile* profile);
virtual ~PredictionModelDownloadManager();
PredictionModelDownloadManager(const PredictionModelDownloadManager&) =
delete;
PredictionModelDownloadManager& operator=(
const PredictionModelDownloadManager&) = delete;
// Starts a download for |download_url|.
virtual void StartDownload(const GURL& download_url);
// Cancels all pending downloads.
virtual void CancelAllPendingDownloads();
// Returns whether the downloader can download models.
virtual bool IsAvailableForDownloads() const;
private:
friend class PredictionModelDownloadClient;
friend class PredictionModelDownloadManagerTest;
// Invoked when the Download Service is ready.
//
// |pending_download_guids| is the set of GUIDs that were previously scheduled
// to be downloaded and have still not been downloaded yet.
// |successful_downloads| is the map from GUID to the file path that it was
// successfully downloaded to.
void OnDownloadServiceReady(
const std::set<std::string>& pending_download_guids,
const std::map<std::string, base::FilePath>& successful_downloads);
// Invoked when the Download Service fails to initialize and should not be
// used for the session.
void OnDownloadServiceUnavailable();
// Invoked when the download has been accepted and persisted by the
// DownloadService.
void OnDownloadStarted(const std::string& guid,
download::DownloadParams::StartResult start_result);
// Invoked when the download as specified by |downloaded_guid| succeeded.
void OnDownloadSucceeded(const std::string& downloaded_guid,
const base::FilePath& file_path);
// Invoked when the download as specified by |failed_download_guid| failed.
void OnDownloadFailed(const std::string& failed_download_guid);
// The set of GUIDs that are still pending download.
std::set<std::string> pending_download_guids_;
// The Download Service to schedule model downloads with.
//
// Guaranteed to outlive |this|.
download::DownloadService* download_service_;
// Whether the download service is available.
bool is_available_for_downloads_;
// The API key to attach to download requests.
std::string api_key_;
base::WeakPtrFactory<PredictionModelDownloadManager> weak_ptr_factory_{this};
};
} // namespace optimization_guide
#endif // CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MODEL_DOWNLOAD_MANAGER_H_
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_manager.h"
#include "base/files/scoped_temp_dir.h"
#include "chrome/browser/download/download_service_factory.h"
#include "chrome/browser/profiles/profile_key.h"
#include "chrome/test/base/chrome_render_view_host_test_harness.h"
#include "components/download/public/background_service/test/mock_download_service.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace optimization_guide {
using ::testing::_;
using ::testing::Eq;
using ::testing::SaveArg;
class PredictionModelDownloadManagerTest
: public ChromeRenderViewHostTestHarness {
public:
PredictionModelDownloadManagerTest() = default;
~PredictionModelDownloadManagerTest() override = default;
void SetUp() override {
ChromeRenderViewHostTestHarness::SetUp();
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
mock_download_service_ = static_cast<download::test::MockDownloadService*>(
DownloadServiceFactory::GetInstance()->SetTestingFactoryAndUse(
profile()->GetProfileKey(),
base::BindRepeating([](SimpleFactoryKey* key)
-> std::unique_ptr<KeyedService> {
return std::make_unique<download::test::MockDownloadService>();
})));
download_manager_ =
std::make_unique<PredictionModelDownloadManager>(profile());
}
void TearDown() override {
download_manager_.reset();
mock_download_service_ = nullptr;
ChromeRenderViewHostTestHarness::TearDown();
}
PredictionModelDownloadManager* download_manager() {
return download_manager_.get();
}
download::test::MockDownloadService* download_service() {
return mock_download_service_;
}
protected:
void SetDownloadServiceReady(const std::set<std::string>& pending_guids,
const std::set<std::string>& successful_guids) {
std::map<std::string, base::FilePath> success_map;
for (const auto& guid : successful_guids) {
success_map.emplace(guid, temp_dir_.GetPath());
}
download_manager()->OnDownloadServiceReady(pending_guids, success_map);
}
void SetDownloadServiceUnavailable() {
download_manager()->OnDownloadServiceUnavailable();
}
void SetDownloadSucceeded(const std::string& guid) {
download_manager()->OnDownloadSucceeded(guid, temp_dir_.GetPath());
}
void SetDownloadFailed(const std::string& guid) {
download_manager()->OnDownloadFailed(guid);
}
private:
base::ScopedTempDir temp_dir_;
download::test::MockDownloadService* mock_download_service_;
std::unique_ptr<PredictionModelDownloadManager> download_manager_;
};
TEST_F(PredictionModelDownloadManagerTest, DownloadServiceReadyPersistsGuids) {
SetDownloadServiceReady({"pending1", "pending2", "pending3"},
{"success1", "success2", "success3"});
// Should only persist and thus cancel the pending ones.
EXPECT_CALL(*download_service(), CancelDownload(Eq("pending1")));
EXPECT_CALL(*download_service(), CancelDownload(Eq("pending2")));
EXPECT_CALL(*download_service(), CancelDownload(Eq("pending3")));
download_manager()->CancelAllPendingDownloads();
}
TEST_F(PredictionModelDownloadManagerTest, StartDownload) {
download::DownloadParams download_params;
EXPECT_CALL(*download_service(), StartDownload(_))
.WillOnce(SaveArg<0>(&download_params));
download_manager()->StartDownload(GURL("someurl"));
// Validate parameters - basically that we attach the correct client, just do
// a passthrough of the URL, and attach the API key.
EXPECT_EQ(download_params.client,
download::DownloadClient::OPTIMIZATION_GUIDE_PREDICTION_MODELS);
EXPECT_EQ(download_params.request_params.url, GURL("someurl"));
EXPECT_EQ(download_params.request_params.method, "GET");
EXPECT_TRUE(download_params.request_params.request_headers.HasHeader(
"X-Goog-Api-Key"));
// Now invoke start callback.
std::move(download_params.callback)
.Run("someguid", download::DownloadParams::StartResult::ACCEPTED);
// Now cancel all downloads to ensure that callback persisted pending GUID.
EXPECT_CALL(*download_service(), CancelDownload(Eq("someguid")));
download_manager()->CancelAllPendingDownloads();
}
TEST_F(PredictionModelDownloadManagerTest, StartDownloadFailedToSchedule) {
download::DownloadParams download_params;
EXPECT_CALL(*download_service(), StartDownload(_))
.WillOnce(SaveArg<0>(&download_params));
download_manager()->StartDownload(GURL("someurl"));
// Now invoke start callback.
std::move(download_params.callback)
.Run("someguid", download::DownloadParams::StartResult::INTERNAL_ERROR);
// Now cancel all downloads to ensure that bad GUID was not accepted.
EXPECT_CALL(*download_service(), CancelDownload(_)).Times(0);
download_manager()->CancelAllPendingDownloads();
}
TEST_F(PredictionModelDownloadManagerTest, IsAvailableForDownloads) {
EXPECT_TRUE(download_manager()->IsAvailableForDownloads());
SetDownloadServiceUnavailable();
EXPECT_FALSE(download_manager()->IsAvailableForDownloads());
}
TEST_F(PredictionModelDownloadManagerTest,
SuccessfulDownloadShouldNoLongerBeTracked) {
SetDownloadServiceReady({"pending1", "pending2", "pending3"},
/*successful_guids=*/{});
SetDownloadSucceeded("pending1");
// Should only persist and thus cancel the pending ones.
EXPECT_CALL(*download_service(), CancelDownload(Eq("pending2")));
EXPECT_CALL(*download_service(), CancelDownload(Eq("pending3")));
download_manager()->CancelAllPendingDownloads();
}
TEST_F(PredictionModelDownloadManagerTest,
FailedDownloadShouldNoLongerBeTracked) {
SetDownloadServiceReady({"pending1", "pending2", "pending3"},
/*successful_guids=*/{});
SetDownloadSucceeded("pending2");
// Should only persist and thus cancel the pending ones.
EXPECT_CALL(*download_service(), CancelDownload(Eq("pending1")));
EXPECT_CALL(*download_service(), CancelDownload(Eq("pending3")));
download_manager()->CancelAllPendingDownloads();
}
} // namespace optimization_guide
......@@ -3514,6 +3514,7 @@ test("unit_tests") {
"../browser/optimization_guide/optimization_guide_session_statistic_unittest.cc",
"../browser/optimization_guide/optimization_guide_top_host_provider_unittest.cc",
"../browser/optimization_guide/prediction/prediction_manager_unittest.cc",
"../browser/optimization_guide/prediction/prediction_model_download_manager_unittest.cc",
"../browser/optimization_guide/prediction/prediction_model_fetcher_unittest.cc",
"../browser/optimization_guide/prediction/remote_decision_tree_predictor_unittest.cc",
"../browser/page_load_metrics/metrics_web_contents_observer_unittest.cc",
......@@ -4140,6 +4141,7 @@ test("unit_tests") {
"//components/content_settings/core/test:test_support",
"//components/data_reduction_proxy/core/browser:test_support",
"//components/data_use_measurement/core",
"//components/download/public/background_service/test:test_support",
"//components/favicon/core/test:test_support",
"//components/flags_ui:test_support",
"//components/mirroring:mirroring_tests",
......
......@@ -33,6 +33,7 @@ class MockDownloadService : public DownloadService {
MOCK_METHOD1(CancelDownload, void(const std::string& guid));
MOCK_METHOD2(ChangeDownloadCriteria,
void(const std::string& guid, const SchedulingParams& params));
MOCK_METHOD0(GetLogger, Logger*());
private:
DISALLOW_COPY_AND_ASSIGN(MockDownloadService);
......
......@@ -59,6 +59,10 @@ const base::Feature kOptimizationTargetPredictionUsingMLService{
"OptimizationGuidePredictionUsingMLService",
base::FEATURE_DISABLED_BY_DEFAULT};
// Enables the downloading of models.
const base::Feature kOptimizationGuideModelDownloading{
"OptimizationGuideModelDownloading", base::FEATURE_DISABLED_BY_DEFAULT};
size_t MaxHintsFetcherTopHostBlacklistSize() {
// The blacklist will be limited to the most engaged hosts and will hold twice
// (2*N) as many hosts that the HintsFetcher request hints for. The extra N
......@@ -307,5 +311,9 @@ bool ShouldUseMLServiceForPrediction() {
kOptimizationTargetPredictionUsingMLService);
}
bool IsModelDownloadingEnabled() {
return base::FeatureList::IsEnabled(kOptimizationGuideModelDownloading);
}
} // namespace features
} // namespace optimization_guide
......@@ -25,6 +25,7 @@ extern const base::Feature kRemoteOptimizationGuideFetching;
extern const base::Feature kRemoteOptimizationGuideFetchingAnonymousDataConsent;
extern const base::Feature kOptimizationTargetPrediction;
extern const base::Feature kOptimizationTargetPredictionUsingMLService;
extern const base::Feature kOptimizationGuideModelDownloading;
// The maximum number of hosts that can be stored in the
// |kHintsFetcherTopHostBlacklist| dictionary pref when initialized. The top
......@@ -158,6 +159,9 @@ base::flat_set<uint32_t> FieldTrialNameHashesAllowedForFetch();
// Whether out-of-process model evaluation via the ML Service is enabled.
bool ShouldUseMLServiceForPrediction();
// Whether the ability to download models is enabled.
bool IsModelDownloadingEnabled();
} // namespace features
} // namespace optimization_guide
......
......@@ -216,6 +216,7 @@ Refer to README.md for content description and update process.
<item id="openscreen_message" added_in_milestone="83" hash_code="23036184" type="0" content_hash_code="124395439" os_list="linux,windows" file_path="components/openscreen_platform/udp_socket.cc"/>
<item id="openscreen_tls_message" added_in_milestone="83" hash_code="40127335" type="0" content_hash_code="15991338" os_list="linux,windows" file_path="components/openscreen_platform/tls_connection_factory.cc"/>
<item id="optimization_guide_model" added_in_milestone="79" hash_code="106373593" type="0" content_hash_code="32403047" os_list="linux,windows" file_path="chrome/browser/optimization_guide/prediction/prediction_model_fetcher.cc"/>
<item id="optimization_guide_model_download" added_in_milestone="88" hash_code="100143055" type="0" content_hash_code="97983899" os_list="linux,windows" file_path="chrome/browser/optimization_guide/prediction/prediction_model_download_manager.cc"/>
<item id="origin_policy_loader" added_in_milestone="69" hash_code="6483617" type="0" content_hash_code="134028975" os_list="linux,windows" file_path="services/network/origin_policy/origin_policy_fetcher.cc"/>
<item id="parallel_download_job" added_in_milestone="62" hash_code="135118587" type="0" content_hash_code="105330419" os_list="linux,windows" file_path="components/download/internal/common/parallel_download_job.cc"/>
<item id="password_protection_request" added_in_milestone="62" hash_code="66322287" type="0" content_hash_code="25596947" os_list="linux,windows" file_path="components/safe_browsing/content/password_protection/password_protection_request.cc"/>
......
......@@ -295,6 +295,7 @@ hidden="true" so that these annotations don't show up in the document.
<traffic_annotation unique_id="data_reduction_proxy_warmup"/>
<traffic_annotation unique_id="hintsfetcher_gethintsrequest"/>
<traffic_annotation unique_id="optimization_guide_model"/>
<traffic_annotation unique_id="optimization_guide_model_download"/>
<traffic_annotation unique_id="previews_litepage_prober"/>
</sender>
<sender name="Network">
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment