Commit 553040d9 authored by Patrick Noland's avatar Patrick Noland Committed by Commit Bot

[Feed] Fix auth header format

Also adds token invalidation when we get a 401 and a test for header
values being set correctly.

Bug: 869132
Change-Id: I61f963bf17a89f58d10583c51eb5234daa263639
Reviewed-on: https://chromium-review.googlesource.com/1155717Reviewed-by: default avatarFilip Gorski <fgorski@chromium.org>
Commit-Queue: Patrick Noland <pnoland@chromium.org>
Cr-Commit-Position: refs/heads/master@{#579204}
parent 2a29004a
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros.h"
#include "base/strings/stringprintf.h"
#include "base/values.h" #include "base/values.h"
#include "components/variations/net/variations_http_headers.h" #include "components/variations/net/variations_http_headers.h"
#include "google_apis/gaia/google_service_auth_error.h" #include "google_apis/gaia/google_service_auth_error.h"
...@@ -31,6 +32,8 @@ namespace { ...@@ -31,6 +32,8 @@ namespace {
constexpr char kApiKeyQueryParam[] = "key"; constexpr char kApiKeyQueryParam[] = "key";
constexpr char kAuthenticationScope[] = constexpr char kAuthenticationScope[] =
"https://www.googleapis.com/auth/googlenow"; "https://www.googleapis.com/auth/googlenow";
constexpr char kAuthorizationRequestHeaderFormat[] = "Bearer %s";
constexpr char kContentEncoding[] = "Content-Encoding"; constexpr char kContentEncoding[] = "Content-Encoding";
constexpr char kContentType[] = "application/octet-stream"; constexpr char kContentType[] = "application/octet-stream";
constexpr char kGzip[] = "gzip"; constexpr char kGzip[] = "gzip";
...@@ -59,15 +62,15 @@ class NetworkFetch { ...@@ -59,15 +62,15 @@ class NetworkFetch {
void StartAccessTokenFetch(); void StartAccessTokenFetch();
void AccessTokenFetchFinished(GoogleServiceAuthError error, void AccessTokenFetchFinished(GoogleServiceAuthError error,
identity::AccessTokenInfo access_token_info); identity::AccessTokenInfo access_token_info);
void StartLoader(const std::string& access_token); void StartLoader();
std::unique_ptr<network::SimpleURLLoader> MakeLoader( std::unique_ptr<network::SimpleURLLoader> MakeLoader();
const std::string& access_token);
net::HttpRequestHeaders MakeHeaders(const std::string& auth_header) const; net::HttpRequestHeaders MakeHeaders(const std::string& auth_header) const;
void PopulateRequestBody(network::SimpleURLLoader* loader); void PopulateRequestBody(network::SimpleURLLoader* loader);
void OnSimpleLoaderComplete(std::unique_ptr<std::string> response); void OnSimpleLoaderComplete(std::unique_ptr<std::string> response);
const GURL url_; const GURL url_;
const std::string request_type_; const std::string request_type_;
std::string access_token_;
const std::vector<uint8_t> request_body_; const std::vector<uint8_t> request_body_;
IdentityManager* const identity_manager_; IdentityManager* const identity_manager_;
std::unique_ptr<identity::PrimaryAccountAccessTokenFetcher> token_fetcher_; std::unique_ptr<identity::PrimaryAccountAccessTokenFetcher> token_fetcher_;
...@@ -96,7 +99,7 @@ void NetworkFetch::Start(FeedNetworkingHost::ResponseCallback done_callback) { ...@@ -96,7 +99,7 @@ void NetworkFetch::Start(FeedNetworkingHost::ResponseCallback done_callback) {
done_callback_ = std::move(done_callback); done_callback_ = std::move(done_callback);
if (!identity_manager_->HasPrimaryAccount()) { if (!identity_manager_->HasPrimaryAccount()) {
StartLoader(std::string()); StartLoader();
return; return;
} }
...@@ -119,19 +122,24 @@ void NetworkFetch::AccessTokenFetchFinished( ...@@ -119,19 +122,24 @@ void NetworkFetch::AccessTokenFetchFinished(
identity::AccessTokenInfo access_token_info) { identity::AccessTokenInfo access_token_info) {
UMA_HISTOGRAM_ENUMERATION("ContentSuggestions.Feed.TokenFetchStatus", UMA_HISTOGRAM_ENUMERATION("ContentSuggestions.Feed.TokenFetchStatus",
error.state(), GoogleServiceAuthError::NUM_STATES); error.state(), GoogleServiceAuthError::NUM_STATES);
StartLoader(access_token_info.token); access_token_ = access_token_info.token;
StartLoader();
} }
void NetworkFetch::StartLoader(const std::string& access_token) { void NetworkFetch::StartLoader() {
simple_loader_ = MakeLoader(access_token); simple_loader_ = MakeLoader();
simple_loader_->DownloadToStringOfUnboundedSizeUntilCrashAndDie( simple_loader_->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
loader_factory_, base::BindOnce(&NetworkFetch::OnSimpleLoaderComplete, loader_factory_, base::BindOnce(&NetworkFetch::OnSimpleLoaderComplete,
base::Unretained(this))); base::Unretained(this)));
} }
std::unique_ptr<network::SimpleURLLoader> NetworkFetch::MakeLoader( std::unique_ptr<network::SimpleURLLoader> NetworkFetch::MakeLoader() {
const std::string& access_token) { std::string auth_header =
net::HttpRequestHeaders headers = MakeHeaders(access_token); access_token_.empty()
? std::string()
: base::StringPrintf(kAuthorizationRequestHeaderFormat,
access_token_.c_str());
net::HttpRequestHeaders headers = MakeHeaders(auth_header);
// TODO(pnoland): Add data use measurement once it's supported for simple // TODO(pnoland): Add data use measurement once it's supported for simple
// url loader. // url loader.
net::NetworkTrafficAnnotationTag traffic_annotation = net::NetworkTrafficAnnotationTag traffic_annotation =
...@@ -160,7 +168,7 @@ std::unique_ptr<network::SimpleURLLoader> NetworkFetch::MakeLoader( ...@@ -160,7 +168,7 @@ std::unique_ptr<network::SimpleURLLoader> NetworkFetch::MakeLoader(
} }
})"); })");
GURL url(url_); GURL url(url_);
if (access_token.empty() && !api_key_.empty()) if (access_token_.empty() && !api_key_.empty())
url = net::AppendQueryParameter(url_, kApiKeyQueryParam, api_key_); url = net::AppendQueryParameter(url_, kApiKeyQueryParam, api_key_);
auto resource_request = std::make_unique<network::ResourceRequest>(); auto resource_request = std::make_unique<network::ResourceRequest>();
...@@ -223,6 +231,14 @@ void NetworkFetch::OnSimpleLoaderComplete( ...@@ -223,6 +231,14 @@ void NetworkFetch::OnSimpleLoaderComplete(
if (response) { if (response) {
status_code = simple_loader_->ResponseInfo()->headers->response_code(); status_code = simple_loader_->ResponseInfo()->headers->response_code();
if (status_code == net::HTTP_UNAUTHORIZED) {
OAuth2TokenService::ScopeSet scopes{kAuthenticationScope};
std::string account_id =
identity_manager_->GetPrimaryAccountInfo().account_id;
identity_manager_->RemoveAccessTokenFromCache(account_id, scopes,
access_token_);
}
const uint8_t* begin = reinterpret_cast<const uint8_t*>(response->data()); const uint8_t* begin = reinterpret_cast<const uint8_t*>(response->data());
const uint8_t* end = begin + response->size(); const uint8_t* end = begin + response->size();
response_body.assign(begin, end); response_body.assign(begin, end);
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "base/strings/string_number_conversions.h" #include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h" #include "base/strings/string_split.h"
#include "base/test/bind_test_util.h"
#include "base/test/metrics/histogram_tester.h" #include "base/test/metrics/histogram_tester.h"
#include "base/test/test_mock_time_task_runner.h" #include "base/test/test_mock_time_task_runner.h"
#include "net/http/http_response_headers.h" #include "net/http/http_response_headers.h"
...@@ -127,6 +128,10 @@ class FeedNetworkingHostTest : public testing::Test { ...@@ -127,6 +128,10 @@ class FeedNetworkingHostTest : public testing::Test {
response_string); response_string);
} }
network::TestURLLoaderFactory* test_factory() {
return &test_factory_;
}
private: private:
scoped_refptr<base::TestMockTimeTaskRunner> mock_task_runner_; scoped_refptr<base::TestMockTimeTaskRunner> mock_task_runner_;
identity::IdentityTestEnvironment identity_test_env_; identity::IdentityTestEnvironment identity_test_env_;
...@@ -222,8 +227,30 @@ TEST_F(FeedNetworkingHostTest, ShouldReportNonProtocolErrorCodes) { ...@@ -222,8 +227,30 @@ TEST_F(FeedNetworkingHostTest, ShouldReportNonProtocolErrorCodes) {
} }
} }
// TODO(pnoland): Add a test that verifies request headers TEST_F(FeedNetworkingHostTest, ShouldSetHeadersCorrectly) {
// specify gzip. MockResponseDoneCallback done_callback;
net::HttpRequestHeaders headers;
base::RunLoop interceptor_run_loop;
base::HistogramTester histogram_tester;
test_factory()->SetInterceptor(
base::BindLambdaForTesting([&](const network::ResourceRequest& request) {
headers = request.headers;
interceptor_run_loop.Quit();
}));
SendRequestAndRespond("http://foobar.com/feed", "POST", "", "",
net::HTTP_OK, network::URLLoaderCompletionStatus(),
&done_callback);
std::string content_encoding;
std::string authorization;
EXPECT_TRUE(headers.GetHeader("content-encoding", &content_encoding));
EXPECT_TRUE(headers.GetHeader("Authorization", &authorization));
EXPECT_EQ(content_encoding, "gzip");
EXPECT_EQ(authorization, "Bearer access_token");
}
TEST_F(FeedNetworkingHostTest, ShouldReportSizeHistograms) { TEST_F(FeedNetworkingHostTest, ShouldReportSizeHistograms) {
std::string uncompressed_request_string(2048, 'a'); std::string uncompressed_request_string(2048, 'a');
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment