Commit a16b542b authored by Sophie Chang's avatar Sophie Chang Committed by Commit Bot

Delete model files when models are updated/removed in the store

Bug: 1146151
Change-Id: I713082856f769e17f01955f721abd1c4b1375da0
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2548707Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Reviewed-by: default avatarRobert Ogden <robertogden@chromium.org>
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#829490}
parent 87f7d073
......@@ -280,9 +280,10 @@ class TestPredictionModelFetcher : public PredictionModelFetcher {
class TestOptimizationGuideStore : public OptimizationGuideStore {
public:
explicit TestOptimizationGuideStore(
std::unique_ptr<StoreEntryProtoDatabase> database)
: OptimizationGuideStore(std::move(database)) {}
TestOptimizationGuideStore(
std::unique_ptr<StoreEntryProtoDatabase> database,
scoped_refptr<base::SequencedTaskRunner> store_task_runner)
: OptimizationGuideStore(std::move(database), store_task_runner) {}
~TestOptimizationGuideStore() override = default;
......@@ -388,9 +389,10 @@ class TestPredictionManager : public PredictionManager {
TopHostProvider* top_host_provider,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
PrefService* pref_service,
Profile* profile)
Profile* profile,
scoped_refptr<base::SequencedTaskRunner> task_runner)
: PredictionManager(optimization_targets_at_initialization,
CreateModelAndHostModelFeaturesStore(),
CreateModelAndHostModelFeaturesStore(task_runner),
top_host_provider,
url_loader_factory,
pref_service,
......@@ -416,13 +418,14 @@ class TestPredictionManager : public PredictionManager {
using PredictionManager::GetHostModelFeaturesForTesting;
using PredictionManager::GetPredictionModelForTesting;
std::unique_ptr<OptimizationGuideStore>
CreateModelAndHostModelFeaturesStore() {
std::unique_ptr<OptimizationGuideStore> CreateModelAndHostModelFeaturesStore(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
// Setup the fake db and the class under test.
auto db = std::make_unique<FakeDB<StoreEntry>>(&db_store_);
std::unique_ptr<OptimizationGuideStore> model_and_features_store =
std::make_unique<TestOptimizationGuideStore>(std::move(db));
std::make_unique<TestOptimizationGuideStore>(std::move(db),
task_runner);
return model_and_features_store;
}
......@@ -478,7 +481,7 @@ class PredictionManagerTest
prediction_manager_ = std::make_unique<TestPredictionManager>(
optimization_targets_at_initialization, temp_dir(), db_provider_.get(),
top_host_provider_.get(), url_loader_factory_, pref_service_.get(),
&testing_profile_);
&testing_profile_, task_environment_.GetMainThreadTaskRunner());
prediction_manager_->SetClockForTesting(task_environment_.GetMockClock());
}
......@@ -490,7 +493,8 @@ class PredictionManagerTest
prediction_manager_ = std::make_unique<TestPredictionManager>(
optimization_targets_at_initialization, temp_dir(), db_provider_.get(),
nullptr, url_loader_factory_, pref_service_.get(), &testing_profile_);
nullptr, url_loader_factory_, pref_service_.get(), &testing_profile_,
task_environment_.GetMainThreadTaskRunner());
prediction_manager_->SetClockForTesting(task_environment_.GetMockClock());
}
......
......@@ -6,6 +6,7 @@
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/files/file_util.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/sequence_checker.h"
......@@ -17,6 +18,7 @@
#include "components/leveldb_proto/public/shared_proto_database_client_list.h"
#include "components/optimization_guide/memory_hint.h"
#include "components/optimization_guide/optimization_guide_prefs.h"
#include "components/optimization_guide/optimization_guide_util.h"
#include "components/optimization_guide/proto/hint_cache.pb.h"
namespace optimization_guide {
......@@ -88,10 +90,10 @@ bool DatabasePrefixFilter(const std::string& key_prefix,
return base::StartsWith(key, key_prefix, base::CompareCase::SENSITIVE);
}
// Returns true if |key| is in |keys_to_remove|.
bool ExpiredKeyFilter(const base::flat_set<std::string>& keys_to_remove,
const std::string& key) {
return keys_to_remove.find(key) != keys_to_remove.end();
// Returns true if |key| is in |key_set|.
bool KeySetFilter(const base::flat_set<std::string>& key_set,
const std::string& key) {
return key_set.find(key) != key_set.end();
}
} // namespace
......@@ -99,17 +101,19 @@ bool ExpiredKeyFilter(const base::flat_set<std::string>& keys_to_remove,
OptimizationGuideStore::OptimizationGuideStore(
leveldb_proto::ProtoDatabaseProvider* database_provider,
const base::FilePath& database_dir,
scoped_refptr<base::SequencedTaskRunner> store_task_runner) {
scoped_refptr<base::SequencedTaskRunner> store_task_runner)
: store_task_runner_(store_task_runner) {
database_ = database_provider->GetDB<proto::StoreEntry>(
leveldb_proto::ProtoDbType::HINT_CACHE_STORE, database_dir,
store_task_runner);
store_task_runner_);
RecordStatusChange(status_);
}
OptimizationGuideStore::OptimizationGuideStore(
std::unique_ptr<leveldb_proto::ProtoDatabase<proto::StoreEntry>> database)
: database_(std::move(database)) {
std::unique_ptr<leveldb_proto::ProtoDatabase<proto::StoreEntry>> database,
scoped_refptr<base::SequencedTaskRunner> store_task_runner)
: database_(std::move(database)), store_task_runner_(store_task_runner) {
RecordStatusChange(status_);
}
......@@ -300,11 +304,9 @@ void OptimizationGuideStore::OnLoadEntriesToPurgeExpired(
entry_keys_.reset();
auto empty_entries = std::make_unique<EntryVector>();
database_->UpdateEntriesWithRemoveFilter(
std::move(empty_entries),
base::BindRepeating(&ExpiredKeyFilter, std::move(expired_keys_to_remove)),
std::make_unique<EntryVector>(),
base::BindRepeating(&KeySetFilter, std::move(expired_keys_to_remove)),
base::BindOnce(&OptimizationGuideStore::OnUpdateStore,
weak_ptr_factory_.GetWeakPtr(), base::DoNothing::Once()));
}
......@@ -822,11 +824,48 @@ void OptimizationGuideStore::UpdatePredictionModels(
return;
}
std::unique_ptr<EntryVector> entry_vectors =
std::unique_ptr<EntryVector> entry_vector =
prediction_models_update_data->TakeUpdateEntries();
EntryKeySet keys_to_update;
for (const auto& entry : *entry_vector)
keys_to_update.insert(entry.first);
// Load the models that are to be updated and delete the old model file, if
// applicable.
database_->LoadKeysAndEntriesWithFilter(
base::BindRepeating(&KeySetFilter, std::move(keys_to_update)),
base::BindOnce(&OptimizationGuideStore::OnLoadModelsToBeUpdated,
weak_ptr_factory_.GetWeakPtr(), std::move(entry_vector),
std::make_unique<leveldb_proto::KeyVector>(),
std::move(callback)));
}
void OptimizationGuideStore::OnLoadModelsToBeUpdated(
std::unique_ptr<EntryVector> update_vector,
std::unique_ptr<leveldb_proto::KeyVector> remove_vector,
base::OnceClosure callback,
bool success,
std::unique_ptr<EntryMap> entries) {
if (!success) {
std::move(callback).Run();
return;
}
for (const auto& entry : *entries) {
// Delete models that are provided via file.
if (entry.second.has_prediction_model() &&
entry.second.prediction_model().model().has_download_url()) {
store_task_runner_->PostTask(
FROM_HERE, base::BindOnce(base::GetDeleteFileCallback(),
GetFilePathFromPredictionModel(
entry.second.prediction_model())
.value()));
}
}
database_->UpdateEntries(
std::move(entry_vectors), std::make_unique<leveldb_proto::KeyVector>(),
std::move(update_vector), std::move(remove_vector),
base::BindOnce(&OptimizationGuideStore::OnUpdateStore,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
......@@ -857,20 +896,18 @@ bool OptimizationGuideStore::RemovePredictionModelFromEntryKey(
auto key_to_remove = std::make_unique<leveldb_proto::KeyVector>();
key_to_remove->push_back(entry_key);
database_->UpdateEntries(
std::make_unique<EntryVector>(), std::move(key_to_remove),
base::BindOnce(
&OptimizationGuideStore::OnRemovePredictionModelFromEntryKey,
weak_ptr_factory_.GetWeakPtr(), entry_key));
return true;
}
EntryKeySet key_set;
key_set.insert(entry_key);
// Load the model that is to be removed and delete the old model file, if
// applicable.
database_->LoadKeysAndEntriesWithFilter(
base::BindRepeating(&KeySetFilter, std::move(key_set)),
base::BindOnce(&OptimizationGuideStore::OnLoadModelsToBeUpdated,
weak_ptr_factory_.GetWeakPtr(),
std::make_unique<EntryVector>(), std::move(key_to_remove),
base::DoNothing::Once()));
void OptimizationGuideStore::OnRemovePredictionModelFromEntryKey(
const EntryKey& entry_key,
bool success) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (success)
entry_keys_->erase(entry_key);
return true;
}
void OptimizationGuideStore::LoadPredictionModel(
......@@ -899,9 +936,8 @@ void OptimizationGuideStore::OnLoadPredictionModel(
// request was started, then the loaded model should not be considered valid.
// Reset the entry so that nothing is returned to
// the requester.
if (!success || !IsAvailable()) {
if (!success || !IsAvailable())
entry.reset();
}
if (!entry || !entry->has_prediction_model()) {
std::unique_ptr<proto::PredictionModel> loaded_prediction_model(nullptr);
......@@ -911,7 +947,41 @@ void OptimizationGuideStore::OnLoadPredictionModel(
std::unique_ptr<proto::PredictionModel> loaded_prediction_model(
entry->release_prediction_model());
std::move(callback).Run(std::move(loaded_prediction_model));
if (!loaded_prediction_model->model().has_download_url()) {
std::move(callback).Run(std::move(loaded_prediction_model));
return;
}
// Make sure the path still exists before we send it back to the load
// initiator.
base::FilePath file_path =
GetFilePathFromPredictionModel(*loaded_prediction_model).value();
store_task_runner_->PostTaskAndReplyWithResult(
FROM_HERE, base::BindOnce(&base::PathExists, file_path),
base::BindOnce(&OptimizationGuideStore::OnModelFilePathVerified,
weak_ptr_factory_.GetWeakPtr(),
std::move(loaded_prediction_model), std::move(callback)));
}
void OptimizationGuideStore::OnModelFilePathVerified(
std::unique_ptr<proto::PredictionModel> loaded_model,
PredictionModelLoadedCallback callback,
bool success) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (success) {
std::move(callback).Run(std::move(loaded_model));
return;
}
// If the model no longer exists, remove the prediction model from the store.
DCHECK(loaded_model);
OptimizationGuideStore::EntryKey model_entry_key;
if (FindPredictionModelEntryKey(
loaded_model->model_info().optimization_target(), &model_entry_key)) {
RemovePredictionModelFromEntryKey(model_entry_key);
}
std::move(callback).Run(nullptr);
}
std::unique_ptr<StoreUpdateData>
......
......@@ -99,7 +99,8 @@ class OptimizationGuideStore {
scoped_refptr<base::SequencedTaskRunner> store_task_runner);
// For tests only.
explicit OptimizationGuideStore(
std::unique_ptr<StoreEntryProtoDatabase> database);
std::unique_ptr<StoreEntryProtoDatabase> database,
scoped_refptr<base::SequencedTaskRunner> store_task_runner);
virtual ~OptimizationGuideStore();
// Initializes the store. If |purge_existing_data| is set to true,
......@@ -415,11 +416,24 @@ class OptimizationGuideStore {
bool success,
std::unique_ptr<proto::StoreEntry> entry);
// Callback that runs after a removal attempt for the prediction model
// specified by |entry_key| with status |success|. It removes |entry_key| from
// |entry_keys_| if |success| is true, and no-op if false.
void OnRemovePredictionModelFromEntryKey(const EntryKey& entry_key,
bool success);
// Callback that runs when the prediction models that need to be updated and
// removed are loaded from the database. This will remove the files associated
// with those models and run the update routine with |update_vector| and
// |remove_vector| after that.
void OnLoadModelsToBeUpdated(
std::unique_ptr<EntryVector> update_vector,
std::unique_ptr<leveldb_proto::KeyVector> remove_vector,
base::OnceClosure callback,
bool success,
std::unique_ptr<EntryMap> entries);
// Callback that runs after the download URL in |loaded_model| has been
// verified. If |success| is false, the associated entry from |database_| will
// be removed and |callback| will run as if the model is not loaded.
void OnModelFilePathVerified(
std::unique_ptr<proto::PredictionModel> loaded_model,
PredictionModelLoadedCallback callback,
bool success);
// Callback that runs after a host model features entry is loaded from the
// database. If there's currently an in-flight update, then the data could be
......@@ -474,6 +488,9 @@ class OptimizationGuideStore {
// The keys of the entries available within the store.
std::unique_ptr<EntryKeySet> entry_keys_;
// The background task runner used to perform operations on the store.
scoped_refptr<base::SequencedTaskRunner> store_task_runner_;
SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<OptimizationGuideStore> weak_ptr_factory_{this};
......
......@@ -6,6 +6,8 @@
#include "base/containers/flat_set.h"
#include "base/notreached.h"
#include "base/strings/utf_string_conversions.h"
#include "build/build_config.h"
#include "components/optimization_guide/optimization_guide_features.h"
#include "components/variations/active_field_trials.h"
#include "net/base/url_util.h"
......@@ -96,4 +98,27 @@ GetActiveFieldTrialsAllowedForFetch() {
return filtered_active_field_trials;
}
base::Optional<base::FilePath> GetFilePathFromPredictionModel(
const proto::PredictionModel& model) {
if (!model.model().has_download_url())
return base::nullopt;
#if defined(OS_WIN)
return base::FilePath(base::UTF8ToWide(model.model().download_url()));
#else
return base::FilePath(model.model().download_url());
#endif
}
void SetFilePathInPredictionModel(const base::FilePath& file_path,
proto::PredictionModel* model) {
DCHECK(model);
#if defined(OS_WIN)
model->mutable_model()->set_download_url(base::WideToUTF8(file_path.value()));
#else
model->mutable_model()->set_download_url(file_path.value());
#endif
}
} // namespace optimization_guide
......@@ -7,6 +7,8 @@
#include <string>
#include "base/files/file_path.h"
#include "base/optional.h"
#include "components/optimization_guide/optimization_guide_decider.h"
#include "components/optimization_guide/optimization_guide_enums.h"
#include "components/optimization_guide/proto/common_types.pb.h"
......@@ -35,6 +37,15 @@ GetOptimizationGuideDecisionFromOptimizationTypeDecision(
google::protobuf::RepeatedPtrField<proto::FieldTrial>
GetActiveFieldTrialsAllowedForFetch();
// Returns the file path that holds the model file for |model|, if applicable.
base::Optional<base::FilePath> GetFilePathFromPredictionModel(
const proto::PredictionModel& model);
// Fills |model| with the path for which the corresponding model file can be
// found.
void SetFilePathInPredictionModel(const base::FilePath& file_path,
proto::PredictionModel* model);
} // namespace optimization_guide
#endif // COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_UTIL_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment