Commit 8f5bb31b authored by Michael Crouse's avatar Michael Crouse Committed by Commit Bot

Limit the number of host model features being requested/stored.

This change limits the number of hosts that are requested
requested by the prediction model fetcher from the remote
optimization guide service and is finch controllable. This will
improve the request latency and load on the remote optimization guide
service. It will also limit the total number of host model features
maintain on the client. Additionally, only hosts that are not already
in the host model features cache will be included in the request.

This change also implements a least-recently-used in-memory cache for
the host model features that the prediction manager maintains.
The maximum size is controllable via finch and the default is set
to hold 95% of all users hosts without needing to evict any host model
features.

Bug: 1036399
Change-Id: Ice9ab7c3c3505fc970ab72e1c78f5d628cd5157b
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1981632
Commit-Queue: Michael Crouse <mcrouse@chromium.org>
Reviewed-by: default avatarTarun Bansal <tbansal@chromium.org>
Reviewed-by: default avatarSophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#727838}
parent 732a75d5
......@@ -5,6 +5,8 @@
#include "chrome/browser/optimization_guide/optimization_guide_util.h"
#include "base/logging.h"
#include "net/base/url_util.h"
#include "url/url_canon.h"
std::string GetStringNameForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) {
......@@ -17,3 +19,15 @@ std::string GetStringNameForOptimizationTarget(
NOTREACHED();
return std::string();
}
bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host) {
if (net::HostStringIsLocalhost(host))
return false;
url::CanonHostInfo host_info;
std::string canonicalized_host(net::CanonicalizeHost(host, &host_info));
if (host_info.IsIPAddress() ||
!net::IsCanonicalizedHostCompliant(canonicalized_host)) {
return false;
}
return true;
}
......@@ -16,4 +16,8 @@
std::string GetStringNameForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target);
// Returns false if the host is an IP address, localhosts, or an invalid
// host that is not supported by the remote optimization guide.
bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host);
#endif // CHROME_BROWSER_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_UTIL_H_
......@@ -186,7 +186,9 @@ PredictionManager::PredictionManager(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
PrefService* pref_service,
Profile* profile)
: session_fcp_(),
: host_model_features_cache_(
std::max(features::MaxHostModelFeaturesCacheSize(), size_t(1))),
session_fcp_(),
top_host_provider_(top_host_provider),
model_and_features_store_(std::move(model_and_features_store)),
url_loader_factory_(url_loader_factory),
......@@ -327,7 +329,7 @@ base::Optional<float> PredictionManager::GetValueForClientFeature(
base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
content::NavigationHandle* navigation_handle,
const base::flat_set<std::string>& model_features) const {
const base::flat_set<std::string>& model_features) {
SEQUENCE_CHECKER(sequence_checker_);
base::flat_map<std::string, float> feature_map;
if (model_features.size() == 0)
......@@ -336,8 +338,8 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
const base::flat_map<std::string, float>* host_model_features = nullptr;
std::string host = navigation_handle->GetURL().host();
auto it = host_model_features_map_.find(host);
if (it != host_model_features_map_.end())
auto it = host_model_features_cache_.Get(host);
if (it != host_model_features_cache_.end())
host_model_features = &(it->second);
UMA_HISTOGRAM_BOOLEAN(
......@@ -367,7 +369,7 @@ base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
OptimizationTargetDecision PredictionManager::ShouldTargetNavigation(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) const {
proto::OptimizationTarget optimization_target) {
SEQUENCE_CHECKER(sequence_checker_);
DCHECK(navigation_handle->GetURL().SchemeIsHTTPOrHTTPS());
......@@ -457,9 +459,9 @@ PredictionModel* PredictionManager::GetPredictionModelForTesting(
return nullptr;
}
base::flat_map<std::string, base::flat_map<std::string, float>>
const HostModelFeaturesMRUCache*
PredictionManager::GetHostModelFeaturesForTesting() const {
return host_model_features_map_;
return &host_model_features_cache_;
}
void PredictionManager::SetPredictionModelFetcherForTesting(
......@@ -486,6 +488,19 @@ void PredictionManager::FetchModelsAndHostModelFeatures() {
std::vector<std::string> top_hosts = top_host_provider_->GetTopHosts();
// Remove hosts that are already available in the host model features cache.
// The request should still be made in case there is a new model or a model
// that does not rely on host model features to be fetched.
auto it = top_hosts.begin();
while (it != top_hosts.end()) {
if (host_model_features_cache_.Peek(*it) !=
host_model_features_cache_.end()) {
it = top_hosts.erase(it);
continue;
}
++it;
}
if (!prediction_model_fetcher_) {
prediction_model_fetcher_ = std::make_unique<PredictionModelFetcher>(
url_loader_factory_,
......@@ -620,6 +635,9 @@ void PredictionManager::OnHostModelFeaturesStored() {
// Clear any data remaining in the stored get models response.
get_models_response_data_to_store_.reset();
// Purge any expired host model features from the store.
model_and_features_store_->PurgeExpiredHostModelFeatures();
// TODO(crbug/1027596): Stopping the timer can be removed once the fetch
// callback refactor is done. Otherwise, at the start of a fetch, a timer is
// running to handle the cases that a fetch fails but the callback is not run.
......@@ -668,7 +686,7 @@ void PredictionManager::OnLoadHostModelFeatures(
}
UMA_HISTOGRAM_COUNTS_1000(
"OptimizationGuide.PredictionManager.HostModelFeaturesMapSize",
host_model_features_map_.size());
host_model_features_cache_.size());
// Load the prediction models for all the registered optimization targets now
// that it is not blocked by loading the host model features.
......@@ -771,8 +789,8 @@ bool PredictionManager::ProcessAndStoreHostModelFeatures(
}
if (model_features_for_host.size() == 0)
return false;
host_model_features_map_[host_model_features.host()] =
model_features_for_host;
host_model_features_cache_.Put(host_model_features.host(),
model_features_for_host);
return true;
}
......@@ -831,9 +849,17 @@ void PredictionManager::SetClockForTesting(const base::Clock* clock) {
}
void PredictionManager::ClearHostModelFeatures() {
host_model_features_map_.clear();
host_model_features_cache_.Clear();
if (model_and_features_store_)
model_and_features_store_->ClearHostModelFeaturesFromDatabase();
}
base::Optional<base::flat_map<std::string, float>>
PredictionManager::GetHostModelFeaturesForHost(const std::string& host) const {
auto it = host_model_features_cache_.Peek(host);
if (it == host_model_features_cache_.end())
return base::nullopt;
return it->second;
}
} // namespace optimization_guide
......@@ -11,6 +11,7 @@
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/containers/mru_cache.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
......@@ -49,6 +50,9 @@ class PredictionModel;
class PredictionModelFetcher;
class TopHostProvider;
using HostModelFeaturesMRUCache =
base::HashingMRUCache<std::string, base::flat_map<std::string, float>>;
// A PredictionManager supported by the optimization guide that makes an
// OptimizationTargetDecision by evaluating the corresponding prediction model
// for an OptimizationTarget.
......@@ -87,7 +91,7 @@ class PredictionManager
// if model for the optimization target is not currently on the client.
OptimizationTargetDecision ShouldTargetNavigation(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) const;
proto::OptimizationTarget optimization_target);
// Update |session_fcp_| and |previous_fcp_| with |fcp|.
void UpdateFCPSessionStatistics(base::TimeDelta fcp);
......@@ -133,8 +137,11 @@ class PredictionManager
// Return the host model features for all hosts used by this
// PredictionManager for testing.
base::flat_map<std::string, base::flat_map<std::string, float>>
GetHostModelFeaturesForTesting() const;
const HostModelFeaturesMRUCache* GetHostModelFeaturesForTesting() const;
// Returns the host model features for a host if available.
base::Optional<base::flat_map<std::string, float>>
GetHostModelFeaturesForHost(const std::string& host) const;
// Return the set of features that each host in |host_model_features_map_|
// contains for testing.
......@@ -166,10 +173,11 @@ class PredictionManager
optimization_targets_at_intialization);
// Construct and return a map containing the current feature values for the
// requested set of model features.
// requested set of model features. The host model features cache is updated
// based on if host model features were used.
base::flat_map<std::string, float> BuildFeatureMap(
content::NavigationHandle* navigation_handle,
const base::flat_set<std::string>& model_features) const;
const base::flat_set<std::string>& model_features);
// Calculate and return the current value for the client feature specified
// by |model_feature|. Return nullopt if the client does not support the
......@@ -273,12 +281,8 @@ class PredictionManager
// prediction manager.
base::flat_set<proto::OptimizationTarget> registered_optimization_targets_;
// A map of host to host model features known to the prediction manager.
//
// TODO(crbug/1001194): When loading features for the map, the size should be
// restricted.
base::flat_map<std::string, base::flat_map<std::string, float>>
host_model_features_map_;
// A MRU cache of host to host model features known to the prediction manager.
HostModelFeaturesMRUCache host_model_features_cache_;
// The current session's FCP statistics for HTTP/HTTPS navigations.
OptimizationGuideSessionStatistic session_fcp_;
......
......@@ -12,6 +12,7 @@
#include "base/feature_list.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "chrome/browser/optimization_guide/optimization_guide_util.h"
#include "components/optimization_guide/optimization_guide_features.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "content/public/browser/network_service_instance.h"
......@@ -70,8 +71,19 @@ bool PredictionModelFetcher::FetchOptimizationGuideServiceModels(
pending_models_request_->set_request_context(request_context);
for (const auto& host : hosts)
// Limit the number of hosts to fetch features for, the list of hosts
// is assumed to be ordered from most to least important by the top
// host provider.
for (const auto& host : hosts) {
// Skip over localhosts, IP addresses, and invalid hosts.
if (!IsHostValidToFetchFromRemoteOptimizationGuide(host))
continue;
pending_models_request_->add_hosts(host);
if (static_cast<size_t>(pending_models_request_->hosts_size()) >=
features::MaxHostsForOptimizationGuideServiceModelsFetch()) {
break;
}
}
for (const auto& model_request_info : models_request_info)
*pending_models_request_->add_requested_models() = model_request_info;
......@@ -122,7 +134,7 @@ bool PredictionModelFetcher::FetchOptimizationGuideServiceModels(
UMA_HISTOGRAM_COUNTS_100(
"OptimizationGuide.PredictionModelFetcher."
"GetModelsRequest.HostCount",
hosts.size());
pending_models_request_->hosts_size());
// |url_loader_| should not retry on 5xx errors since the server may already
// be overloaded. |url_loader_| should retry on network changes since the
......
......@@ -3,6 +3,7 @@
// found in the LICENSE file.
#include <memory>
#include <string>
#include <vector>
#include "base/callback.h"
......@@ -10,10 +11,12 @@
#include "base/memory/scoped_refptr.h"
#include "base/optional.h"
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_fetcher.h"
#include "components/optimization_guide/optimization_guide_features.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "net/base/url_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
......@@ -138,6 +141,59 @@ TEST_F(PredictionModelFetcherTest, FetchOptimizationGuideServiceModels) {
0, 1);
}
TEST_F(PredictionModelFetcherTest,
FetchOptimizationGuideServiceModelsLimitHosts) {
base::HistogramTester histogram_tester;
std::string response_content;
std::vector<std::string> hosts;
for (size_t i = 0;
i <= features::MaxHostsForOptimizationGuideServiceModelsFetch() + 1; i++)
hosts.push_back("host" + base::NumberToString(i) + ".com");
std::vector<optimization_guide::proto::ModelInfo> models_request_info({});
EXPECT_TRUE(FetchModels(
models_request_info, hosts,
optimization_guide::proto::RequestContext::CONTEXT_BATCH_UPDATE));
VerifyHasPendingFetchRequests();
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.PredictionModelFetcher.GetModelsRequest.HostCount",
features::MaxHostsForOptimizationGuideServiceModelsFetch(), 1);
EXPECT_TRUE(SimulateResponse(response_content, net::HTTP_OK));
EXPECT_TRUE(models_fetched());
// No HostModelFeatures are returned.
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.PredictionModelFetcher.GetModelsResponse."
"HostModelFeatureCount",
0, 1);
}
TEST_F(PredictionModelFetcherTest, FetchFilterInvalidHosts) {
base::HistogramTester histogram_tester;
std::string response_content;
std::vector<std::string> hosts = {"192.168.1.1", "_abc", "localhost",
"foo.com"};
std::vector<optimization_guide::proto::ModelInfo> models_request_info({});
EXPECT_TRUE(FetchModels(
models_request_info, hosts,
optimization_guide::proto::RequestContext::CONTEXT_BATCH_UPDATE));
VerifyHasPendingFetchRequests();
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.PredictionModelFetcher.GetModelsRequest.HostCount", 1,
1);
EXPECT_TRUE(SimulateResponse(response_content, net::HTTP_OK));
EXPECT_TRUE(models_fetched());
// No HostModelFeatures are returned.
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.PredictionModelFetcher.GetModelsResponse."
"HostModelFeatureCount",
0, 1);
}
// Tests 404 response from request.
TEST_F(PredictionModelFetcherTest, FetchReturned404) {
base::HistogramTester histogram_tester;
......
......@@ -193,6 +193,17 @@ base::TimeDelta StoredHostModelFeaturesFreshnessDuration() {
"max_store_duration_for_host_model_features_in_days", 7));
}
size_t MaxHostsForOptimizationGuideServiceModelsFetch() {
return GetFieldTrialParamByFeatureAsInt(
kOptimizationTargetPrediction,
"max_hosts_for_optimization_guide_service_models_fetch", 30);
}
size_t MaxHostModelFeaturesCacheSize() {
return GetFieldTrialParamByFeatureAsInt(
kOptimizationTargetPrediction, "max_host_model_features_cache_size", 100);
}
bool IsOptimizationTargetPredictionEnabled() {
return base::FeatureList::IsEnabled(kOptimizationTargetPrediction);
}
......
......@@ -99,6 +99,14 @@ bool IsOptimizationTargetPredictionEnabled();
// to be used and remain in the OptimizationGuideStore.
base::TimeDelta StoredHostModelFeaturesFreshnessDuration();
// The maximum number of hosts allowed to be requested by the client to the
// remote Optimzation Guide Service for use by prediction models.
size_t MaxHostsForOptimizationGuideServiceModelsFetch();
// The maximum number of hosts allowed to be maintained in a least-recently-used
// cache by the prediction manager.
size_t MaxHostModelFeaturesCacheSize();
// Returns true if the optimization target decision for |optimization_target|
// should not be propagated to the caller in an effort to fully understand the
// statistics for the served model and not taint the resulting data.
......
......@@ -88,6 +88,12 @@ bool DatabasePrefixFilter(const std::string& key_prefix,
return base::StartsWith(key, key_prefix, base::CompareCase::SENSITIVE);
}
// Returns true if |key| is in |keys_to_remove|.
bool ExpiredKeyFilter(const base::flat_set<std::string>& keys_to_remove,
const std::string& key) {
return keys_to_remove.find(key) != keys_to_remove.end();
}
} // namespace
OptimizationGuideStore::OptimizationGuideStore(
......@@ -258,40 +264,50 @@ void OptimizationGuideStore::UpdateFetchedHints(
void OptimizationGuideStore::PurgeExpiredFetchedHints() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!IsAvailable()) {
if (!IsAvailable())
return;
}
// Load all the fetched hints to check their expiry times.
database_->LoadKeysAndEntriesWithFilter(
base::BindRepeating(&DatabasePrefixFilter,
GetFetchedHintEntryKeyPrefix()),
base::BindOnce(&OptimizationGuideStore::OnLoadFetchedHintsToPurgeExpired,
base::BindOnce(&OptimizationGuideStore::OnLoadEntriesToPurgeExpired,
weak_ptr_factory_.GetWeakPtr()));
}
void OptimizationGuideStore::PurgeExpiredHostModelFeatures() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!IsAvailable())
return;
// Load all the host model features to check their expiry times.
database_->LoadKeysAndEntriesWithFilter(
base::BindRepeating(&DatabasePrefixFilter,
GetHostModelFeaturesEntryKeyPrefix()),
base::BindOnce(&OptimizationGuideStore::OnLoadEntriesToPurgeExpired,
weak_ptr_factory_.GetWeakPtr()));
}
void OptimizationGuideStore::OnLoadFetchedHintsToPurgeExpired(
void OptimizationGuideStore::OnLoadEntriesToPurgeExpired(
bool success,
std::unique_ptr<EntryMap> fetched_entries) {
std::unique_ptr<EntryMap> entries) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!success) {
if (!success)
return;
}
auto keys_to_remove = std::make_unique<EntryKeySet>();
EntryKeySet expired_keys_to_remove;
int64_t now_since_epoch =
base::Time::Now().ToDeltaSinceWindowsEpoch().InSeconds();
for (const auto& entry : *fetched_entries) {
if (entry.second.expiry_time_secs() <= now_since_epoch) {
keys_to_remove->insert(entry.first);
for (const auto& entry : *entries) {
if (entry.second.has_expiry_time_secs() &&
entry.second.expiry_time_secs() <= now_since_epoch) {
expired_keys_to_remove.insert(entry.first);
}
}
// TODO(mcrouse): Record the number of hints that will be expired from the
// store.
data_update_in_flight_ = true;
entry_keys_.reset();
......@@ -299,11 +315,7 @@ void OptimizationGuideStore::OnLoadFetchedHintsToPurgeExpired(
database_->UpdateEntriesWithRemoveFilter(
std::move(empty_entries),
base::BindRepeating(
[](EntryKeySet* keys_to_remove, const std::string& key) {
return keys_to_remove->find(key) != keys_to_remove->end();
},
keys_to_remove.get()),
base::BindRepeating(&ExpiredKeyFilter, std::move(expired_keys_to_remove)),
base::BindOnce(&OptimizationGuideStore::OnUpdateStore,
weak_ptr_factory_.GetWeakPtr(), base::DoNothing::Once()));
}
......
......@@ -170,6 +170,11 @@ class OptimizationGuideStore {
// removed.
void PurgeExpiredFetchedHints();
// Removes all host model features that have expired from the store.
// |entry_keys_| is updated after the expired host model features are
// removed.
void PurgeExpiredHostModelFeatures();
// Creates and returns a StoreUpdateData object for Prediction Models. This
// object is used to collect a batch of prediction models in a format that is
// usable to update the store on a background thread. This is always created
......@@ -344,11 +349,10 @@ class OptimizationGuideStore {
EntryKey* out_entry_key,
const EntryKeyPrefix& entry_key_prefix) const;
// Callback that identifies any expired hints from |fetched_entries| and
// Callback that identifies any expired |entries| and
// asynchronously removes them from the store.
void OnLoadFetchedHintsToPurgeExpired(
bool success,
std::unique_ptr<EntryMap> fetched_entries);
void OnLoadEntriesToPurgeExpired(bool success,
std::unique_ptr<EntryMap> entries);
// Callback that runs after the database finishes being initialized. If
// |purge_existing_data| is true, then unconditionally purges the database;
......
......@@ -314,7 +314,18 @@ class OptimizationGuideStoreTest : public testing::Test {
void PurgeExpiredFetchedHints() {
guide_store()->PurgeExpiredFetchedHints();
// OnFetchedHintsLoadedToMaybePurge
// OnLoadExpiredEntriesToPurge
db()->LoadCallback(true);
// OnUpdateStore
db()->UpdateCallback(true);
// OnLoadEntryKeys callback
db()->LoadCallback(true);
}
void PurgeExpiredHostModelFeatures() {
guide_store()->PurgeExpiredHostModelFeatures();
// OnLoadExpiredEntriesToPurge
db()->LoadCallback(true);
// OnUpdateStore
db()->UpdateCallback(true);
......@@ -2019,4 +2030,41 @@ TEST_F(OptimizationGuideStoreTest, ClearHostModelFeatures) {
}
}
TEST_F(OptimizationGuideStoreTest, PurgeExpiredHostModelFeatures) {
base::HistogramTester histogram_tester;
size_t update_host_model_features_count = 5;
MetadataSchemaState schema_state = MetadataSchemaState::kValid;
base::Time update_time = base::Time().Now();
SeedInitialData(schema_state, 0, base::Time().Now());
CreateDatabase();
InitializeStore(schema_state);
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForHostModelFeatures(
update_time, update_time -
optimization_guide::features::
StoredHostModelFeaturesFreshnessDuration());
ASSERT_TRUE(update_data);
SeedHostModelFeaturesUpdateData(update_data.get(),
update_host_model_features_count);
UpdateHostModelFeatures(std::move(update_data));
for (size_t i = 0; i < update_host_model_features_count; ++i) {
std::string host_suffix = GetHostSuffix(i);
OptimizationGuideStore::EntryKey entry_key;
EXPECT_TRUE(
guide_store()->FindHostModelFeaturesEntryKey(host_suffix, &entry_key));
}
// Remove expired host model features from the opt. guide store.
PurgeExpiredHostModelFeatures();
for (size_t i = 0; i < update_host_model_features_count; ++i) {
std::string host_suffix = GetHostSuffix(i);
OptimizationGuideStore::EntryKey entry_key;
EXPECT_FALSE(
guide_store()->FindHostModelFeaturesEntryKey(host_suffix, &entry_key));
}
}
} // namespace optimization_guide
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment