Commit 7825f4da authored by Jeremy Roman's avatar Jeremy Roman Committed by Chromium LUCI CQ

optimization_guide: Miscellaneous cleanup around PredictionModel.

* PredictionModel was using SequenceChecker incorrectly, to no effect.
  Since the fields it was guarding are immutable, this sequence check
  was in any event unnecessary. Instead, the members have been made
  const.

* PredictionModel::model_info_ was never set or read. It is removed.

* The PredictionModel constructor has been made explicit, consistent
  with Chromium style.

* The code was very inconsistent about whether optimization_guide::
  qualification was used. For consistency, brevity and style it has
  been rewritten in these places to not be qualified.

* It is inefficient to construct a flat_set by repeated insertion.
  It is equally easy, and preferred (see base/containers/flat_set.h)
  to construct a vector that can be sorted on construction. This
  reduces complexity from O(n^2) to O(n log n).

* PredictionModel::GetModelFeatures returned a copy, unneecessarily.
  It now returns a const reference.

* Fewer conversions between std::unique_ptr<proto::PredictionModel>
  and const proto::PredictionModel& are now done. std::unique_ptr
  is left where ownership is passed or where it interfaces with
  other code that uses unique_ptr in a more complicated way. This
  reduces copies and heap allocations. Regrettably it requires a
  number of trivial changes to the unit tests.

* DecisionTreePredictionModel::ValidateTreeNode's node_index argument
  is changed to int, which is more efficient and more usual than
  const int&.

Change-Id: Ib6b9d77d4f5941f578d3213c323ead01319715b2
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2607485
Commit-Queue: Jeremy Roman <jbroman@chromium.org>
Reviewed-by: default avatarSophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#840333}
parent 2eb448a6
...@@ -735,8 +735,7 @@ void PredictionManager::UpdateHostModelFeatures( ...@@ -735,8 +735,7 @@ void PredictionManager::UpdateHostModelFeatures(
std::unique_ptr<PredictionModel> PredictionManager::CreatePredictionModel( std::unique_ptr<PredictionModel> PredictionManager::CreatePredictionModel(
const proto::PredictionModel& model) const { const proto::PredictionModel& model) const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return PredictionModel::Create( return PredictionModel::Create(model);
std::make_unique<proto::PredictionModel>(model));
} }
void PredictionManager::UpdatePredictionModels( void PredictionManager::UpdatePredictionModels(
...@@ -916,26 +915,26 @@ void PredictionManager::OnLoadPredictionModel( ...@@ -916,26 +915,26 @@ void PredictionManager::OnLoadPredictionModel(
return; return;
bool success = ProcessAndStoreLoadedModel(*model); bool success = ProcessAndStoreLoadedModel(*model);
OnProcessLoadedModel(std::move(model), success); OnProcessLoadedModel(*model, success);
} }
void PredictionManager::OnProcessLoadedModel( void PredictionManager::OnProcessLoadedModel(
std::unique_ptr<proto::PredictionModel> model, const proto::PredictionModel& model,
bool success) { bool success) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (success) { if (success) {
base::UmaHistogramSparse( base::UmaHistogramSparse(
"OptimizationGuide.PredictionModelLoadedVersion." + "OptimizationGuide.PredictionModelLoadedVersion." +
optimization_guide::GetStringNameForOptimizationTarget( optimization_guide::GetStringNameForOptimizationTarget(
model->model_info().optimization_target()), model.model_info().optimization_target()),
model->model_info().version()); model.model_info().version());
return; return;
} }
// Remove model from store if it exists. // Remove model from store if it exists.
OptimizationGuideStore::EntryKey model_entry_key; OptimizationGuideStore::EntryKey model_entry_key;
if (model_and_features_store_->FindPredictionModelEntryKey( if (model_and_features_store_->FindPredictionModelEntryKey(
model->model_info().optimization_target(), &model_entry_key)) { model.model_info().optimization_target(), &model_entry_key)) {
model_and_features_store_->RemovePredictionModelFromEntryKey( model_and_features_store_->RemovePredictionModelFromEntryKey(
model_entry_key); model_entry_key);
} }
......
...@@ -318,8 +318,7 @@ class PredictionManager ...@@ -318,8 +318,7 @@ class PredictionManager
std::unique_ptr<PredictionModel> prediction_model); std::unique_ptr<PredictionModel> prediction_model);
// Post-processing callback invoked after processing |model|. // Post-processing callback invoked after processing |model|.
void OnProcessLoadedModel(std::unique_ptr<proto::PredictionModel> model, void OnProcessLoadedModel(const proto::PredictionModel& model, bool success);
bool success);
// Process |host_model_features| from the into host model features // Process |host_model_features| from the into host model features
// usable by the PredictionManager. The processed host model features are // usable by the PredictionManager. The processed host model features are
......
...@@ -9,18 +9,17 @@ ...@@ -9,18 +9,17 @@
namespace optimization_guide { namespace optimization_guide {
DecisionTreePredictionModel::DecisionTreePredictionModel( DecisionTreePredictionModel::DecisionTreePredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel> const proto::PredictionModel& prediction_model)
prediction_model) : PredictionModel(prediction_model) {}
: PredictionModel(std::move(prediction_model)) {}
DecisionTreePredictionModel::~DecisionTreePredictionModel() = default; DecisionTreePredictionModel::~DecisionTreePredictionModel() = default;
bool DecisionTreePredictionModel::ValidatePredictionModel() const { bool DecisionTreePredictionModel::ValidatePredictionModel() const {
// Only the top-level ensemble or decision tree must have a threshold. Any // Only the top-level ensemble or decision tree must have a threshold. Any
// submodels of an ensemble will have model weights but no threshold. // submodels of an ensemble will have model weights but no threshold.
if (!model_->has_threshold()) if (!model_.has_threshold())
return false; return false;
return ValidateModel(*model_.get()); return ValidateModel(model_);
} }
bool DecisionTreePredictionModel::ValidateModel( bool DecisionTreePredictionModel::ValidateModel(
...@@ -77,7 +76,7 @@ bool DecisionTreePredictionModel::ValidateInequalityTest( ...@@ -77,7 +76,7 @@ bool DecisionTreePredictionModel::ValidateInequalityTest(
bool DecisionTreePredictionModel::ValidateTreeNode( bool DecisionTreePredictionModel::ValidateTreeNode(
const proto::DecisionTree& tree, const proto::DecisionTree& tree,
const proto::TreeNode& node, const proto::TreeNode& node,
const int& node_index) const { int node_index) const {
if (node.has_leaf()) if (node.has_leaf())
return ValidateLeaf(node.leaf()); return ValidateLeaf(node.leaf());
...@@ -119,19 +118,18 @@ bool DecisionTreePredictionModel::ValidateTreeNode( ...@@ -119,19 +118,18 @@ bool DecisionTreePredictionModel::ValidateTreeNode(
return true; return true;
} }
optimization_guide::OptimizationTargetDecision OptimizationTargetDecision DecisionTreePredictionModel::Predict(
DecisionTreePredictionModel::Predict(
const base::flat_map<std::string, float>& model_features, const base::flat_map<std::string, float>& model_features,
double* prediction_score) { double* prediction_score) {
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
*prediction_score = 0.0; *prediction_score = 0.0;
// TODO(mcrouse): Add metrics to record if the model evaluation fails. // TODO(mcrouse): Add metrics to record if the model evaluation fails.
if (!EvaluateModel(*model_.get(), model_features, prediction_score)) if (!EvaluateModel(model_, model_features, prediction_score))
return optimization_guide::OptimizationTargetDecision::kUnknown; return OptimizationTargetDecision::kUnknown;
if (*prediction_score > model_->threshold().value()) if (*prediction_score > model_.threshold().value())
return optimization_guide::OptimizationTargetDecision::kPageLoadMatches; return OptimizationTargetDecision::kPageLoadMatches;
return optimization_guide::OptimizationTargetDecision::kPageLoadDoesNotMatch; return OptimizationTargetDecision::kPageLoadDoesNotMatch;
} }
bool DecisionTreePredictionModel::TraverseTree( bool DecisionTreePredictionModel::TraverseTree(
......
...@@ -22,13 +22,12 @@ namespace optimization_guide { ...@@ -22,13 +22,12 @@ namespace optimization_guide {
class DecisionTreePredictionModel : public PredictionModel { class DecisionTreePredictionModel : public PredictionModel {
public: public:
explicit DecisionTreePredictionModel( explicit DecisionTreePredictionModel(
std::unique_ptr<optimization_guide::proto::PredictionModel> const proto::PredictionModel& prediction_model);
prediction_model);
~DecisionTreePredictionModel() override; ~DecisionTreePredictionModel() override;
// PredictionModel implementation: // PredictionModel implementation:
optimization_guide::OptimizationTargetDecision Predict( OptimizationTargetDecision Predict(
const base::flat_map<std::string, float>& model_features, const base::flat_map<std::string, float>& model_features,
double* prediction_score) override; double* prediction_score) override;
...@@ -89,7 +88,7 @@ class DecisionTreePredictionModel : public PredictionModel { ...@@ -89,7 +88,7 @@ class DecisionTreePredictionModel : public PredictionModel {
// node of the |tree|. Returns false if any part of the tree is invalid. // node of the |tree|. Returns false if any part of the tree is invalid.
bool ValidateTreeNode(const proto::DecisionTree& tree, bool ValidateTreeNode(const proto::DecisionTree& tree,
const proto::TreeNode& node, const proto::TreeNode& node,
const int& node_index) const; int node_index) const;
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
......
...@@ -12,52 +12,50 @@ namespace optimization_guide { ...@@ -12,52 +12,50 @@ namespace optimization_guide {
// static // static
std::unique_ptr<PredictionModel> PredictionModel::Create( std::unique_ptr<PredictionModel> PredictionModel::Create(
std::unique_ptr<optimization_guide::proto::PredictionModel> const proto::PredictionModel& prediction_model) {
prediction_model) {
// TODO(crbug/1009123): Add a histogram to record if the provided model is // TODO(crbug/1009123): Add a histogram to record if the provided model is
// constructed successfully or not. // constructed successfully or not.
// TODO(crbug/1009123): Adding timing metrics around initialization due to // TODO(crbug/1009123): Adding timing metrics around initialization due to
// potential validation overhead. // potential validation overhead.
if (!prediction_model->has_model()) if (!prediction_model.has_model())
return nullptr; return nullptr;
if (!prediction_model->has_model_info()) if (!prediction_model.has_model_info())
return nullptr; return nullptr;
if (!prediction_model->model_info().has_version()) if (!prediction_model.model_info().has_version())
return nullptr; return nullptr;
// Enforce that only one ModelType is specified for the PredictionModel. // Enforce that only one ModelType is specified for the PredictionModel.
if (prediction_model->model_info().supported_model_types_size() != 1) { if (prediction_model.model_info().supported_model_types_size() != 1) {
return nullptr; return nullptr;
} }
// Check that the client supports this type of model and is not an unknown // Check that the client supports this type of model and is not an unknown
// type. // type.
if (!optimization_guide::proto::ModelType_IsValid( if (!proto::ModelType_IsValid(
prediction_model->model_info().supported_model_types(0)) || prediction_model.model_info().supported_model_types(0)) ||
prediction_model->model_info().supported_model_types(0) == prediction_model.model_info().supported_model_types(0) ==
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN) { proto::ModelType::MODEL_TYPE_UNKNOWN) {
return nullptr; return nullptr;
} }
// Check that the client supports the model features for |prediction model|. // Check that the client supports the model features for |prediction model|.
for (const auto& model_feature : for (const auto& model_feature :
prediction_model->model_info().supported_model_features()) { prediction_model.model_info().supported_model_features()) {
if (!optimization_guide::proto::ClientModelFeature_IsValid(model_feature) || if (!proto::ClientModelFeature_IsValid(model_feature) ||
model_feature == optimization_guide::proto::ClientModelFeature:: model_feature ==
CLIENT_MODEL_FEATURE_UNKNOWN) proto::ClientModelFeature::CLIENT_MODEL_FEATURE_UNKNOWN)
return nullptr; return nullptr;
} }
std::unique_ptr<PredictionModel> model; std::unique_ptr<PredictionModel> model;
// The Decision Tree model type is currently the only supported model type. // The Decision Tree model type is currently the only supported model type.
if (prediction_model->model_info().supported_model_types(0) != if (prediction_model.model_info().supported_model_types(0) !=
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE) { proto::ModelType::MODEL_TYPE_DECISION_TREE) {
return nullptr; return nullptr;
} }
model = std::make_unique<DecisionTreePredictionModel>( model = std::make_unique<DecisionTreePredictionModel>(prediction_model);
std::move(prediction_model));
// Any constructed model must be validated for correctness according to its // Any constructed model must be validated for correctness according to its
// model type before being returned. // model type before being returned.
...@@ -67,37 +65,32 @@ std::unique_ptr<PredictionModel> PredictionModel::Create( ...@@ -67,37 +65,32 @@ std::unique_ptr<PredictionModel> PredictionModel::Create(
return model; return model;
} }
PredictionModel::PredictionModel( namespace {
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model) { std::vector<std::string> ComputeModelFeatures(
version_ = prediction_model->model_info().version(); const proto::ModelInfo& model_info) {
model_features_.reserve( std::vector<std::string> features;
prediction_model->model_info().supported_model_features_size() + features.reserve(model_info.supported_model_features_size() +
prediction_model->model_info().supported_host_model_features_size()); model_info.supported_host_model_features_size());
// Insert all the client model features for the owned |model_|. // Insert all the client model features for the owned |model_|.
for (const auto& client_model_feature : for (const auto& client_model_feature :
prediction_model->model_info().supported_model_features()) { model_info.supported_model_features()) {
model_features_.emplace(optimization_guide::proto::ClientModelFeature_Name( features.push_back(proto::ClientModelFeature_Name(client_model_feature));
client_model_feature));
} }
// Insert all the host model features for the owned |model_|. // Insert all the host model features for the owned |model_|.
for (const auto& host_model_feature : for (const auto& host_model_feature :
prediction_model->model_info().supported_host_model_features()) { model_info.supported_host_model_features()) {
model_features_.emplace(host_model_feature); features.push_back(host_model_feature);
} }
model_ = std::make_unique<optimization_guide::proto::Model>( return features;
prediction_model->model());
} }
int64_t PredictionModel::GetVersion() const { } // namespace
SEQUENCE_CHECKER(sequence_checker_);
return version_;
}
base::flat_set<std::string> PredictionModel::GetModelFeatures() const { PredictionModel::PredictionModel(const proto::PredictionModel& prediction_model)
SEQUENCE_CHECKER(sequence_checker_); : model_(prediction_model.model()),
return model_features_; model_features_(ComputeModelFeatures(prediction_model.model_info())),
} version_(prediction_model.model_info().version()) {}
PredictionModel::~PredictionModel() = default; PredictionModel::~PredictionModel() = default;
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include "base/containers/flat_map.h" #include "base/containers/flat_map.h"
#include "base/containers/flat_set.h" #include "base/containers/flat_set.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/sequence_checker.h"
#include "components/optimization_guide/optimization_guide_enums.h" #include "components/optimization_guide/optimization_guide_enums.h"
#include "components/optimization_guide/proto/models.pb.h" #include "components/optimization_guide/proto/models.pb.h"
...@@ -28,45 +27,40 @@ class PredictionModel { ...@@ -28,45 +27,40 @@ class PredictionModel {
// |prediction_model|. The validation overhead of this factory can be high and // |prediction_model|. The validation overhead of this factory can be high and
// should should be called in the background. // should should be called in the background.
static std::unique_ptr<PredictionModel> Create( static std::unique_ptr<PredictionModel> Create(
std::unique_ptr<optimization_guide::proto::PredictionModel> const proto::PredictionModel& prediction_model);
prediction_model);
// Returns the OptimizationTargetDecision by evaluating the |model_| // Returns the OptimizationTargetDecision by evaluating the |model_|
// using the provided |model_features|. |prediction_score| will be populated // using the provided |model_features|. |prediction_score| will be populated
// with the score output by the model. // with the score output by the model.
virtual optimization_guide::OptimizationTargetDecision Predict( virtual OptimizationTargetDecision Predict(
const base::flat_map<std::string, float>& model_features, const base::flat_map<std::string, float>& model_features,
double* prediction_score) = 0; double* prediction_score) = 0;
// Provide the version of the |model_| by |this|. // Provide the version of the |model_| by |this|.
int64_t GetVersion() const; int64_t GetVersion() const { return version_; }
// Provide the model features required for evaluation of the |model_| by // Provide the model features required for evaluation of the |model_| by
// |this|. // |this|.
base::flat_set<std::string> GetModelFeatures() const; const base::flat_set<std::string>& GetModelFeatures() const {
return model_features_;
}
protected: protected:
PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel> explicit PredictionModel(const proto::PredictionModel& prediction_model);
prediction_model);
// The in-memory model used for prediction. // The in-memory model used for prediction.
std::unique_ptr<optimization_guide::proto::Model> model_; const proto::Model model_;
private: private:
// Determines if the |model_| is complete and can be successfully evaluated by // Determines if the |model_| is complete and can be successfully evaluated by
// |this|. // |this|.
virtual bool ValidatePredictionModel() const = 0; virtual bool ValidatePredictionModel() const = 0;
// The information that describes the |model_|
std::unique_ptr<optimization_guide::proto::ModelInfo> model_info_;
// The set of features required by the |model_| to be evaluated. // The set of features required by the |model_| to be evaluated.
base::flat_set<std::string> model_features_; const base::flat_set<std::string> model_features_;
// The version of the |model_|. // The version of the |model_|.
int64_t version_; const int64_t version_;
SEQUENCE_CHECKER(sequence_checker_);
DISALLOW_COPY_AND_ASSIGN(PredictionModel); DISALLOW_COPY_AND_ASSIGN(PredictionModel);
}; };
......
...@@ -12,9 +12,8 @@ ...@@ -12,9 +12,8 @@
namespace optimization_guide { namespace optimization_guide {
TEST(PredictionModelTest, ValidPredictionModel) { TEST(PredictionModelTest, ValidPredictionModel) {
std::unique_ptr<proto::PredictionModel> prediction_model = proto::PredictionModel prediction_model;
std::make_unique<proto::PredictionModel>(); prediction_model.mutable_model()->mutable_threshold()->set_value(5.0);
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
proto::DecisionTree decision_tree_model = proto::DecisionTree(); proto::DecisionTree decision_tree_model = proto::DecisionTree();
decision_tree_model.set_weight(2.0); decision_tree_model.set_weight(2.0);
...@@ -46,21 +45,20 @@ TEST(PredictionModelTest, ValidPredictionModel) { ...@@ -46,21 +45,20 @@ TEST(PredictionModelTest, ValidPredictionModel) {
tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value( tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
4.); 4.);
*prediction_model->mutable_model()->mutable_decision_tree() = *prediction_model.mutable_model()->mutable_decision_tree() =
decision_tree_model; decision_tree_model;
optimization_guide::proto::ModelInfo* model_info = proto::ModelInfo* model_info = prediction_model.mutable_model_info();
prediction_model->mutable_model_info();
model_info->set_version(1); model_info->set_version(1);
model_info->add_supported_model_types( model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE); proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features( model_info->add_supported_model_features(
optimization_guide::proto::ClientModelFeature:: proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE); CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1"); model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model)); PredictionModel::Create(prediction_model);
EXPECT_EQ(1, model->GetVersion()); EXPECT_EQ(1, model->GetVersion());
EXPECT_EQ(2u, model->GetModelFeatures().size()); EXPECT_EQ(2u, model->GetModelFeatures().size());
...@@ -70,37 +68,33 @@ TEST(PredictionModelTest, ValidPredictionModel) { ...@@ -70,37 +68,33 @@ TEST(PredictionModelTest, ValidPredictionModel) {
} }
TEST(PredictionModelTest, NoModel) { TEST(PredictionModelTest, NoModel) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = proto::PredictionModel prediction_model;
std::make_unique<optimization_guide::proto::PredictionModel>();
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model)); PredictionModel::Create(prediction_model);
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
TEST(PredictionModelTest, NoModelVersion) { TEST(PredictionModelTest, NoModelVersion) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = proto::PredictionModel prediction_model;
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model = proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree(); prediction_model.mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0); decision_tree_model->set_weight(2.0);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model)); PredictionModel::Create(prediction_model);
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
TEST(PredictionModelTest, NoModelType) { TEST(PredictionModelTest, NoModelType) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = proto::PredictionModel prediction_model;
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model = proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree(); prediction_model.mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0); decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info = proto::ModelInfo* model_info = prediction_model.mutable_model_info();
prediction_model->mutable_model_info();
model_info->set_version(1); model_info->set_version(1);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
...@@ -109,65 +103,56 @@ TEST(PredictionModelTest, NoModelType) { ...@@ -109,65 +103,56 @@ TEST(PredictionModelTest, NoModelType) {
} }
TEST(PredictionModelTest, UnknownModelType) { TEST(PredictionModelTest, UnknownModelType) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = proto::PredictionModel prediction_model;
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model = proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree(); prediction_model.mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0); decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info = proto::ModelInfo* model_info = prediction_model.mutable_model_info();
prediction_model->mutable_model_info();
model_info->set_version(1); model_info->set_version(1);
model_info->add_supported_model_types( model_info->add_supported_model_types(proto::ModelType::MODEL_TYPE_UNKNOWN);
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model)); PredictionModel::Create(prediction_model);
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
TEST(PredictionModelTest, MultipleModelTypes) { TEST(PredictionModelTest, MultipleModelTypes) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = proto::PredictionModel prediction_model;
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model = proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree(); prediction_model.mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0); decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info = proto::ModelInfo* model_info = prediction_model.mutable_model_info();
prediction_model->mutable_model_info();
model_info->set_version(1); model_info->set_version(1);
model_info->add_supported_model_types( model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE); proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_types( model_info->add_supported_model_types(proto::ModelType::MODEL_TYPE_UNKNOWN);
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model)); PredictionModel::Create(prediction_model);
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
TEST(PredictionModelTest, UnknownModelClientFeature) { TEST(PredictionModelTest, UnknownModelClientFeature) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = proto::PredictionModel prediction_model;
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model = proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree(); prediction_model.mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0); decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info = proto::ModelInfo* model_info = prediction_model.mutable_model_info();
prediction_model->mutable_model_info();
model_info->set_version(1); model_info->set_version(1);
model_info->add_supported_model_types( model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE); proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features( model_info->add_supported_model_features(
optimization_guide::proto::ClientModelFeature:: proto::ClientModelFeature::CLIENT_MODEL_FEATURE_UNKNOWN);
CLIENT_MODEL_FEATURE_UNKNOWN);
std::unique_ptr<PredictionModel> model = std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model)); PredictionModel::Create(prediction_model);
EXPECT_FALSE(model); EXPECT_FALSE(model);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment