Commit df9a1cbf authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Use host model features field instead of some fake set calculation

Bug: 1001194
Change-Id: I1ea4fa11ba5560ac893cef456c00e104d5748317
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1931143Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#718371}
parent ee4dd036
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/decision_tree_prediction_model.h" #include "chrome/browser/optimization_guide/prediction/decision_tree_prediction_model.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include <utility>
namespace optimization_guide { namespace optimization_guide {
DecisionTreePredictionModel::DecisionTreePredictionModel( DecisionTreePredictionModel::DecisionTreePredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel> std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model, prediction_model)
const base::flat_set<std::string>& host_model_features) : PredictionModel(std::move(prediction_model)) {}
: PredictionModel(std::move(prediction_model), host_model_features) {}
DecisionTreePredictionModel::~DecisionTreePredictionModel() = default; DecisionTreePredictionModel::~DecisionTreePredictionModel() = default;
......
...@@ -21,10 +21,9 @@ namespace optimization_guide { ...@@ -21,10 +21,9 @@ namespace optimization_guide {
// supported by the optimization guide. // supported by the optimization guide.
class DecisionTreePredictionModel : public PredictionModel { class DecisionTreePredictionModel : public PredictionModel {
public: public:
DecisionTreePredictionModel( explicit DecisionTreePredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel> std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model, prediction_model);
const base::flat_set<std::string>& host_model_features);
~DecisionTreePredictionModel() override; ~DecisionTreePredictionModel() override;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// found in the LICENSE file. // found in the LICENSE file.
#include <memory> #include <memory>
#include <utility>
#include "base/containers/flat_map.h" #include "base/containers/flat_map.h"
#include "base/containers/flat_set.h" #include "base/containers/flat_set.h"
...@@ -79,9 +80,10 @@ TEST(DecisionTreePredictionModel, ValidDecisionTreeModel) { ...@@ -79,9 +80,10 @@ TEST(DecisionTreePredictionModel, ValidDecisionTreeModel) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model); EXPECT_TRUE(model);
double prediction_score; double prediction_score;
...@@ -111,9 +113,10 @@ TEST(DecisionTreePredictionModel, InequalityLessThan) { ...@@ -111,9 +113,10 @@ TEST(DecisionTreePredictionModel, InequalityLessThan) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model); EXPECT_TRUE(model);
double prediction_score; double prediction_score;
...@@ -143,9 +146,10 @@ TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) { ...@@ -143,9 +146,10 @@ TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model); EXPECT_TRUE(model);
double prediction_score; double prediction_score;
...@@ -175,9 +179,10 @@ TEST(DecisionTreePredictionModel, InequalityGreaterThan) { ...@@ -175,9 +179,10 @@ TEST(DecisionTreePredictionModel, InequalityGreaterThan) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model); EXPECT_TRUE(model);
double prediction_score; double prediction_score;
...@@ -207,9 +212,10 @@ TEST(DecisionTreePredictionModel, MissingInequalityTest) { ...@@ -207,9 +212,10 @@ TEST(DecisionTreePredictionModel, MissingInequalityTest) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -226,9 +232,10 @@ TEST(DecisionTreePredictionModel, NoDecisionTreeThreshold) { ...@@ -226,9 +232,10 @@ TEST(DecisionTreePredictionModel, NoDecisionTreeThreshold) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -245,9 +252,10 @@ TEST(DecisionTreePredictionModel, EmptyTree) { ...@@ -245,9 +252,10 @@ TEST(DecisionTreePredictionModel, EmptyTree) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -264,9 +272,10 @@ TEST(DecisionTreePredictionModel, ModelFeatureNotInFeatureMap) { ...@@ -264,9 +272,10 @@ TEST(DecisionTreePredictionModel, ModelFeatureNotInFeatureMap) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -287,9 +296,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeMissingLeaf) { ...@@ -287,9 +296,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeMissingLeaf) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -311,9 +321,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeLeftChildIndexInvalid) { ...@@ -311,9 +321,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeLeftChildIndexInvalid) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -335,9 +346,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeRightChildIndexInvalid) { ...@@ -335,9 +346,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeRightChildIndexInvalid) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -373,9 +385,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnLeftChild) { ...@@ -373,9 +385,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnLeftChild) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -411,9 +424,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnRightChild) { ...@@ -411,9 +424,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnRightChild) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -428,9 +442,10 @@ TEST(DecisionTreePredictionModel, ValidEnsembleModel) { ...@@ -428,9 +442,10 @@ TEST(DecisionTreePredictionModel, ValidEnsembleModel) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model); EXPECT_TRUE(model);
double prediction_score; double prediction_score;
...@@ -457,9 +472,10 @@ TEST(DecisionTreePredictionModel, EnsembleWithNoMembers) { ...@@ -457,9 +472,10 @@ TEST(DecisionTreePredictionModel, EnsembleWithNoMembers) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
#include "chrome/browser/optimization_guide/prediction/prediction_manager.h" #include "chrome/browser/optimization_guide/prediction/prediction_manager.h"
#include <memory> #include <utility>
#include <vector>
#include "base/containers/flat_map.h" #include "base/containers/flat_map.h"
#include "base/containers/flat_set.h" #include "base/containers/flat_set.h"
...@@ -175,9 +174,7 @@ void PredictionManager::RegisterOptimizationTargets( ...@@ -175,9 +174,7 @@ void PredictionManager::RegisterOptimizationTargets(
if (new_optimization_targets.size() == 0) if (new_optimization_targets.size() == 0)
return; return;
// Start loading the host model features if they are not already. Models // Start loading the host model features if they are not already.
// cannot be loaded from the store until the host model features have loaded
// from the store as they are required to construct each prediction model.
if (!host_model_features_loaded_) { if (!host_model_features_loaded_) {
LoadHostModelFeatures(); LoadHostModelFeatures();
return; return;
...@@ -464,15 +461,13 @@ void PredictionManager::UpdateHostModelFeatures( ...@@ -464,15 +461,13 @@ void PredictionManager::UpdateHostModelFeatures(
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
for (const auto& host_model_features : host_model_features) for (const auto& host_model_features : host_model_features)
ProcessAndStoreHostModelFeatures(host_model_features); ProcessAndStoreHostModelFeatures(host_model_features);
UpdateSupportedHostModelFeatures();
} }
std::unique_ptr<PredictionModel> PredictionManager::CreatePredictionModel( std::unique_ptr<PredictionModel> PredictionManager::CreatePredictionModel(
const proto::PredictionModel& model, const proto::PredictionModel& model) const {
const base::flat_set<std::string>& host_model_features) const {
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
return PredictionModel::Create( return PredictionModel::Create(
std::make_unique<proto::PredictionModel>(model), host_model_features); std::make_unique<proto::PredictionModel>(model));
} }
void PredictionManager::UpdatePredictionModels( void PredictionManager::UpdatePredictionModels(
...@@ -494,10 +489,9 @@ void PredictionManager::OnStoreInitialized() { ...@@ -494,10 +489,9 @@ void PredictionManager::OnStoreInitialized() {
return; return;
// The store is ready so start loading host model features and the models for // The store is ready so start loading host model features and the models for
// the registered optimization targets. The host model features must be loaded // the registered optimization targets. Once the host model features are
// first because prediction models require them to be constructed. Once the // loaded, prediction models for the registered optimization targets will be
// host model features are loaded, prediction models for the registered // loaded.
// optimization targets will be loaded.
LoadHostModelFeatures(); LoadHostModelFeatures();
MaybeScheduleModelAndHostModelFeaturesFetch(); MaybeScheduleModelAndHostModelFeaturesFetch();
...@@ -524,7 +518,6 @@ void PredictionManager::OnLoadHostModelFeatures( ...@@ -524,7 +518,6 @@ void PredictionManager::OnLoadHostModelFeatures(
if (all_host_model_features) { if (all_host_model_features) {
for (const auto& host_model_features : *all_host_model_features) for (const auto& host_model_features : *all_host_model_features)
ProcessAndStoreHostModelFeatures(host_model_features); ProcessAndStoreHostModelFeatures(host_model_features);
UpdateSupportedHostModelFeatures();
} }
UMA_HISTOGRAM_COUNTS_1000( UMA_HISTOGRAM_COUNTS_1000(
"OptimizationGuide.PredictionManager.HostModelFeaturesMapSize", "OptimizationGuide.PredictionManager.HostModelFeaturesMapSize",
...@@ -535,28 +528,6 @@ void PredictionManager::OnLoadHostModelFeatures( ...@@ -535,28 +528,6 @@ void PredictionManager::OnLoadHostModelFeatures(
LoadPredictionModels(registered_optimization_targets_); LoadPredictionModels(registered_optimization_targets_);
} }
void PredictionManager::UpdateSupportedHostModelFeatures() {
SEQUENCE_CHECKER(sequence_checker_);
if (host_model_features_map_.size() > 0) {
// Clear the current supported host model features if they exist.
if (supported_host_model_features_.size() != 0)
supported_host_model_features_.clear();
// TODO(crbug/1027224): Add support to collect the set of all features, not
// just for the first host in the map. This is needed when additional models
// are supported.
base::flat_map<std::string, float> host_model_features =
host_model_features_map_.begin()->second;
supported_host_model_features_.reserve(host_model_features.size());
for (const auto& model_feature : host_model_features)
supported_host_model_features_.insert(model_feature.first);
}
}
base::flat_set<std::string>
PredictionManager::GetSupportedHostModelFeaturesForTesting() const {
return supported_host_model_features_;
}
void PredictionManager::LoadPredictionModels( void PredictionManager::LoadPredictionModels(
const base::flat_set<proto::OptimizationTarget>& optimization_targets) { const base::flat_set<proto::OptimizationTarget>& optimization_targets) {
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
...@@ -597,7 +568,7 @@ void PredictionManager::ProcessAndStorePredictionModel( ...@@ -597,7 +568,7 @@ void PredictionManager::ProcessAndStorePredictionModel(
} }
std::unique_ptr<PredictionModel> prediction_model = std::unique_ptr<PredictionModel> prediction_model =
CreatePredictionModel(model, supported_host_model_features_); CreatePredictionModel(model);
if (!prediction_model) if (!prediction_model)
return; return;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#define CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MANAGER_H_ #define CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MANAGER_H_
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "base/containers/flat_map.h" #include "base/containers/flat_map.h"
...@@ -137,8 +138,7 @@ class PredictionManager ...@@ -137,8 +138,7 @@ class PredictionManager
// Create a PredictionModel, virtual for testing. // Create a PredictionModel, virtual for testing.
virtual std::unique_ptr<PredictionModel> CreatePredictionModel( virtual std::unique_ptr<PredictionModel> CreatePredictionModel(
const proto::PredictionModel& model, const proto::PredictionModel& model) const;
const base::flat_set<std::string>& host_model_features) const;
// Process |host_model_features| to be stored in |host_model_features_map|. // Process |host_model_features| to be stored in |host_model_features_map|.
void UpdateHostModelFeatures( void UpdateHostModelFeatures(
...@@ -227,11 +227,6 @@ class PredictionManager ...@@ -227,11 +227,6 @@ class PredictionManager
void ProcessAndStoreHostModelFeatures( void ProcessAndStoreHostModelFeatures(
const proto::HostModelFeatures& host_model_features); const proto::HostModelFeatures& host_model_features);
// Capture the set of feature names that each host in
// |host_model_features_map_| has and store them in
// |supported_host_model_features_|
void UpdateSupportedHostModelFeatures();
// Return the time when a prediction model and host model features fetch was // Return the time when a prediction model and host model features fetch was
// last attempted. // last attempted.
base::Time GetLastFetchAttemptTime() const; base::Time GetLastFetchAttemptTime() const;
...@@ -265,10 +260,6 @@ class PredictionManager ...@@ -265,10 +260,6 @@ class PredictionManager
base::flat_map<std::string, base::flat_map<std::string, float>> base::flat_map<std::string, base::flat_map<std::string, float>>
host_model_features_map_; host_model_features_map_;
// The set of features available across every host in
// |host_model_features_map_|.
base::flat_set<std::string> supported_host_model_features_;
// The current session's FCP statistics for HTTP/HTTPS navigations. // The current session's FCP statistics for HTTP/HTTPS navigations.
OptimizationGuideSessionStatistic session_fcp_; OptimizationGuideSessionStatistic session_fcp_;
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
#include "chrome/browser/optimization_guide/prediction/prediction_manager.h" #include "chrome/browser/optimization_guide/prediction/prediction_manager.h"
#include <map>
#include <memory> #include <memory>
#include <string>
#include <utility>
#include "base/strings/string_number_conversions.h" #include "base/strings/string_number_conversions.h"
#include "base/test/metrics/histogram_tester.h" #include "base/test/metrics/histogram_tester.h"
...@@ -94,9 +97,9 @@ std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse( ...@@ -94,9 +97,9 @@ std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse(
class TestPredictionModel : public PredictionModel { class TestPredictionModel : public PredictionModel {
public: public:
TestPredictionModel(std::unique_ptr<proto::PredictionModel> prediction_model, explicit TestPredictionModel(
const base::flat_set<std::string>& host_model_features) std::unique_ptr<proto::PredictionModel> prediction_model)
: PredictionModel(std::move(prediction_model), host_model_features) {} : PredictionModel(std::move(prediction_model)) {}
~TestPredictionModel() override = default; ~TestPredictionModel() override = default;
optimization_guide::OptimizationTargetDecision Predict( optimization_guide::OptimizationTargetDecision Predict(
...@@ -293,18 +296,15 @@ class TestPredictionManager : public PredictionManager { ...@@ -293,18 +296,15 @@ class TestPredictionManager : public PredictionManager {
~TestPredictionManager() override = default; ~TestPredictionManager() override = default;
std::unique_ptr<PredictionModel> CreatePredictionModel( std::unique_ptr<PredictionModel> CreatePredictionModel(
const proto::PredictionModel& model, const proto::PredictionModel& model) const override {
const base::flat_set<std::string>& host_model_features) const override {
std::unique_ptr<PredictionModel> prediction_model = std::unique_ptr<PredictionModel> prediction_model =
std::make_unique<TestPredictionModel>( std::make_unique<TestPredictionModel>(
std::make_unique<proto::PredictionModel>(model), std::make_unique<proto::PredictionModel>(model));
host_model_features);
return prediction_model; return prediction_model;
} }
using PredictionManager::GetHostModelFeaturesForTesting; using PredictionManager::GetHostModelFeaturesForTesting;
using PredictionManager::GetPredictionModelForTesting; using PredictionManager::GetPredictionModelForTesting;
using PredictionManager::GetSupportedHostModelFeaturesForTesting;
std::unique_ptr<OptimizationGuideStore> std::unique_ptr<OptimizationGuideStore>
CreateModelAndHostModelFeaturesStore() { CreateModelAndHostModelFeaturesStore() {
...@@ -1214,22 +1214,6 @@ TEST_F(PredictionManagerTest, ...@@ -1214,22 +1214,6 @@ TEST_F(PredictionManagerTest,
EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
} }
TEST_F(PredictionManagerTest, SupportedHostModelFeaturesUpdated) {
CreatePredictionManager(
{optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD});
prediction_manager()->SetPredictionModelFetcherForTesting(
BuildTestPredictionModelFetcher(
PredictionModelFetcherEndState::kFetchFailed));
SetStoreInitialized();
EXPECT_TRUE(models_and_features_store()->WasHostModelFeaturesLoaded());
EXPECT_TRUE(prediction_manager()->GetHostModelFeaturesForTesting().contains(
"foo.com"));
EXPECT_TRUE(
prediction_manager()->GetSupportedHostModelFeaturesForTesting().contains(
"host_feat1"));
}
TEST_F(PredictionManagerTest, ModelFetcherTimerRetryDelay) { TEST_F(PredictionManagerTest, ModelFetcherTimerRetryDelay) {
base::test::ScopedFeatureList feature_list; base::test::ScopedFeatureList feature_list;
feature_list.InitWithFeatures( feature_list.InitWithFeatures(
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/prediction_model.h" #include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include <utility>
#include "chrome/browser/optimization_guide/prediction/decision_tree_prediction_model.h" #include "chrome/browser/optimization_guide/prediction/decision_tree_prediction_model.h"
namespace optimization_guide { namespace optimization_guide {
...@@ -10,8 +13,7 @@ namespace optimization_guide { ...@@ -10,8 +13,7 @@ namespace optimization_guide {
// static // static
std::unique_ptr<PredictionModel> PredictionModel::Create( std::unique_ptr<PredictionModel> PredictionModel::Create(
std::unique_ptr<optimization_guide::proto::PredictionModel> std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model, prediction_model) {
const base::flat_set<std::string>& host_model_features) {
// TODO(crbug/1009123): Add a histogram to record if the provided model is // TODO(crbug/1009123): Add a histogram to record if the provided model is
// constructed successfully or not. // constructed successfully or not.
// TODO(crbug/1009123): Adding timing metrics around initialization due to // TODO(crbug/1009123): Adding timing metrics around initialization due to
...@@ -55,7 +57,7 @@ std::unique_ptr<PredictionModel> PredictionModel::Create( ...@@ -55,7 +57,7 @@ std::unique_ptr<PredictionModel> PredictionModel::Create(
return nullptr; return nullptr;
} }
model = std::make_unique<DecisionTreePredictionModel>( model = std::make_unique<DecisionTreePredictionModel>(
std::move(prediction_model), host_model_features); std::move(prediction_model));
// Any constructed model must be validated for correctness according to its // Any constructed model must be validated for correctness according to its
// model type before being returned. // model type before being returned.
...@@ -67,12 +69,11 @@ std::unique_ptr<PredictionModel> PredictionModel::Create( ...@@ -67,12 +69,11 @@ std::unique_ptr<PredictionModel> PredictionModel::Create(
PredictionModel::PredictionModel( PredictionModel::PredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel> std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model, prediction_model) {
const base::flat_set<std::string>& host_model_features) {
version_ = prediction_model->model_info().version(); version_ = prediction_model->model_info().version();
model_features_.reserve( model_features_.reserve(
prediction_model->model_info().supported_model_features_size() + prediction_model->model_info().supported_model_features_size() +
host_model_features.size()); prediction_model->model_info().supported_host_model_features_size());
// Insert all the client model features for the owned |model_|. // Insert all the client model features for the owned |model_|.
for (const auto& client_model_feature : for (const auto& client_model_feature :
prediction_model->model_info().supported_model_features()) { prediction_model->model_info().supported_model_features()) {
...@@ -80,8 +81,10 @@ PredictionModel::PredictionModel( ...@@ -80,8 +81,10 @@ PredictionModel::PredictionModel(
client_model_feature)); client_model_feature));
} }
// Insert all the host model features for the owned |model_|. // Insert all the host model features for the owned |model_|.
for (const auto& host_model_feature : host_model_features) for (const auto& host_model_feature :
prediction_model->model_info().supported_host_model_features()) {
model_features_.emplace(host_model_feature); model_features_.emplace(host_model_feature);
}
model_ = std::make_unique<optimization_guide::proto::Model>( model_ = std::make_unique<optimization_guide::proto::Model>(
prediction_model->model()); prediction_model->model());
} }
......
...@@ -29,8 +29,7 @@ class PredictionModel { ...@@ -29,8 +29,7 @@ class PredictionModel {
// should should be called in the background. // should should be called in the background.
static std::unique_ptr<PredictionModel> Create( static std::unique_ptr<PredictionModel> Create(
std::unique_ptr<optimization_guide::proto::PredictionModel> std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model, prediction_model);
const base::flat_set<std::string>& host_model_features);
// Returns the OptimizationTargetDecision by evaluating the |model_| // Returns the OptimizationTargetDecision by evaluating the |model_|
// using the provided |model_features|. |prediction_score| will be populated // using the provided |model_features|. |prediction_score| will be populated
...@@ -48,8 +47,7 @@ class PredictionModel { ...@@ -48,8 +47,7 @@ class PredictionModel {
protected: protected:
PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel> PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model, prediction_model);
const base::flat_set<std::string>& host_model_features);
// The in-memory model used for prediction. // The in-memory model used for prediction.
std::unique_ptr<optimization_guide::proto::Model> model_; std::unique_ptr<optimization_guide::proto::Model> model_;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "chrome/browser/optimization_guide/prediction/prediction_model.h" #include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include <memory> #include <utility>
#include "components/optimization_guide/proto/models.pb.h" #include "components/optimization_guide/proto/models.pb.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -57,9 +57,10 @@ TEST(PredictionModelTest, ValidPredictionModel) { ...@@ -57,9 +57,10 @@ TEST(PredictionModelTest, ValidPredictionModel) {
model_info->add_supported_model_features( model_info->add_supported_model_features(
optimization_guide::proto::ClientModelFeature:: optimization_guide::proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_EQ(1, model->GetVersion()); EXPECT_EQ(1, model->GetVersion());
EXPECT_EQ(2u, model->GetModelFeatures().size()); EXPECT_EQ(2u, model->GetModelFeatures().size());
...@@ -73,7 +74,7 @@ TEST(PredictionModelTest, NoModel) { ...@@ -73,7 +74,7 @@ TEST(PredictionModelTest, NoModel) {
std::make_unique<optimization_guide::proto::PredictionModel>(); std::make_unique<optimization_guide::proto::PredictionModel>();
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -86,7 +87,7 @@ TEST(PredictionModelTest, NoModelVersion) { ...@@ -86,7 +87,7 @@ TEST(PredictionModelTest, NoModelVersion) {
decision_tree_model->set_weight(2.0); decision_tree_model->set_weight(2.0);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -103,7 +104,7 @@ TEST(PredictionModelTest, NoModelType) { ...@@ -103,7 +104,7 @@ TEST(PredictionModelTest, NoModelType) {
model_info->set_version(1); model_info->set_version(1);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -122,7 +123,7 @@ TEST(PredictionModelTest, UnknownModelType) { ...@@ -122,7 +123,7 @@ TEST(PredictionModelTest, UnknownModelType) {
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN); optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -143,7 +144,7 @@ TEST(PredictionModelTest, MultipleModelTypes) { ...@@ -143,7 +144,7 @@ TEST(PredictionModelTest, MultipleModelTypes) {
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN); optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
...@@ -166,7 +167,7 @@ TEST(PredictionModelTest, UnknownModelClientFeature) { ...@@ -166,7 +167,7 @@ TEST(PredictionModelTest, UnknownModelClientFeature) {
CLIENT_MODEL_FEATURE_UNKNOWN); CLIENT_MODEL_FEATURE_UNKNOWN);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"}); PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
......
...@@ -199,6 +199,10 @@ message ModelInfo { ...@@ -199,6 +199,10 @@ message ModelInfo {
repeated ClientModelFeature supported_model_features = 3; repeated ClientModelFeature supported_model_features = 3;
// The set of model types the requesting client can use to make predictions. // The set of model types the requesting client can use to make predictions.
repeated ModelType supported_model_types = 4; repeated ModelType supported_model_types = 4;
// The set of host model features that are referenced by the model.
//
// Note that this should only be populated if part of the response.
repeated string supported_host_model_features = 5;
} }
// The scenarios for which the optimization guide has models for. // The scenarios for which the optimization guide has models for.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment