Commit 2be78178 authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Add an async ShouldTargetNavigation that allows for clients to inject client model features

Bug: 1099371
Change-Id: I3fe8f12e0ad827b180efec3963bb59999c2c2b75
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2288632
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#786577}
parent a9d6f586
...@@ -174,8 +174,9 @@ OptimizationGuideKeyedService::ShouldTargetNavigation( ...@@ -174,8 +174,9 @@ OptimizationGuideKeyedService::ShouldTargetNavigation(
} }
optimization_guide::OptimizationTargetDecision optimization_target_decision = optimization_guide::OptimizationTargetDecision optimization_target_decision =
prediction_manager_->ShouldTargetNavigation(navigation_handle, prediction_manager_->ShouldTargetNavigation(
optimization_target); navigation_handle, optimization_target,
/*override_client_model_feature_values=*/{});
base::UmaHistogramExactLinear( base::UmaHistogramExactLinear(
"OptimizationGuide.TargetDecision." + "OptimizationGuide.TargetDecision." +
...@@ -187,6 +188,36 @@ OptimizationGuideKeyedService::ShouldTargetNavigation( ...@@ -187,6 +188,36 @@ OptimizationGuideKeyedService::ShouldTargetNavigation(
optimization_target_decision); optimization_target_decision);
} }
void OptimizationGuideKeyedService::ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
optimization_guide::proto::OptimizationTarget optimization_target,
const base::flat_map<optimization_guide::proto::ClientModelFeature, float>&
client_model_feature_values,
optimization_guide::OptimizationGuideTargetDecisionCallback callback) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
DCHECK(navigation_handle->IsInMainFrame());
if (!prediction_manager_) {
// We are not initialized yet, so just return unknown.
std::move(callback).Run(
optimization_guide::OptimizationGuideDecision::kUnknown);
return;
}
optimization_guide::OptimizationTargetDecision optimization_target_decision =
prediction_manager_->ShouldTargetNavigation(
navigation_handle, optimization_target, client_model_feature_values);
base::UmaHistogramExactLinear(
"OptimizationGuide.TargetDecision." +
GetStringNameForOptimizationTarget(optimization_target),
static_cast<int>(optimization_target_decision),
static_cast<int>(
optimization_guide::OptimizationTargetDecision::kMaxValue));
std::move(callback).Run(
GetOptimizationGuideDecisionFromOptimizationTargetDecision(
optimization_target_decision));
}
void OptimizationGuideKeyedService::RegisterOptimizationTypes( void OptimizationGuideKeyedService::RegisterOptimizationTypes(
const std::vector<optimization_guide::proto::OptimizationType>& const std::vector<optimization_guide::proto::OptimizationType>&
optimization_types) { optimization_types) {
......
...@@ -58,6 +58,13 @@ class OptimizationGuideKeyedService ...@@ -58,6 +58,13 @@ class OptimizationGuideKeyedService
content::NavigationHandle* navigation_handle, content::NavigationHandle* navigation_handle,
optimization_guide::proto::OptimizationTarget optimization_target) optimization_guide::proto::OptimizationTarget optimization_target)
override; override;
void ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
optimization_guide::proto::OptimizationTarget optimization_target,
const base::flat_map<optimization_guide::proto::ClientModelFeature,
float>& client_model_feature_values,
optimization_guide::OptimizationGuideTargetDecisionCallback callback)
override;
void RegisterOptimizationTypes( void RegisterOptimizationTypes(
const std::vector<optimization_guide::proto::OptimizationType>& const std::vector<optimization_guide::proto::OptimizationType>&
optimization_types) override; optimization_types) override;
......
...@@ -266,13 +266,20 @@ void PredictionManager::RegisterOptimizationTargets( ...@@ -266,13 +266,20 @@ void PredictionManager::RegisterOptimizationTargets(
base::Optional<float> PredictionManager::GetValueForClientFeature( base::Optional<float> PredictionManager::GetValueForClientFeature(
const std::string& model_feature, const std::string& model_feature,
content::NavigationHandle* navigation_handle) const { content::NavigationHandle* navigation_handle,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) const {
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
proto::ClientModelFeature client_model_feature; proto::ClientModelFeature client_model_feature;
if (!proto::ClientModelFeature_Parse(model_feature, &client_model_feature)) if (!proto::ClientModelFeature_Parse(model_feature, &client_model_feature))
return base::nullopt; return base::nullopt;
auto cmf_value_it =
override_client_model_feature_values.find(client_model_feature);
if (cmf_value_it != override_client_model_feature_values.end())
return cmf_value_it->second;
base::Optional<float> value; base::Optional<float> value;
switch (client_model_feature) { switch (client_model_feature) {
...@@ -348,7 +355,9 @@ base::Optional<float> PredictionManager::GetValueForClientFeature( ...@@ -348,7 +355,9 @@ base::Optional<float> PredictionManager::GetValueForClientFeature(
base::flat_map<std::string, float> PredictionManager::BuildFeatureMap( base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
content::NavigationHandle* navigation_handle, content::NavigationHandle* navigation_handle,
const base::flat_set<std::string>& model_features) { const base::flat_set<std::string>& model_features,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) {
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
base::flat_map<std::string, float> feature_map; base::flat_map<std::string, float> feature_map;
if (model_features.size() == 0) if (model_features.size() == 0)
...@@ -370,8 +379,8 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap( ...@@ -370,8 +379,8 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
// created for it. This ensures that the prediction model will have values for // created for it. This ensures that the prediction model will have values for
// every feature that it requires to be evaluated. // every feature that it requires to be evaluated.
for (const auto& model_feature : model_features) { for (const auto& model_feature : model_features) {
base::Optional<float> value = base::Optional<float> value = GetValueForClientFeature(
GetValueForClientFeature(model_feature, navigation_handle); model_feature, navigation_handle, override_client_model_feature_values);
if (value) { if (value) {
feature_map[model_feature] = *value; feature_map[model_feature] = *value;
continue; continue;
...@@ -388,7 +397,9 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap( ...@@ -388,7 +397,9 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
OptimizationTargetDecision PredictionManager::ShouldTargetNavigation( OptimizationTargetDecision PredictionManager::ShouldTargetNavigation(
content::NavigationHandle* navigation_handle, content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) { proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) {
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
DCHECK(navigation_handle->GetURL().SchemeIsHTTPOrHTTPS()); DCHECK(navigation_handle->GetURL().SchemeIsHTTPOrHTTPS());
...@@ -434,7 +445,8 @@ OptimizationTargetDecision PredictionManager::ShouldTargetNavigation( ...@@ -434,7 +445,8 @@ OptimizationTargetDecision PredictionManager::ShouldTargetNavigation(
PredictionModel* prediction_model = it->second.get(); PredictionModel* prediction_model = it->second.get();
base::flat_map<std::string, float> feature_map = base::flat_map<std::string, float> feature_map =
BuildFeatureMap(navigation_handle, prediction_model->GetModelFeatures()); BuildFeatureMap(navigation_handle, prediction_model->GetModelFeatures(),
override_client_model_feature_values);
base::TimeTicks model_evaluation_start_time = base::TimeTicks::Now(); base::TimeTicks model_evaluation_start_time = base::TimeTicks::Now();
double prediction_score = 0.0; double prediction_score = 0.0;
......
...@@ -89,9 +89,18 @@ class PredictionManager ...@@ -89,9 +89,18 @@ class PredictionManager
// |optimization_target|. Return kUnknown if a PredictionModel for the // |optimization_target|. Return kUnknown if a PredictionModel for the
// optimization target is not registered and kModelNotAvailableOnClient if the // optimization target is not registered and kModelNotAvailableOnClient if the
// if model for the optimization target is not currently on the client. // if model for the optimization target is not currently on the client.
// If the model for the optimization target requires a client model feature
// that is present in |override_client_model_feature_values|, the value from
// |override_client_model_feature_values| will be used. The client will
// calculate the value for any required client model features not present in
// |override_client_model_feature_values| and inject any host model features
// it received from the server and send that complete feature map for
// evaluation.
OptimizationTargetDecision ShouldTargetNavigation( OptimizationTargetDecision ShouldTargetNavigation(
content::NavigationHandle* navigation_handle, content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target); proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values);
// Update |session_fcp_| and |previous_fcp_| with |fcp|. // Update |session_fcp_| and |previous_fcp_| with |fcp|.
void UpdateFCPSessionStatistics(base::TimeDelta fcp); void UpdateFCPSessionStatistics(base::TimeDelta fcp);
...@@ -184,14 +193,21 @@ class PredictionManager ...@@ -184,14 +193,21 @@ class PredictionManager
// based on if host model features were used. // based on if host model features were used.
base::flat_map<std::string, float> BuildFeatureMap( base::flat_map<std::string, float> BuildFeatureMap(
content::NavigationHandle* navigation_handle, content::NavigationHandle* navigation_handle,
const base::flat_set<std::string>& model_features); const base::flat_set<std::string>& model_features,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values);
// Calculate and return the current value for the client feature specified // Calculate and return the current value for the client feature specified
// by |model_feature|. Return nullopt if the client does not support the // by |model_feature|. If |model_feature| is in
// |override_client_model_feature_values|, the value from
// |client_model_feature_values| will be used. Otherwise, the client will
// calculate the value or return nullopt if the client does not support the
// model feature. // model feature.
base::Optional<float> GetValueForClientFeature( base::Optional<float> GetValueForClientFeature(
const std::string& model_feature, const std::string& model_feature,
content::NavigationHandle* navigation_handle) const; content::NavigationHandle* navigation_handle,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) const;
// Called to make a request to fetch models and host model features from the // Called to make a request to fetch models and host model features from the
// remote Optimization Guide Service. Used to fetch models for the registered // remote Optimization Guide Service. Used to fetch models for the registered
......
...@@ -113,15 +113,13 @@ GetValidEnsemblePredictionModel() { ...@@ -113,15 +113,13 @@ GetValidEnsemblePredictionModel() {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>(); std::make_unique<optimization_guide::proto::PredictionModel>();
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0); prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
optimization_guide::proto::Ensemble ensemble =
optimization_guide::proto::Ensemble();
*ensemble.add_members()->mutable_submodel() =
*GetValidDecisionTreePredictionModel()->mutable_model();
*ensemble.add_members()->mutable_submodel() = optimization_guide::proto::Model valid_decision_tree_model =
*GetValidDecisionTreePredictionModel()->mutable_model(); GetValidDecisionTreePredictionModel()->model();
optimization_guide::proto::Ensemble* ensemble =
*prediction_model->mutable_model()->mutable_ensemble() = ensemble; prediction_model->mutable_model()->mutable_ensemble();
*ensemble->add_members()->mutable_submodel() = valid_decision_tree_model;
*ensemble->add_members()->mutable_submodel() = valid_decision_tree_model;
return prediction_model; return prediction_model;
} }
...@@ -136,6 +134,7 @@ CreatePredictionModel() { ...@@ -136,6 +134,7 @@ CreatePredictionModel() {
model_info->add_supported_model_features( model_info->add_supported_model_features(
optimization_guide::proto:: optimization_guide::proto::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
model_info->set_optimization_target( model_info->set_optimization_target(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
model_info->add_supported_model_types( model_info->add_supported_model_types(
...@@ -158,7 +157,7 @@ BuildGetModelsResponse( ...@@ -158,7 +157,7 @@ BuildGetModelsResponse(
host_model_features->set_host(host); host_model_features->set_host(host);
optimization_guide::proto::ModelFeature* model_feature = optimization_guide::proto::ModelFeature* model_feature =
host_model_features->add_model_features(); host_model_features->add_model_features();
model_feature->set_feature_name("host_feat1"); model_feature->set_feature_name("agg1");
model_feature->set_double_value(2.0); model_feature->set_double_value(2.0);
} }
...@@ -197,10 +196,33 @@ class OptimizationGuideConsumerWebContentsObserver ...@@ -197,10 +196,33 @@ class OptimizationGuideConsumerWebContentsObserver
OptimizationGuideKeyedService* service = OptimizationGuideKeyedService* service =
OptimizationGuideKeyedServiceFactory::GetForProfile( OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(web_contents()->GetBrowserContext())); Profile::FromBrowserContext(web_contents()->GetBrowserContext()));
service->ShouldTargetNavigation( last_should_target_decision_ = service->ShouldTargetNavigation(
navigation_handle, navigation_handle,
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
if (callback_) {
// Intentionally do not set client model feature values to override to
// make sure decisions are the same in both sync and async variants.
service->ShouldTargetNavigationAsync(
navigation_handle,
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, {},
std::move(callback_));
}
} }
void set_callback(
optimization_guide::OptimizationGuideTargetDecisionCallback callback) {
callback_ = std::move(callback);
}
optimization_guide::OptimizationGuideDecision last_should_target_decision()
const {
return last_should_target_decision_;
}
private:
optimization_guide::OptimizationGuideTargetDecisionCallback callback_;
optimization_guide::OptimizationGuideDecision last_should_target_decision_ =
optimization_guide::OptimizationGuideDecision::kUnknown;
}; };
} // namespace } // namespace
...@@ -285,6 +307,19 @@ class PredictionManagerBrowserTest : public InProcessBrowserTest { ...@@ -285,6 +307,19 @@ class PredictionManagerBrowserTest : public InProcessBrowserTest {
{optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD}); {optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD});
} }
// Sets the callback on the consumer of the OptimizationGuideKeyedService. If
// set, this will call the async version of ShouldTargetNavigation.
void SetCallbackOnConsumer(
optimization_guide::OptimizationGuideTargetDecisionCallback callback) {
ASSERT_TRUE(consumer_);
consumer_->set_callback(std::move(callback));
}
OptimizationGuideConsumerWebContentsObserver* consumer() {
return consumer_.get();
}
PredictionManager* GetPredictionManager() { PredictionManager* GetPredictionManager() {
OptimizationGuideKeyedService* optimization_guide_keyed_service = OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile( OptimizationGuideKeyedServiceFactory::GetForProfile(
...@@ -581,4 +616,40 @@ IN_PROC_BROWSER_TEST_F(PredictionManagerBrowserSameOriginTest, ...@@ -581,4 +616,40 @@ IN_PROC_BROWSER_TEST_F(PredictionManagerBrowserSameOriginTest,
"OptimizationGuide.PredictionManager.IsSameOrigin", true, 1); "OptimizationGuide.PredictionManager.IsSameOrigin", true, 1);
} }
IN_PROC_BROWSER_TEST_F(
PredictionManagerBrowserSameOriginTest,
DISABLE_ON_WIN_MAC_CHROMEOS(
ShouldTargetNavigationAsyncAndSyncDecisionAreTheSameWithoutOverrides)) {
base::HistogramTester histogram_tester;
RegisterWithKeyedService();
// Wait until histograms have been updated before performing checks for
// correct behavior based on the response.
RetryForHistogramUntilCountReached(
&histogram_tester,
"OptimizationGuide.PredictionModelFetcher.GetModelsResponse.Status", 1);
RetryForHistogramUntilCountReached(
&histogram_tester,
"OptimizationGuide.PredictionManager.HostModelFeaturesStored", 1);
RetryForHistogramUntilCountReached(
&histogram_tester,
"OptimizationGuide.PredictionManager.PredictionModelsStored", 1);
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
SetCallbackOnConsumer(base::BindOnce(
[](base::RunLoop* run_loop,
OptimizationGuideConsumerWebContentsObserver* consumer,
optimization_guide::OptimizationGuideDecision decision) {
EXPECT_EQ(consumer->last_should_target_decision(), decision);
run_loop->Quit();
},
run_loop.get(), consumer()));
ui_test_utils::NavigateToURL(browser(), https_url_with_content());
run_loop->Run();
}
} // namespace optimization_guide } // namespace optimization_guide
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include "base/callback_forward.h" #include "base/callback_forward.h"
#include "base/containers/flat_map.h"
#include "base/optional.h" #include "base/optional.h"
#include "components/optimization_guide/optimization_metadata.h" #include "components/optimization_guide/optimization_metadata.h"
#include "components/optimization_guide/proto/hints.pb.h" #include "components/optimization_guide/proto/hints.pb.h"
...@@ -38,6 +39,9 @@ enum class OptimizationGuideDecision { ...@@ -38,6 +39,9 @@ enum class OptimizationGuideDecision {
kMaxValue = kFalse, kMaxValue = kFalse,
}; };
using OptimizationGuideTargetDecisionCallback =
base::OnceCallback<void(optimization_guide::OptimizationGuideDecision)>;
using OptimizationGuideDecisionCallback = using OptimizationGuideDecisionCallback =
base::OnceCallback<void(optimization_guide::OptimizationGuideDecision, base::OnceCallback<void(optimization_guide::OptimizationGuideDecision,
const optimization_guide::OptimizationMetadata&)>; const optimization_guide::OptimizationMetadata&)>;
...@@ -56,6 +60,20 @@ class OptimizationGuideDecider { ...@@ -56,6 +60,20 @@ class OptimizationGuideDecider {
content::NavigationHandle* navigation_handle, content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) = 0; proto::OptimizationTarget optimization_target) = 0;
// Invokes |callback| with the decision for whether the current browser
// conditions, as expressed by |client_model_feature_values| and the
// |navigation_handle|, match |optimization_target|.
//
// Values provided in |client_model_feature_values| will be used over any
// values for features required by the model that may be calculated by the
// Optimization Guide.
virtual void ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
client_model_feature_values,
OptimizationGuideTargetDecisionCallback callback) = 0;
// Registers the optimization types that intend to be queried during the // Registers the optimization types that intend to be queried during the
// session. It is expected for this to be called after the browser has been // session. It is expected for this to be called after the browser has been
// initialized. // initialized.
......
...@@ -20,6 +20,15 @@ OptimizationGuideDecision TestOptimizationGuideDecider::ShouldTargetNavigation( ...@@ -20,6 +20,15 @@ OptimizationGuideDecision TestOptimizationGuideDecider::ShouldTargetNavigation(
return OptimizationGuideDecision::kFalse; return OptimizationGuideDecision::kFalse;
} }
void TestOptimizationGuideDecider::ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
client_model_feature_values,
OptimizationGuideTargetDecisionCallback callback) {
std::move(callback).Run(OptimizationGuideDecision::kFalse);
}
void TestOptimizationGuideDecider::RegisterOptimizationTypes( void TestOptimizationGuideDecider::RegisterOptimizationTypes(
const std::vector<proto::OptimizationType>& optimization_types) {} const std::vector<proto::OptimizationType>& optimization_types) {}
......
...@@ -26,6 +26,12 @@ class TestOptimizationGuideDecider : public OptimizationGuideDecider { ...@@ -26,6 +26,12 @@ class TestOptimizationGuideDecider : public OptimizationGuideDecider {
OptimizationGuideDecision ShouldTargetNavigation( OptimizationGuideDecision ShouldTargetNavigation(
content::NavigationHandle* navigation_handle, content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) override; proto::OptimizationTarget optimization_target) override;
void ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
client_model_feature_values,
OptimizationGuideTargetDecisionCallback callback) override;
void RegisterOptimizationTypes( void RegisterOptimizationTypes(
const std::vector<proto::OptimizationType>& optimization_types) override; const std::vector<proto::OptimizationType>& optimization_types) override;
void CanApplyOptimizationAsync( void CanApplyOptimizationAsync(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment