Commit da6a55db authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Skip over models that have download URL populated

Bug: 1146151
Change-Id: I11e43077984753aa8bbae25abd7841e2fa862500
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2530200Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#826008}
parent f61db86c
...@@ -824,7 +824,17 @@ void PredictionManager::UpdatePredictionModels( ...@@ -824,7 +824,17 @@ void PredictionManager::UpdatePredictionModels(
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
std::unique_ptr<StoreUpdateData> prediction_model_update_data = std::unique_ptr<StoreUpdateData> prediction_model_update_data =
StoreUpdateData::CreatePredictionModelStoreUpdateData(); StoreUpdateData::CreatePredictionModelStoreUpdateData();
bool has_models_to_update = false;
for (const auto& model : prediction_models) { for (const auto& model : prediction_models) {
if (model.has_model() && !model.model().download_url().empty()) {
// Skip over models that have a download URL since they will be updated
// out-of-band.
// TODO(crbug/1146151): Download model from URL.
continue;
}
has_models_to_update = true;
// Storing the model regardless of whether the model is valid or not. Model // Storing the model regardless of whether the model is valid or not. Model
// will be removed from store if it fails to load. // will be removed from store if it fails to load.
prediction_model_update_data->CopyPredictionModelIntoUpdateData(model); prediction_model_update_data->CopyPredictionModelIntoUpdateData(model);
...@@ -835,10 +845,13 @@ void PredictionManager::UpdatePredictionModels( ...@@ -835,10 +845,13 @@ void PredictionManager::UpdatePredictionModels(
model.model_info().version()); model.model_info().version());
OnLoadPredictionModel(std::make_unique<proto::PredictionModel>(model)); OnLoadPredictionModel(std::make_unique<proto::PredictionModel>(model));
} }
model_and_features_store_->UpdatePredictionModels(
std::move(prediction_model_update_data), if (has_models_to_update) {
base::BindOnce(&PredictionManager::OnPredictionModelsStored, model_and_features_store_->UpdatePredictionModels(
ui_weak_ptr_factory_.GetWeakPtr())); std::move(prediction_model_update_data),
base::BindOnce(&PredictionManager::OnPredictionModelsStored,
ui_weak_ptr_factory_.GetWeakPtr()));
}
} }
void PredictionManager::OnPredictionModelsStored() { void PredictionManager::OnPredictionModelsStored() {
......
...@@ -56,7 +56,8 @@ constexpr int kUpdateFetchModelAndFeaturesTimeSecs = 24 * 60 * 60; // 24 hours. ...@@ -56,7 +56,8 @@ constexpr int kUpdateFetchModelAndFeaturesTimeSecs = 24 * 60 * 60; // 24 hours.
namespace optimization_guide { namespace optimization_guide {
std::unique_ptr<proto::PredictionModel> CreatePredictionModel() { std::unique_ptr<proto::PredictionModel> CreatePredictionModel(
bool output_model_as_download_url = false) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>(); std::make_unique<optimization_guide::proto::PredictionModel>();
...@@ -71,13 +72,17 @@ std::unique_ptr<proto::PredictionModel> CreatePredictionModel() { ...@@ -71,13 +72,17 @@ std::unique_ptr<proto::PredictionModel> CreatePredictionModel() {
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
model_info->add_supported_model_types( model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE); proto::ModelType::MODEL_TYPE_DECISION_TREE);
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0); if (output_model_as_download_url)
prediction_model->mutable_model()->set_download_url("someurl");
else
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
return prediction_model; return prediction_model;
} }
std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse( std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse(
const std::vector<std::string>& hosts, const std::vector<std::string>& hosts,
const std::vector<proto::ClientModelFeature>& client_model_features) { const std::vector<proto::ClientModelFeature>& client_model_features,
bool output_model_as_download_url = false) {
std::unique_ptr<proto::GetModelsResponse> get_models_response = std::unique_ptr<proto::GetModelsResponse> get_models_response =
std::make_unique<proto::GetModelsResponse>(); std::make_unique<proto::GetModelsResponse>();
...@@ -92,7 +97,7 @@ std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse( ...@@ -92,7 +97,7 @@ std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse(
} }
std::unique_ptr<proto::PredictionModel> prediction_model = std::unique_ptr<proto::PredictionModel> prediction_model =
CreatePredictionModel(); CreatePredictionModel(output_model_as_download_url);
for (const auto& client_model_feature : client_model_features) { for (const auto& client_model_feature : client_model_features) {
prediction_model->mutable_model_info()->add_supported_model_features( prediction_model->mutable_model_info()->add_supported_model_features(
client_model_feature); client_model_feature);
...@@ -165,6 +170,7 @@ enum class PredictionModelFetcherEndState { ...@@ -165,6 +170,7 @@ enum class PredictionModelFetcherEndState {
kFetchFailed = 0, kFetchFailed = 0,
kFetchSuccessWithModelsAndHostsModelFeatures = 1, kFetchSuccessWithModelsAndHostsModelFeatures = 1,
kFetchSuccessWithEmptyResponse = 2, kFetchSuccessWithEmptyResponse = 2,
kFetchSuccessWithModelDownloadUrls = 3,
}; };
// A mock class implementation of PredictionModelFetcher. // A mock class implementation of PredictionModelFetcher.
...@@ -206,6 +212,12 @@ class TestPredictionModelFetcher : public PredictionModelFetcher { ...@@ -206,6 +212,12 @@ class TestPredictionModelFetcher : public PredictionModelFetcher {
.Run(BuildGetModelsResponse({} /* hosts */, .Run(BuildGetModelsResponse({} /* hosts */,
{} /* client model features */)); {} /* client model features */));
return true; return true;
case PredictionModelFetcherEndState::kFetchSuccessWithModelDownloadUrls:
models_fetched_ = true;
std::move(models_fetched_callback)
.Run(BuildGetModelsResponse(hosts, {},
/*output_model_as_download_url=*/true));
return true;
} }
return true; return true;
} }
...@@ -1130,6 +1142,42 @@ TEST_P(PredictionManagerMLServiceTest, ...@@ -1130,6 +1142,42 @@ TEST_P(PredictionManagerMLServiceTest,
} }
} }
TEST_P(PredictionManagerMLServiceTest, UpdateModelWithDownloadUrl) {
base::test::ScopedFeatureList scoped_feature_list;
if (UsingMLService()) {
scoped_feature_list.InitWithFeatures(
{optimization_guide::features::
kOptimizationTargetPredictionUsingMLService},
{});
SetLoadModelResult(machine_learning::mojom::LoadModelResult::kOk);
}
base::HistogramTester histogram_tester;
std::unique_ptr<content::MockNavigationHandle> navigation_handle =
CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver(
GURL("https://foo.com"));
CreatePredictionManager({});
prediction_manager()->SetPredictionModelFetcherForTesting(
BuildTestPredictionModelFetcher(
PredictionModelFetcherEndState::kFetchSuccessWithModelDownloadUrls));
prediction_manager()->RegisterOptimizationTargets(
{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD});
SetStoreInitialized();
EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
models_and_features_store()->RunUpdateHostModelFeaturesCallback();
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.PredictionManager.HostModelFeaturesStored", true, 1);
histogram_tester.ExpectTotalCount(
"OptimizationGuide.PredictionManager.PredictionModelsStored", 0);
// TODO(crbug/1146151): Update test to incorporate downloading of model.
}
TEST_P(PredictionManagerMLServiceTest, TEST_P(PredictionManagerMLServiceTest,
EvaluatePredictionModelPopulatesNavData) { EvaluatePredictionModelPopulatesNavData) {
base::test::ScopedFeatureList scoped_feature_list; base::test::ScopedFeatureList scoped_feature_list;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment