Commit 86cb6592 authored by juliatuttle's avatar juliatuttle Committed by Commit bot

Reporting: Cap number of clients in cache.

BUG=704259

Review-Url: https://codereview.chromium.org/2851603002
Cr-Commit-Position: refs/heads/master@{#471071}
parent 0d001e69
......@@ -12,6 +12,7 @@
#include "base/memory/ptr_util.h"
#include "base/stl_util.h"
#include "base/time/tick_clock.h"
#include "base/time/time.h"
#include "net/reporting/reporting_client.h"
#include "net/reporting/reporting_context.h"
......@@ -191,56 +192,77 @@ void ReportingCache::SetClient(const url::Origin& origin,
base::TimeTicks expires) {
DCHECK(endpoint.SchemeIsCryptographic());
// Since |subdomains| may differ from a previous call to SetClient for this
// origin and endpoint, the cache needs to remove and re-add the client to the
// index of wildcard clients, if applicable.
if (base::ContainsKey(clients_, origin) &&
base::ContainsKey(clients_[origin], endpoint)) {
MaybeRemoveWildcardClient(clients_[origin][endpoint].get());
}
base::TimeTicks last_used = tick_clock()->NowTicks();
clients_[origin][endpoint] = base::MakeUnique<ReportingClient>(
origin, endpoint, subdomains, group, expires);
const ReportingClient* old_client =
GetClientByOriginAndEndpoint(origin, endpoint);
if (old_client) {
last_used = client_last_used_[old_client];
RemoveClient(old_client);
}
MaybeAddWildcardClient(clients_[origin][endpoint].get());
AddClient(base::MakeUnique<ReportingClient>(origin, endpoint, subdomains,
group, expires),
last_used);
if (client_last_used_.size() > context_->policy().max_client_count) {
// There should only ever be one extra client, added above.
DCHECK_EQ(context_->policy().max_client_count + 1,
client_last_used_.size());
// And that shouldn't happen if it was replaced, not added.
DCHECK(!old_client);
const ReportingClient* to_evict =
FindClientToEvict(tick_clock()->NowTicks());
DCHECK(to_evict);
RemoveClient(to_evict);
}
context_->NotifyCacheUpdated();
}
void ReportingCache::MarkClientUsed(const url::Origin& origin,
const GURL& endpoint) {
const ReportingClient* client =
GetClientByOriginAndEndpoint(origin, endpoint);
DCHECK(client);
client_last_used_[client] = tick_clock()->NowTicks();
}
void ReportingCache::RemoveClients(
const std::vector<const ReportingClient*>& clients_to_remove) {
for (const ReportingClient* client : clients_to_remove) {
MaybeRemoveWildcardClient(client);
size_t erased = clients_[client->origin].erase(client->endpoint);
DCHECK_EQ(1u, erased);
}
for (const ReportingClient* client : clients_to_remove)
RemoveClient(client);
context_->NotifyCacheUpdated();
}
void ReportingCache::RemoveClientForOriginAndEndpoint(const url::Origin& origin,
const GURL& endpoint) {
MaybeRemoveWildcardClient(clients_[origin][endpoint].get());
size_t erased = clients_[origin].erase(endpoint);
DCHECK_EQ(1u, erased);
const ReportingClient* client =
GetClientByOriginAndEndpoint(origin, endpoint);
RemoveClient(client);
context_->NotifyCacheUpdated();
}
void ReportingCache::RemoveClientsForEndpoint(const GURL& endpoint) {
for (auto& origin_and_endpoints : clients_) {
if (base::ContainsKey(origin_and_endpoints.second, endpoint)) {
MaybeRemoveWildcardClient(origin_and_endpoints.second[endpoint].get());
origin_and_endpoints.second.erase(endpoint);
}
}
std::vector<const ReportingClient*> clients_to_remove;
context_->NotifyCacheUpdated();
for (auto& origin_and_endpoints : clients_)
if (base::ContainsKey(origin_and_endpoints.second, endpoint))
clients_to_remove.push_back(origin_and_endpoints.second[endpoint].get());
for (const ReportingClient* client : clients_to_remove)
RemoveClient(client);
if (!clients_to_remove.empty())
context_->NotifyCacheUpdated();
}
void ReportingCache::RemoveAllClients() {
clients_.clear();
wildcard_clients_.clear();
client_last_used_.clear();
context_->NotifyCacheUpdated();
}
......@@ -260,22 +282,68 @@ const ReportingReport* ReportingCache::FindReportToEvict() const {
return earliest_queued;
}
void ReportingCache::MaybeAddWildcardClient(const ReportingClient* client) {
if (client->subdomains != ReportingClient::Subdomains::INCLUDE)
return;
void ReportingCache::AddClient(std::unique_ptr<ReportingClient> client,
base::TimeTicks last_used) {
DCHECK(client);
const std::string& domain = client->origin.host();
auto inserted = wildcard_clients_[domain].insert(client);
DCHECK(inserted.second);
url::Origin origin = client->origin;
GURL endpoint = client->endpoint;
auto inserted_last_used =
client_last_used_.insert(std::make_pair(client.get(), last_used));
DCHECK(inserted_last_used.second);
if (client->subdomains == ReportingClient::Subdomains::INCLUDE) {
const std::string& domain = origin.host();
auto inserted_wildcard_client =
wildcard_clients_[domain].insert(client.get());
DCHECK(inserted_wildcard_client.second);
}
auto inserted_client =
clients_[origin].insert(std::make_pair(endpoint, std::move(client)));
DCHECK(inserted_client.second);
}
void ReportingCache::MaybeRemoveWildcardClient(const ReportingClient* client) {
if (client->subdomains != ReportingClient::Subdomains::INCLUDE)
return;
void ReportingCache::RemoveClient(const ReportingClient* client) {
DCHECK(client);
const std::string& domain = client->origin.host();
size_t erased = wildcard_clients_[domain].erase(client);
DCHECK_EQ(1u, erased);
url::Origin origin = client->origin;
GURL endpoint = client->endpoint;
if (client->subdomains == ReportingClient::Subdomains::INCLUDE) {
const std::string& domain = origin.host();
size_t erased_wildcard_client = wildcard_clients_[domain].erase(client);
DCHECK_EQ(1u, erased_wildcard_client);
if (wildcard_clients_[domain].empty()) {
size_t erased_wildcard_domain = wildcard_clients_.erase(domain);
DCHECK_EQ(1u, erased_wildcard_domain);
}
}
size_t erased_last_used = client_last_used_.erase(client);
DCHECK_EQ(1u, erased_last_used);
size_t erased_endpoint = clients_[origin].erase(endpoint);
DCHECK_EQ(1u, erased_endpoint);
if (clients_[origin].empty()) {
size_t erased_origin = clients_.erase(origin);
DCHECK_EQ(1u, erased_origin);
}
}
const ReportingClient* ReportingCache::GetClientByOriginAndEndpoint(
const url::Origin& origin,
const GURL& endpoint) const {
const auto& origin_it = clients_.find(origin);
if (origin_it == clients_.end())
return nullptr;
const auto& endpoint_it = origin_it->second.find(endpoint);
if (endpoint_it == origin_it->second.end())
return nullptr;
return endpoint_it->second.get();
}
void ReportingCache::GetWildcardClientsForDomainAndGroup(
......@@ -295,4 +363,37 @@ void ReportingCache::GetWildcardClientsForDomainAndGroup(
}
}
const ReportingClient* ReportingCache::FindClientToEvict(
base::TimeTicks now) const {
DCHECK(!client_last_used_.empty());
const ReportingClient* earliest_used = nullptr;
base::TimeTicks earliest_used_last_used;
const ReportingClient* earliest_expired = nullptr;
for (const auto& it : client_last_used_) {
const ReportingClient* client = it.first;
base::TimeTicks client_last_used = it.second;
if (earliest_used == nullptr ||
client_last_used < earliest_used_last_used) {
earliest_used = client;
earliest_used_last_used = client_last_used;
}
if (earliest_expired == nullptr ||
client->expires < earliest_expired->expires) {
earliest_expired = client;
}
}
// If there are expired clients, return the earliest-expired.
if (earliest_expired->expires < now)
return earliest_expired;
else
return earliest_used;
}
base::TickClock* ReportingCache::tick_clock() {
return context_->tick_clock();
}
} // namespace net
......@@ -22,6 +22,10 @@
#include "url/gurl.h"
#include "url/origin.h"
namespace base {
class TickClock;
} // namespace base
namespace net {
class ReportingContext;
......@@ -98,6 +102,8 @@ class NET_EXPORT ReportingCache {
const std::string& group,
base::TimeTicks expires);
void MarkClientUsed(const url::Origin& origin, const GURL& endpoint);
// Gets all of the clients in the cache, regardless of origin or group.
//
// (Clears any existing data in |*clients_out|.)
......@@ -161,27 +167,25 @@ class NET_EXPORT ReportingCache {
private:
const ReportingReport* FindReportToEvict() const;
void MaybeAddWildcardClient(const ReportingClient* client);
void AddClient(std::unique_ptr<ReportingClient> client,
base::TimeTicks last_used);
void MaybeRemoveWildcardClient(const ReportingClient* client);
void RemoveClient(const ReportingClient* client);
const ReportingClient* GetClientByOriginAndEndpoint(
const url::Origin& origin,
const GURL& endpoint) const;
void GetWildcardClientsForDomainAndGroup(
const std::string& domain,
const std::string& group,
std::vector<const ReportingClient*>* clients_out) const;
ReportingContext* context_;
const ReportingClient* FindClientToEvict(base::TimeTicks now) const;
// Owns all clients, keyed by origin, then endpoint URL.
// (These would be unordered_map, but neither url::Origin nor GURL has a hash
// function implemented.)
std::map<url::Origin, std::map<GURL, std::unique_ptr<ReportingClient>>>
clients_;
base::TickClock* tick_clock();
// References but does not own all clients with includeSubdomains set, keyed
// by domain name.
std::unordered_map<std::string, std::unordered_set<const ReportingClient*>>
wildcard_clients_;
ReportingContext* context_;
// Owns all reports, keyed by const raw pointer for easier lookup.
std::unordered_map<const ReportingReport*, std::unique_ptr<ReportingReport>>
......@@ -195,6 +199,20 @@ class NET_EXPORT ReportingCache {
// pending when the deletion was requested).
std::unordered_set<const ReportingReport*> doomed_reports_;
// Owns all clients, keyed by origin, then endpoint URL.
// (These would be unordered_map, but neither url::Origin nor GURL has a hash
// function implemented.)
std::map<url::Origin, std::map<GURL, std::unique_ptr<ReportingClient>>>
clients_;
// References but does not own all clients with includeSubdomains set, keyed
// by domain name.
std::unordered_map<std::string, std::unordered_set<const ReportingClient*>>
wildcard_clients_;
// The time that each client has last been used.
std::unordered_map<const ReportingClient*, base::TimeTicks> client_last_used_;
DISALLOW_COPY_AND_ASSIGN(ReportingCache);
};
......
......@@ -7,6 +7,7 @@
#include <string>
#include "base/memory/ptr_util.h"
#include "base/strings/stringprintf.h"
#include "base/test/simple_test_tick_clock.h"
#include "base/time/time.h"
#include "base/values.h"
......@@ -36,6 +37,11 @@ class TestReportingObserver : public ReportingObserver {
class ReportingCacheTest : public ReportingTestBase {
protected:
ReportingCacheTest() : ReportingTestBase() {
ReportingPolicy policy;
policy.max_report_count = 5;
policy.max_client_count = 5;
UsePolicy(policy);
context()->AddObserver(&observer_);
}
......@@ -49,6 +55,12 @@ class ReportingCacheTest : public ReportingTestBase {
return reports.size();
}
size_t client_count() {
std::vector<const ReportingClient*> clients;
cache()->GetClients(&clients);
return clients.size();
}
const GURL kUrl1_ = GURL("https://origin1/path");
const url::Origin kOrigin1_ = url::Origin(GURL("https://origin1/"));
const url::Origin kOrigin2_ = url::Origin(GURL("https://origin2/"));
......@@ -407,20 +419,22 @@ TEST_F(ReportingCacheTest, IncludeSubdomainsPreferMoreSpecificSuperdomain) {
EXPECT_EQ(kSuperOrigin, clients[0]->origin);
}
TEST_F(ReportingCacheTest, EvictOldest) {
ASSERT_LT(0u, policy().max_report_count);
ASSERT_GT(std::numeric_limits<size_t>::max(), policy().max_report_count);
TEST_F(ReportingCacheTest, EvictOldestReport) {
size_t max_report_count = policy().max_report_count;
ASSERT_LT(0u, max_report_count);
ASSERT_GT(std::numeric_limits<size_t>::max(), max_report_count);
base::TimeTicks earliest_queued = tick_clock()->NowTicks();
// Enqueue the maximum number of reports, spaced apart in time.
for (size_t i = 0; i < policy().max_report_count; ++i) {
for (size_t i = 0; i < max_report_count; ++i) {
cache()->AddReport(kUrl1_, kGroup1_, kType_,
base::MakeUnique<base::DictionaryValue>(),
tick_clock()->NowTicks(), 0);
tick_clock()->Advance(base::TimeDelta::FromMinutes(1));
}
EXPECT_EQ(policy().max_report_count, report_count());
EXPECT_EQ(max_report_count, report_count());
// Add one more report to force the cache to evict one.
cache()->AddReport(kUrl1_, kGroup1_, kType_,
......@@ -430,23 +444,25 @@ TEST_F(ReportingCacheTest, EvictOldest) {
// sure the report evicted was the earliest-queued one.
std::vector<const ReportingReport*> reports;
cache()->GetReports(&reports);
EXPECT_EQ(policy().max_report_count, reports.size());
EXPECT_EQ(max_report_count, reports.size());
for (const ReportingReport* report : reports)
EXPECT_NE(earliest_queued, report->queued);
}
TEST_F(ReportingCacheTest, DontEvictPendingReports) {
ASSERT_LT(0u, policy().max_report_count);
ASSERT_GT(std::numeric_limits<size_t>::max(), policy().max_report_count);
size_t max_report_count = policy().max_report_count;
ASSERT_LT(0u, max_report_count);
ASSERT_GT(std::numeric_limits<size_t>::max(), max_report_count);
// Enqueue the maximum number of reports, spaced apart in time.
for (size_t i = 0; i < policy().max_report_count; ++i) {
for (size_t i = 0; i < max_report_count; ++i) {
cache()->AddReport(kUrl1_, kGroup1_, kType_,
base::MakeUnique<base::DictionaryValue>(),
tick_clock()->NowTicks(), 0);
tick_clock()->Advance(base::TimeDelta::FromMinutes(1));
}
EXPECT_EQ(policy().max_report_count, report_count());
EXPECT_EQ(max_report_count, report_count());
// Mark all of the queued reports pending.
std::vector<const ReportingReport*> queued_reports;
......@@ -462,10 +478,65 @@ TEST_F(ReportingCacheTest, DontEvictPendingReports) {
// the new, non-pending one.
std::vector<const ReportingReport*> reports;
cache()->GetReports(&reports);
EXPECT_EQ(policy().max_report_count, reports.size());
EXPECT_EQ(max_report_count, reports.size());
for (const ReportingReport* report : reports)
EXPECT_TRUE(cache()->IsReportPendingForTesting(report));
}
GURL MakeEndpoint(size_t index) {
return GURL(base::StringPrintf("https://endpoint/%zd", index));
}
TEST_F(ReportingCacheTest, EvictLRUClient) {
size_t max_client_count = policy().max_client_count;
ASSERT_LT(0u, max_client_count);
ASSERT_GT(std::numeric_limits<size_t>::max(), max_client_count);
for (size_t i = 0; i < max_client_count; ++i) {
cache()->SetClient(kOrigin1_, MakeEndpoint(i),
ReportingClient::Subdomains::EXCLUDE, kGroup1_,
tomorrow());
}
EXPECT_EQ(max_client_count, client_count());
// Use clients in reverse order, so client (max_client_count - 1) is LRU.
for (size_t i = 1; i <= max_client_count; ++i) {
cache()->MarkClientUsed(kOrigin1_, MakeEndpoint(max_client_count - i));
tick_clock()->Advance(base::TimeDelta::FromSeconds(1));
}
// Add one more client, forcing the cache to evict the LRU.
cache()->SetClient(kOrigin1_, MakeEndpoint(max_client_count),
ReportingClient::Subdomains::EXCLUDE, kGroup1_,
tomorrow());
EXPECT_EQ(max_client_count, client_count());
EXPECT_FALSE(FindClientInCache(cache(), kOrigin1_,
MakeEndpoint(max_client_count - 1)));
}
TEST_F(ReportingCacheTest, EvictExpiredClient) {
size_t max_client_count = policy().max_client_count;
ASSERT_LT(0u, max_client_count);
ASSERT_GT(std::numeric_limits<size_t>::max(), max_client_count);
for (size_t i = 0; i < max_client_count; ++i) {
base::TimeTicks expires =
(i == max_client_count - 1) ? yesterday() : tomorrow();
cache()->SetClient(kOrigin1_, MakeEndpoint(i),
ReportingClient::Subdomains::EXCLUDE, kGroup1_, expires);
}
EXPECT_EQ(max_client_count, client_count());
// Add one more client, forcing the cache to evict the expired one.
cache()->SetClient(kOrigin1_, MakeEndpoint(max_client_count),
ReportingClient::Subdomains::EXCLUDE, kGroup1_,
tomorrow());
EXPECT_EQ(max_client_count, client_count());
EXPECT_FALSE(FindClientInCache(cache(), kOrigin1_,
MakeEndpoint(max_client_count - 1)));
}
} // namespace
} // namespace net
......@@ -135,6 +135,8 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
continue;
}
cache()->MarkClientUsed(origin_group.first, endpoint_url);
endpoint_reports[endpoint_url].insert(
endpoint_reports[endpoint_url].end(), it.second.begin(),
it.second.end());
......
......@@ -10,6 +10,7 @@ namespace net {
ReportingPolicy::ReportingPolicy()
: max_report_count(100u),
max_client_count(1000u),
delivery_interval(base::TimeDelta::FromMinutes(1)),
persistence_interval(base::TimeDelta::FromMinutes(1)),
persist_reports_across_restarts(false),
......@@ -30,6 +31,7 @@ ReportingPolicy::ReportingPolicy()
ReportingPolicy::ReportingPolicy(const ReportingPolicy& other)
: max_report_count(other.max_report_count),
max_client_count(other.max_client_count),
delivery_interval(other.delivery_interval),
endpoint_backoff_policy(other.endpoint_backoff_policy),
persistence_interval(other.persistence_interval),
......
......@@ -21,6 +21,9 @@ struct NET_EXPORT ReportingPolicy {
// Maximum number of reports to queue before evicting the oldest.
size_t max_report_count;
// Maximum number of clients to remember before evicting least-recently-used.
size_t max_client_count;
// Minimum interval at which to attempt delivery of queued reports.
base::TimeDelta delivery_interval;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment