Commit df9a1cbf authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Use host model features field instead of some fake set calculation

Bug: 1001194
Change-Id: I1ea4fa11ba5560ac893cef456c00e104d5748317
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1931143Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#718371}
parent ee4dd036
......@@ -3,15 +3,15 @@
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/decision_tree_prediction_model.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include <utility>
namespace optimization_guide {
DecisionTreePredictionModel::DecisionTreePredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model,
const base::flat_set<std::string>& host_model_features)
: PredictionModel(std::move(prediction_model), host_model_features) {}
prediction_model)
: PredictionModel(std::move(prediction_model)) {}
DecisionTreePredictionModel::~DecisionTreePredictionModel() = default;
......
......@@ -21,10 +21,9 @@ namespace optimization_guide {
// supported by the optimization guide.
class DecisionTreePredictionModel : public PredictionModel {
public:
DecisionTreePredictionModel(
explicit DecisionTreePredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model,
const base::flat_set<std::string>& host_model_features);
prediction_model);
~DecisionTreePredictionModel() override;
......
......@@ -3,6 +3,7 @@
// found in the LICENSE file.
#include <memory>
#include <utility>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
......@@ -79,9 +80,10 @@ TEST(DecisionTreePredictionModel, ValidDecisionTreeModel) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
......@@ -111,9 +113,10 @@ TEST(DecisionTreePredictionModel, InequalityLessThan) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
......@@ -143,9 +146,10 @@ TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
......@@ -175,9 +179,10 @@ TEST(DecisionTreePredictionModel, InequalityGreaterThan) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
......@@ -207,9 +212,10 @@ TEST(DecisionTreePredictionModel, MissingInequalityTest) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -226,9 +232,10 @@ TEST(DecisionTreePredictionModel, NoDecisionTreeThreshold) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -245,9 +252,10 @@ TEST(DecisionTreePredictionModel, EmptyTree) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -264,9 +272,10 @@ TEST(DecisionTreePredictionModel, ModelFeatureNotInFeatureMap) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -287,9 +296,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeMissingLeaf) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -311,9 +321,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeLeftChildIndexInvalid) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -335,9 +346,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeRightChildIndexInvalid) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -373,9 +385,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnLeftChild) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -411,9 +424,10 @@ TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnRightChild) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -428,9 +442,10 @@ TEST(DecisionTreePredictionModel, ValidEnsembleModel) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
......@@ -457,9 +472,10 @@ TEST(DecisionTreePredictionModel, EnsembleWithNoMembers) {
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......
......@@ -4,8 +4,7 @@
#include "chrome/browser/optimization_guide/prediction/prediction_manager.h"
#include <memory>
#include <vector>
#include <utility>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
......@@ -175,9 +174,7 @@ void PredictionManager::RegisterOptimizationTargets(
if (new_optimization_targets.size() == 0)
return;
// Start loading the host model features if they are not already. Models
// cannot be loaded from the store until the host model features have loaded
// from the store as they are required to construct each prediction model.
// Start loading the host model features if they are not already.
if (!host_model_features_loaded_) {
LoadHostModelFeatures();
return;
......@@ -464,15 +461,13 @@ void PredictionManager::UpdateHostModelFeatures(
SEQUENCE_CHECKER(sequence_checker_);
for (const auto& host_model_features : host_model_features)
ProcessAndStoreHostModelFeatures(host_model_features);
UpdateSupportedHostModelFeatures();
}
std::unique_ptr<PredictionModel> PredictionManager::CreatePredictionModel(
const proto::PredictionModel& model,
const base::flat_set<std::string>& host_model_features) const {
const proto::PredictionModel& model) const {
SEQUENCE_CHECKER(sequence_checker_);
return PredictionModel::Create(
std::make_unique<proto::PredictionModel>(model), host_model_features);
std::make_unique<proto::PredictionModel>(model));
}
void PredictionManager::UpdatePredictionModels(
......@@ -494,10 +489,9 @@ void PredictionManager::OnStoreInitialized() {
return;
// The store is ready so start loading host model features and the models for
// the registered optimization targets. The host model features must be loaded
// first because prediction models require them to be constructed. Once the
// host model features are loaded, prediction models for the registered
// optimization targets will be loaded.
// the registered optimization targets. Once the host model features are
// loaded, prediction models for the registered optimization targets will be
// loaded.
LoadHostModelFeatures();
MaybeScheduleModelAndHostModelFeaturesFetch();
......@@ -524,7 +518,6 @@ void PredictionManager::OnLoadHostModelFeatures(
if (all_host_model_features) {
for (const auto& host_model_features : *all_host_model_features)
ProcessAndStoreHostModelFeatures(host_model_features);
UpdateSupportedHostModelFeatures();
}
UMA_HISTOGRAM_COUNTS_1000(
"OptimizationGuide.PredictionManager.HostModelFeaturesMapSize",
......@@ -535,28 +528,6 @@ void PredictionManager::OnLoadHostModelFeatures(
LoadPredictionModels(registered_optimization_targets_);
}
void PredictionManager::UpdateSupportedHostModelFeatures() {
SEQUENCE_CHECKER(sequence_checker_);
if (host_model_features_map_.size() > 0) {
// Clear the current supported host model features if they exist.
if (supported_host_model_features_.size() != 0)
supported_host_model_features_.clear();
// TODO(crbug/1027224): Add support to collect the set of all features, not
// just for the first host in the map. This is needed when additional models
// are supported.
base::flat_map<std::string, float> host_model_features =
host_model_features_map_.begin()->second;
supported_host_model_features_.reserve(host_model_features.size());
for (const auto& model_feature : host_model_features)
supported_host_model_features_.insert(model_feature.first);
}
}
base::flat_set<std::string>
PredictionManager::GetSupportedHostModelFeaturesForTesting() const {
return supported_host_model_features_;
}
void PredictionManager::LoadPredictionModels(
const base::flat_set<proto::OptimizationTarget>& optimization_targets) {
SEQUENCE_CHECKER(sequence_checker_);
......@@ -597,7 +568,7 @@ void PredictionManager::ProcessAndStorePredictionModel(
}
std::unique_ptr<PredictionModel> prediction_model =
CreatePredictionModel(model, supported_host_model_features_);
CreatePredictionModel(model);
if (!prediction_model)
return;
......
......@@ -6,6 +6,7 @@
#define CHROME_BROWSER_OPTIMIZATION_GUIDE_PREDICTION_PREDICTION_MANAGER_H_
#include <memory>
#include <string>
#include <vector>
#include "base/containers/flat_map.h"
......@@ -137,8 +138,7 @@ class PredictionManager
// Create a PredictionModel, virtual for testing.
virtual std::unique_ptr<PredictionModel> CreatePredictionModel(
const proto::PredictionModel& model,
const base::flat_set<std::string>& host_model_features) const;
const proto::PredictionModel& model) const;
// Process |host_model_features| to be stored in |host_model_features_map|.
void UpdateHostModelFeatures(
......@@ -227,11 +227,6 @@ class PredictionManager
void ProcessAndStoreHostModelFeatures(
const proto::HostModelFeatures& host_model_features);
// Capture the set of feature names that each host in
// |host_model_features_map_| has and store them in
// |supported_host_model_features_|
void UpdateSupportedHostModelFeatures();
// Return the time when a prediction model and host model features fetch was
// last attempted.
base::Time GetLastFetchAttemptTime() const;
......@@ -265,10 +260,6 @@ class PredictionManager
base::flat_map<std::string, base::flat_map<std::string, float>>
host_model_features_map_;
// The set of features available across every host in
// |host_model_features_map_|.
base::flat_set<std::string> supported_host_model_features_;
// The current session's FCP statistics for HTTP/HTTPS navigations.
OptimizationGuideSessionStatistic session_fcp_;
......
......@@ -4,7 +4,10 @@
#include "chrome/browser/optimization_guide/prediction/prediction_manager.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "base/strings/string_number_conversions.h"
#include "base/test/metrics/histogram_tester.h"
......@@ -94,9 +97,9 @@ std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse(
class TestPredictionModel : public PredictionModel {
public:
TestPredictionModel(std::unique_ptr<proto::PredictionModel> prediction_model,
const base::flat_set<std::string>& host_model_features)
: PredictionModel(std::move(prediction_model), host_model_features) {}
explicit TestPredictionModel(
std::unique_ptr<proto::PredictionModel> prediction_model)
: PredictionModel(std::move(prediction_model)) {}
~TestPredictionModel() override = default;
optimization_guide::OptimizationTargetDecision Predict(
......@@ -293,18 +296,15 @@ class TestPredictionManager : public PredictionManager {
~TestPredictionManager() override = default;
std::unique_ptr<PredictionModel> CreatePredictionModel(
const proto::PredictionModel& model,
const base::flat_set<std::string>& host_model_features) const override {
const proto::PredictionModel& model) const override {
std::unique_ptr<PredictionModel> prediction_model =
std::make_unique<TestPredictionModel>(
std::make_unique<proto::PredictionModel>(model),
host_model_features);
std::make_unique<proto::PredictionModel>(model));
return prediction_model;
}
using PredictionManager::GetHostModelFeaturesForTesting;
using PredictionManager::GetPredictionModelForTesting;
using PredictionManager::GetSupportedHostModelFeaturesForTesting;
std::unique_ptr<OptimizationGuideStore>
CreateModelAndHostModelFeaturesStore() {
......@@ -1214,22 +1214,6 @@ TEST_F(PredictionManagerTest,
EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
}
TEST_F(PredictionManagerTest, SupportedHostModelFeaturesUpdated) {
CreatePredictionManager(
{optimization_guide::proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD});
prediction_manager()->SetPredictionModelFetcherForTesting(
BuildTestPredictionModelFetcher(
PredictionModelFetcherEndState::kFetchFailed));
SetStoreInitialized();
EXPECT_TRUE(models_and_features_store()->WasHostModelFeaturesLoaded());
EXPECT_TRUE(prediction_manager()->GetHostModelFeaturesForTesting().contains(
"foo.com"));
EXPECT_TRUE(
prediction_manager()->GetSupportedHostModelFeaturesForTesting().contains(
"host_feat1"));
}
TEST_F(PredictionManagerTest, ModelFetcherTimerRetryDelay) {
base::test::ScopedFeatureList feature_list;
feature_list.InitWithFeatures(
......
......@@ -3,6 +3,9 @@
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include <utility>
#include "chrome/browser/optimization_guide/prediction/decision_tree_prediction_model.h"
namespace optimization_guide {
......@@ -10,8 +13,7 @@ namespace optimization_guide {
// static
std::unique_ptr<PredictionModel> PredictionModel::Create(
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model,
const base::flat_set<std::string>& host_model_features) {
prediction_model) {
// TODO(crbug/1009123): Add a histogram to record if the provided model is
// constructed successfully or not.
// TODO(crbug/1009123): Adding timing metrics around initialization due to
......@@ -55,7 +57,7 @@ std::unique_ptr<PredictionModel> PredictionModel::Create(
return nullptr;
}
model = std::make_unique<DecisionTreePredictionModel>(
std::move(prediction_model), host_model_features);
std::move(prediction_model));
// Any constructed model must be validated for correctness according to its
// model type before being returned.
......@@ -67,12 +69,11 @@ std::unique_ptr<PredictionModel> PredictionModel::Create(
PredictionModel::PredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model,
const base::flat_set<std::string>& host_model_features) {
prediction_model) {
version_ = prediction_model->model_info().version();
model_features_.reserve(
prediction_model->model_info().supported_model_features_size() +
host_model_features.size());
prediction_model->model_info().supported_host_model_features_size());
// Insert all the client model features for the owned |model_|.
for (const auto& client_model_feature :
prediction_model->model_info().supported_model_features()) {
......@@ -80,8 +81,10 @@ PredictionModel::PredictionModel(
client_model_feature));
}
// Insert all the host model features for the owned |model_|.
for (const auto& host_model_feature : host_model_features)
for (const auto& host_model_feature :
prediction_model->model_info().supported_host_model_features()) {
model_features_.emplace(host_model_feature);
}
model_ = std::make_unique<optimization_guide::proto::Model>(
prediction_model->model());
}
......
......@@ -29,8 +29,7 @@ class PredictionModel {
// should should be called in the background.
static std::unique_ptr<PredictionModel> Create(
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model,
const base::flat_set<std::string>& host_model_features);
prediction_model);
// Returns the OptimizationTargetDecision by evaluating the |model_|
// using the provided |model_features|. |prediction_score| will be populated
......@@ -48,8 +47,7 @@ class PredictionModel {
protected:
PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model,
const base::flat_set<std::string>& host_model_features);
prediction_model);
// The in-memory model used for prediction.
std::unique_ptr<optimization_guide::proto::Model> model_;
......
......@@ -4,7 +4,7 @@
#include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include <memory>
#include <utility>
#include "components/optimization_guide/proto/models.pb.h"
#include "testing/gtest/include/gtest/gtest.h"
......@@ -57,9 +57,10 @@ TEST(PredictionModelTest, ValidPredictionModel) {
model_info->add_supported_model_features(
optimization_guide::proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_EQ(1, model->GetVersion());
EXPECT_EQ(2u, model->GetModelFeatures().size());
......@@ -73,7 +74,7 @@ TEST(PredictionModelTest, NoModel) {
std::make_unique<optimization_guide::proto::PredictionModel>();
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -86,7 +87,7 @@ TEST(PredictionModelTest, NoModelVersion) {
decision_tree_model->set_weight(2.0);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -103,7 +104,7 @@ TEST(PredictionModelTest, NoModelType) {
model_info->set_version(1);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -122,7 +123,7 @@ TEST(PredictionModelTest, UnknownModelType) {
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -143,7 +144,7 @@ TEST(PredictionModelTest, MultipleModelTypes) {
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......@@ -166,7 +167,7 @@ TEST(PredictionModelTest, UnknownModelClientFeature) {
CLIENT_MODEL_FEATURE_UNKNOWN);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model), {"agg1"});
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
......
......@@ -199,6 +199,10 @@ message ModelInfo {
repeated ClientModelFeature supported_model_features = 3;
// The set of model types the requesting client can use to make predictions.
repeated ModelType supported_model_types = 4;
// The set of host model features that are referenced by the model.
//
// Note that this should only be populated if part of the response.
repeated string supported_host_model_features = 5;
}
// The scenarios for which the optimization guide has models for.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment