Commit 2a0ff283 authored by zelidrag@chromium.org's avatar zelidrag@chromium.org

Handling of multiple concurrent requests from different clients in OAuth2TokenService

BUG=268937
TEST=OAuth2TokenServiceTest.SameScopesRequestedForDifferentClients
TBR=tim (for chrome/browser/sync)

Review URL: https://chromiumcodereview.appspot.com/22581003

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@221303 0039d316-1c4b-4281-b951-d872f2087c98
parent 64c15a3b
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "content/public/test/test_browser_thread_bundle.h" #include "content/public/test/test_browser_thread_bundle.h"
#include "google_apis/gaia/gaia_constants.h" #include "google_apis/gaia/gaia_constants.h"
#include "google_apis/gaia/google_service_auth_error.h" #include "google_apis/gaia/google_service_auth_error.h"
#include "google_apis/gaia/oauth2_access_token_fetcher.h"
#include "net/http/http_status_code.h" #include "net/http/http_status_code.h"
#include "net/url_request/test_url_fetcher_factory.h" #include "net/url_request/test_url_fetcher_factory.h"
#include "net/url_request/url_request_context_getter.h" #include "net/url_request/url_request_context_getter.h"
...@@ -142,6 +143,8 @@ class FakeAndroidProfileOAuth2TokenService ...@@ -142,6 +143,8 @@ class FakeAndroidProfileOAuth2TokenService
#endif #endif
} // namespace
class UserPolicySigninServiceTest : public testing::Test { class UserPolicySigninServiceTest : public testing::Test {
public: public:
UserPolicySigninServiceTest() UserPolicySigninServiceTest()
...@@ -219,6 +222,7 @@ class UserPolicySigninServiceTest : public testing::Test { ...@@ -219,6 +222,7 @@ class UserPolicySigninServiceTest : public testing::Test {
testing_browser_process->SetBrowserPolicyConnector(NULL); testing_browser_process->SetBrowserPolicyConnector(NULL);
base::RunLoop run_loop; base::RunLoop run_loop;
run_loop.RunUntilIdle(); run_loop.RunUntilIdle();
ResetLastFetcherId();
} }
#if defined(OS_ANDROID) #if defined(OS_ANDROID)
...@@ -348,6 +352,10 @@ class UserPolicySigninServiceTest : public testing::Test { ...@@ -348,6 +352,10 @@ class UserPolicySigninServiceTest : public testing::Test {
Mock::VerifyAndClearExpectations(this); Mock::VerifyAndClearExpectations(this);
} }
void ResetLastFetcherId() {
OAuth2AccessTokenFetcher::ResetLastFetcherIdForTest();
}
scoped_ptr<TestingProfile> profile_; scoped_ptr<TestingProfile> profile_;
// Weak pointer to a MockUserCloudPolicyStore - lifetime is managed by the // Weak pointer to a MockUserCloudPolicyStore - lifetime is managed by the
// UserCloudPolicyManager. // UserCloudPolicyManager.
...@@ -786,6 +794,7 @@ TEST_F(UserPolicySigninServiceTest, SignOutThenSignInAgain) { ...@@ -786,6 +794,7 @@ TEST_F(UserPolicySigninServiceTest, SignOutThenSignInAgain) {
ASSERT_FALSE(manager_->core()->service()); ASSERT_FALSE(manager_->core()->service());
// Now sign in again. // Now sign in again.
ResetLastFetcherId();
ASSERT_NO_FATAL_FAILURE(TestSuccessfulSignin()); ASSERT_NO_FATAL_FAILURE(TestSuccessfulSignin());
} }
...@@ -843,6 +852,4 @@ TEST_F(UserPolicySigninServiceTest, PolicyFetchFailureDisableManagement) { ...@@ -843,6 +852,4 @@ TEST_F(UserPolicySigninServiceTest, PolicyFetchFailureDisableManagement) {
#endif #endif
} }
} // namespace
} // namespace policy } // namespace policy
...@@ -168,13 +168,15 @@ GoogleServiceAuthError ProfileOAuth2TokenService::GetAuthStatus() const { ...@@ -168,13 +168,15 @@ GoogleServiceAuthError ProfileOAuth2TokenService::GetAuthStatus() const {
} }
void ProfileOAuth2TokenService::RegisterCacheEntry( void ProfileOAuth2TokenService::RegisterCacheEntry(
const std::string& client_id,
const std::string& refresh_token, const std::string& refresh_token,
const ScopeSet& scopes, const ScopeSet& scopes,
const std::string& access_token, const std::string& access_token,
const base::Time& expiration_date) { const base::Time& expiration_date) {
if (ShouldCacheForRefreshToken(TokenServiceFactory::GetForProfile(profile_), if (ShouldCacheForRefreshToken(TokenServiceFactory::GetForProfile(profile_),
refresh_token)) { refresh_token)) {
OAuth2TokenService::RegisterCacheEntry(refresh_token, OAuth2TokenService::RegisterCacheEntry(client_id,
refresh_token,
scopes, scopes,
access_token, access_token,
expiration_date); expiration_date);
......
...@@ -103,7 +103,8 @@ class ProfileOAuth2TokenService : public OAuth2TokenService, ...@@ -103,7 +103,8 @@ class ProfileOAuth2TokenService : public OAuth2TokenService,
// logs back in with a different account, then any in-flight token // logs back in with a different account, then any in-flight token
// fetches will be for the old account's refresh token. Therefore // fetches will be for the old account's refresh token. Therefore
// when they come back, they shouldn't be cached. // when they come back, they shouldn't be cached.
virtual void RegisterCacheEntry(const std::string& refresh_token, virtual void RegisterCacheEntry(const std::string& client_id,
const std::string& refresh_token,
const ScopeSet& scopes, const ScopeSet& scopes,
const std::string& access_token, const std::string& access_token,
const base::Time& expiration_date) OVERRIDE; const base::Time& expiration_date) OVERRIDE;
......
...@@ -331,7 +331,7 @@ TEST_F(ProfileOAuth2TokenServiceTest, TokenServiceUpdateClearsCache) { ...@@ -331,7 +331,7 @@ TEST_F(ProfileOAuth2TokenServiceTest, TokenServiceUpdateClearsCache) {
request = oauth2_service_->StartRequest(scope_list, &consumer_); request = oauth2_service_->StartRequest(scope_list, &consumer_);
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
fetcher = factory_.GetFetcherByID(0); fetcher = factory_.GetFetcherByID(1);
fetcher->set_response_code(net::HTTP_OK); fetcher->set_response_code(net::HTTP_OK);
fetcher->SetResponseString(GetValidTokenResponse("another token", 3600)); fetcher->SetResponseString(GetValidTokenResponse("another token", 3600));
fetcher->delegate()->OnURLFetchComplete(fetcher); fetcher->delegate()->OnURLFetchComplete(fetcher);
......
...@@ -319,8 +319,11 @@ scoped_ptr<OAuth2TokenService::Request> FakeOAuth2TokenService::StartRequest( ...@@ -319,8 +319,11 @@ scoped_ptr<OAuth2TokenService::Request> FakeOAuth2TokenService::StartRequest(
OAuth2TokenService::Consumer* consumer) { OAuth2TokenService::Consumer* consumer) {
// Ensure token in question is cached and never expires. Request will succeed // Ensure token in question is cached and never expires. Request will succeed
// without network IO. // without network IO.
RegisterCacheEntry(GetRefreshToken(), scopes, "access_token", RegisterCacheEntry("test_client_id",
base::Time::Max()); GetRefreshToken(),
scopes,
"access_token",
base::Time::Max());
return ProfileOAuth2TokenService::StartRequest(scopes, consumer); return ProfileOAuth2TokenService::StartRequest(scopes, consumer);
} }
......
...@@ -56,13 +56,14 @@ static GoogleServiceAuthError CreateAuthError(URLRequestStatus status) { ...@@ -56,13 +56,14 @@ static GoogleServiceAuthError CreateAuthError(URLRequestStatus status) {
} }
} }
static URLFetcher* CreateFetcher(URLRequestContextGetter* getter, static URLFetcher* CreateFetcher(int id,
URLRequestContextGetter* getter,
const GURL& url, const GURL& url,
const std::string& body, const std::string& body,
URLFetcherDelegate* delegate) { URLFetcherDelegate* delegate) {
bool empty_body = body.empty(); bool empty_body = body.empty();
URLFetcher* result = net::URLFetcher::Create( URLFetcher* result = net::URLFetcher::Create(
0, url, id, url,
empty_body ? URLFetcher::GET : URLFetcher::POST, empty_body ? URLFetcher::GET : URLFetcher::POST,
delegate); delegate);
...@@ -82,6 +83,8 @@ static URLFetcher* CreateFetcher(URLRequestContextGetter* getter, ...@@ -82,6 +83,8 @@ static URLFetcher* CreateFetcher(URLRequestContextGetter* getter,
} }
} // namespace } // namespace
int OAuth2AccessTokenFetcher::last_fetcher_id_ = 0;
OAuth2AccessTokenFetcher::OAuth2AccessTokenFetcher( OAuth2AccessTokenFetcher::OAuth2AccessTokenFetcher(
OAuth2AccessTokenConsumer* consumer, OAuth2AccessTokenConsumer* consumer,
URLRequestContextGetter* getter) URLRequestContextGetter* getter)
...@@ -110,6 +113,7 @@ void OAuth2AccessTokenFetcher::StartGetAccessToken() { ...@@ -110,6 +113,7 @@ void OAuth2AccessTokenFetcher::StartGetAccessToken() {
CHECK_EQ(INITIAL, state_); CHECK_EQ(INITIAL, state_);
state_ = GET_ACCESS_TOKEN_STARTED; state_ = GET_ACCESS_TOKEN_STARTED;
fetcher_.reset(CreateFetcher( fetcher_.reset(CreateFetcher(
last_fetcher_id_++,
getter_, getter_,
MakeGetAccessTokenUrl(), MakeGetAccessTokenUrl(),
MakeGetAccessTokenBody( MakeGetAccessTokenBody(
...@@ -231,3 +235,8 @@ bool OAuth2AccessTokenFetcher::ParseGetAccessTokenResponse( ...@@ -231,3 +235,8 @@ bool OAuth2AccessTokenFetcher::ParseGetAccessTokenResponse(
return dict->GetString(kAccessTokenKey, access_token) && return dict->GetString(kAccessTokenKey, access_token) &&
dict->GetInteger(kExpiresInKey, expires_in); dict->GetInteger(kExpiresInKey, expires_in);
} }
// static
void OAuth2AccessTokenFetcher::ResetLastFetcherIdForTest() {
last_fetcher_id_ = 0;
}
...@@ -26,6 +26,10 @@ class URLRequestContextGetter; ...@@ -26,6 +26,10 @@ class URLRequestContextGetter;
class URLRequestStatus; class URLRequestStatus;
} }
namespace policy {
class UserPolicySigninServiceTest;
}
// Abstracts the details to get OAuth2 access token token from // Abstracts the details to get OAuth2 access token token from
// OAuth2 refresh token. // OAuth2 refresh token.
// See "Using the Refresh Token" section in: // See "Using the Refresh Token" section in:
...@@ -94,6 +98,9 @@ class OAuth2AccessTokenFetcher : public net::URLFetcherDelegate { ...@@ -94,6 +98,9 @@ class OAuth2AccessTokenFetcher : public net::URLFetcherDelegate {
std::string* access_token, std::string* access_token,
int* expires_in); int* expires_in);
// Resets |last_fetcher_id_| to 0.
static void ResetLastFetcherIdForTest();
// State that is set during construction. // State that is set during construction.
OAuth2AccessTokenConsumer* const consumer_; OAuth2AccessTokenConsumer* const consumer_;
net::URLRequestContextGetter* const getter_; net::URLRequestContextGetter* const getter_;
...@@ -106,7 +113,12 @@ class OAuth2AccessTokenFetcher : public net::URLFetcherDelegate { ...@@ -106,7 +113,12 @@ class OAuth2AccessTokenFetcher : public net::URLFetcherDelegate {
std::string refresh_token_; std::string refresh_token_;
std::vector<std::string> scopes_; std::vector<std::string> scopes_;
// The last fetcher id.
static int last_fetcher_id_;
friend class OAuth2AccessTokenFetcherTest; friend class OAuth2AccessTokenFetcherTest;
friend class OAuth2TokenServiceTest;
friend class policy::UserPolicySigninServiceTest;
FRIEND_TEST_ALL_PREFIXES(OAuth2AccessTokenFetcherTest, FRIEND_TEST_ALL_PREFIXES(OAuth2AccessTokenFetcherTest,
ParseGetAccessTokenResponse); ParseGetAccessTokenResponse);
FRIEND_TEST_ALL_PREFIXES(OAuth2AccessTokenFetcherTest, FRIEND_TEST_ALL_PREFIXES(OAuth2AccessTokenFetcherTest,
......
This diff is collapsed.
...@@ -10,12 +10,16 @@ ...@@ -10,12 +10,16 @@
#include <string> #include <string>
#include "base/basictypes.h" #include "base/basictypes.h"
#include "base/gtest_prod_util.h"
#include "base/memory/scoped_ptr.h" #include "base/memory/scoped_ptr.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/observer_list.h" #include "base/observer_list.h"
#include "base/threading/non_thread_safe.h" #include "base/threading/non_thread_safe.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "base/timer/timer.h"
#include "google_apis/gaia/google_service_auth_error.h" #include "google_apis/gaia/google_service_auth_error.h"
#include "google_apis/gaia/oauth2_access_token_consumer.h"
#include "google_apis/gaia/oauth2_access_token_fetcher.h"
namespace net { namespace net {
class URLRequestContextGetter; class URLRequestContextGetter;
...@@ -145,8 +149,23 @@ class OAuth2TokenService : public base::NonThreadSafe { ...@@ -145,8 +149,23 @@ class OAuth2TokenService : public base::NonThreadSafe {
// Return the current number of entries in the cache. // Return the current number of entries in the cache.
int cache_size_for_testing() const; int cache_size_for_testing() const;
void set_max_authorization_token_fetch_retries_for_testing(int max_retries); void set_max_authorization_token_fetch_retries_for_testing(int max_retries);
// Returns the current number of pending fetchers matching given params.
size_t GetNumPendingRequestsForTesting(
const std::string& client_id,
const std::string& refresh_token,
const ScopeSet& scopes) const;
protected: protected:
struct ClientScopeSet {
ClientScopeSet(const std::string& client_id,
const ScopeSet& scopes);
~ClientScopeSet();
bool operator<(const ClientScopeSet& set) const;
std::string client_id;
ScopeSet scopes;
};
// Implements a cancelable |OAuth2TokenService::Request|, which should be // Implements a cancelable |OAuth2TokenService::Request|, which should be
// operated on the UI thread. // operated on the UI thread.
// TODO(davidroche): move this out of header file. // TODO(davidroche): move this out of header file.
...@@ -178,19 +197,20 @@ class OAuth2TokenService : public base::NonThreadSafe { ...@@ -178,19 +197,20 @@ class OAuth2TokenService : public base::NonThreadSafe {
// Add a new entry to the cache. // Add a new entry to the cache.
// Subclasses can override if there are implementation-specific reasons // Subclasses can override if there are implementation-specific reasons
// that an access token should ever not be cached. // that an access token should ever not be cached.
virtual void RegisterCacheEntry(const std::string& refresh_token, virtual void RegisterCacheEntry(const std::string& client_id,
const std::string& refresh_token,
const ScopeSet& scopes, const ScopeSet& scopes,
const std::string& access_token, const std::string& access_token,
const base::Time& expiration_date); const base::Time& expiration_date);
// Returns true if GetCacheEntry would return a valid cache entry for the // Returns true if GetCacheEntry would return a valid cache entry for the
// given scopes. // given scopes.
bool HasCacheEntry(const ScopeSet& scopes); bool HasCacheEntry(const ClientScopeSet& client_scopes);
// Posts a task to fire the Consumer callback with the cached token. Must // Posts a task to fire the Consumer callback with the cached token. Must
// Must only be called if HasCacheEntry() returns true. // Must only be called if HasCacheEntry() returns true.
void StartCacheLookupRequest(RequestImpl* request, void StartCacheLookupRequest(RequestImpl* request,
const ScopeSet& scopes, const ClientScopeSet& client_scopes,
Consumer* consumer); Consumer* consumer);
// Clears the internal token cache. // Clears the internal token cache.
...@@ -208,10 +228,6 @@ class OAuth2TokenService : public base::NonThreadSafe { ...@@ -208,10 +228,6 @@ class OAuth2TokenService : public base::NonThreadSafe {
void FireRefreshTokensLoaded(); void FireRefreshTokensLoaded();
void FireRefreshTokensCleared(); void FireRefreshTokensCleared();
// Derived classes must provide a request context used for fetching access
// tokens with the |StartRequest| method.
virtual net::URLRequestContextGetter* GetRequestContext() = 0;
// Fetches an OAuth token for the specified client/scopes. Virtual so it can // Fetches an OAuth token for the specified client/scopes. Virtual so it can
// be overridden for tests and for platform-specific behavior on Android. // be overridden for tests and for platform-specific behavior on Android.
virtual void FetchOAuth2Token(RequestImpl* request, virtual void FetchOAuth2Token(RequestImpl* request,
...@@ -219,13 +235,32 @@ class OAuth2TokenService : public base::NonThreadSafe { ...@@ -219,13 +235,32 @@ class OAuth2TokenService : public base::NonThreadSafe {
const std::string& client_id, const std::string& client_id,
const std::string& client_secret, const std::string& client_secret,
const ScopeSet& scopes); const ScopeSet& scopes);
private: private:
// Class that fetches an OAuth2 access token for a given set of scopes and
// OAuth2 refresh token.
class Fetcher; class Fetcher;
friend class Fetcher; friend class Fetcher;
// The parameters used to fetch an OAuth2 access token.
struct FetchParameters {
FetchParameters(const std::string& client_id,
const std::string& refresh_token,
const ScopeSet& scopes);
~FetchParameters();
bool operator<(const FetchParameters& params) const;
// OAuth2 client id.
std::string client_id;
// Refresh token used for minting access tokens within this request.
std::string refresh_token;
// URL scopes for the requested access token.
ScopeSet scopes;
};
typedef std::map<FetchParameters, Fetcher*> PendingFetcherMap;
// Derived classes must provide a request context used for fetching access
// tokens with the |StartRequest| method.
virtual net::URLRequestContextGetter* GetRequestContext() = 0;
// Struct that contains the information of an OAuth2 access token. // Struct that contains the information of an OAuth2 access token.
struct CacheEntry { struct CacheEntry {
std::string access_token; std::string access_token;
...@@ -244,14 +279,14 @@ class OAuth2TokenService : public base::NonThreadSafe { ...@@ -244,14 +279,14 @@ class OAuth2TokenService : public base::NonThreadSafe {
// Returns a currently valid OAuth2 access token for the given set of scopes, // Returns a currently valid OAuth2 access token for the given set of scopes,
// or NULL if none have been cached. Note the user of this method should // or NULL if none have been cached. Note the user of this method should
// ensure no entry with the same |scopes| is added before the usage of the // ensure no entry with the same |client_scopes| is added before the usage of
// returned entry is done. // the returned entry is done.
const CacheEntry* GetCacheEntry(const ScopeSet& scopes); const CacheEntry* GetCacheEntry(const ClientScopeSet& client_scopes);
// Removes an access token for the given set of scopes from the cache. // Removes an access token for the given set of scopes from the cache.
// Returns true if the entry was removed, otherwise false. // Returns true if the entry was removed, otherwise false.
bool RemoveCacheEntry(const OAuth2TokenService::ScopeSet& scopes, bool RemoveCacheEntry(const ClientScopeSet& client_scopes,
const std::string& token_to_remove); const std::string& token_to_remove);
...@@ -262,15 +297,12 @@ class OAuth2TokenService : public base::NonThreadSafe { ...@@ -262,15 +297,12 @@ class OAuth2TokenService : public base::NonThreadSafe {
void CancelFetchers(std::vector<Fetcher*> fetchers_to_cancel); void CancelFetchers(std::vector<Fetcher*> fetchers_to_cancel);
// The cache of currently valid tokens. // The cache of currently valid tokens.
typedef std::map<ScopeSet, CacheEntry> TokenCache; typedef std::map<ClientScopeSet, CacheEntry> TokenCache;
TokenCache token_cache_; TokenCache token_cache_;
// The parameters (refresh token and scope set) used to fetch an OAuth2 access
// token.
typedef std::pair<std::string, ScopeSet> FetchParameters;
// A map from fetch parameters to a fetcher that is fetching an OAuth2 access // A map from fetch parameters to a fetcher that is fetching an OAuth2 access
// token using these parameters. // token using these parameters.
std::map<FetchParameters, Fetcher*> pending_fetchers_; PendingFetcherMap pending_fetchers_;
// List of observers to notify when token availability changes. // List of observers to notify when token availability changes.
// Makes sure list is empty on destruction. // Makes sure list is empty on destruction.
...@@ -279,6 +311,11 @@ class OAuth2TokenService : public base::NonThreadSafe { ...@@ -279,6 +311,11 @@ class OAuth2TokenService : public base::NonThreadSafe {
// Maximum number of retries in fetching an OAuth2 access token. // Maximum number of retries in fetching an OAuth2 access token.
static int max_fetch_retry_num_; static int max_fetch_retry_num_;
FRIEND_TEST_ALL_PREFIXES(OAuth2TokenServiceTest, ClientScopeSetOrderTest);
FRIEND_TEST_ALL_PREFIXES(OAuth2TokenServiceTest, FetchParametersOrderTest);
FRIEND_TEST_ALL_PREFIXES(OAuth2TokenServiceTest,
SameScopesRequestedForDifferentClients);
DISALLOW_COPY_AND_ASSIGN(OAuth2TokenService); DISALLOW_COPY_AND_ASSIGN(OAuth2TokenService);
}; };
......
...@@ -275,6 +275,10 @@ void TestURLFetcherFactory::RemoveFetcherFromMap(int id) { ...@@ -275,6 +275,10 @@ void TestURLFetcherFactory::RemoveFetcherFromMap(int id) {
fetchers_.erase(i); fetchers_.erase(i);
} }
size_t TestURLFetcherFactory::GetFetcherCount() const {
return fetchers_.size();
}
void TestURLFetcherFactory::SetDelegateForTests( void TestURLFetcherFactory::SetDelegateForTests(
TestURLFetcherDelegateForTests* delegate_for_tests) { TestURLFetcherDelegateForTests* delegate_for_tests) {
delegate_for_tests_ = delegate_for_tests; delegate_for_tests_ = delegate_for_tests;
......
...@@ -241,6 +241,7 @@ class TestURLFetcherFactory : public URLFetcherFactory, ...@@ -241,6 +241,7 @@ class TestURLFetcherFactory : public URLFetcherFactory,
URLFetcherDelegate* d) OVERRIDE; URLFetcherDelegate* d) OVERRIDE;
TestURLFetcher* GetFetcherByID(int id) const; TestURLFetcher* GetFetcherByID(int id) const;
void RemoveFetcherFromMap(int id); void RemoveFetcherFromMap(int id);
size_t GetFetcherCount() const;
void SetDelegateForTests(TestURLFetcherDelegateForTests* delegate_for_tests); void SetDelegateForTests(TestURLFetcherDelegateForTests* delegate_for_tests);
void set_remove_fetcher_on_delete(bool remove_fetcher_on_delete) { void set_remove_fetcher_on_delete(bool remove_fetcher_on_delete) {
remove_fetcher_on_delete_ = remove_fetcher_on_delete; remove_fetcher_on_delete_ = remove_fetcher_on_delete;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment