Commit 05a58a95 authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Add painful page load prediction ukm for model version and prediction score

Bug: 1001194
Change-Id: Iae8a3fbf9d0e65c543deee41849c2a22db0547cb
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1906880
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: default avatarRobert Kaplow <rkaplow@chromium.org>
Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#714570}
parent d4f3bb8a
......@@ -105,21 +105,6 @@ bool CanProcessComponentVersion(PrefService* pref_service,
return true;
}
// Returns the OptimizationGuideNavigationData for |navigation_handle| if the
// OptimizationGuideWebContentsObserver is registered.
OptimizationGuideNavigationData* GetNavigationDataForNavigationHandle(
content::NavigationHandle* navigation_handle) {
OptimizationGuideWebContentsObserver*
optimization_guide_web_contents_observer =
OptimizationGuideWebContentsObserver::FromWebContents(
navigation_handle->GetWebContents());
if (!optimization_guide_web_contents_observer)
return nullptr;
return optimization_guide_web_contents_observer
->GetOrCreateOptimizationGuideNavigationData(navigation_handle);
}
// Returns the page hint for the navigation, if applicable. It will use the
// cached page hint stored in |navigation_handle| if we have already done the
// computation to find the page hint in a previous request to the hints manager.
......@@ -129,7 +114,8 @@ const optimization_guide::proto::PageHint* GetPageHintForNavigation(
content::NavigationHandle* navigation_handle,
const optimization_guide::proto::Hint* loaded_hint) {
OptimizationGuideNavigationData* navigation_data =
GetNavigationDataForNavigationHandle(navigation_handle);
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
// If we already know we had a page hint for the navigation, then just return
// that.
......@@ -588,7 +574,8 @@ void OptimizationGuideHintsManager::LoadHintForNavigation(
}
OptimizationGuideNavigationData* navigation_data =
GetNavigationDataForNavigationHandle(navigation_handle);
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
if (navigation_data) {
bool has_hint = hint_cache_->HasHint(url.host());
if (navigation_handle->HasCommitted()) {
......@@ -788,7 +775,8 @@ void OptimizationGuideHintsManager::CanApplyOptimization(
// Populate navigation data with hint information.
OptimizationGuideNavigationData* navigation_data =
GetNavigationDataForNavigationHandle(navigation_handle);
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
if (navigation_data) {
navigation_data->set_has_hint_after_commit(has_hint_in_cache);
......
......@@ -10,7 +10,6 @@
#include "chrome/browser/optimization_guide/optimization_guide_navigation_data.h"
#include "chrome/browser/optimization_guide/optimization_guide_session_statistic.h"
#include "chrome/browser/optimization_guide/optimization_guide_top_host_provider.h"
#include "chrome/browser/optimization_guide/optimization_guide_web_contents_observer.h"
#include "chrome/browser/optimization_guide/prediction/prediction_manager.h"
#include "chrome/browser/profiles/profile.h"
#include "components/leveldb_proto/public/proto_database_provider.h"
......@@ -48,18 +47,13 @@ void LogOptimizationTargetDecision(
optimization_guide::proto::OptimizationTarget optimization_target,
optimization_guide::OptimizationTargetDecision
optimization_target_decision) {
OptimizationGuideWebContentsObserver*
optimization_guide_web_contents_observer =
OptimizationGuideWebContentsObserver::FromWebContents(
navigation_handle->GetWebContents());
if (!optimization_guide_web_contents_observer)
return;
OptimizationGuideNavigationData* navigation_data =
optimization_guide_web_contents_observer
->GetOrCreateOptimizationGuideNavigationData(navigation_handle);
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
if (navigation_data) {
navigation_data->SetDecisionForOptimizationTarget(
optimization_target, optimization_target_decision);
}
}
// Logs the |optimization_type_decision| for |optimization_type| in the current
......@@ -68,18 +62,13 @@ void LogOptimizationTypeDecision(
content::NavigationHandle* navigation_handle,
optimization_guide::proto::OptimizationType optimization_type,
optimization_guide::OptimizationTypeDecision optimization_type_decision) {
OptimizationGuideWebContentsObserver*
optimization_guide_web_contents_observer =
OptimizationGuideWebContentsObserver::FromWebContents(
navigation_handle->GetWebContents());
if (!optimization_guide_web_contents_observer)
return;
OptimizationGuideNavigationData* navigation_data =
optimization_guide_web_contents_observer
->GetOrCreateOptimizationGuideNavigationData(navigation_handle);
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
if (navigation_data) {
navigation_data->SetDecisionForOptimizationType(optimization_type,
optimization_type_decision);
}
}
// Returns the OptimizationGuideDecision from |optimization_target_decision|.
......@@ -245,12 +234,6 @@ OptimizationGuideKeyedService::ShouldTargetNavigation(
if (prediction_manager_) {
optimization_target_decision = prediction_manager_->ShouldTargetNavigation(
navigation_handle, optimization_target);
if (optimization_guide::features::
ShouldOverrideOptimizationTargetDecisionForMetricsPurposes(
optimization_target)) {
optimization_target_decision = optimization_guide::
OptimizationTargetDecision::kModelPredictionHoldback;
}
} else {
DCHECK(hints_manager_);
optimization_guide::OptimizationTypeDecision
......
......@@ -817,57 +817,3 @@ IN_PROC_BROWSER_TEST_F(
kModelNotAvailableOnClient),
1);
}
class OptimizationGuideKeyedServiceModelPredictionHoldbackBrowserTest
: public OptimizationGuideKeyedServiceBrowserTest {
public:
OptimizationGuideKeyedServiceModelPredictionHoldbackBrowserTest() {
scoped_feature_list_.InitWithFeaturesAndParameters(
{base::test::ScopedFeatureList::FeatureAndParams(
optimization_guide::features::kOptimizationTargetPrediction,
{{"painful_page_load_metrics_only", "true"}}),
base::test::ScopedFeatureList::FeatureAndParams(
optimization_guide::features::kOptimizationHintsFetching, {{}})},
{});
}
~OptimizationGuideKeyedServiceModelPredictionHoldbackBrowserTest() override =
default;
void SetUpCommandLine(base::CommandLine* cmd) override {
cmd->AppendSwitchASCII(optimization_guide::switches::kFetchHintsOverride,
"whatever.com,somehost.com");
}
private:
base::test::ScopedFeatureList scoped_feature_list_;
};
IN_PROC_BROWSER_TEST_F(
OptimizationGuideKeyedServiceModelPredictionHoldbackBrowserTest,
ModelPredictionHoldbackOverridesActualTargetDecision) {
PushHintsComponentAndWaitForCompletion();
RegisterWithKeyedService();
ukm::TestAutoSetUkmRecorder ukm_recorder;
base::HistogramTester histogram_tester;
ui_test_utils::NavigateToURL(browser(), url_with_hints());
EXPECT_EQ(RetryForHistogramUntilCountReached(
histogram_tester, "OptimizationGuide.LoadedHint.Result", 1),
1);
// There should be a hint that matches this URL.
histogram_tester.ExpectUniqueSample("OptimizationGuide.LoadedHint.Result",
true, 1);
EXPECT_EQ(optimization_guide::OptimizationGuideDecision::kFalse,
last_should_target_navigation_decision());
EXPECT_EQ(optimization_guide::OptimizationGuideDecision::kTrue,
last_can_apply_optimization_decision());
EXPECT_EQ(optimization_guide::OptimizationGuideDecision::kFalse,
last_consumer_decision());
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.TargetDecision.PainfulPageLoad",
static_cast<int>(optimization_guide::OptimizationTargetDecision::
kModelPredictionHoldback),
1);
}
......@@ -8,7 +8,9 @@
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/strings/stringprintf.h"
#include "chrome/browser/optimization_guide/optimization_guide_web_contents_observer.h"
#include "components/optimization_guide/hints_processing_util.h"
#include "content/public/browser/navigation_handle.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "services/metrics/public/cpp/ukm_recorder.h"
#include "services/metrics/public/cpp/ukm_source.h"
......@@ -45,6 +47,10 @@ OptimizationGuideNavigationData::OptimizationGuideNavigationData(
serialized_hint_version_string_(other.serialized_hint_version_string_),
optimization_type_decisions_(other.optimization_type_decisions_),
optimization_target_decisions_(other.optimization_target_decisions_),
optimization_target_model_versions_(
other.optimization_target_model_versions_),
optimization_target_model_prediction_scores_(
other.optimization_target_model_prediction_scores_),
has_hint_before_commit_(other.has_hint_before_commit_),
has_hint_after_commit_(other.has_hint_after_commit_),
was_host_covered_by_fetch_at_navigation_start_(
......@@ -55,6 +61,20 @@ OptimizationGuideNavigationData::OptimizationGuideNavigationData(
}
}
// static
OptimizationGuideNavigationData*
OptimizationGuideNavigationData::GetFromNavigationHandle(
content::NavigationHandle* navigation_handle) {
OptimizationGuideWebContentsObserver*
optimization_guide_web_contents_observer =
OptimizationGuideWebContentsObserver::FromWebContents(
navigation_handle->GetWebContents());
if (!optimization_guide_web_contents_observer)
return nullptr;
return optimization_guide_web_contents_observer
->GetOrCreateOptimizationGuideNavigationData(navigation_handle);
}
void OptimizationGuideNavigationData::RecordMetrics(bool has_committed) const {
RecordHintCacheMatch(has_committed);
RecordOptimizationTypeAndTargetDecisions();
......@@ -130,26 +150,40 @@ void OptimizationGuideNavigationData::RecordOptimizationTypeAndTargetDecisions()
}
void OptimizationGuideNavigationData::RecordOptimizationGuideUKM() const {
if (!serialized_hint_version_string_.has_value() ||
serialized_hint_version_string_.value().empty())
return;
// Deserialize the serialized version string into its protobuffer.
std::string binary_version_pb;
if (!base::Base64Decode(serialized_hint_version_string_.value(),
&binary_version_pb))
return;
optimization_guide::proto::Version hint_version;
if (!hint_version.ParseFromString(binary_version_pb))
return;
// Record the UKM.
bool did_record_metric = false;
ukm::SourceId ukm_source_id =
ukm::ConvertToSourceId(navigation_id_, ukm::SourceIdType::NAVIGATION_ID);
ukm::builders::OptimizationGuide builder(ukm_source_id);
bool did_record_metric = false;
// Record model metrics.
for (const auto& optimization_target_model_version :
optimization_target_model_versions_) {
if (optimization_target_model_version.first ==
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) {
did_record_metric = true;
builder.SetPainfulPageLoadModelVersion(
optimization_target_model_version.second);
}
}
for (const auto& optimization_target_model_prediction_score :
optimization_target_model_prediction_scores_) {
if (optimization_target_model_prediction_score.first ==
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) {
did_record_metric = true;
builder.SetPainfulPageLoadModelPredictionScore(static_cast<int64_t>(
100 * optimization_target_model_prediction_score.second));
}
}
// Record hint metrics.
if (serialized_hint_version_string_.has_value() &&
!serialized_hint_version_string_.value().empty()) {
// Deserialize the serialized version string into its protobuffer.
std::string binary_version_pb;
if (base::Base64Decode(serialized_hint_version_string_.value(),
&binary_version_pb)) {
optimization_guide::proto::Version hint_version;
if (hint_version.ParseFromString(binary_version_pb)) {
if (hint_version.has_generation_timestamp() &&
hint_version.generation_timestamp().seconds() > 0) {
did_record_metric = true;
......@@ -162,6 +196,9 @@ void OptimizationGuideNavigationData::RecordOptimizationGuideUKM() const {
did_record_metric = true;
builder.SetHintSource(static_cast<int>(hint_version.hint_source()));
}
}
}
}
// Only record UKM if a metric was recorded.
if (did_record_metric)
......@@ -199,3 +236,39 @@ void OptimizationGuideNavigationData::SetDecisionForOptimizationTarget(
optimization_guide::OptimizationTargetDecision decision) {
optimization_target_decisions_[optimization_target] = decision;
}
base::Optional<int64_t>
OptimizationGuideNavigationData::GetModelVersionForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) const {
auto optimization_target_model_version_iter =
optimization_target_model_versions_.find(optimization_target);
if (optimization_target_model_version_iter ==
optimization_target_model_versions_.end())
return base::nullopt;
return optimization_target_model_version_iter->second;
}
void OptimizationGuideNavigationData::SetModelVersionForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target,
int64_t model_version) {
optimization_target_model_versions_[optimization_target] = model_version;
}
base::Optional<double>
OptimizationGuideNavigationData::GetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) const {
auto optimization_target_model_prediction_score_iter =
optimization_target_model_prediction_scores_.find(optimization_target);
if (optimization_target_model_prediction_score_iter ==
optimization_target_model_prediction_scores_.end())
return base::nullopt;
return optimization_target_model_prediction_score_iter->second;
}
void OptimizationGuideNavigationData::
SetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target,
double model_prediction_score) {
optimization_target_model_prediction_scores_[optimization_target] =
model_prediction_score;
}
......@@ -26,6 +26,11 @@ class OptimizationGuideNavigationData {
OptimizationGuideNavigationData(const OptimizationGuideNavigationData& other);
// Returns the OptimizationGuideNavigationData for |navigation_handle|. Will
// return nullptr if one cannot be created for it for any reason.
static OptimizationGuideNavigationData* GetFromNavigationHandle(
content::NavigationHandle* navigation_handle);
// Records metrics based on data currently held in |this|. |has_committed|
// indicates whether commit-time metrics should be recorded.
void RecordMetrics(bool has_committed) const;
......@@ -61,6 +66,23 @@ class OptimizationGuideNavigationData {
optimization_guide::proto::OptimizationTarget optimization_target,
optimization_guide::OptimizationTargetDecision decision);
// Returns the version of the model evaluated for |optimization_target|.
base::Optional<int64_t> GetModelVersionForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) const;
// Sets the |model_version| for |optimization_target|.
void SetModelVersionForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target,
int64_t model_version);
// Returns the prediction score of the model evaluated for
// |optimization_target|.
base::Optional<double> GetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) const;
// Sets the |model_prediction_score| for |optimization_target|.
void SetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target,
double model_prediction_score);
// Whether the hint cache had a hint for the navigation before commit.
base::Optional<bool> has_hint_before_commit() const {
return has_hint_before_commit_;
......@@ -127,6 +149,17 @@ class OptimizationGuideNavigationData {
optimization_guide::OptimizationTargetDecision>
optimization_target_decisions_;
// The version of the painful page load model that was evaluated for the
// page load.
base::flat_map<optimization_guide::proto::OptimizationTarget, int64_t>
optimization_target_model_versions_;
// The score output after evaluating the painful page load model. If
// populated, this is 100x the fractional value output by the model
// evaluation.
base::flat_map<optimization_guide::proto::OptimizationTarget, double>
optimization_target_model_prediction_scores_;
// Whether the hint cache had a hint for the navigation before commit.
base::Optional<bool> has_hint_before_commit_;
......
......@@ -433,6 +433,87 @@ TEST(OptimizationGuideNavigationDataTest,
optimization_guide::proto::HINT_SOURCE_OPTIMIZATION_GUIDE_SERVICE));
}
TEST(OptimizationGuideNavigationDataTest,
RecordMetricsOptimizationTargetModelVersion) {
base::test::TaskEnvironment env;
ukm::TestAutoSetUkmRecorder ukm_recorder;
OptimizationGuideNavigationData data(/*navigation_id=*/3);
data.SetModelVersionForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, 2);
data.RecordMetrics(/*has_committed=*/false);
auto entries = ukm_recorder.GetEntriesByName(
ukm::builders::OptimizationGuide::kEntryName);
EXPECT_EQ(1u, entries.size());
auto* entry = entries[0];
EXPECT_TRUE(ukm_recorder.EntryHasMetric(
entry,
ukm::builders::OptimizationGuide::kPainfulPageLoadModelVersionName));
ukm_recorder.ExpectEntryMetric(
entry, ukm::builders::OptimizationGuide::kPainfulPageLoadModelVersionName,
2);
}
TEST(OptimizationGuideNavigationDataTest,
RecordMetricsModelVersionForOptimizationTargetHasNoCorrespondingUkm) {
base::test::TaskEnvironment env;
ukm::TestAutoSetUkmRecorder ukm_recorder;
OptimizationGuideNavigationData data(/*navigation_id=*/3);
data.SetModelVersionForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_UNKNOWN, 2);
data.RecordMetrics(/*has_committed=*/false);
// Make sure UKM not recorded for all empty values.
auto entries = ukm_recorder.GetEntriesByName(
ukm::builders::OptimizationGuide::kEntryName);
EXPECT_TRUE(entries.empty());
}
TEST(OptimizationGuideNavigationDataTest,
RecordMetricsOptimizationTargetModelPredictionScore) {
base::test::TaskEnvironment env;
ukm::TestAutoSetUkmRecorder ukm_recorder;
OptimizationGuideNavigationData data(/*navigation_id=*/3);
data.SetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, 0.123);
data.RecordMetrics(/*has_committed=*/false);
auto entries = ukm_recorder.GetEntriesByName(
ukm::builders::OptimizationGuide::kEntryName);
EXPECT_EQ(1u, entries.size());
auto* entry = entries[0];
EXPECT_TRUE(ukm_recorder.EntryHasMetric(
entry, ukm::builders::OptimizationGuide::
kPainfulPageLoadModelPredictionScoreName));
ukm_recorder.ExpectEntryMetric(entry,
ukm::builders::OptimizationGuide::
kPainfulPageLoadModelPredictionScoreName,
12);
}
TEST(OptimizationGuideNavigationDataTest,
RecordMetricsModelPredicitonScoreOptimizationTargetHasNoCorrespondingUkm) {
base::test::TaskEnvironment env;
ukm::TestAutoSetUkmRecorder ukm_recorder;
OptimizationGuideNavigationData data(/*navigation_id=*/3);
data.SetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_UNKNOWN, 0.123);
data.RecordMetrics(/*has_committed=*/false);
// Make sure UKM not recorded for all empty values.
auto entries = ukm_recorder.GetEntriesByName(
ukm::builders::OptimizationGuide::kEntryName);
EXPECT_TRUE(entries.empty());
}
TEST(OptimizationGuideNavigationDataTest,
RecordMetricsMultipleOptimizationTypes) {
base::HistogramTester histogram_tester;
......@@ -537,6 +618,14 @@ TEST(OptimizationGuideNavigationDataTest, DeepCopy) {
EXPECT_EQ(base::nullopt, data->has_hint_before_commit());
EXPECT_EQ(base::nullopt, data->has_hint_after_commit());
EXPECT_FALSE(data->has_page_hint_value());
EXPECT_EQ(
base::nullopt,
data->GetModelVersionForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD));
EXPECT_EQ(
base::nullopt,
data->GetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD));
data->set_serialized_hint_version_string("123abc");
data->SetDecisionForOptimizationType(
......@@ -552,6 +641,10 @@ TEST(OptimizationGuideNavigationDataTest, DeepCopy) {
page_hint.set_page_pattern("pagepattern");
data->set_page_hint(
std::make_unique<optimization_guide::proto::PageHint>(page_hint));
data->SetModelVersionForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, 123);
data->SetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, 0.12);
OptimizationGuideNavigationData data_copy(*data);
EXPECT_EQ(3, data_copy.navigation_id());
......@@ -567,4 +660,12 @@ TEST(OptimizationGuideNavigationDataTest, DeepCopy) {
EXPECT_EQ("123abc", *(data_copy.serialized_hint_version_string()));
EXPECT_TRUE(data_copy.has_page_hint_value());
EXPECT_EQ("pagepattern", data_copy.page_hint()->page_pattern());
EXPECT_EQ(
123,
*data_copy.GetModelVersionForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD));
EXPECT_EQ(
0.12,
*data_copy.GetModelPredictionScoreForOptimizationTarget(
optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD));
}
......@@ -123,14 +123,15 @@ bool DecisionTreePredictionModel::ValidateTreeNode(
optimization_guide::OptimizationTargetDecision
DecisionTreePredictionModel::Predict(
const base::flat_map<std::string, float>& model_features) {
const base::flat_map<std::string, float>& model_features,
double* prediction_score) {
SEQUENCE_CHECKER(sequence_checker_);
double result = 0.0;
*prediction_score = 0.0;
// TODO(mcrouse): Add metrics to record if the model evaluation fails.
if (!EvaluateModel(*model_.get(), model_features, &result))
if (!EvaluateModel(*model_.get(), model_features, prediction_score))
return optimization_guide::OptimizationTargetDecision::kUnknown;
if (result > model_->threshold().value())
if (*prediction_score > model_->threshold().value())
return optimization_guide::OptimizationTargetDecision::kPageLoadMatches;
return optimization_guide::OptimizationTargetDecision::kPageLoadDoesNotMatch;
}
......
......@@ -30,7 +30,8 @@ class DecisionTreePredictionModel : public PredictionModel {
// PredictionModel implementation:
optimization_guide::OptimizationTargetDecision Predict(
const base::flat_map<std::string, float>& model_features) override;
const base::flat_map<std::string, float>& model_features,
double* prediction_score) override;
private:
// Evaluates the provided model, either an ensemble or decision tree model,
......
......@@ -83,10 +83,14 @@ TEST(DecisionTreePredictionModel, ValidDecisionTreeModel) {
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 1.0}}));
model->Predict({{"agg1", 1.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 2.0}}));
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
}
TEST(DecisionTreePredictionModel, InequalityLessThan) {
......@@ -111,10 +115,14 @@ TEST(DecisionTreePredictionModel, InequalityLessThan) {
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 0.5}}));
model->Predict({{"agg1", 0.5}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 2.0}}));
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
}
TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) {
......@@ -139,10 +147,14 @@ TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) {
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 0.5}}));
model->Predict({{"agg1", 0.5}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 1.0}}));
model->Predict({{"agg1", 1.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
}
TEST(DecisionTreePredictionModel, InequalityGreaterThan) {
......@@ -167,10 +179,14 @@ TEST(DecisionTreePredictionModel, InequalityGreaterThan) {
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 0.5}}));
model->Predict({{"agg1", 0.5}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 2.0}}));
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
}
TEST(DecisionTreePredictionModel, MissingInequalityTest) {
......@@ -416,10 +432,14 @@ TEST(DecisionTreePredictionModel, ValidEnsembleModel) {
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 1.0}}));
model->Predict({{"agg1", 1.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 2.0}}));
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
}
TEST(DecisionTreePredictionModel, EnsembleWithNoMembers) {
......
......@@ -12,6 +12,7 @@
#include "base/metrics/histogram_macros.h"
#include "chrome/browser/browser_process.h"
#include "chrome/browser/engagement/site_engagement_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_navigation_data.h"
#include "chrome/browser/optimization_guide/optimization_guide_session_statistic.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_fetcher.h"
......@@ -182,7 +183,28 @@ OptimizationTargetDecision PredictionManager::ShouldTargetNavigation(
base::flat_map<std::string, float> feature_map =
BuildFeatureMap(navigation_handle, prediction_model->GetModelFeatures());
return prediction_model->Predict(feature_map);
double prediction_score = 0.0;
optimization_guide::OptimizationTargetDecision target_decision =
prediction_model->Predict(feature_map, &prediction_score);
OptimizationGuideNavigationData* navigation_data =
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
if (navigation_data) {
navigation_data->SetModelVersionForOptimizationTarget(
optimization_target, prediction_model->GetVersion());
navigation_data->SetModelPredictionScoreForOptimizationTarget(
optimization_target, prediction_score);
}
if (optimization_guide::features::
ShouldOverrideOptimizationTargetDecisionForMetricsPurposes(
optimization_target)) {
return optimization_guide::OptimizationTargetDecision::
kModelPredictionHoldback;
}
return target_decision;
}
void PredictionManager::OnEffectiveConnectionTypeChanged(
......
......@@ -33,9 +33,11 @@ class PredictionModel {
const base::flat_set<std::string>& host_model_features);
// Returns the OptimizationTargetDecision by evaluating the |model_|
// using the provided |model_features|.
// using the provided |model_features|. |prediction_score| will be populated
// with the score output by the model.
virtual optimization_guide::OptimizationTargetDecision Predict(
const base::flat_map<std::string, float>& model_features) = 0;
const base::flat_map<std::string, float>& model_features,
double* prediction_score) = 0;
// Provide the version of the |model_| by |this|.
int64_t GetVersion() const;
......
......@@ -5332,6 +5332,19 @@ be describing additional metrics about the same event.
applied on the page load.
</summary>
</metric>
<metric name="PainfulPageLoadModelPredictionScore">
<summary>
The score output after evaluating the painful page load model. If
populated, this is 100x the fractional value output by the model
evaluation. This will be a value between 0 and 100.
</summary>
</metric>
<metric name="PainfulPageLoadModelVersion">
<summary>
The server-generated version of the painful page load model that was
evaluated for the page load.
</summary>
</metric>
</event>
<event name="PageDomainInfo">
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment