Commit 2be78178 authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Add an async ShouldTargetNavigation that allows for clients to inject client model features

Bug: 1099371
Change-Id: I3fe8f12e0ad827b180efec3963bb59999c2c2b75
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2288632
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#786577}
parent a9d6f586
......@@ -174,8 +174,9 @@ OptimizationGuideKeyedService::ShouldTargetNavigation(
}
optimization_guide::OptimizationTargetDecision optimization_target_decision =
prediction_manager_->ShouldTargetNavigation(navigation_handle,
optimization_target);
prediction_manager_->ShouldTargetNavigation(
navigation_handle, optimization_target,
/*override_client_model_feature_values=*/{});
base::UmaHistogramExactLinear(
"OptimizationGuide.TargetDecision." +
......@@ -187,6 +188,36 @@ OptimizationGuideKeyedService::ShouldTargetNavigation(
optimization_target_decision);
}
void OptimizationGuideKeyedService::ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
optimization_guide::proto::OptimizationTarget optimization_target,
const base::flat_map<optimization_guide::proto::ClientModelFeature, float>&
client_model_feature_values,
optimization_guide::OptimizationGuideTargetDecisionCallback callback) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
DCHECK(navigation_handle->IsInMainFrame());
if (!prediction_manager_) {
// We are not initialized yet, so just return unknown.
std::move(callback).Run(
optimization_guide::OptimizationGuideDecision::kUnknown);
return;
}
optimization_guide::OptimizationTargetDecision optimization_target_decision =
prediction_manager_->ShouldTargetNavigation(
navigation_handle, optimization_target, client_model_feature_values);
base::UmaHistogramExactLinear(
"OptimizationGuide.TargetDecision." +
GetStringNameForOptimizationTarget(optimization_target),
static_cast<int>(optimization_target_decision),
static_cast<int>(
optimization_guide::OptimizationTargetDecision::kMaxValue));
std::move(callback).Run(
GetOptimizationGuideDecisionFromOptimizationTargetDecision(
optimization_target_decision));
}
void OptimizationGuideKeyedService::RegisterOptimizationTypes(
const std::vector<optimization_guide::proto::OptimizationType>&
optimization_types) {
......
......@@ -58,6 +58,13 @@ class OptimizationGuideKeyedService
content::NavigationHandle* navigation_handle,
optimization_guide::proto::OptimizationTarget optimization_target)
override;
void ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
optimization_guide::proto::OptimizationTarget optimization_target,
const base::flat_map<optimization_guide::proto::ClientModelFeature,
float>& client_model_feature_values,
optimization_guide::OptimizationGuideTargetDecisionCallback callback)
override;
void RegisterOptimizationTypes(
const std::vector<optimization_guide::proto::OptimizationType>&
optimization_types) override;
......
......@@ -266,13 +266,20 @@ void PredictionManager::RegisterOptimizationTargets(
base::Optional<float> PredictionManager::GetValueForClientFeature(
const std::string& model_feature,
content::NavigationHandle* navigation_handle) const {
content::NavigationHandle* navigation_handle,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) const {
SEQUENCE_CHECKER(sequence_checker_);
proto::ClientModelFeature client_model_feature;
if (!proto::ClientModelFeature_Parse(model_feature, &client_model_feature))
return base::nullopt;
auto cmf_value_it =
override_client_model_feature_values.find(client_model_feature);
if (cmf_value_it != override_client_model_feature_values.end())
return cmf_value_it->second;
base::Optional<float> value;
switch (client_model_feature) {
......@@ -348,7 +355,9 @@ base::Optional<float> PredictionManager::GetValueForClientFeature(
base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
content::NavigationHandle* navigation_handle,
const base::flat_set<std::string>& model_features) {
const base::flat_set<std::string>& model_features,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) {
SEQUENCE_CHECKER(sequence_checker_);
base::flat_map<std::string, float> feature_map;
if (model_features.size() == 0)
......@@ -370,8 +379,8 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
// created for it. This ensures that the prediction model will have values for
// every feature that it requires to be evaluated.
for (const auto& model_feature : model_features) {
base::Optional<float> value =
GetValueForClientFeature(model_feature, navigation_handle);
base::Optional<float> value = GetValueForClientFeature(
model_feature, navigation_handle, override_client_model_feature_values);
if (value) {
feature_map[model_feature] = *value;
continue;
......@@ -388,7 +397,9 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
OptimizationTargetDecision PredictionManager::ShouldTargetNavigation(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) {
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) {
SEQUENCE_CHECKER(sequence_checker_);
DCHECK(navigation_handle->GetURL().SchemeIsHTTPOrHTTPS());
......@@ -434,7 +445,8 @@ OptimizationTargetDecision PredictionManager::ShouldTargetNavigation(
PredictionModel* prediction_model = it->second.get();
base::flat_map<std::string, float> feature_map =
BuildFeatureMap(navigation_handle, prediction_model->GetModelFeatures());
BuildFeatureMap(navigation_handle, prediction_model->GetModelFeatures(),
override_client_model_feature_values);
base::TimeTicks model_evaluation_start_time = base::TimeTicks::Now();
double prediction_score = 0.0;
......
......@@ -89,9 +89,18 @@ class PredictionManager
// |optimization_target|. Return kUnknown if a PredictionModel for the
// optimization target is not registered and kModelNotAvailableOnClient if the
// if model for the optimization target is not currently on the client.
// If the model for the optimization target requires a client model feature
// that is present in |override_client_model_feature_values|, the value from
// |override_client_model_feature_values| will be used. The client will
// calculate the value for any required client model features not present in
// |override_client_model_feature_values| and inject any host model features
// it received from the server and send that complete feature map for
// evaluation.
OptimizationTargetDecision ShouldTargetNavigation(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target);
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values);
// Update |session_fcp_| and |previous_fcp_| with |fcp|.
void UpdateFCPSessionStatistics(base::TimeDelta fcp);
......@@ -184,14 +193,21 @@ class PredictionManager
// based on if host model features were used.
base::flat_map<std::string, float> BuildFeatureMap(
content::NavigationHandle* navigation_handle,
const base::flat_set<std::string>& model_features);
const base::flat_set<std::string>& model_features,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values);
// Calculate and return the current value for the client feature specified
// by |model_feature|. Return nullopt if the client does not support the
// by |model_feature|. If |model_feature| is in
// |override_client_model_feature_values|, the value from
// |client_model_feature_values| will be used. Otherwise, the client will
// calculate the value or return nullopt if the client does not support the
// model feature.
base::Optional<float> GetValueForClientFeature(
const std::string& model_feature,
content::NavigationHandle* navigation_handle) const;
content::NavigationHandle* navigation_handle,
const base::flat_map<proto::ClientModelFeature, float>&
override_client_model_feature_values) const;
// Called to make a request to fetch models and host model features from the
// remote Optimization Guide Service. Used to fetch models for the registered
......
......@@ -113,15 +113,13 @@ GetValidEnsemblePredictionModel() {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>();
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
optimization_guide::proto::Ensemble ensemble =
optimization_guide::proto::Ensemble();
*ensemble.add_members()->mutable_submodel() =
*GetValidDecisionTreePredictionModel()->mutable_model();
*ensemble.add_members()->mutable_submodel() =
*GetValidDecisionTreePredictionModel()->mutable_model();
*prediction_model->mutable_model()->mutable_ensemble() = ensemble;
optimization_guide::proto::Model valid_decision_tree_model =
GetValidDecisionTreePredictionModel()->model();
optimization_guide::proto::Ensemble* ensemble =
prediction_model->mutable_model()->mutable_ensemble();
*ensemble->add_members()->mutable_submodel() = valid_decision_tree_model;
*ensemble->add_members()->mutable_submodel() = valid_decision_tree_model;
return prediction_model;
}
......@@ -136,6 +134,7 @@ CreatePredictionModel() {
model_info->add_supported_model_features(
optimization_guide::proto::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
model_info->set_optimization_target(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
model_info->add_supported_model_types(
......@@ -158,7 +157,7 @@ BuildGetModelsResponse(
host_model_features->set_host(host);
optimization_guide::proto::ModelFeature* model_feature =
host_model_features->add_model_features();
model_feature->set_feature_name("host_feat1");
model_feature->set_feature_name("agg1");
model_feature->set_double_value(2.0);
}
......@@ -197,10 +196,33 @@ class OptimizationGuideConsumerWebContentsObserver
OptimizationGuideKeyedService* service =
OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(web_contents()->GetBrowserContext()));
service->ShouldTargetNavigation(
last_should_target_decision_ = service->ShouldTargetNavigation(
navigation_handle,
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
if (callback_) {
// Intentionally do not set client model feature values to override to
// make sure decisions are the same in both sync and async variants.
service->ShouldTargetNavigationAsync(
navigation_handle,
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, {},
std::move(callback_));
}
}
void set_callback(
optimization_guide::OptimizationGuideTargetDecisionCallback callback) {
callback_ = std::move(callback);
}
optimization_guide::OptimizationGuideDecision last_should_target_decision()
const {
return last_should_target_decision_;
}
private:
optimization_guide::OptimizationGuideTargetDecisionCallback callback_;
optimization_guide::OptimizationGuideDecision last_should_target_decision_ =
optimization_guide::OptimizationGuideDecision::kUnknown;
};
} // namespace
......@@ -285,6 +307,19 @@ class PredictionManagerBrowserTest : public InProcessBrowserTest {
{optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD});
}
// Sets the callback on the consumer of the OptimizationGuideKeyedService. If
// set, this will call the async version of ShouldTargetNavigation.
void SetCallbackOnConsumer(
optimization_guide::OptimizationGuideTargetDecisionCallback callback) {
ASSERT_TRUE(consumer_);
consumer_->set_callback(std::move(callback));
}
OptimizationGuideConsumerWebContentsObserver* consumer() {
return consumer_.get();
}
PredictionManager* GetPredictionManager() {
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(
......@@ -581,4 +616,40 @@ IN_PROC_BROWSER_TEST_F(PredictionManagerBrowserSameOriginTest,
"OptimizationGuide.PredictionManager.IsSameOrigin", true, 1);
}
IN_PROC_BROWSER_TEST_F(
PredictionManagerBrowserSameOriginTest,
DISABLE_ON_WIN_MAC_CHROMEOS(
ShouldTargetNavigationAsyncAndSyncDecisionAreTheSameWithoutOverrides)) {
base::HistogramTester histogram_tester;
RegisterWithKeyedService();
// Wait until histograms have been updated before performing checks for
// correct behavior based on the response.
RetryForHistogramUntilCountReached(
&histogram_tester,
"OptimizationGuide.PredictionModelFetcher.GetModelsResponse.Status", 1);
RetryForHistogramUntilCountReached(
&histogram_tester,
"OptimizationGuide.PredictionManager.HostModelFeaturesStored", 1);
RetryForHistogramUntilCountReached(
&histogram_tester,
"OptimizationGuide.PredictionManager.PredictionModelsStored", 1);
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
SetCallbackOnConsumer(base::BindOnce(
[](base::RunLoop* run_loop,
OptimizationGuideConsumerWebContentsObserver* consumer,
optimization_guide::OptimizationGuideDecision decision) {
EXPECT_EQ(consumer->last_should_target_decision(), decision);
run_loop->Quit();
},
run_loop.get(), consumer()));
ui_test_utils::NavigateToURL(browser(), https_url_with_content());
run_loop->Run();
}
} // namespace optimization_guide
......@@ -8,6 +8,7 @@
#include <vector>
#include "base/callback_forward.h"
#include "base/containers/flat_map.h"
#include "base/optional.h"
#include "components/optimization_guide/optimization_metadata.h"
#include "components/optimization_guide/proto/hints.pb.h"
......@@ -38,6 +39,9 @@ enum class OptimizationGuideDecision {
kMaxValue = kFalse,
};
using OptimizationGuideTargetDecisionCallback =
base::OnceCallback<void(optimization_guide::OptimizationGuideDecision)>;
using OptimizationGuideDecisionCallback =
base::OnceCallback<void(optimization_guide::OptimizationGuideDecision,
const optimization_guide::OptimizationMetadata&)>;
......@@ -56,6 +60,20 @@ class OptimizationGuideDecider {
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) = 0;
// Invokes |callback| with the decision for whether the current browser
// conditions, as expressed by |client_model_feature_values| and the
// |navigation_handle|, match |optimization_target|.
//
// Values provided in |client_model_feature_values| will be used over any
// values for features required by the model that may be calculated by the
// Optimization Guide.
virtual void ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
client_model_feature_values,
OptimizationGuideTargetDecisionCallback callback) = 0;
// Registers the optimization types that intend to be queried during the
// session. It is expected for this to be called after the browser has been
// initialized.
......
......@@ -20,6 +20,15 @@ OptimizationGuideDecision TestOptimizationGuideDecider::ShouldTargetNavigation(
return OptimizationGuideDecision::kFalse;
}
void TestOptimizationGuideDecider::ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
client_model_feature_values,
OptimizationGuideTargetDecisionCallback callback) {
std::move(callback).Run(OptimizationGuideDecision::kFalse);
}
void TestOptimizationGuideDecider::RegisterOptimizationTypes(
const std::vector<proto::OptimizationType>& optimization_types) {}
......
......@@ -26,6 +26,12 @@ class TestOptimizationGuideDecider : public OptimizationGuideDecider {
OptimizationGuideDecision ShouldTargetNavigation(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) override;
void ShouldTargetNavigationAsync(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target,
const base::flat_map<proto::ClientModelFeature, float>&
client_model_feature_values,
OptimizationGuideTargetDecisionCallback callback) override;
void RegisterOptimizationTypes(
const std::vector<proto::OptimizationType>& optimization_types) override;
void CanApplyOptimizationAsync(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment