Commit b6f3b558 authored by Colin Blundell's avatar Colin Blundell Committed by Chromium LUCI CQ

[Safe Browsing] Pull out timeout tracking of token fetches for reuse

This CL separates separates the mechanism via which
SafeBrowsingPrimaryAccountTokenFetcher fetches access tokens and the
mechanisms via which it tracks and responds to ongoing token fetches,
moving the latter into a new SafeBrowsingTokenFetchTracker helper class
for reuse by WebLayer. This class includes the following:

- Tracking of ongoing access token fetch requests coming from the client
  of SafeBrowsingTokenFetcher, responding to these requests when
  notified that a given access token has been fetched, and gating these
  requests on a given time threshold (responding with an empty access
  token if a request times out)
- Responding to all active token requests with an empty access token on
  destruction

This CL also adds unittests of SafeBrowsingTokenFetchTracker.

Via this helper class, WebLayer will be able to share the key
functionality of gating safe browsing access token fetches on a timeout
rather than duplicating this functionality and risking divergence over
time.

The only subtlety is the behavior on token fetch timeout:
SafeBrowsingPrimaryAccountTokenFetcher needs to be notified in addition
to its client so that the former can destroy the corresponding
AccessTokenFetcher. To facilitate this,
SafeBrowsingTokenFetchTracker::StartTrackingTokenFetch() takes in both
the SafeBrowsingTokenFetcher::Callback that the client passed in and a
callback to be invoked on timeout, via which its owner can clean up any
state associated with the request.

Bug: 1080748
Change-Id: I79bb83096cc2094c7cf9916eccf382271a5f46a5
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2624630
Commit-Queue: Colin Blundell <blundell@chromium.org>
Reviewed-by: default avatarXinghui Lu <xinghuilu@chromium.org>
Cr-Commit-Position: refs/heads/master@{#843441}
parent 3c36515a
......@@ -288,6 +288,7 @@ test("components_unittests") {
"//components/safe_browsing/content/password_protection:password_protection_unittest",
"//components/safe_browsing/content/triggers:unit_tests",
"//components/safe_browsing/content/web_ui:unit_tests",
"//components/safe_browsing/core/browser:token_fetcher_unit_tests",
"//components/safe_browsing/core/common:unit_tests",
"//components/safe_browsing/core/realtime:unit_tests",
"//components/safe_browsing/core/triggers:unit_tests",
......
......@@ -78,7 +78,26 @@ source_set("referrer_chain_provider") {
}
source_set("token_fetcher") {
sources = [ "safe_browsing_token_fetcher.h" ]
sources = [
"safe_browsing_token_fetch_tracker.cc",
"safe_browsing_token_fetch_tracker.h",
"safe_browsing_token_fetcher.h",
]
deps = [
"//base",
"//components/safe_browsing/core/common:thread_utils",
]
}
deps = [ "//base" ]
source_set("token_fetcher_unit_tests") {
testonly = true
sources = [ "safe_browsing_token_fetch_tracker_unittest.cc" ]
deps = [
":token_fetcher",
"//base/test:test_support",
"//components/safe_browsing/core/common:test_support",
"//testing/gtest",
]
}
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/safe_browsing/core/browser/safe_browsing_token_fetch_tracker.h"
#include "base/bind.h"
#include "base/memory/weak_ptr.h"
#include "base/task/post_task.h"
#include "base/time/time.h"
#include "components/safe_browsing/core/common/thread_utils.h"
namespace safe_browsing {
SafeBrowsingTokenFetchTracker::SafeBrowsingTokenFetchTracker()
: weak_ptr_factory_(this) {
DCHECK(CurrentlyOnThread(ThreadID::UI));
}
SafeBrowsingTokenFetchTracker::~SafeBrowsingTokenFetchTracker() {
for (auto& id_and_callback : callbacks_) {
std::move(id_and_callback.second).Run(std::string());
}
}
int SafeBrowsingTokenFetchTracker::StartTrackingTokenFetch(
SafeBrowsingTokenFetcher::Callback on_token_fetched_callback,
OnTokenFetchTimeoutCallback on_token_fetch_timeout_callback) {
DCHECK(CurrentlyOnThread(ThreadID::UI));
const int request_id = requests_sent_;
requests_sent_++;
callbacks_[request_id] = std::move(on_token_fetched_callback);
base::PostDelayedTask(
FROM_HERE, CreateTaskTraits(ThreadID::UI),
base::BindOnce(&SafeBrowsingTokenFetchTracker::OnTokenFetchTimeout,
weak_ptr_factory_.GetWeakPtr(), request_id,
std::move(on_token_fetch_timeout_callback)),
base::TimeDelta::FromMilliseconds(
kTokenFetchTimeoutDelayFromMilliseconds));
return request_id;
}
void SafeBrowsingTokenFetchTracker::OnTokenFetchComplete(
int request_id,
std::string access_token) {
Finish(request_id, access_token);
}
void SafeBrowsingTokenFetchTracker::OnTokenFetchTimeout(
int request_id,
OnTokenFetchTimeoutCallback on_token_fetch_timeout_callback) {
Finish(request_id, std::string());
std::move(on_token_fetch_timeout_callback).Run(request_id);
}
void SafeBrowsingTokenFetchTracker::Finish(int request_id,
const std::string& access_token) {
if (callbacks_.contains(request_id)) {
std::move(callbacks_[request_id]).Run(access_token);
}
callbacks_.erase(request_id);
}
} // namespace safe_browsing
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SAFE_BROWSING_CORE_BROWSER_SAFE_BROWSING_TOKEN_FETCH_TRACKER_H_
#define COMPONENTS_SAFE_BROWSING_CORE_BROWSER_SAFE_BROWSING_TOKEN_FETCH_TRACKER_H_
#include <memory>
#include "base/callback.h"
#include "base/containers/flat_map.h"
#include "base/memory/weak_ptr.h"
#include "build/build_config.h"
#include "components/safe_browsing/core/browser/safe_browsing_token_fetcher.h"
namespace safe_browsing {
// Exposed for unittests.
#if defined(OS_ANDROID)
constexpr int kTokenFetchTimeoutDelayFromMilliseconds = 50;
#else
constexpr int kTokenFetchTimeoutDelayFromMilliseconds = 1000;
#endif
// Helper class for use by implementations of SafeBrowsingTokenFetcher:
// tracks a set of outstanding access token fetches, timing out a fetch after a
// given delay.
class SafeBrowsingTokenFetchTracker {
public:
using OnTokenFetchTimeoutCallback = base::OnceCallback<void(int request_id)>;
SafeBrowsingTokenFetchTracker();
~SafeBrowsingTokenFetchTracker();
// Should be invoked when a safe browsing access token fetch is started. Takes
// in the callback that the client passed to SafeBrowsingTokenFetcher::Start()
// as well as a callback via which the SafeBrowsingTokenFetcher implementation
// is informed of token fetch timeouts. Returns the request ID associated with
// the fetch. If the access token is fetched before the timeout is invoked,
// the SafeBrowsingTokenFetcher implementation should invoke
// OnTokenFetchComplete(), in which case this object will invoke
// |on_token_fetched_callback| with the given access token. If the timeout
// occurs, this object will invoke |on_token_fetched_callback| with an empty
// token and invoke |on_token_fetch_timeout_callback| so that the
// SafeBrowsingTokenFetcher implementation can clean up any associated state.
int StartTrackingTokenFetch(
SafeBrowsingTokenFetcher::Callback on_token_fetched_callback,
OnTokenFetchTimeoutCallback on_token_fetch_timeout_callback);
// Should be invoked when an access token fetch has completed.
void OnTokenFetchComplete(int request_id, std::string access_token);
private:
void OnTokenFetchTimeout(
int request_id,
OnTokenFetchTimeoutCallback on_token_fetch_timeout_callback);
void Finish(int request_id, const std::string& access_token);
// The count of requests sent. This is used as an ID for requests.
int requests_sent_ = 0;
// Active callbacks, keyed by ID.
base::flat_map<int, SafeBrowsingTokenFetcher::Callback> callbacks_;
base::WeakPtrFactory<SafeBrowsingTokenFetchTracker> weak_ptr_factory_;
};
} // namespace safe_browsing
#endif // COMPONENTS_SAFE_BROWSING_CORE_BROWSER_SAFE_BROWSING_TOKEN_FETCH_TRACKER_H_
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/safe_browsing/core/browser/safe_browsing_token_fetch_tracker.h"
#include <memory>
#include "base/run_loop.h"
#include "base/test/task_environment.h"
#include "components/safe_browsing/core/common/test_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace safe_browsing {
class SafeBrowsingTokenFetchTrackerTest : public ::testing::Test {
public:
SafeBrowsingTokenFetchTrackerTest()
: task_environment_(CreateTestTaskEnvironment(
base::test::TaskEnvironment::TimeSource::MOCK_TIME)) {}
protected:
std::unique_ptr<base::test::TaskEnvironment> task_environment_;
};
TEST_F(SafeBrowsingTokenFetchTrackerTest, Success) {
SafeBrowsingTokenFetchTracker fetcher;
std::string access_token;
int request_id = fetcher.StartTrackingTokenFetch(
base::BindOnce([](std::string* target_token,
const std::string& token) { *target_token = token; },
&access_token),
base::BindOnce([](int request_id) {}));
fetcher.OnTokenFetchComplete(request_id, "token");
EXPECT_EQ(access_token, "token");
}
TEST_F(SafeBrowsingTokenFetchTrackerTest, MultipleRequests) {
SafeBrowsingTokenFetchTracker fetcher;
std::string access_token1;
std::string access_token2;
int request_id1 = fetcher.StartTrackingTokenFetch(
base::BindOnce([](std::string* target_token,
const std::string& token) { *target_token = token; },
&access_token1),
base::BindOnce([](int request_id) {}));
int request_id2 = fetcher.StartTrackingTokenFetch(
base::BindOnce([](std::string* target_token,
const std::string& token) { *target_token = token; },
&access_token2),
base::BindOnce([](int request_id) {}));
fetcher.OnTokenFetchComplete(request_id2, "token2");
EXPECT_EQ(access_token1, "");
EXPECT_EQ(access_token2, "token2");
fetcher.OnTokenFetchComplete(request_id1, "token1");
EXPECT_EQ(access_token1, "token1");
EXPECT_EQ(access_token2, "token2");
}
TEST_F(SafeBrowsingTokenFetchTrackerTest, FetcherDestruction) {
auto fetcher = std::make_unique<SafeBrowsingTokenFetchTracker>();
std::string access_token1;
std::string access_token2 = "dummy_value";
int request_id1 = fetcher->StartTrackingTokenFetch(
base::BindOnce([](std::string* target_token,
const std::string& token) { *target_token = token; },
&access_token1),
base::BindOnce([](int request_id) {}));
fetcher->StartTrackingTokenFetch(
base::BindOnce([](std::string* target_token,
const std::string& token) { *target_token = token; },
&access_token2),
base::BindOnce([](int request_id) {}));
fetcher->OnTokenFetchComplete(request_id1, "token1");
EXPECT_EQ(access_token1, "token1");
EXPECT_EQ(access_token2, "dummy_value");
fetcher.reset();
// The second request was outstanding when the fetcher is destroyed, so it
// should have been invoked with the empty string.
EXPECT_EQ(access_token1, "token1");
EXPECT_EQ(access_token2, "");
}
TEST_F(SafeBrowsingTokenFetchTrackerTest, Timeout) {
SafeBrowsingTokenFetchTracker fetcher;
std::string access_token1 = "dummy_value1";
std::string access_token2 = "dummy_value2";
bool on_timeout1_invoked = false;
bool on_timeout2_invoked = false;
int delay_before_second_request_from_ms =
kTokenFetchTimeoutDelayFromMilliseconds / 2;
fetcher.StartTrackingTokenFetch(
base::BindOnce([](std::string* target_token,
const std::string& token) { *target_token = token; },
&access_token1),
base::BindOnce([](bool* target_on_timeout_invoked,
int request_id) { *target_on_timeout_invoked = true; },
&on_timeout1_invoked));
task_environment_->FastForwardBy(
base::TimeDelta::FromMilliseconds(delay_before_second_request_from_ms));
fetcher.StartTrackingTokenFetch(
base::BindOnce([](std::string* target_token,
const std::string& token) { *target_token = token; },
&access_token2),
base::BindOnce([](bool* target_on_timeout_invoked,
int request_id) { *target_on_timeout_invoked = true; },
&on_timeout2_invoked));
// Fast-forward to trigger the first request's timeout threshold, but not the
// second.
int time_to_trigger_first_timeout_from_ms =
kTokenFetchTimeoutDelayFromMilliseconds -
delay_before_second_request_from_ms;
task_environment_->FastForwardBy(
base::TimeDelta::FromMilliseconds(time_to_trigger_first_timeout_from_ms));
EXPECT_EQ(access_token1, "");
EXPECT_TRUE(on_timeout1_invoked);
EXPECT_EQ(access_token2, "dummy_value2");
EXPECT_FALSE(on_timeout2_invoked);
// Fast-forward to trigger the second request's timeout threshold.
task_environment_->FastForwardBy(base::TimeDelta::FromMilliseconds(
kTokenFetchTimeoutDelayFromMilliseconds -
time_to_trigger_first_timeout_from_ms));
EXPECT_EQ(access_token2, "");
EXPECT_TRUE(on_timeout2_invoked);
}
} // namespace safe_browsing
......@@ -7,10 +7,6 @@
#include "base/bind.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/histogram_macros.h"
#include "base/optional.h"
#include "base/task/post_task.h"
#include "base/time/time.h"
#include "build/build_config.h"
#include "components/safe_browsing/core/common/thread_utils.h"
#include "components/signin/public/identity_manager/access_token_fetcher.h"
#include "components/signin/public/identity_manager/access_token_info.h"
......@@ -25,37 +21,31 @@ namespace {
const char kAPIScope[] = "https://www.googleapis.com/auth/chrome-safe-browsing";
#if defined(OS_ANDROID)
const int kTimeoutDelayFromMilliseconds = 50;
#else
const int kTimeoutDelayFromMilliseconds = 1000;
#endif
} // namespace
SafeBrowsingPrimaryAccountTokenFetcher::SafeBrowsingPrimaryAccountTokenFetcher(
signin::IdentityManager* identity_manager)
: identity_manager_(identity_manager),
requests_sent_(0),
weak_ptr_factory_(this) {
DCHECK(CurrentlyOnThread(ThreadID::UI));
}
SafeBrowsingPrimaryAccountTokenFetcher::
~SafeBrowsingPrimaryAccountTokenFetcher() {
for (auto& id_and_callback : callbacks_) {
std::move(id_and_callback.second).Run(std::string());
}
}
~SafeBrowsingPrimaryAccountTokenFetcher() = default;
void SafeBrowsingPrimaryAccountTokenFetcher::Start(
Callback callback) {
DCHECK(CurrentlyOnThread(ThreadID::UI));
const int request_id = requests_sent_;
requests_sent_++;
// NOTE: base::Unretained() is safe below as this object owns
// |token_fetch_tracker_|, and the callback will not be invoked after
// |token_fetch_tracker_| is destroyed.
const int request_id = token_fetch_tracker_.StartTrackingTokenFetch(
std::move(callback),
base::BindOnce(&SafeBrowsingPrimaryAccountTokenFetcher::OnTokenTimeout,
base::Unretained(this)));
CoreAccountId account_id = identity_manager_->GetPrimaryAccountId(
signin::ConsentLevel::kNotRequired);
callbacks_[request_id] = std::move(callback);
token_fetchers_[request_id] =
identity_manager_->CreateAccessTokenFetcherForAccount(
account_id, "safe_browsing_service", {kAPIScope},
......@@ -63,11 +53,6 @@ void SafeBrowsingPrimaryAccountTokenFetcher::Start(
&SafeBrowsingPrimaryAccountTokenFetcher::OnTokenFetched,
weak_ptr_factory_.GetWeakPtr(), request_id),
signin::AccessTokenFetcher::Mode::kImmediate);
base::PostDelayedTask(
FROM_HERE, CreateTaskTraits(ThreadID::UI),
base::BindOnce(&SafeBrowsingPrimaryAccountTokenFetcher::OnTokenTimeout,
weak_ptr_factory_.GetWeakPtr(), request_id),
base::TimeDelta::FromMilliseconds(kTimeoutDelayFromMilliseconds));
}
void SafeBrowsingPrimaryAccountTokenFetcher::OnTokenFetched(
......@@ -77,24 +62,16 @@ void SafeBrowsingPrimaryAccountTokenFetcher::OnTokenFetched(
UMA_HISTOGRAM_ENUMERATION("SafeBrowsing.TokenFetcher.ErrorType",
error.state(), GoogleServiceAuthError::NUM_STATES);
if (error.state() == GoogleServiceAuthError::NONE)
Finish(request_id, access_token_info.token);
token_fetch_tracker_.OnTokenFetchComplete(request_id,
access_token_info.token);
else
Finish(request_id, std::string());
}
token_fetch_tracker_.OnTokenFetchComplete(request_id, std::string());
void SafeBrowsingPrimaryAccountTokenFetcher::OnTokenTimeout(int request_id) {
Finish(request_id, std::string());
token_fetchers_.erase(request_id);
}
void SafeBrowsingPrimaryAccountTokenFetcher::Finish(
int request_id,
const std::string& access_token) {
if (callbacks_.contains(request_id)) {
std::move(callbacks_[request_id]).Run(access_token);
}
void SafeBrowsingPrimaryAccountTokenFetcher::OnTokenTimeout(int request_id) {
token_fetchers_.erase(request_id);
callbacks_.erase(request_id);
}
} // namespace safe_browsing
......@@ -9,6 +9,7 @@
#include "base/containers/flat_map.h"
#include "base/memory/weak_ptr.h"
#include "components/safe_browsing/core/browser/safe_browsing_token_fetch_tracker.h"
#include "components/safe_browsing/core/browser/safe_browsing_token_fetcher.h"
#include "components/signin/public/identity_manager/access_token_info.h"
#include "google_apis/gaia/google_service_auth_error.h"
......@@ -40,20 +41,15 @@ class SafeBrowsingPrimaryAccountTokenFetcher : public SafeBrowsingTokenFetcher {
GoogleServiceAuthError error,
signin::AccessTokenInfo access_token_info);
void OnTokenTimeout(int request_id);
void Finish(int request_id, const std::string& access_token);
// Reference to the identity manager to fetch from.
signin::IdentityManager* identity_manager_;
// The count of requests sent. This is used as an ID for requests.
int requests_sent_;
// Active fetchers, keyed by ID.
base::flat_map<int, std::unique_ptr<signin::AccessTokenFetcher>>
token_fetchers_;
// Active callbacks, keyed by ID.
base::flat_map<int, Callback> callbacks_;
SafeBrowsingTokenFetchTracker token_fetch_tracker_;
base::WeakPtrFactory<SafeBrowsingPrimaryAccountTokenFetcher>
weak_ptr_factory_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment