Commit 515136d0 authored by Matt Menke's avatar Matt Menke Committed by Commit Bot

Double-key ReportingEndpointManager.

Also turn the ReportingEndpointManager tests into unit tests, so that
this is testable without wiring up NetworkIsolationKey through all the
reporting code.

Bug: 993805
Change-Id: I0d511dd7eb5f1418bf0a29e853198c39b4b4ea97
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1842131
Commit-Queue: Matt Menke <mmenke@chromium.org>
Reviewed-by: default avatarLily Chen <chlily@chromium.org>
Cr-Commit-Position: refs/heads/master@{#703107}
parent c06fd192
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
namespace net { namespace net {
class NetworkIsolationKey;
class ReportingContext; class ReportingContext;
// The cache holds undelivered reports and clients (per-origin endpoint // The cache holds undelivered reports and clients (per-origin endpoint
...@@ -189,6 +190,7 @@ class NET_EXPORT ReportingCache { ...@@ -189,6 +190,7 @@ class NET_EXPORT ReportingCache {
// name |group| with include_subdomains enabled, this method would return // name |group| with include_subdomains enabled, this method would return
// endpoints from that group from the earliest-inserted origin. // endpoints from that group from the earliest-inserted origin.
virtual std::vector<ReportingEndpoint> GetCandidateEndpointsForDelivery( virtual std::vector<ReportingEndpoint> GetCandidateEndpointsForDelivery(
const NetworkIsolationKey& network_isolation_key,
const url::Origin& origin, const url::Origin& origin,
const std::string& group_name) = 0; const std::string& group_name) = 0;
......
...@@ -492,12 +492,14 @@ void ReportingCacheImpl::AddClientsLoadedFromStore( ...@@ -492,12 +492,14 @@ void ReportingCacheImpl::AddClientsLoadedFromStore(
std::vector<ReportingEndpoint> std::vector<ReportingEndpoint>
ReportingCacheImpl::GetCandidateEndpointsForDelivery( ReportingCacheImpl::GetCandidateEndpointsForDelivery(
const NetworkIsolationKey& network_isolation_key,
const url::Origin& origin, const url::Origin& origin,
const std::string& group_name) { const std::string& group_name) {
base::Time now = clock().Now(); base::Time now = clock().Now();
SanityCheckClients(); SanityCheckClients();
// Look for an exact origin match for |origin| and |group|. // Look for an exact origin match for |origin| and |group|.
// TODO(mmenke): Respect NetworkIsolationKey.
EndpointGroupMap::iterator group_it = EndpointGroupMap::iterator group_it =
FindEndpointGroupIt(ReportingEndpointGroupKey(origin, group_name)); FindEndpointGroupIt(ReportingEndpointGroupKey(origin, group_name));
if (group_it != endpoint_groups_.end() && group_it->second.expires > now) { if (group_it != endpoint_groups_.end() && group_it->second.expires > now) {
......
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
namespace net { namespace net {
class NetworkIsolationKey;
class ReportingCacheImpl : public ReportingCache { class ReportingCacheImpl : public ReportingCache {
public: public:
ReportingCacheImpl(ReportingContext* context); ReportingCacheImpl(ReportingContext* context);
...@@ -80,6 +82,7 @@ class ReportingCacheImpl : public ReportingCache { ...@@ -80,6 +82,7 @@ class ReportingCacheImpl : public ReportingCache {
std::vector<CachedReportingEndpointGroup> loaded_endpoint_groups) std::vector<CachedReportingEndpointGroup> loaded_endpoint_groups)
override; override;
std::vector<ReportingEndpoint> GetCandidateEndpointsForDelivery( std::vector<ReportingEndpoint> GetCandidateEndpointsForDelivery(
const NetworkIsolationKey& network_isolation_key,
const url::Origin& origin, const url::Origin& origin,
const std::string& group_name) override; const std::string& group_name) override;
base::Value GetClientsAsValue() const override; base::Value GetClientsAsValue() const override;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "base/test/values_test_util.h" #include "base/test/values_test_util.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "base/values.h" #include "base/values.h"
#include "net/base/network_isolation_key.h"
#include "net/reporting/mock_persistent_reporting_store.h" #include "net/reporting/mock_persistent_reporting_store.h"
#include "net/reporting/reporting_cache_impl.h" #include "net/reporting/reporting_cache_impl.h"
#include "net/reporting/reporting_cache_observer.h" #include "net/reporting/reporting_cache_observer.h"
...@@ -750,15 +751,16 @@ TEST_P(ReportingCacheTest, GetCandidateEndpointsForDelivery) { ...@@ -750,15 +751,16 @@ TEST_P(ReportingCacheTest, GetCandidateEndpointsForDelivery) {
ASSERT_TRUE(SetEndpointInCache(kOrigin2_, kGroup1_, kEndpoint1_, kExpires1_)); ASSERT_TRUE(SetEndpointInCache(kOrigin2_, kGroup1_, kEndpoint1_, kExpires1_));
ASSERT_TRUE(SetEndpointInCache(kOrigin2_, kGroup2_, kEndpoint2_, kExpires1_)); ASSERT_TRUE(SetEndpointInCache(kOrigin2_, kGroup2_, kEndpoint2_, kExpires1_));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin1_, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(),
kOrigin1_, kGroup1_);
ASSERT_EQ(2u, candidate_endpoints.size()); ASSERT_EQ(2u, candidate_endpoints.size());
for (const ReportingEndpoint& endpoint : candidate_endpoints) { for (const ReportingEndpoint& endpoint : candidate_endpoints) {
EXPECT_EQ(kOrigin1_, endpoint.group_key.origin); EXPECT_EQ(kOrigin1_, endpoint.group_key.origin);
EXPECT_EQ(kGroup1_, endpoint.group_key.group_name); EXPECT_EQ(kGroup1_, endpoint.group_key.group_name);
} }
candidate_endpoints = candidate_endpoints = cache()->GetCandidateEndpointsForDelivery(
cache()->GetCandidateEndpointsForDelivery(kOrigin2_, kGroup1_); NetworkIsolationKey(), kOrigin2_, kGroup1_);
ASSERT_EQ(1u, candidate_endpoints.size()); ASSERT_EQ(1u, candidate_endpoints.size());
EXPECT_EQ(kOrigin2_, candidate_endpoints[0].group_key.origin); EXPECT_EQ(kOrigin2_, candidate_endpoints[0].group_key.origin);
EXPECT_EQ(kGroup1_, candidate_endpoints[0].group_key.group_name); EXPECT_EQ(kGroup1_, candidate_endpoints[0].group_key.group_name);
...@@ -777,15 +779,16 @@ TEST_P(ReportingCacheTest, GetCandidateEndpointsExcludesExpired) { ...@@ -777,15 +779,16 @@ TEST_P(ReportingCacheTest, GetCandidateEndpointsExcludesExpired) {
ASSERT_LT(clock()->Now(), kExpires2_); ASSERT_LT(clock()->Now(), kExpires2_);
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin1_, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(),
kOrigin1_, kGroup1_);
ASSERT_EQ(0u, candidate_endpoints.size()); ASSERT_EQ(0u, candidate_endpoints.size());
candidate_endpoints = candidate_endpoints = cache()->GetCandidateEndpointsForDelivery(
cache()->GetCandidateEndpointsForDelivery(kOrigin2_, kGroup1_); NetworkIsolationKey(), kOrigin2_, kGroup1_);
ASSERT_EQ(0u, candidate_endpoints.size()); ASSERT_EQ(0u, candidate_endpoints.size());
candidate_endpoints = candidate_endpoints = cache()->GetCandidateEndpointsForDelivery(
cache()->GetCandidateEndpointsForDelivery(kOrigin2_, kGroup2_); NetworkIsolationKey(), kOrigin2_, kGroup2_);
ASSERT_EQ(1u, candidate_endpoints.size()); ASSERT_EQ(1u, candidate_endpoints.size());
EXPECT_EQ(kEndpoint2_, candidate_endpoints[0].info.url); EXPECT_EQ(kEndpoint2_, candidate_endpoints[0].info.url);
} }
...@@ -801,7 +804,8 @@ TEST_P(ReportingCacheTest, ExcludeSubdomainsDifferentPort) { ...@@ -801,7 +804,8 @@ TEST_P(ReportingCacheTest, ExcludeSubdomainsDifferentPort) {
kExpires1_, OriginSubdomains::EXCLUDE)); kExpires1_, OriginSubdomains::EXCLUDE));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin,
kGroup1_);
ASSERT_EQ(0u, candidate_endpoints.size()); ASSERT_EQ(0u, candidate_endpoints.size());
} }
...@@ -816,7 +820,8 @@ TEST_P(ReportingCacheTest, ExcludeSubdomainsSuperdomain) { ...@@ -816,7 +820,8 @@ TEST_P(ReportingCacheTest, ExcludeSubdomainsSuperdomain) {
kExpires1_, OriginSubdomains::EXCLUDE)); kExpires1_, OriginSubdomains::EXCLUDE));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin,
kGroup1_);
ASSERT_EQ(0u, candidate_endpoints.size()); ASSERT_EQ(0u, candidate_endpoints.size());
} }
...@@ -831,7 +836,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsDifferentPort) { ...@@ -831,7 +836,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsDifferentPort) {
kExpires1_, OriginSubdomains::INCLUDE)); kExpires1_, OriginSubdomains::INCLUDE));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin,
kGroup1_);
ASSERT_EQ(1u, candidate_endpoints.size()); ASSERT_EQ(1u, candidate_endpoints.size());
EXPECT_EQ(kDifferentPortOrigin, candidate_endpoints[0].group_key.origin); EXPECT_EQ(kDifferentPortOrigin, candidate_endpoints[0].group_key.origin);
} }
...@@ -847,7 +853,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsSuperdomain) { ...@@ -847,7 +853,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsSuperdomain) {
kExpires1_, OriginSubdomains::INCLUDE)); kExpires1_, OriginSubdomains::INCLUDE));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin,
kGroup1_);
ASSERT_EQ(1u, candidate_endpoints.size()); ASSERT_EQ(1u, candidate_endpoints.size());
EXPECT_EQ(kSuperOrigin, candidate_endpoints[0].group_key.origin); EXPECT_EQ(kSuperOrigin, candidate_endpoints[0].group_key.origin);
} }
...@@ -865,7 +872,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsPreferOriginToDifferentPort) { ...@@ -865,7 +872,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsPreferOriginToDifferentPort) {
kExpires1_, OriginSubdomains::INCLUDE)); kExpires1_, OriginSubdomains::INCLUDE));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin,
kGroup1_);
ASSERT_EQ(1u, candidate_endpoints.size()); ASSERT_EQ(1u, candidate_endpoints.size());
EXPECT_EQ(kOrigin, candidate_endpoints[0].group_key.origin); EXPECT_EQ(kOrigin, candidate_endpoints[0].group_key.origin);
} }
...@@ -883,7 +891,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsPreferOriginToSuperdomain) { ...@@ -883,7 +891,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsPreferOriginToSuperdomain) {
kExpires1_, OriginSubdomains::INCLUDE)); kExpires1_, OriginSubdomains::INCLUDE));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin,
kGroup1_);
ASSERT_EQ(1u, candidate_endpoints.size()); ASSERT_EQ(1u, candidate_endpoints.size());
EXPECT_EQ(kOrigin, candidate_endpoints[0].group_key.origin); EXPECT_EQ(kOrigin, candidate_endpoints[0].group_key.origin);
} }
...@@ -904,7 +913,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsPreferMoreSpecificSuperdomain) { ...@@ -904,7 +913,8 @@ TEST_P(ReportingCacheTest, IncludeSubdomainsPreferMoreSpecificSuperdomain) {
kExpires1_, OriginSubdomains::INCLUDE)); kExpires1_, OriginSubdomains::INCLUDE));
std::vector<ReportingEndpoint> candidate_endpoints = std::vector<ReportingEndpoint> candidate_endpoints =
cache()->GetCandidateEndpointsForDelivery(kOrigin, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin,
kGroup1_);
ASSERT_EQ(1u, candidate_endpoints.size()); ASSERT_EQ(1u, candidate_endpoints.size());
EXPECT_EQ(kSuperOrigin, candidate_endpoints[0].group_key.origin); EXPECT_EQ(kSuperOrigin, candidate_endpoints[0].group_key.origin);
} }
...@@ -1005,7 +1015,8 @@ TEST_P(ReportingCacheTest, EvictExpiredGroups) { ...@@ -1005,7 +1015,8 @@ TEST_P(ReportingCacheTest, EvictExpiredGroups) {
// Make the group expired (but not stale). // Make the group expired (but not stale).
clock()->SetNow(kExpires1_ - base::TimeDelta::FromMinutes(1)); clock()->SetNow(kExpires1_ - base::TimeDelta::FromMinutes(1));
cache()->GetCandidateEndpointsForDelivery(kOrigin1_, kGroup1_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin1_,
kGroup1_);
clock()->SetNow(kExpires1_ + base::TimeDelta::FromMinutes(1)); clock()->SetNow(kExpires1_ + base::TimeDelta::FromMinutes(1));
// Insert one more endpoint in a different group (not expired); eviction // Insert one more endpoint in a different group (not expired); eviction
...@@ -1054,7 +1065,7 @@ TEST_P(ReportingCacheTest, EvictFromStalestGroup) { ...@@ -1054,7 +1065,7 @@ TEST_P(ReportingCacheTest, EvictFromStalestGroup) {
EXPECT_TRUE(EndpointGroupExistsInCache(kOrigin1_, base::NumberToString(i), EXPECT_TRUE(EndpointGroupExistsInCache(kOrigin1_, base::NumberToString(i),
OriginSubdomains::DEFAULT)); OriginSubdomains::DEFAULT));
// Mark group used. // Mark group used.
cache()->GetCandidateEndpointsForDelivery(kOrigin1_, cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(), kOrigin1_,
base::NumberToString(i)); base::NumberToString(i));
clock()->Advance(base::TimeDelta::FromMinutes(1)); clock()->Advance(base::TimeDelta::FromMinutes(1));
} }
...@@ -1098,7 +1109,8 @@ TEST_P(ReportingCacheTest, EvictFromLargestGroup) { ...@@ -1098,7 +1109,8 @@ TEST_P(ReportingCacheTest, EvictFromLargestGroup) {
OriginSubdomains::DEFAULT)); OriginSubdomains::DEFAULT));
// Count the number of endpoints remaining in kGroup2_. // Count the number of endpoints remaining in kGroup2_.
std::vector<ReportingEndpoint> endpoints_in_group = std::vector<ReportingEndpoint> endpoints_in_group =
cache()->GetCandidateEndpointsForDelivery(kOrigin1_, kGroup2_); cache()->GetCandidateEndpointsForDelivery(NetworkIsolationKey(),
kOrigin1_, kGroup2_);
EXPECT_EQ(1u, endpoints_in_group.size()); EXPECT_EQ(1u, endpoints_in_group.size());
} }
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "net/reporting/reporting_cache_observer.h" #include "net/reporting/reporting_cache_observer.h"
#include "net/reporting/reporting_delegate.h" #include "net/reporting/reporting_delegate.h"
#include "net/reporting/reporting_delivery_agent.h" #include "net/reporting/reporting_delivery_agent.h"
#include "net/reporting/reporting_endpoint_manager.h"
#include "net/reporting/reporting_garbage_collector.h" #include "net/reporting/reporting_garbage_collector.h"
#include "net/reporting/reporting_network_change_observer.h" #include "net/reporting/reporting_network_change_observer.h"
#include "net/reporting/reporting_policy.h" #include "net/reporting/reporting_policy.h"
...@@ -104,8 +103,7 @@ ReportingContext::ReportingContext( ...@@ -104,8 +103,7 @@ ReportingContext::ReportingContext(
delegate_(std::move(delegate)), delegate_(std::move(delegate)),
cache_(ReportingCache::Create(this)), cache_(ReportingCache::Create(this)),
store_(store), store_(store),
endpoint_manager_(ReportingEndpointManager::Create(this, rand_callback)), delivery_agent_(ReportingDeliveryAgent::Create(this, rand_callback)),
delivery_agent_(ReportingDeliveryAgent::Create(this)),
garbage_collector_(ReportingGarbageCollector::Create(this)), garbage_collector_(ReportingGarbageCollector::Create(this)),
network_change_observer_(ReportingNetworkChangeObserver::Create(this)) {} network_change_observer_(ReportingNetworkChangeObserver::Create(this)) {}
......
...@@ -25,7 +25,6 @@ namespace net { ...@@ -25,7 +25,6 @@ namespace net {
class ReportingCacheObserver; class ReportingCacheObserver;
class ReportingDelegate; class ReportingDelegate;
class ReportingDeliveryAgent; class ReportingDeliveryAgent;
class ReportingEndpointManager;
class ReportingGarbageCollector; class ReportingGarbageCollector;
class ReportingNetworkChangeObserver; class ReportingNetworkChangeObserver;
class ReportingUploader; class ReportingUploader;
...@@ -51,9 +50,6 @@ class NET_EXPORT ReportingContext { ...@@ -51,9 +50,6 @@ class NET_EXPORT ReportingContext {
ReportingDelegate* delegate() { return delegate_.get(); } ReportingDelegate* delegate() { return delegate_.get(); }
ReportingCache* cache() { return cache_.get(); } ReportingCache* cache() { return cache_.get(); }
ReportingCache::PersistentReportingStore* store() { return store_; } ReportingCache::PersistentReportingStore* store() { return store_; }
ReportingEndpointManager* endpoint_manager() {
return endpoint_manager_.get();
}
ReportingDeliveryAgent* delivery_agent() { return delivery_agent_.get(); } ReportingDeliveryAgent* delivery_agent() { return delivery_agent_.get(); }
ReportingGarbageCollector* garbage_collector() { ReportingGarbageCollector* garbage_collector() {
return garbage_collector_.get(); return garbage_collector_.get();
...@@ -97,11 +93,8 @@ class NET_EXPORT ReportingContext { ...@@ -97,11 +93,8 @@ class NET_EXPORT ReportingContext {
ReportingCache::PersistentReportingStore* const store_; ReportingCache::PersistentReportingStore* const store_;
// |endpoint_manager_| must come after |tick_clock_| and |cache_|.
std::unique_ptr<ReportingEndpointManager> endpoint_manager_;
// |delivery_agent_| must come after |tick_clock_|, |delegate_|, |uploader_|, // |delivery_agent_| must come after |tick_clock_|, |delegate_|, |uploader_|,
// |cache_|, and |endpoint_manager_|. // and |cache_|.
std::unique_ptr<ReportingDeliveryAgent> delivery_agent_; std::unique_ptr<ReportingDeliveryAgent> delivery_agent_;
// |garbage_collector_| must come after |tick_clock_| and |cache_|. // |garbage_collector_| must come after |tick_clock_| and |cache_|.
......
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
#include "base/time/tick_clock.h" #include "base/time/tick_clock.h"
#include "base/timer/timer.h" #include "base/timer/timer.h"
#include "base/values.h" #include "base/values.h"
#include "net/base/network_isolation_key.h"
#include "net/reporting/reporting_cache.h" #include "net/reporting/reporting_cache.h"
#include "net/reporting/reporting_cache_observer.h" #include "net/reporting/reporting_cache_observer.h"
#include "net/reporting/reporting_context.h"
#include "net/reporting/reporting_delegate.h" #include "net/reporting/reporting_delegate.h"
#include "net/reporting/reporting_endpoint_manager.h" #include "net/reporting/reporting_endpoint_manager.h"
#include "net/reporting/reporting_report.h" #include "net/reporting/reporting_report.h"
...@@ -54,8 +56,16 @@ void SerializeReports(const std::vector<const ReportingReport*>& reports, ...@@ -54,8 +56,16 @@ void SerializeReports(const std::vector<const ReportingReport*>& reports,
class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
public ReportingCacheObserver { public ReportingCacheObserver {
public: public:
ReportingDeliveryAgentImpl(ReportingContext* context) ReportingDeliveryAgentImpl(ReportingContext* context,
: context_(context), timer_(std::make_unique<base::OneShotTimer>()) { const RandIntCallback& rand_callback)
: context_(context),
timer_(std::make_unique<base::OneShotTimer>()),
endpoint_manager_(
ReportingEndpointManager::Create(&context->policy(),
&context->tick_clock(),
context->delegate(),
context->cache(),
rand_callback)) {
context_->AddCacheObserver(this); context_->AddCacheObserver(this);
} }
...@@ -171,8 +181,10 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -171,8 +181,10 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
if (base::Contains(pending_origin_groups_, origin_group)) if (base::Contains(pending_origin_groups_, origin_group))
continue; continue;
// TODO(mmenke): Populate NetworkIsolationKey argument.
const ReportingEndpoint endpoint = const ReportingEndpoint endpoint =
endpoint_manager()->FindEndpointForDelivery(report_origin, group); endpoint_manager_->FindEndpointForDelivery(NetworkIsolationKey(),
report_origin, group);
if (!endpoint) { if (!endpoint) {
// TODO(chlily): Remove reports for which there are no valid // TODO(chlily): Remove reports for which there are no valid
// delivery endpoints. // delivery endpoints.
...@@ -242,10 +254,14 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -242,10 +254,14 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
if (outcome == ReportingUploader::Outcome::SUCCESS) { if (outcome == ReportingUploader::Outcome::SUCCESS) {
cache()->RemoveReports(delivery->reports, cache()->RemoveReports(delivery->reports,
ReportingReport::Outcome::DELIVERED); ReportingReport::Outcome::DELIVERED);
endpoint_manager()->InformOfEndpointRequest(delivery->endpoint, true); // TODO(mmenke): Populate NetworkIsolationKey argument.
endpoint_manager_->InformOfEndpointRequest(NetworkIsolationKey(),
delivery->endpoint, true);
} else { } else {
cache()->IncrementReportsAttempts(delivery->reports); cache()->IncrementReportsAttempts(delivery->reports);
endpoint_manager()->InformOfEndpointRequest(delivery->endpoint, false); // TODO(mmenke): Populate NetworkIsolationKey argument.
endpoint_manager_->InformOfEndpointRequest(NetworkIsolationKey(),
delivery->endpoint, false);
} }
if (outcome == ReportingUploader::Outcome::REMOVE_ENDPOINT) if (outcome == ReportingUploader::Outcome::REMOVE_ENDPOINT)
...@@ -264,9 +280,6 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -264,9 +280,6 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
ReportingDelegate* delegate() { return context_->delegate(); } ReportingDelegate* delegate() { return context_->delegate(); }
ReportingCache* cache() { return context_->cache(); } ReportingCache* cache() { return context_->cache(); }
ReportingUploader* uploader() { return context_->uploader(); } ReportingUploader* uploader() { return context_->uploader(); }
ReportingEndpointManager* endpoint_manager() {
return context_->endpoint_manager();
}
ReportingContext* context_; ReportingContext* context_;
...@@ -276,6 +289,8 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -276,6 +289,8 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
// (Would be an unordered_set, but there's no hash on pair.) // (Would be an unordered_set, but there's no hash on pair.)
std::set<OriginGroup> pending_origin_groups_; std::set<OriginGroup> pending_origin_groups_;
std::unique_ptr<ReportingEndpointManager> endpoint_manager_;
base::WeakPtrFactory<ReportingDeliveryAgentImpl> weak_factory_{this}; base::WeakPtrFactory<ReportingDeliveryAgentImpl> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(ReportingDeliveryAgentImpl); DISALLOW_COPY_AND_ASSIGN(ReportingDeliveryAgentImpl);
...@@ -285,8 +300,9 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -285,8 +300,9 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
// static // static
std::unique_ptr<ReportingDeliveryAgent> ReportingDeliveryAgent::Create( std::unique_ptr<ReportingDeliveryAgent> ReportingDeliveryAgent::Create(
ReportingContext* context) { ReportingContext* context,
return std::make_unique<ReportingDeliveryAgentImpl>(context); const RandIntCallback& rand_callback) {
return std::make_unique<ReportingDeliveryAgentImpl>(context, rand_callback);
} }
ReportingDeliveryAgent::~ReportingDeliveryAgent() = default; ReportingDeliveryAgent::~ReportingDeliveryAgent() = default;
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "base/macros.h" #include "base/macros.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/base/rand_callback.h"
namespace base { namespace base {
class OneShotTimer; class OneShotTimer;
...@@ -51,7 +52,8 @@ class NET_EXPORT ReportingDeliveryAgent { ...@@ -51,7 +52,8 @@ class NET_EXPORT ReportingDeliveryAgent {
public: public:
// Creates a ReportingDeliveryAgent. |context| must outlive the agent. // Creates a ReportingDeliveryAgent. |context| must outlive the agent.
static std::unique_ptr<ReportingDeliveryAgent> Create( static std::unique_ptr<ReportingDeliveryAgent> Create(
ReportingContext* context); ReportingContext* context,
const RandIntCallback& rand_callback);
virtual ~ReportingDeliveryAgent(); virtual ~ReportingDeliveryAgent();
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <map> #include <map>
#include <set> #include <set>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "base/logging.h" #include "base/logging.h"
...@@ -15,6 +16,7 @@ ...@@ -15,6 +16,7 @@
#include "base/stl_util.h" #include "base/stl_util.h"
#include "base/time/tick_clock.h" #include "base/time/tick_clock.h"
#include "net/base/backoff_entry.h" #include "net/base/backoff_entry.h"
#include "net/base/network_isolation_key.h"
#include "net/base/rand_callback.h" #include "net/base/rand_callback.h"
#include "net/reporting/reporting_cache.h" #include "net/reporting/reporting_cache.h"
#include "net/reporting/reporting_delegate.h" #include "net/reporting/reporting_delegate.h"
...@@ -29,19 +31,33 @@ namespace { ...@@ -29,19 +31,33 @@ namespace {
class ReportingEndpointManagerImpl : public ReportingEndpointManager { class ReportingEndpointManagerImpl : public ReportingEndpointManager {
public: public:
ReportingEndpointManagerImpl(ReportingContext* context, ReportingEndpointManagerImpl(const ReportingPolicy* policy,
const base::TickClock* tick_clock,
const ReportingDelegate* delegate,
ReportingCache* cache,
const RandIntCallback& rand_callback) const RandIntCallback& rand_callback)
: context_(context), rand_callback_(rand_callback) {} : policy_(policy),
tick_clock_(tick_clock),
delegate_(delegate),
cache_(cache),
rand_callback_(rand_callback) {
DCHECK(policy);
DCHECK(tick_clock);
DCHECK(delegate);
DCHECK(cache);
}
~ReportingEndpointManagerImpl() override = default; ~ReportingEndpointManagerImpl() override = default;
const ReportingEndpoint FindEndpointForDelivery( const ReportingEndpoint FindEndpointForDelivery(
const NetworkIsolationKey& network_isolation_key,
const url::Origin& origin, const url::Origin& origin,
const std::string& group) override { const std::string& group) override {
// Get unexpired endpoints that apply to a delivery to |origin| and |group|. // Get unexpired endpoints that apply to a delivery to |origin| and |group|.
// May have been configured by a superdomain of |origin|. // May have been configured by a superdomain of |origin|.
std::vector<ReportingEndpoint> endpoints = std::vector<ReportingEndpoint> endpoints =
cache()->GetCandidateEndpointsForDelivery(origin, group); cache_->GetCandidateEndpointsForDelivery(network_isolation_key, origin,
group);
// Highest-priority endpoint(s) that are not expired, failing, or // Highest-priority endpoint(s) that are not expired, failing, or
// forbidden for use by the ReportingDelegate. // forbidden for use by the ReportingDelegate.
...@@ -50,12 +66,15 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager { ...@@ -50,12 +66,15 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager {
int total_weight = 0; int total_weight = 0;
for (const ReportingEndpoint endpoint : endpoints) { for (const ReportingEndpoint endpoint : endpoints) {
if (base::Contains(endpoint_backoff_, endpoint.info.url) && auto endpoint_backoff_it = endpoint_backoff_.find(
endpoint_backoff_[endpoint.info.url]->ShouldRejectRequest()) { EndpointBackoffKey(network_isolation_key, endpoint.info.url));
if (endpoint_backoff_it != endpoint_backoff_.end() &&
endpoint_backoff_it->second->ShouldRejectRequest()) {
continue; continue;
} }
if (!delegate()->CanUseClient(endpoint.group_key.origin,
endpoint.info.url)) { if (!delegate_->CanUseClient(endpoint.group_key.origin,
endpoint.info.url)) {
continue; continue;
} }
...@@ -101,21 +120,29 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager { ...@@ -101,21 +120,29 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager {
return ReportingEndpoint(); return ReportingEndpoint();
} }
void InformOfEndpointRequest(const GURL& endpoint, bool succeeded) override { void InformOfEndpointRequest(const NetworkIsolationKey& network_isolation_key,
if (!base::Contains(endpoint_backoff_, endpoint)) { const GURL& endpoint,
endpoint_backoff_[endpoint] = std::make_unique<BackoffEntry>( bool succeeded) override {
&policy().endpoint_backoff_policy, &tick_clock()); EndpointBackoffKey endpoint_backoff_key(network_isolation_key, endpoint);
auto endpoint_backoff_it = endpoint_backoff_.find(endpoint_backoff_key);
if (endpoint_backoff_it == endpoint_backoff_.end()) {
endpoint_backoff_it =
endpoint_backoff_
.emplace(std::move(endpoint_backoff_key),
std::make_unique<BackoffEntry>(
&policy_->endpoint_backoff_policy, tick_clock_))
.first;
} }
endpoint_backoff_[endpoint]->InformOfRequest(succeeded); endpoint_backoff_it->second->InformOfRequest(succeeded);
} }
private: private:
const ReportingPolicy& policy() const { return context_->policy(); } using EndpointBackoffKey = std::pair<NetworkIsolationKey, GURL>;
const base::TickClock& tick_clock() const { return context_->tick_clock(); }
ReportingDelegate* delegate() { return context_->delegate(); }
ReportingCache* cache() { return context_->cache(); }
ReportingContext* context_; const ReportingPolicy* const policy_;
const base::TickClock* const tick_clock_;
const ReportingDelegate* const delegate_;
ReportingCache* const cache_;
RandIntCallback rand_callback_; RandIntCallback rand_callback_;
...@@ -124,7 +151,8 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager { ...@@ -124,7 +151,8 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager {
// to be cleared as well. // to be cleared as well.
// TODO(chlily): clear this data when endpoints are deleted to avoid unbounded // TODO(chlily): clear this data when endpoints are deleted to avoid unbounded
// growth of this map. // growth of this map.
std::map<GURL, std::unique_ptr<net::BackoffEntry>> endpoint_backoff_; std::map<EndpointBackoffKey, std::unique_ptr<net::BackoffEntry>>
endpoint_backoff_;
DISALLOW_COPY_AND_ASSIGN(ReportingEndpointManagerImpl); DISALLOW_COPY_AND_ASSIGN(ReportingEndpointManagerImpl);
}; };
...@@ -133,9 +161,13 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager { ...@@ -133,9 +161,13 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager {
// static // static
std::unique_ptr<ReportingEndpointManager> ReportingEndpointManager::Create( std::unique_ptr<ReportingEndpointManager> ReportingEndpointManager::Create(
ReportingContext* context, const ReportingPolicy* policy,
const base::TickClock* tick_clock,
const ReportingDelegate* delegate,
ReportingCache* cache,
const RandIntCallback& rand_callback) { const RandIntCallback& rand_callback) {
return std::make_unique<ReportingEndpointManagerImpl>(context, rand_callback); return std::make_unique<ReportingEndpointManagerImpl>(
policy, tick_clock, delegate, cache, rand_callback);
} }
ReportingEndpointManager::~ReportingEndpointManager() = default; ReportingEndpointManager::~ReportingEndpointManager() = default;
......
...@@ -11,26 +11,37 @@ ...@@ -11,26 +11,37 @@
#include "base/macros.h" #include "base/macros.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/base/rand_callback.h" #include "net/base/rand_callback.h"
#include "net/reporting/reporting_context.h"
class GURL; class GURL;
namespace base {
class TickClock;
}
namespace url { namespace url {
class Origin; class Origin;
} // namespace url } // namespace url
namespace net { namespace net {
class NetworkIsolationKey;
class ReportingCache;
class ReportingDelegate;
struct ReportingEndpoint; struct ReportingEndpoint;
struct ReportingPolicy;
// Keeps track of which endpoints are pending (have active delivery attempts to // Keeps track of which endpoints are pending (have active delivery attempts to
// them) or in exponential backoff after one or more failures, and chooses an // them) or in exponential backoff after one or more failures, and chooses an
// endpoint from an endpoint group to receive reports for an origin. // endpoint from an endpoint group to receive reports for an origin.
class NET_EXPORT ReportingEndpointManager { class NET_EXPORT ReportingEndpointManager {
public: public:
// |context| must outlive the ReportingEndpointManager. // The ReportingEndpointManager must not be used after any of the objects
// passed to its constructor are destroyed.
static std::unique_ptr<ReportingEndpointManager> Create( static std::unique_ptr<ReportingEndpointManager> Create(
ReportingContext* context, const ReportingPolicy* policy,
const base::TickClock* tick_clock,
const ReportingDelegate* delegate,
ReportingCache* cache,
const RandIntCallback& rand_callback); const RandIntCallback& rand_callback);
virtual ~ReportingEndpointManager(); virtual ~ReportingEndpointManager();
...@@ -43,13 +54,16 @@ class NET_EXPORT ReportingEndpointManager { ...@@ -43,13 +54,16 @@ class NET_EXPORT ReportingEndpointManager {
// If no suitable endpoint was found, returns an endpoint with is_valid() // If no suitable endpoint was found, returns an endpoint with is_valid()
// false. // false.
virtual const ReportingEndpoint FindEndpointForDelivery( virtual const ReportingEndpoint FindEndpointForDelivery(
const NetworkIsolationKey& network_isolation_key,
const url::Origin& origin, const url::Origin& origin,
const std::string& group) = 0; const std::string& group) = 0;
// Informs the EndpointManager of a successful or unsuccessful request made to // Informs the EndpointManager of a successful or unsuccessful request made to
// |endpoint| so it can manage exponential backoff of failing endpoints. // |endpoint| so it can manage exponential backoff of failing endpoints.
virtual void InformOfEndpointRequest(const GURL& endpoint, virtual void InformOfEndpointRequest(
bool succeeded) = 0; const NetworkIsolationKey& network_isolation_key,
const GURL& endpoint,
bool succeeded) = 0;
}; };
} // namespace net } // namespace net
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "base/test/simple_test_clock.h" #include "base/test/simple_test_clock.h"
#include "base/test/simple_test_tick_clock.h" #include "base/test/simple_test_tick_clock.h"
#include "base/timer/mock_timer.h" #include "base/timer/mock_timer.h"
#include "net/base/rand_callback.h"
#include "net/reporting/reporting_cache.h" #include "net/reporting/reporting_cache.h"
#include "net/reporting/reporting_context.h" #include "net/reporting/reporting_context.h"
#include "net/reporting/reporting_delegate.h" #include "net/reporting/reporting_delegate.h"
...@@ -25,7 +24,6 @@ ...@@ -25,7 +24,6 @@
#include "net/reporting/reporting_garbage_collector.h" #include "net/reporting/reporting_garbage_collector.h"
#include "net/reporting/reporting_policy.h" #include "net/reporting/reporting_policy.h"
#include "net/reporting/reporting_uploader.h" #include "net/reporting/reporting_uploader.h"
#include "net/url_request/url_request_test_util.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h" #include "url/gurl.h"
#include "url/origin.h" #include "url/origin.h"
...@@ -85,6 +83,15 @@ void ErasePendingUpload( ...@@ -85,6 +83,15 @@ void ErasePendingUpload(
} // namespace } // namespace
RandIntCallback TestReportingRandIntCallback() {
return base::BindRepeating(
[](int* rand_counter, int min, int max) {
DCHECK_LE(min, max);
return min + ((*rand_counter)++ % (max - min + 1));
},
base::Owned(std::make_unique<int>(0)));
}
TestReportingUploader::PendingUpload::~PendingUpload() = default; TestReportingUploader::PendingUpload::~PendingUpload() = default;
TestReportingUploader::PendingUpload::PendingUpload() = default; TestReportingUploader::PendingUpload::PendingUpload() = default;
...@@ -109,8 +116,7 @@ int TestReportingUploader::GetPendingUploadCountForTesting() const { ...@@ -109,8 +116,7 @@ int TestReportingUploader::GetPendingUploadCountForTesting() const {
return pending_uploads_.size(); return pending_uploads_.size();
} }
TestReportingDelegate::TestReportingDelegate() TestReportingDelegate::TestReportingDelegate() = default;
: test_request_context_(std::make_unique<TestURLRequestContext>()) {}
TestReportingDelegate::~TestReportingDelegate() = default; TestReportingDelegate::~TestReportingDelegate() = default;
...@@ -157,16 +163,13 @@ TestReportingContext::TestReportingContext( ...@@ -157,16 +163,13 @@ TestReportingContext::TestReportingContext(
const base::TickClock* tick_clock, const base::TickClock* tick_clock,
const ReportingPolicy& policy, const ReportingPolicy& policy,
ReportingCache::PersistentReportingStore* store) ReportingCache::PersistentReportingStore* store)
: ReportingContext( : ReportingContext(policy,
policy, clock,
clock, tick_clock,
tick_clock, TestReportingRandIntCallback(),
base::BindRepeating(&TestReportingContext::RandIntCallback, std::make_unique<TestReportingUploader>(),
base::Unretained(this)), std::make_unique<TestReportingDelegate>(),
std::make_unique<TestReportingUploader>(), store),
std::make_unique<TestReportingDelegate>(),
store),
rand_counter_(0),
delivery_timer_(new base::MockOneShotTimer()), delivery_timer_(new base::MockOneShotTimer()),
garbage_collection_timer_(new base::MockOneShotTimer()) { garbage_collection_timer_(new base::MockOneShotTimer()) {
garbage_collector()->SetTimerForTesting( garbage_collector()->SetTimerForTesting(
...@@ -179,11 +182,6 @@ TestReportingContext::~TestReportingContext() { ...@@ -179,11 +182,6 @@ TestReportingContext::~TestReportingContext() {
garbage_collection_timer_ = nullptr; garbage_collection_timer_ = nullptr;
} }
int TestReportingContext::RandIntCallback(int min, int max) {
DCHECK_LE(min, max);
return min + (rand_counter_++ % (max - min + 1));
}
ReportingTestBase::ReportingTestBase() : store_(nullptr) { ReportingTestBase::ReportingTestBase() : store_(nullptr) {
// For tests, disable jitter. // For tests, disable jitter.
ReportingPolicy policy; ReportingPolicy policy;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "base/macros.h" #include "base/macros.h"
#include "base/test/simple_test_clock.h" #include "base/test/simple_test_clock.h"
#include "base/test/simple_test_tick_clock.h" #include "base/test/simple_test_tick_clock.h"
#include "net/base/rand_callback.h"
#include "net/reporting/reporting_cache.h" #include "net/reporting/reporting_cache.h"
#include "net/reporting/reporting_context.h" #include "net/reporting/reporting_context.h"
#include "net/reporting/reporting_delegate.h" #include "net/reporting/reporting_delegate.h"
...@@ -38,7 +39,6 @@ namespace net { ...@@ -38,7 +39,6 @@ namespace net {
struct ReportingEndpoint; struct ReportingEndpoint;
class ReportingGarbageCollector; class ReportingGarbageCollector;
class TestURLRequestContext;
// A matcher for ReportingReports, which checks that the url of the report is // A matcher for ReportingReports, which checks that the url of the report is
// the given url. // the given url.
...@@ -49,6 +49,8 @@ MATCHER_P(ReportUrlIs, url, "") { ...@@ -49,6 +49,8 @@ MATCHER_P(ReportUrlIs, url, "") {
return arg.url == url; return arg.url == url;
} }
RandIntCallback TestReportingRandIntCallback();
// A test implementation of ReportingUploader that holds uploads for tests to // A test implementation of ReportingUploader that holds uploads for tests to
// examine and complete with a specified outcome. // examine and complete with a specified outcome.
class TestReportingUploader : public ReportingUploader { class TestReportingUploader : public ReportingUploader {
...@@ -128,7 +130,6 @@ class TestReportingDelegate : public ReportingDelegate { ...@@ -128,7 +130,6 @@ class TestReportingDelegate : public ReportingDelegate {
const GURL& endpoint) const override; const GURL& endpoint) const override;
private: private:
std::unique_ptr<TestURLRequestContext> test_request_context_;
bool disallow_report_uploads_ = false; bool disallow_report_uploads_ = false;
bool pause_permissions_check_ = false; bool pause_permissions_check_ = false;
...@@ -162,10 +163,6 @@ class TestReportingContext : public ReportingContext { ...@@ -162,10 +163,6 @@ class TestReportingContext : public ReportingContext {
} }
private: private:
int RandIntCallback(int min, int max);
int rand_counter_;
// Owned by the DeliveryAgent and GarbageCollector, respectively, but // Owned by the DeliveryAgent and GarbageCollector, respectively, but
// referenced here to preserve type: // referenced here to preserve type:
...@@ -250,9 +247,6 @@ class ReportingTestBase : public TestWithTaskEnvironment { ...@@ -250,9 +247,6 @@ class ReportingTestBase : public TestWithTaskEnvironment {
TestReportingUploader* uploader() { return context_->test_uploader(); } TestReportingUploader* uploader() { return context_->test_uploader(); }
ReportingCache* cache() { return context_->cache(); } ReportingCache* cache() { return context_->cache(); }
ReportingEndpointManager* endpoint_manager() {
return context_->endpoint_manager();
}
ReportingDeliveryAgent* delivery_agent() { ReportingDeliveryAgent* delivery_agent() {
return context_->delivery_agent(); return context_->delivery_agent();
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment