Commit dafbfe24 authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Rip out ML Service support from Optimization Guide

This makes it easier to add support for the model file path that we
are supporting in the near future for TFLite models

Bug: 1146151
Change-Id: Ib63e2c68f3e5a227cca8db9d73db9765de41bb16
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2545872Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#828536}
parent 0120e99b
...@@ -982,8 +982,6 @@ static_library("browser") { ...@@ -982,8 +982,6 @@ static_library("browser") {
"optimization_guide/prediction/prediction_model_download_observer.h", "optimization_guide/prediction/prediction_model_download_observer.h",
"optimization_guide/prediction/prediction_model_fetcher.cc", "optimization_guide/prediction/prediction_model_fetcher.cc",
"optimization_guide/prediction/prediction_model_fetcher.h", "optimization_guide/prediction/prediction_model_fetcher.h",
"optimization_guide/prediction/remote_decision_tree_predictor.cc",
"optimization_guide/prediction/remote_decision_tree_predictor.h",
"page_load_metrics/observers/aborts_page_load_metrics_observer.cc", "page_load_metrics/observers/aborts_page_load_metrics_observer.cc",
"page_load_metrics/observers/aborts_page_load_metrics_observer.h", "page_load_metrics/observers/aborts_page_load_metrics_observer.h",
"page_load_metrics/observers/ad_metrics/ads_page_load_metrics_observer.cc", "page_load_metrics/observers/ad_metrics/ads_page_load_metrics_observer.cc",
......
...@@ -203,11 +203,11 @@ void OptimizationGuideKeyedService::ShouldTargetNavigationAsync( ...@@ -203,11 +203,11 @@ void OptimizationGuideKeyedService::ShouldTargetNavigationAsync(
return; return;
} }
prediction_manager_->ShouldTargetNavigationAsync( optimization_guide::OptimizationTargetDecision target_decision =
navigation_handle, optimization_target, client_model_feature_values, prediction_manager_->ShouldTargetNavigation(
base::BindOnce( navigation_handle, optimization_target, client_model_feature_values);
&LogOptimizationTargetDecisionAndPassOptimizationGuideDecision, LogOptimizationTargetDecisionAndPassOptimizationGuideDecision(
optimization_target, std::move(callback))); optimization_target, std::move(callback), target_decision);
} }
void OptimizationGuideKeyedService::RegisterOptimizationTypes( void OptimizationGuideKeyedService::RegisterOptimizationTypes(
......
...@@ -19,11 +19,8 @@ ...@@ -19,11 +19,8 @@
#include "base/timer/timer.h" #include "base/timer/timer.h"
#include "chrome/browser/optimization_guide/optimization_guide_session_statistic.h" #include "chrome/browser/optimization_guide/optimization_guide_session_statistic.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_download_observer.h" #include "chrome/browser/optimization_guide/prediction/prediction_model_download_observer.h"
#include "chrome/services/machine_learning/public/mojom/decision_tree.mojom.h"
#include "chrome/services/machine_learning/public/mojom/machine_learning_service.mojom-forward.h"
#include "components/optimization_guide/optimization_guide_enums.h" #include "components/optimization_guide/optimization_guide_enums.h"
#include "components/optimization_guide/proto/models.pb.h" #include "components/optimization_guide/proto/models.pb.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/network/public/cpp/network_quality_tracker.h" #include "services/network/public/cpp/network_quality_tracker.h"
#include "url/origin.h" #include "url/origin.h"
...@@ -54,11 +51,6 @@ class PredictionModel; ...@@ -54,11 +51,6 @@ class PredictionModel;
class PredictionModelDownloadManager; class PredictionModelDownloadManager;
class PredictionModelFetcher; class PredictionModelFetcher;
class TopHostProvider; class TopHostProvider;
class RemoteDecisionTreePredictor;
// Parameters to be passed to PredictionManager::OnModelEvaluated for post
// processing after the model prediction decision and score are obtained.
struct PredictionDecisionParams;
using HostModelFeaturesMRUCache = using HostModelFeaturesMRUCache =
base::HashingMRUCache<std::string, base::flat_map<std::string, float>>; base::HashingMRUCache<std::string, base::flat_map<std::string, float>>;
...@@ -119,22 +111,6 @@ class PredictionManager ...@@ -119,22 +111,6 @@ class PredictionManager
const base::flat_map<proto::ClientModelFeature, float>& const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values); override_client_model_feature_values);
// Invokes |callback| with the decision for whether the navigation matches the
// criteria for |optimization_target|. Passes kUnknown if a PredictionModel
// for the optimization target is not registered
// and kModelNotAvailableOnClient if the model for the optimization target is
// not currently on the client.
//
// Values provided in |client_model_feature_values| will be used over any
// values for features required by the model that may be calculated by the
// Optimization Guide.
void ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values,
OptimizationTargetDecisionCallback callback);
// Update |session_fcp_| and |previous_fcp_| with |fcp|. // Update |session_fcp_| and |previous_fcp_| with |fcp|.
void UpdateFCPSessionStatistics(base::TimeDelta fcp); void UpdateFCPSessionStatistics(base::TimeDelta fcp);
...@@ -197,11 +173,6 @@ class PredictionManager ...@@ -197,11 +173,6 @@ class PredictionManager
PredictionModel* GetPredictionModelForTesting( PredictionModel* GetPredictionModelForTesting(
proto::OptimizationTarget optimization_target) const; proto::OptimizationTarget optimization_target) const;
// Return the remote model predictor handle for the optimization target used
// by this PredictionManager for testing.
RemoteDecisionTreePredictor* GetRemoteDecisionTreePredictorForTesting(
proto::OptimizationTarget optimization_target) const;
// Return the host model features for all hosts used by this // Return the host model features for all hosts used by this
// PredictionManager for testing. // PredictionManager for testing.
const HostModelFeaturesMRUCache* GetHostModelFeaturesForTesting() const; const HostModelFeaturesMRUCache* GetHostModelFeaturesForTesting() const;
...@@ -322,20 +293,6 @@ class PredictionManager ...@@ -322,20 +293,6 @@ class PredictionManager
// model object was created and successfully stored, otherwise false. // model object was created and successfully stored, otherwise false.
bool ProcessAndStorePredictionModel(const proto::PredictionModel& model); bool ProcessAndStorePredictionModel(const proto::PredictionModel& model);
// Send |model| to the ML service and bind the predictor handle to the
// |optimization_target_remote_model_predictor_map_|, then run |callback|
// for post-processing.
bool SendPredictionModelToMLService(
std::unique_ptr<proto::PredictionModel> model,
PostModelLoadCallback callback);
// Callback run after a prediction |model| is sent to the ML service.
void OnPredictionModelSentToMLService(
PostModelLoadCallback callback,
std::unique_ptr<proto::PredictionModel> model,
std::unique_ptr<RemoteDecisionTreePredictor> predictor_handle,
machine_learning::mojom::LoadModelResult result);
// Post-processing callback invoked after processing |model| or sending it to // Post-processing callback invoked after processing |model| or sending it to
// the ML Service. // the ML Service.
void OnProcessOrSendPredictionModel( void OnProcessOrSendPredictionModel(
...@@ -349,14 +306,6 @@ class PredictionManager ...@@ -349,14 +306,6 @@ class PredictionManager
bool ProcessAndStoreHostModelFeatures( bool ProcessAndStoreHostModelFeatures(
const proto::HostModelFeatures& host_model_features); const proto::HostModelFeatures& host_model_features);
// Callback to be passed to the ML Service via the predictor handle and to
// retrieve |result| and |prediction_score|. Performs post processing using
// information passed via |params|.
void OnModelEvaluated(
std::unique_ptr<PredictionDecisionParams> params,
machine_learning::mojom::DecisionTreePredictionResult result,
double prediction_score);
// Return the time when a prediction model and host model features fetch was // Return the time when a prediction model and host model features fetch was
// last attempted. // last attempted.
base::Time GetLastFetchAttemptTime() const; base::Time GetLastFetchAttemptTime() const;
...@@ -379,12 +328,6 @@ class PredictionManager ...@@ -379,12 +328,6 @@ class PredictionManager
base::flat_map<proto::OptimizationTarget, std::unique_ptr<PredictionModel>> base::flat_map<proto::OptimizationTarget, std::unique_ptr<PredictionModel>>
optimization_target_prediction_model_map_; optimization_target_prediction_model_map_;
// A map of optimization target to the model predictor handle capable of
// sending prediction calls to the prediction model loaded in the ML Service.
base::flat_map<proto::OptimizationTarget,
std::unique_ptr<RemoteDecisionTreePredictor>>
optimization_target_remote_model_predictor_map_;
// The set of optimization targets that have been registered with the // The set of optimization targets that have been registered with the
// prediction manager. // prediction manager.
base::flat_set<proto::OptimizationTarget> registered_optimization_targets_; base::flat_set<proto::OptimizationTarget> registered_optimization_targets_;
......
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/remote_decision_tree_predictor.h"
#include <string>
#include "base/containers/flat_set.h"
namespace optimization_guide {
RemoteDecisionTreePredictor::RemoteDecisionTreePredictor(
const proto::PredictionModel& model) {
// The Decision Tree model type is currently the only supported model type.
DCHECK(model.model_info().supported_model_types(0) ==
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE);
version_ = model.model_info().version();
model_features_.reserve(
model.model_info().supported_model_features_size() +
model.model_info().supported_host_model_features_size());
// Insert all the client model features for the owned |model_|.
for (const auto& client_model_feature :
model.model_info().supported_model_features()) {
model_features_.emplace(
proto::ClientModelFeature_Name(client_model_feature));
}
// Insert all the host model features for the owned |model_|.
for (const auto& host_model_feature :
model.model_info().supported_host_model_features()) {
model_features_.emplace(host_model_feature);
}
}
RemoteDecisionTreePredictor::~RemoteDecisionTreePredictor() = default;
machine_learning::mojom::DecisionTreePredictorProxy*
RemoteDecisionTreePredictor::Get() const {
if (!remote_)
return nullptr;
return remote_.get();
}
bool RemoteDecisionTreePredictor::IsConnected() const {
return remote_.is_connected();
}
void RemoteDecisionTreePredictor::FlushForTesting() {
remote_.FlushForTesting();
}
mojo::PendingReceiver<machine_learning::mojom::DecisionTreePredictor>
RemoteDecisionTreePredictor::BindNewPipeAndPassReceiver() {
return remote_.BindNewPipeAndPassReceiver();
}
const base::flat_set<std::string>& RemoteDecisionTreePredictor::model_features()
const {
return model_features_;
}
int64_t RemoteDecisionTreePredictor::version() const {
return version_;
}
} // namespace optimization_guide
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_REMOTE_DECISION_TREE_PREDICTOR_H_
#define CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_REMOTE_DECISION_TREE_PREDICTOR_H_
#include <string>
#include "base/containers/flat_set.h"
#include "chrome/services/machine_learning/public/mojom/decision_tree.mojom.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "mojo/public/cpp/bindings/remote.h"
namespace optimization_guide {
// Holds a Mojo remote handle connected to a DecisionTreePredictor instance in
// the ML Service, together with model information necessary for building
// feature maps.
class RemoteDecisionTreePredictor {
public:
// Initializes an unbound remote handle and saves necessary information from
// |model|. |model| must be a valid model.
explicit RemoteDecisionTreePredictor(const proto::PredictionModel& model);
~RemoteDecisionTreePredictor();
RemoteDecisionTreePredictor(const RemoteDecisionTreePredictor&) = delete;
RemoteDecisionTreePredictor& operator=(const RemoteDecisionTreePredictor&) =
delete;
// Exposes access to callable interface methods directed at the |remote_|'s
// receiver. Returns nullptr if |remote_| is unbound.
machine_learning::mojom::DecisionTreePredictorProxy* Get() const;
// Whether |remote_| is connected.
bool IsConnected() const;
// Flushes |remote_| for testing purpose.
void FlushForTesting();
// Calls the |BindNewPipeAndPassReceiver| method of the |remote_|. Must only
// be called on a bound |remote_|.
mojo::PendingReceiver<machine_learning::mojom::DecisionTreePredictor>
BindNewPipeAndPassReceiver();
// A set of model features bound to the predictor handle.
const base::flat_set<std::string>& model_features() const;
// Version of the model bound to the predictor handle.
int64_t version() const;
private:
mojo::Remote<machine_learning::mojom::DecisionTreePredictor> remote_;
base::flat_set<std::string> model_features_;
int64_t version_;
};
} // namespace optimization_guide
#endif // CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_REMOTE_DECISION_TREE_PREDICTOR_H_
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/remote_decision_tree_predictor.h"
#include <memory>
#include "base/test/task_environment.h"
#include "components/optimization_guide/optimization_guide_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace optimization_guide {
namespace {
const double kThreshold = 1.0;
const double kWeight = 1.0;
const double kModelValueDiff = 0.2;
const int64_t kVersion = 1;
std::unique_ptr<proto::PredictionModel> GetValidPredictionModel() {
// This model will return true upon evaluation.
auto model = GetSingleLeafDecisionTreePredictionModel(
kThreshold, kWeight, (kThreshold + kModelValueDiff) / kWeight);
model->mutable_model_info()->set_version(kVersion);
model->mutable_model_info()->add_supported_model_types(
optimization_guide::proto::MODEL_TYPE_DECISION_TREE);
model->mutable_model_info()->set_optimization_target(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
return model;
}
} // namespace
TEST(RemoteDecisionTreePredictorTest, Initialization) {
base::test::SingleThreadTaskEnvironment task_environment;
auto model = GetValidPredictionModel();
RemoteDecisionTreePredictor predictor(*model);
EXPECT_EQ(kVersion, predictor.version());
EXPECT_TRUE(predictor.model_features().empty());
EXPECT_FALSE(predictor.Get());
}
TEST(RemoteDecisionTreePredictorTest, BindPredictorToReceiver) {
base::test::SingleThreadTaskEnvironment task_environment;
auto model = GetValidPredictionModel();
RemoteDecisionTreePredictor predictor(*model);
auto pending_receiver = predictor.BindNewPipeAndPassReceiver();
EXPECT_TRUE(predictor.Get());
EXPECT_TRUE(pending_receiver);
EXPECT_TRUE(predictor.IsConnected());
pending_receiver.reset();
predictor.FlushForTesting();
EXPECT_TRUE(predictor.Get());
EXPECT_FALSE(predictor.IsConnected());
}
} // namespace optimization_guide
...@@ -3518,7 +3518,6 @@ test("unit_tests") { ...@@ -3518,7 +3518,6 @@ test("unit_tests") {
"../browser/optimization_guide/prediction/prediction_manager_unittest.cc", "../browser/optimization_guide/prediction/prediction_manager_unittest.cc",
"../browser/optimization_guide/prediction/prediction_model_download_manager_unittest.cc", "../browser/optimization_guide/prediction/prediction_model_download_manager_unittest.cc",
"../browser/optimization_guide/prediction/prediction_model_fetcher_unittest.cc", "../browser/optimization_guide/prediction/prediction_model_fetcher_unittest.cc",
"../browser/optimization_guide/prediction/remote_decision_tree_predictor_unittest.cc",
"../browser/page_load_metrics/metrics_web_contents_observer_unittest.cc", "../browser/page_load_metrics/metrics_web_contents_observer_unittest.cc",
"../browser/page_load_metrics/observers/aborts_page_load_metrics_observer_unittest.cc", "../browser/page_load_metrics/observers/aborts_page_load_metrics_observer_unittest.cc",
"../browser/page_load_metrics/observers/ad_metrics/ads_page_load_metrics_observer_unittest.cc", "../browser/page_load_metrics/observers/ad_metrics/ads_page_load_metrics_observer_unittest.cc",
......
...@@ -54,11 +54,6 @@ const base::Feature kRemoteOptimizationGuideFetchingAnonymousDataConsent{ ...@@ -54,11 +54,6 @@ const base::Feature kRemoteOptimizationGuideFetchingAnonymousDataConsent{
const base::Feature kOptimizationTargetPrediction{ const base::Feature kOptimizationTargetPrediction{
"OptimizationTargetPrediction", base::FEATURE_ENABLED_BY_DEFAULT}; "OptimizationTargetPrediction", base::FEATURE_ENABLED_BY_DEFAULT};
// Enables out-of-service evaluation of prediction models via the ML Service.
const base::Feature kOptimizationTargetPredictionUsingMLService{
"OptimizationGuidePredictionUsingMLService",
base::FEATURE_DISABLED_BY_DEFAULT};
// Enables the downloading of models. // Enables the downloading of models.
const base::Feature kOptimizationGuideModelDownloading{ const base::Feature kOptimizationGuideModelDownloading{
"OptimizationGuideModelDownloading", base::FEATURE_DISABLED_BY_DEFAULT}; "OptimizationGuideModelDownloading", base::FEATURE_DISABLED_BY_DEFAULT};
...@@ -306,11 +301,6 @@ base::flat_set<uint32_t> FieldTrialNameHashesAllowedForFetch() { ...@@ -306,11 +301,6 @@ base::flat_set<uint32_t> FieldTrialNameHashesAllowedForFetch() {
return allowed_field_trial_name_hashes; return allowed_field_trial_name_hashes;
} }
bool ShouldUseMLServiceForPrediction() {
return base::FeatureList::IsEnabled(
kOptimizationTargetPredictionUsingMLService);
}
bool IsModelDownloadingEnabled() { bool IsModelDownloadingEnabled() {
return base::FeatureList::IsEnabled(kOptimizationGuideModelDownloading); return base::FeatureList::IsEnabled(kOptimizationGuideModelDownloading);
} }
......
...@@ -24,7 +24,6 @@ extern const base::Feature kOptimizationHintsFieldTrials; ...@@ -24,7 +24,6 @@ extern const base::Feature kOptimizationHintsFieldTrials;
extern const base::Feature kRemoteOptimizationGuideFetching; extern const base::Feature kRemoteOptimizationGuideFetching;
extern const base::Feature kRemoteOptimizationGuideFetchingAnonymousDataConsent; extern const base::Feature kRemoteOptimizationGuideFetchingAnonymousDataConsent;
extern const base::Feature kOptimizationTargetPrediction; extern const base::Feature kOptimizationTargetPrediction;
extern const base::Feature kOptimizationTargetPredictionUsingMLService;
extern const base::Feature kOptimizationGuideModelDownloading; extern const base::Feature kOptimizationGuideModelDownloading;
// The maximum number of hosts that can be stored in the // The maximum number of hosts that can be stored in the
...@@ -156,9 +155,6 @@ base::flat_set<std::string> ExternalAppPackageNamesApprovedForFetch(); ...@@ -156,9 +155,6 @@ base::flat_set<std::string> ExternalAppPackageNamesApprovedForFetch();
// specified field trials. // specified field trials.
base::flat_set<uint32_t> FieldTrialNameHashesAllowedForFetch(); base::flat_set<uint32_t> FieldTrialNameHashesAllowedForFetch();
// Whether out-of-process model evaluation via the ML Service is enabled.
bool ShouldUseMLServiceForPrediction();
// Whether the ability to download models is enabled. // Whether the ability to download models is enabled.
bool IsModelDownloadingEnabled(); bool IsModelDownloadingEnabled();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment