Commit e13028c5 authored by Douglas Creager's avatar Douglas Creager Committed by Commit Bot

Reporting: Don't assume report and policy have same origin

When processing a report, we can't assume that the report's origin
matches the origin that the corresponding policy belongs to.  If the
policy uses include_subdomains, then it might be used for reports for
any subdomain of the policy's origin.  We had a couple of places where
we did assume they were the same, and used a DCHECK to verify this.
That caused segfaults whenever we tried to process a subdomain report.

Bug: 854248
Change-Id: I251d82664c7deee2293144ceed199e084795152c
Reviewed-on: https://chromium-review.googlesource.com/1106304Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Commit-Queue: Douglas Creager <dcreager@chromium.org>
Cr-Commit-Position: refs/heads/master@{#571946}
parent c995ccc1
...@@ -197,27 +197,19 @@ class ReportingCacheImpl : public ReportingCache { ...@@ -197,27 +197,19 @@ class ReportingCacheImpl : public ReportingCache {
context_->NotifyCacheUpdated(); context_->NotifyCacheUpdated();
} }
void IncrementEndpointDeliveries( void IncrementEndpointDeliveries(const url::Origin& origin,
const GURL& endpoint, const GURL& endpoint,
const std::vector<const ReportingReport*>& reports, int reports_delivered,
bool successful) override { bool successful) override {
std::unordered_map<const ReportingClient*, int> reports_per_client; const ReportingClient* client =
for (const ReportingReport* report : reports) { GetClientByOriginAndEndpoint(origin, endpoint);
DCHECK(base::ContainsKey(reports_, report)); if (client) {
url::Origin origin = url::Origin::Create(report->url); auto& metadata = client_metadata_[client];
const ReportingClient* client =
GetClientByOriginAndEndpoint(origin, endpoint);
DCHECK(client);
reports_per_client[client]++;
}
for (const auto& client_and_report_count : reports_per_client) {
auto& metadata = client_metadata_[client_and_report_count.first];
metadata.stats.attempted_uploads++; metadata.stats.attempted_uploads++;
metadata.stats.attempted_reports += client_and_report_count.second; metadata.stats.attempted_reports += reports_delivered;
if (successful) { if (successful) {
metadata.stats.successful_uploads++; metadata.stats.successful_uploads++;
metadata.stats.successful_reports += client_and_report_count.second; metadata.stats.successful_reports += reports_delivered;
} }
} }
} }
...@@ -292,10 +284,7 @@ class ReportingCacheImpl : public ReportingCache { ...@@ -292,10 +284,7 @@ class ReportingCacheImpl : public ReportingCache {
context_->NotifyCacheUpdated(); context_->NotifyCacheUpdated();
} }
void MarkClientUsed(const url::Origin& origin, void MarkClientUsed(const ReportingClient* client) override {
const GURL& endpoint) override {
const ReportingClient* client =
GetClientByOriginAndEndpoint(origin, endpoint);
DCHECK(client); DCHECK(client);
client_metadata_[client].last_used = tick_clock()->NowTicks(); client_metadata_[client].last_used = tick_clock()->NowTicks();
} }
......
...@@ -104,10 +104,10 @@ class NET_EXPORT ReportingCache { ...@@ -104,10 +104,10 @@ class NET_EXPORT ReportingCache {
// Records that we attempted (and possibly succeeded at) delivering |reports| // Records that we attempted (and possibly succeeded at) delivering |reports|
// to |endpoint|. // to |endpoint|.
virtual void IncrementEndpointDeliveries( virtual void IncrementEndpointDeliveries(const url::Origin& origin,
const GURL& endpoint, const GURL& endpoint,
const std::vector<const ReportingReport*>& reports, int reports_delivered,
bool successful) = 0; bool successful) = 0;
// Removes a set of reports. Any reports that are pending will not be removed // Removes a set of reports. Any reports that are pending will not be removed
// immediately, but rather marked doomed and removed once they are no longer // immediately, but rather marked doomed and removed once they are no longer
...@@ -134,8 +134,7 @@ class NET_EXPORT ReportingCache { ...@@ -134,8 +134,7 @@ class NET_EXPORT ReportingCache {
int priority, int priority,
int client) = 0; int client) = 0;
virtual void MarkClientUsed(const url::Origin& origin, virtual void MarkClientUsed(const ReportingClient* client) = 0;
const GURL& endpoint) = 0;
// Gets all of the clients in the cache, regardless of origin or group. // Gets all of the clients in the cache, regardless of origin or group.
// //
......
...@@ -411,18 +411,8 @@ TEST_F(ReportingCacheTest, GetClientsAsValue) { ...@@ -411,18 +411,8 @@ TEST_F(ReportingCacheTest, GetClientsAsValue) {
SetClient(kOrigin1_, kEndpoint1_, false, kGroup1_, expires); SetClient(kOrigin1_, kEndpoint1_, false, kGroup1_, expires);
SetClient(kOrigin2_, kEndpoint2_, true, kGroup1_, expires); SetClient(kOrigin2_, kEndpoint2_, true, kGroup1_, expires);
// Add some reports so that we can test the upload counts. cache()->IncrementEndpointDeliveries(kOrigin1_, kEndpoint1_, 2, true);
const ReportingReport* report1a = AddAndReturnReport( cache()->IncrementEndpointDeliveries(kOrigin2_, kEndpoint2_, 1, false);
kUrl1_, kGroup1_, kType_, std::make_unique<base::DictionaryValue>(), 0,
expires, 0);
const ReportingReport* report1b = AddAndReturnReport(
kUrl1_, kGroup1_, kType_, std::make_unique<base::DictionaryValue>(), 0,
expires, 1);
const ReportingReport* report2 = AddAndReturnReport(
kUrl2_, kGroup1_, kType_, std::make_unique<base::DictionaryValue>(), 0,
expires, 1);
cache()->IncrementEndpointDeliveries(kEndpoint1_, {report1a, report1b}, true);
cache()->IncrementEndpointDeliveries(kEndpoint2_, {report2}, false);
base::Value actual = cache()->GetClientsAsValue(); base::Value actual = cache()->GetClientsAsValue();
std::unique_ptr<base::Value> expected = base::test::ParseJson(R"json( std::unique_ptr<base::Value> expected = base::test::ParseJson(R"json(
...@@ -651,7 +641,9 @@ TEST_F(ReportingCacheTest, EvictLRUClient) { ...@@ -651,7 +641,9 @@ TEST_F(ReportingCacheTest, EvictLRUClient) {
// Use clients in reverse order, so client (max_client_count - 1) is LRU. // Use clients in reverse order, so client (max_client_count - 1) is LRU.
for (size_t i = 1; i <= max_client_count; ++i) { for (size_t i = 1; i <= max_client_count; ++i) {
cache()->MarkClientUsed(kOrigin1_, MakeEndpoint(max_client_count - i)); const ReportingClient* client = FindClientInCache(
cache(), kOrigin1_, MakeEndpoint(max_client_count - i));
cache()->MarkClientUsed(client);
tick_clock()->Advance(base::TimeDelta::FromSeconds(1)); tick_clock()->Advance(base::TimeDelta::FromSeconds(1));
} }
......
...@@ -80,13 +80,19 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -80,13 +80,19 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
private: private:
class Delivery { class Delivery {
public: public:
Delivery(const GURL& endpoint, std::vector<const ReportingReport*> reports) Delivery(const GURL& endpoint) : endpoint(endpoint) {}
: endpoint(endpoint), reports(std::move(reports)) {}
~Delivery() = default; ~Delivery() = default;
void AddReports(const ReportingClient* client,
const std::vector<const ReportingReport*>& to_add) {
reports_per_client[client->origin][client->endpoint] += to_add.size();
reports.insert(reports.end(), to_add.begin(), to_add.end());
}
const GURL endpoint; const GURL endpoint;
std::vector<const ReportingReport*> reports; std::vector<const ReportingReport*> reports;
std::map<url::Origin, std::map<GURL, int>> reports_per_client;
}; };
using OriginGroup = std::pair<url::Origin, std::string>; using OriginGroup = std::pair<url::Origin, std::string>;
...@@ -146,24 +152,32 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -146,24 +152,32 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
// Find endpoint for each (origin, group) bucket and sort reports into // Find endpoint for each (origin, group) bucket and sort reports into
// endpoint buckets. Don't allow concurrent deliveries to the same (origin, // endpoint buckets. Don't allow concurrent deliveries to the same (origin,
// group) bucket. // group) bucket.
std::map<GURL, std::vector<const ReportingReport*>> endpoint_reports; std::map<GURL, std::unique_ptr<Delivery>> deliveries;
for (auto& it : origin_group_reports) { for (auto& it : origin_group_reports) {
const OriginGroup& origin_group = it.first; const OriginGroup& origin_group = it.first;
if (base::ContainsKey(pending_origin_groups_, origin_group)) if (base::ContainsKey(pending_origin_groups_, origin_group))
continue; continue;
GURL endpoint_url; const ReportingClient* client =
if (!endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(origin_group.first,
origin_group.first, origin_group.second, &endpoint_url)) { origin_group.second);
if (client == nullptr) {
continue; continue;
} }
cache()->MarkClientUsed(client);
Delivery* delivery;
auto delivery_it = deliveries.find(client->endpoint);
if (delivery_it == deliveries.end()) {
auto new_delivery = std::make_unique<Delivery>(client->endpoint);
delivery = new_delivery.get();
deliveries[client->endpoint] = std::move(new_delivery);
} else {
delivery = delivery_it->second.get();
}
cache()->MarkClientUsed(origin_group.first, endpoint_url); delivery->AddReports(client, it.second);
endpoint_reports[endpoint_url].insert(
endpoint_reports[endpoint_url].end(), it.second.begin(),
it.second.end());
pending_origin_groups_.insert(origin_group); pending_origin_groups_.insert(origin_group);
} }
...@@ -172,18 +186,18 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -172,18 +186,18 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
std::unordered_set<const ReportingReport*> undelivered_reports( std::unordered_set<const ReportingReport*> undelivered_reports(
reports.begin(), reports.end()); reports.begin(), reports.end());
// Start a delivery to each endpoint. // Start an upload for each delivery.
for (auto& it : endpoint_reports) { for (auto& it : deliveries) {
const GURL& endpoint = it.first; const GURL& endpoint = it.first;
const std::vector<const ReportingReport*>& reports = it.second; std::unique_ptr<Delivery>& delivery = it.second;
endpoint_manager()->SetEndpointPending(endpoint); endpoint_manager()->SetEndpointPending(endpoint);
std::string json; std::string json;
SerializeReports(reports, tick_clock()->NowTicks(), &json); SerializeReports(delivery->reports, tick_clock()->NowTicks(), &json);
int max_depth = 0; int max_depth = 0;
for (const ReportingReport* report : reports) { for (const ReportingReport* report : delivery->reports) {
undelivered_reports.erase(report); undelivered_reports.erase(report);
if (report->depth > max_depth) if (report->depth > max_depth)
max_depth = report->depth; max_depth = report->depth;
...@@ -192,10 +206,8 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -192,10 +206,8 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
// TODO: Calculate actual max depth. // TODO: Calculate actual max depth.
uploader()->StartUpload( uploader()->StartUpload(
endpoint, json, max_depth, endpoint, json, max_depth,
base::BindOnce( base::BindOnce(&ReportingDeliveryAgentImpl::OnUploadComplete,
&ReportingDeliveryAgentImpl::OnUploadComplete, weak_factory_.GetWeakPtr(), std::move(delivery)));
weak_factory_.GetWeakPtr(),
std::make_unique<Delivery>(endpoint, std::move(reports))));
} }
cache()->ClearReportsPending( cache()->ClearReportsPending(
...@@ -204,9 +216,16 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent, ...@@ -204,9 +216,16 @@ class ReportingDeliveryAgentImpl : public ReportingDeliveryAgent,
void OnUploadComplete(const std::unique_ptr<Delivery>& delivery, void OnUploadComplete(const std::unique_ptr<Delivery>& delivery,
ReportingUploader::Outcome outcome) { ReportingUploader::Outcome outcome) {
cache()->IncrementEndpointDeliveries( for (const auto& origin_and_pair : delivery->reports_per_client) {
delivery->endpoint, delivery->reports, const url::Origin& origin = origin_and_pair.first;
outcome == ReportingUploader::Outcome::SUCCESS); for (const auto& endpoint_and_count : origin_and_pair.second) {
const GURL& endpoint = endpoint_and_count.first;
int report_count = endpoint_and_count.second;
cache()->IncrementEndpointDeliveries(
origin, endpoint, report_count,
outcome == ReportingUploader::Outcome::SUCCESS);
}
}
if (outcome == ReportingUploader::Outcome::SUCCESS) { if (outcome == ReportingUploader::Outcome::SUCCESS) {
cache()->RemoveReports(delivery->reports, cache()->RemoveReports(delivery->reports,
......
...@@ -44,13 +44,16 @@ class ReportingDeliveryAgentTest : public ReportingTestBase { ...@@ -44,13 +44,16 @@ class ReportingDeliveryAgentTest : public ReportingTestBase {
void SetClient(const url::Origin& origin, void SetClient(const url::Origin& origin,
const GURL& endpoint, const GURL& endpoint,
const std::string& group) { const std::string& group,
cache()->SetClient(origin, endpoint, ReportingClient::Subdomains::EXCLUDE, ReportingClient::Subdomains subdomains =
group, tomorrow(), ReportingClient::kDefaultPriority, ReportingClient::Subdomains::EXCLUDE) {
cache()->SetClient(origin, endpoint, subdomains, group, tomorrow(),
ReportingClient::kDefaultPriority,
ReportingClient::kDefaultWeight); ReportingClient::kDefaultWeight);
} }
const GURL kUrl_ = GURL("https://origin/path"); const GURL kUrl_ = GURL("https://origin/path");
const GURL kSubdomainUrl_ = GURL("https://sub.origin/path");
const url::Origin kOrigin_ = url::Origin::Create(GURL("https://origin/")); const url::Origin kOrigin_ = url::Origin::Create(GURL("https://origin/"));
const GURL kEndpoint_ = GURL("https://endpoint/"); const GURL kEndpoint_ = GURL("https://endpoint/");
const std::string kGroup_ = "group"; const std::string kGroup_ = "group";
...@@ -92,7 +95,70 @@ TEST_F(ReportingDeliveryAgentTest, SuccessfulImmediateUpload) { ...@@ -92,7 +95,70 @@ TEST_F(ReportingDeliveryAgentTest, SuccessfulImmediateUpload) {
cache()->GetReports(&reports); cache()->GetReports(&reports);
EXPECT_TRUE(reports.empty()); EXPECT_TRUE(reports.empty());
// TODO(juliatuttle): Check that BackoffEntry was informed of success. // TODO(dcreager): Check that BackoffEntry was informed of success.
}
TEST_F(ReportingDeliveryAgentTest, SuccessfulImmediateSubdomainUpload) {
base::DictionaryValue body;
body.SetString("key", "value");
SetClient(kOrigin_, kEndpoint_, kGroup_,
ReportingClient::Subdomains::INCLUDE);
cache()->AddReport(kSubdomainUrl_, kGroup_, kType_, body.CreateDeepCopy(), 0,
tick_clock()->NowTicks(), 0);
// Upload is automatically started when cache is modified.
ASSERT_EQ(1u, pending_uploads().size());
EXPECT_EQ(kEndpoint_, pending_uploads()[0]->url());
{
auto value = pending_uploads()[0]->GetValue();
base::ListValue* list;
ASSERT_TRUE(value->GetAsList(&list));
EXPECT_EQ(1u, list->GetSize());
base::DictionaryValue* report;
ASSERT_TRUE(list->GetDictionary(0, &report));
EXPECT_EQ(4u, report->size());
ExpectDictIntegerValue(0, *report, "age");
ExpectDictStringValue(kType_, *report, "type");
ExpectDictStringValue(kSubdomainUrl_.spec(), *report, "url");
ExpectDictDictionaryValue(body, *report, "report");
}
pending_uploads()[0]->Complete(ReportingUploader::Outcome::SUCCESS);
// Successful upload should remove delivered reports.
std::vector<const ReportingReport*> reports;
cache()->GetReports(&reports);
EXPECT_TRUE(reports.empty());
// TODO(dcreager): Check that BackoffEntry was informed of success.
}
TEST_F(ReportingDeliveryAgentTest,
SuccessfulImmediateSubdomainUploadWithEvictedClient) {
base::DictionaryValue body;
body.SetString("key", "value");
SetClient(kOrigin_, kEndpoint_, kGroup_,
ReportingClient::Subdomains::INCLUDE);
cache()->AddReport(kSubdomainUrl_, kGroup_, kType_, body.CreateDeepCopy(), 0,
tick_clock()->NowTicks(), 0);
// Upload is automatically started when cache is modified.
ASSERT_EQ(1u, pending_uploads().size());
// Evict the client
SetClient(kOrigin_, kEndpoint_, kGroup_,
ReportingClient::Subdomains::EXCLUDE);
pending_uploads()[0]->Complete(ReportingUploader::Outcome::SUCCESS);
// Successful upload should remove delivered reports.
std::vector<const ReportingReport*> reports;
cache()->GetReports(&reports);
EXPECT_TRUE(reports.empty());
} }
TEST_F(ReportingDeliveryAgentTest, SuccessfulDelayedUpload) { TEST_F(ReportingDeliveryAgentTest, SuccessfulDelayedUpload) {
......
...@@ -35,9 +35,9 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager { ...@@ -35,9 +35,9 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager {
~ReportingEndpointManagerImpl() override = default; ~ReportingEndpointManagerImpl() override = default;
bool FindEndpointForOriginAndGroup(const url::Origin& origin, const ReportingClient* FindClientForOriginAndGroup(
const std::string& group, const url::Origin& origin,
GURL* endpoint_url_out) override { const std::string& group) override {
std::vector<const ReportingClient*> clients; std::vector<const ReportingClient*> clients;
cache()->GetClientsForOriginAndGroup(origin, group, &clients); cache()->GetClientsForOriginAndGroup(origin, group, &clients);
...@@ -79,8 +79,7 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager { ...@@ -79,8 +79,7 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager {
} }
if (available_clients.empty()) { if (available_clients.empty()) {
*endpoint_url_out = GURL(); return nullptr;
return false;
} }
int random_index = rand_callback_.Run(0, total_weight - 1); int random_index = rand_callback_.Run(0, total_weight - 1);
...@@ -89,14 +88,13 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager { ...@@ -89,14 +88,13 @@ class ReportingEndpointManagerImpl : public ReportingEndpointManager {
const ReportingClient* client = available_clients[i]; const ReportingClient* client = available_clients[i];
weight_so_far += client->weight; weight_so_far += client->weight;
if (random_index < weight_so_far) { if (random_index < weight_so_far) {
*endpoint_url_out = client->endpoint; return client;
return true;
} }
} }
// TODO(juliatuttle): Can we reach this in some weird overflow case? // TODO(juliatuttle): Can we reach this in some weird overflow case?
NOTREACHED(); NOTREACHED();
return false; return nullptr;
} }
void SetEndpointPending(const GURL& endpoint) override { void SetEndpointPending(const GURL& endpoint) override {
......
...@@ -21,6 +21,8 @@ class Origin; ...@@ -21,6 +21,8 @@ class Origin;
namespace net { namespace net {
struct ReportingClient;
// Keeps track of which endpoints are pending (have active delivery attempts to // Keeps track of which endpoints are pending (have active delivery attempts to
// them) or in exponential backoff after one or more failures, and chooses an // them) or in exponential backoff after one or more failures, and chooses an
// endpoint from an endpoint group to receive reports for an origin. // endpoint from an endpoint group to receive reports for an origin.
...@@ -39,12 +41,11 @@ class NET_EXPORT ReportingEndpointManager { ...@@ -39,12 +41,11 @@ class NET_EXPORT ReportingEndpointManager {
// Deliberately chooses an endpoint randomly to ensure sites aren't relying on // Deliberately chooses an endpoint randomly to ensure sites aren't relying on
// any sort of fallback ordering. // any sort of fallback ordering.
// //
// Returns true and sets |*endpoint_url_out| to the endpoint URL if an // Returns the endpoint's |ReportingClient| if endpoint was chosen; returns
// endpoint was chosen; returns false (and leaves |*endpoint_url_out| invalid) // nullptr if no endpoint was found.
// if no endpoint was found. virtual const ReportingClient* FindClientForOriginAndGroup(
virtual bool FindEndpointForOriginAndGroup(const url::Origin& origin, const url::Origin& origin,
const std::string& group, const std::string& group) = 0;
GURL* endpoint_url_out) = 0;
// Adds |endpoint| to the set of pending endpoints, preventing it from being // Adds |endpoint| to the set of pending endpoints, preventing it from being
// chosen for a second parallel delivery attempt. // chosen for a second parallel delivery attempt.
......
...@@ -33,21 +33,19 @@ class ReportingEndpointManagerTest : public ReportingTestBase { ...@@ -33,21 +33,19 @@ class ReportingEndpointManagerTest : public ReportingTestBase {
}; };
TEST_F(ReportingEndpointManagerTest, NoEndpoint) { TEST_F(ReportingEndpointManagerTest, NoEndpoint) {
GURL endpoint_url; const ReportingClient* client =
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client == nullptr);
EXPECT_FALSE(found_endpoint);
} }
TEST_F(ReportingEndpointManagerTest, Endpoint) { TEST_F(ReportingEndpointManagerTest, Endpoint) {
SetClient(kEndpoint_, ReportingClient::kDefaultPriority, SetClient(kEndpoint_, ReportingClient::kDefaultPriority,
ReportingClient::kDefaultWeight); ReportingClient::kDefaultWeight);
GURL endpoint_url; const ReportingClient* client =
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
EXPECT_TRUE(found_endpoint); EXPECT_EQ(kEndpoint_, client->endpoint);
EXPECT_EQ(kEndpoint_, endpoint_url);
} }
TEST_F(ReportingEndpointManagerTest, ExpiredEndpoint) { TEST_F(ReportingEndpointManagerTest, ExpiredEndpoint) {
...@@ -57,10 +55,9 @@ TEST_F(ReportingEndpointManagerTest, ExpiredEndpoint) { ...@@ -57,10 +55,9 @@ TEST_F(ReportingEndpointManagerTest, ExpiredEndpoint) {
// Default expiration is "tomorrow", so make sure we're past that. // Default expiration is "tomorrow", so make sure we're past that.
tick_clock()->Advance(base::TimeDelta::FromDays(2)); tick_clock()->Advance(base::TimeDelta::FromDays(2));
GURL endpoint_url; const ReportingClient* client =
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client == nullptr);
EXPECT_FALSE(found_endpoint);
} }
TEST_F(ReportingEndpointManagerTest, PendingEndpoint) { TEST_F(ReportingEndpointManagerTest, PendingEndpoint) {
...@@ -69,17 +66,15 @@ TEST_F(ReportingEndpointManagerTest, PendingEndpoint) { ...@@ -69,17 +66,15 @@ TEST_F(ReportingEndpointManagerTest, PendingEndpoint) {
endpoint_manager()->SetEndpointPending(kEndpoint_); endpoint_manager()->SetEndpointPending(kEndpoint_);
GURL endpoint_url; const ReportingClient* client =
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client == nullptr);
EXPECT_FALSE(found_endpoint);
endpoint_manager()->ClearEndpointPending(kEndpoint_); endpoint_manager()->ClearEndpointPending(kEndpoint_);
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
EXPECT_TRUE(found_endpoint); EXPECT_EQ(kEndpoint_, client->endpoint);
EXPECT_EQ(kEndpoint_, endpoint_url);
} }
TEST_F(ReportingEndpointManagerTest, BackedOffEndpoint) { TEST_F(ReportingEndpointManagerTest, BackedOffEndpoint) {
...@@ -94,40 +89,35 @@ TEST_F(ReportingEndpointManagerTest, BackedOffEndpoint) { ...@@ -94,40 +89,35 @@ TEST_F(ReportingEndpointManagerTest, BackedOffEndpoint) {
endpoint_manager()->InformOfEndpointRequest(kEndpoint_, false); endpoint_manager()->InformOfEndpointRequest(kEndpoint_, false);
// After one failure, endpoint is in exponential backoff. // After one failure, endpoint is in exponential backoff.
GURL endpoint_url; const ReportingClient* client =
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client == nullptr);
EXPECT_FALSE(found_endpoint);
// After initial delay, endpoint is usable again. // After initial delay, endpoint is usable again.
tick_clock()->Advance(initial_delay); tick_clock()->Advance(initial_delay);
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
EXPECT_TRUE(found_endpoint); EXPECT_EQ(kEndpoint_, client->endpoint);
EXPECT_EQ(kEndpoint_, endpoint_url);
endpoint_manager()->InformOfEndpointRequest(kEndpoint_, false); endpoint_manager()->InformOfEndpointRequest(kEndpoint_, false);
// After a second failure, endpoint is backed off again. // After a second failure, endpoint is backed off again.
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client == nullptr);
EXPECT_FALSE(found_endpoint);
tick_clock()->Advance(initial_delay); tick_clock()->Advance(initial_delay);
// Next backoff is longer -- 2x the first -- so endpoint isn't usable yet. // Next backoff is longer -- 2x the first -- so endpoint isn't usable yet.
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client == nullptr);
EXPECT_FALSE(found_endpoint);
tick_clock()->Advance(initial_delay); tick_clock()->Advance(initial_delay);
// After 2x the initial delay, the endpoint is usable again. // After 2x the initial delay, the endpoint is usable again.
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
EXPECT_TRUE(found_endpoint); EXPECT_EQ(kEndpoint_, client->endpoint);
EXPECT_EQ(kEndpoint_, endpoint_url);
endpoint_manager()->InformOfEndpointRequest(kEndpoint_, true); endpoint_manager()->InformOfEndpointRequest(kEndpoint_, true);
endpoint_manager()->InformOfEndpointRequest(kEndpoint_, true); endpoint_manager()->InformOfEndpointRequest(kEndpoint_, true);
...@@ -136,15 +126,13 @@ TEST_F(ReportingEndpointManagerTest, BackedOffEndpoint) { ...@@ -136,15 +126,13 @@ TEST_F(ReportingEndpointManagerTest, BackedOffEndpoint) {
// again. // again.
endpoint_manager()->InformOfEndpointRequest(kEndpoint_, false); endpoint_manager()->InformOfEndpointRequest(kEndpoint_, false);
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client == nullptr);
EXPECT_FALSE(found_endpoint);
tick_clock()->Advance(initial_delay); tick_clock()->Advance(initial_delay);
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); EXPECT_TRUE(client != nullptr);
EXPECT_TRUE(found_endpoint);
} }
// Make sure that multiple endpoints will all be returned at some point, to // Make sure that multiple endpoints will all be returned at some point, to
...@@ -163,15 +151,15 @@ TEST_F(ReportingEndpointManagerTest, RandomEndpoint) { ...@@ -163,15 +151,15 @@ TEST_F(ReportingEndpointManagerTest, RandomEndpoint) {
bool endpoint2_seen = false; bool endpoint2_seen = false;
for (int i = 0; i < kMaxAttempts; ++i) { for (int i = 0; i < kMaxAttempts; ++i) {
GURL endpoint_url; const ReportingClient* client =
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
ASSERT_TRUE(found_endpoint); ASSERT_TRUE(client->endpoint == kEndpoint1 ||
ASSERT_TRUE(endpoint_url == kEndpoint1 || endpoint_url == kEndpoint2); client->endpoint == kEndpoint2);
if (endpoint_url == kEndpoint1) if (client->endpoint == kEndpoint1)
endpoint1_seen = true; endpoint1_seen = true;
else if (endpoint_url == kEndpoint2) else if (client->endpoint == kEndpoint2)
endpoint2_seen = true; endpoint2_seen = true;
if (endpoint1_seen && endpoint2_seen) if (endpoint1_seen && endpoint2_seen)
...@@ -189,26 +177,22 @@ TEST_F(ReportingEndpointManagerTest, Priority) { ...@@ -189,26 +177,22 @@ TEST_F(ReportingEndpointManagerTest, Priority) {
SetClient(kPrimaryEndpoint, 10, ReportingClient::kDefaultWeight); SetClient(kPrimaryEndpoint, 10, ReportingClient::kDefaultWeight);
SetClient(kBackupEndpoint, 20, ReportingClient::kDefaultWeight); SetClient(kBackupEndpoint, 20, ReportingClient::kDefaultWeight);
GURL endpoint_url; const ReportingClient* client =
endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( ASSERT_TRUE(client != nullptr);
kOrigin_, kGroup_, &endpoint_url); EXPECT_EQ(kPrimaryEndpoint, client->endpoint);
ASSERT_TRUE(found_endpoint);
EXPECT_EQ(kPrimaryEndpoint, endpoint_url);
endpoint_manager()->SetEndpointPending(kPrimaryEndpoint); endpoint_manager()->SetEndpointPending(kPrimaryEndpoint);
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
ASSERT_TRUE(found_endpoint); EXPECT_EQ(kBackupEndpoint, client->endpoint);
EXPECT_EQ(kBackupEndpoint, endpoint_url);
endpoint_manager()->ClearEndpointPending(kPrimaryEndpoint); endpoint_manager()->ClearEndpointPending(kPrimaryEndpoint);
found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( client = endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
ASSERT_TRUE(found_endpoint); EXPECT_EQ(kPrimaryEndpoint, client->endpoint);
EXPECT_EQ(kPrimaryEndpoint, endpoint_url);
} }
// Note: This test depends on the deterministic mock RandIntCallback set up in // Note: This test depends on the deterministic mock RandIntCallback set up in
...@@ -229,15 +213,15 @@ TEST_F(ReportingEndpointManagerTest, Weight) { ...@@ -229,15 +213,15 @@ TEST_F(ReportingEndpointManagerTest, Weight) {
int endpoint2_count = 0; int endpoint2_count = 0;
for (int i = 0; i < kTotalEndpointWeight; ++i) { for (int i = 0; i < kTotalEndpointWeight; ++i) {
GURL endpoint_url; const ReportingClient* client =
bool found_endpoint = endpoint_manager()->FindEndpointForOriginAndGroup( endpoint_manager()->FindClientForOriginAndGroup(kOrigin_, kGroup_);
kOrigin_, kGroup_, &endpoint_url); ASSERT_TRUE(client != nullptr);
ASSERT_TRUE(found_endpoint); ASSERT_TRUE(client->endpoint == kEndpoint1 ||
ASSERT_TRUE(endpoint_url == kEndpoint1 || endpoint_url == kEndpoint2); client->endpoint == kEndpoint2);
if (endpoint_url == kEndpoint1) if (client->endpoint == kEndpoint1)
++endpoint1_count; ++endpoint1_count;
else if (endpoint_url == kEndpoint2) else if (client->endpoint == kEndpoint2)
++endpoint2_count; ++endpoint2_count;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment