Commit 612af1b6 authored by Lily Houghton's avatar Lily Houghton Committed by Commit Bot

Move predictor HSTS redirection to NetworkService

This CL moves the HSTS redirection call in the predictor's preconnect
method into the NetworkService (creating an internal analogue to
GetHSTSRedirectOnIOThread).

Bug: 821027
Cq-Include-Trybots: luci.chromium.try:linux_mojo
Change-Id: I5070de29aa219d2c038634c584b15c2a714a128f
Reviewed-on: https://chromium-review.googlesource.com/1152082Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Commit-Queue: Lily Houghton <lilyhoughton@chromium.org>
Cr-Commit-Position: refs/heads/master@{#579510}
parent 5124b6d3
...@@ -740,14 +740,11 @@ enum SubresourceValue { ...@@ -740,14 +740,11 @@ enum SubresourceValue {
}; };
void Predictor::PreconnectUrlOnIOThread( void Predictor::PreconnectUrlOnIOThread(
const GURL& original_url, const GURL& url,
const GURL& site_for_cookies, const GURL& site_for_cookies,
UrlInfo::ResolutionMotivation motivation, UrlInfo::ResolutionMotivation motivation,
bool allow_credentials, bool allow_credentials,
int count) { int count) {
// Skip the HSTS redirect.
GURL url = GetHSTSRedirectOnIOThread(original_url);
// TODO(csharrison): The observer should only be notified after the null check // TODO(csharrison): The observer should only be notified after the null check
// for the URLRequestContextGetter. The predictor tests should be fixed to // for the URLRequestContextGetter. The predictor tests should be fixed to
// allow for this, as they currently expect a callback with no getter. // allow for this, as they currently expect a callback with no getter.
......
...@@ -384,87 +384,6 @@ class TestPredictorObserver : public PredictorObserver { ...@@ -384,87 +384,6 @@ class TestPredictorObserver : public PredictorObserver {
std::vector<GURL> preconnected_urls_; std::vector<GURL> preconnected_urls_;
}; };
// Tests that preconnects apply the HSTS list.
TEST_F(PredictorTest, HSTSRedirect) {
const GURL kHttpUrl("http://example.com");
const GURL kHttpsUrl("https://example.com");
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
net::TransportSecurityState state;
state.AddHSTS(kHttpUrl.host(), expiry, false);
Predictor predictor(true);
TestPredictorObserver observer;
predictor.SetObserver(&observer);
predictor.SetTransportSecurityState(&state);
predictor.PreconnectUrl(kHttpUrl, GURL(), UrlInfo::OMNIBOX_MOTIVATED, true,
2);
ASSERT_EQ(1u, observer.preconnected_urls_.size());
EXPECT_EQ(kHttpsUrl, observer.preconnected_urls_[0]);
predictor.Shutdown();
}
// Tests that preconnecting a URL on the HSTS list preconnects the subresources
// for the SSL version.
TEST_F(PredictorTest, HSTSRedirectSubresources) {
const GURL kHttpUrl("http://example.com");
const GURL kHttpsUrl("https://example.com");
const GURL kSubresourceUrl("https://images.example.com");
const double kUseRate = 23.4;
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
net::TransportSecurityState state;
state.AddHSTS(kHttpUrl.host(), expiry, false);
SimplePredictor predictor(true);
TestPredictorObserver observer;
predictor.SetObserver(&observer);
predictor.SetTransportSecurityState(&state);
std::unique_ptr<base::ListValue> referral_list(NewEmptySerializationList());
AddToSerializedList(
kHttpsUrl, kSubresourceUrl, kUseRate, referral_list.get());
predictor.DeserializeReferrers(*referral_list.get());
predictor.PreconnectUrlAndSubresources(kHttpUrl, GURL());
ASSERT_EQ(2u, observer.preconnected_urls_.size());
EXPECT_EQ(kHttpsUrl, observer.preconnected_urls_[0]);
EXPECT_EQ(kSubresourceUrl, observer.preconnected_urls_[1]);
predictor.Shutdown();
}
TEST_F(PredictorTest, HSTSRedirectLearnedSubresource) {
const GURL kHttpUrl("http://example.com");
const GURL kHttpsUrl("https://example.com");
const GURL kSubresourceUrl("https://images.example.com");
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
net::TransportSecurityState state;
state.AddHSTS(kHttpUrl.host(), expiry, false);
SimplePredictor predictor(true);
TestPredictorObserver observer;
predictor.SetObserver(&observer);
predictor.SetTransportSecurityState(&state);
// Note that the predictor would also learn the HSTS redirect from kHttpUrl to
// kHttpsUrl during the navigation.
predictor.LearnFromNavigation(kHttpUrl, kSubresourceUrl);
predictor.PreconnectUrlAndSubresources(kHttpUrl, GURL());
ASSERT_EQ(2u, observer.preconnected_urls_.size());
EXPECT_EQ(kHttpsUrl, observer.preconnected_urls_[0]);
EXPECT_EQ(kSubresourceUrl, observer.preconnected_urls_[1]);
predictor.Shutdown();
}
TEST_F(PredictorTest, NoProxyService) { TEST_F(PredictorTest, NoProxyService) {
// Don't actually try to resolve names. // Don't actually try to resolve names.
Predictor::set_max_parallel_resolves(0); Predictor::set_max_parallel_resolves(0);
......
...@@ -801,9 +801,11 @@ void NetworkContext::SetFailingHttpTransactionForTesting( ...@@ -801,9 +801,11 @@ void NetworkContext::SetFailingHttpTransactionForTesting(
} }
void NetworkContext::PreconnectSockets(uint32_t num_streams, void NetworkContext::PreconnectSockets(uint32_t num_streams,
const GURL& url, const GURL& original_url,
int32_t load_flags, int32_t load_flags,
bool privacy_mode_enabled) { bool privacy_mode_enabled) {
GURL url = GetHSTSRedirect(original_url);
// |PreconnectSockets| may receive arguments from the renderer, which is not // |PreconnectSockets| may receive arguments from the renderer, which is not
// guaranteed to validate them. // guaranteed to validate them.
if (num_streams == 0) if (num_streams == 0)
...@@ -1223,4 +1225,20 @@ URLRequestContextOwner NetworkContext::MakeURLRequestContext() { ...@@ -1223,4 +1225,20 @@ URLRequestContextOwner NetworkContext::MakeURLRequestContext() {
return result; return result;
} }
GURL NetworkContext::GetHSTSRedirect(const GURL& original_url) {
// TODO(lilyhoughton) This needs to be gotten rid of once explicit
// construction with a URLRequestContext is no longer supported.
if (!url_request_context_->transport_security_state() ||
!original_url.SchemeIs("http") ||
!url_request_context_->transport_security_state()->ShouldUpgradeToSSL(
original_url.host())) {
return original_url;
}
url::Replacements<char> replacements;
const char kNewScheme[] = "https";
replacements.SetScheme(kNewScheme, url::Component(0, strlen(kNewScheme)));
return original_url.ReplaceComponents(replacements);
}
} // namespace network } // namespace network
...@@ -241,6 +241,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext ...@@ -241,6 +241,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext
URLRequestContextOwner MakeURLRequestContext(); URLRequestContextOwner MakeURLRequestContext();
GURL GetHSTSRedirect(const GURL& original_url);
NetworkService* const network_service_; NetworkService* const network_service_;
std::unique_ptr<ResourceScheduler> resource_scheduler_; std::unique_ptr<ResourceScheduler> resource_scheduler_;
......
...@@ -59,6 +59,7 @@ ...@@ -59,6 +59,7 @@
#include "net/proxy_resolution/proxy_config.h" #include "net/proxy_resolution/proxy_config.h"
#include "net/proxy_resolution/proxy_info.h" #include "net/proxy_resolution/proxy_info.h"
#include "net/proxy_resolution/proxy_resolution_service.h" #include "net/proxy_resolution/proxy_resolution_service.h"
#include "net/socket/ssl_client_socket_pool.h"
#include "net/socket/transport_client_socket_pool.h" #include "net/socket/transport_client_socket_pool.h"
#include "net/ssl/channel_id_service.h" #include "net/ssl/channel_id_service.h"
#include "net/ssl/channel_id_store.h" #include "net/ssl/channel_id_store.h"
...@@ -209,6 +210,39 @@ class NetworkContextTest : public testing::Test, ...@@ -209,6 +210,39 @@ class NetworkContextTest : public testing::Test,
return value; return value;
} }
// Looks up a value with the given name from the NetworkContext's
// SSLSocketPool info dictionary.
int GetSSLSocketPoolInfo(NetworkContext* context, base::StringPiece name) {
int value;
context->url_request_context()
->http_transaction_factory()
->GetSession()
->GetSSLSocketPool(
net::HttpNetworkSession::SocketPoolType::NORMAL_SOCKET_POOL)
->GetInfoAsValue("", "", false)
->GetInteger(name, &value);
return value;
}
int GetSocketCount(NetworkContext* network_context) {
return GetSocketPoolInfo(network_context, "idle_socket_count") +
GetSocketPoolInfo(network_context, "connecting_socket_count") +
GetSocketPoolInfo(network_context, "handed_out_socket_count");
}
int GetSSLSocketCount(NetworkContext* network_context) {
return GetSSLSocketPoolInfo(network_context, "idle_socket_count") +
GetSSLSocketPoolInfo(network_context, "connecting_socket_count") +
GetSSLSocketPoolInfo(network_context, "handed_out_socket_count");
}
GURL GetHttpUrlFromHttps(const GURL& https_url) {
url::Replacements<char> replacements;
const char http[] = "http";
replacements.SetScheme(http, url::Component(0, strlen(http)));
return https_url.ReplaceComponents(replacements);
}
protected: protected:
base::test::ScopedTaskEnvironment scoped_task_environment_; base::test::ScopedTaskEnvironment scoped_task_environment_;
std::unique_ptr<NetworkService> network_service_; std::unique_ptr<NetworkService> network_service_;
...@@ -2499,7 +2533,9 @@ class ConnectionListener ...@@ -2499,7 +2533,9 @@ class ConnectionListener
: public net::test_server::EmbeddedTestServerConnectionListener { : public net::test_server::EmbeddedTestServerConnectionListener {
public: public:
ConnectionListener() ConnectionListener()
: task_runner_(base::ThreadTaskRunnerHandle::Get()), : total_sockets_seen_(0),
total_sockets_waited_for_(0),
task_runner_(base::ThreadTaskRunnerHandle::Get()),
num_accepted_connections_needed_(0), num_accepted_connections_needed_(0),
num_accepted_connections_loop_(nullptr) {} num_accepted_connections_loop_(nullptr) {}
...@@ -2513,6 +2549,7 @@ class ConnectionListener ...@@ -2513,6 +2549,7 @@ class ConnectionListener
EXPECT_TRUE(sockets_.find(socket) == sockets_.end()); EXPECT_TRUE(sockets_.find(socket) == sockets_.end());
sockets_[socket] = SOCKET_ACCEPTED; sockets_[socket] = SOCKET_ACCEPTED;
total_sockets_seen_++;
CheckAccepted(); CheckAccepted();
} }
...@@ -2529,7 +2566,7 @@ class ConnectionListener ...@@ -2529,7 +2566,7 @@ class ConnectionListener
base::RunLoop run_loop; base::RunLoop run_loop;
{ {
base::AutoLock lock(lock_); base::AutoLock lock(lock_);
EXPECT_GE(num_connections, sockets_.size()); EXPECT_GE(num_connections, sockets_.size() - total_sockets_waited_for_);
num_accepted_connections_loop_ = &run_loop; num_accepted_connections_loop_ = &run_loop;
num_accepted_connections_needed_ = num_connections; num_accepted_connections_needed_ = num_connections;
CheckAccepted(); CheckAccepted();
...@@ -2541,7 +2578,8 @@ class ConnectionListener ...@@ -2541,7 +2578,8 @@ class ConnectionListener
// Grab the mutex again and make sure that the number of accepted sockets is // Grab the mutex again and make sure that the number of accepted sockets is
// indeed |num_connections|. // indeed |num_connections|.
base::AutoLock lock(lock_); base::AutoLock lock(lock_);
EXPECT_EQ(num_connections, sockets_.size()); total_sockets_waited_for_ += num_connections;
EXPECT_EQ(total_sockets_seen_, total_sockets_waited_for_);
} }
// Helper function to stop the waiting for sockets to be accepted for // Helper function to stop the waiting for sockets to be accepted for
...@@ -2555,7 +2593,8 @@ class ConnectionListener ...@@ -2555,7 +2593,8 @@ class ConnectionListener
DCHECK(num_accepted_connections_loop_ || DCHECK(num_accepted_connections_loop_ ||
num_accepted_connections_needed_ == 0); num_accepted_connections_needed_ == 0);
if (!num_accepted_connections_loop_ || if (!num_accepted_connections_loop_ ||
num_accepted_connections_needed_ != sockets_.size()) { num_accepted_connections_needed_ !=
sockets_.size() - total_sockets_waited_for_) {
return; return;
} }
...@@ -2580,6 +2619,9 @@ class ConnectionListener ...@@ -2580,6 +2619,9 @@ class ConnectionListener
return address.port(); return address.port();
} }
int total_sockets_seen_;
int total_sockets_waited_for_;
enum SocketStatus { SOCKET_ACCEPTED, SOCKET_READ_FROM }; enum SocketStatus { SOCKET_ACCEPTED, SOCKET_READ_FROM };
scoped_refptr<base::SingleThreadTaskRunner> task_runner_; scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
...@@ -2613,6 +2655,39 @@ TEST_F(NetworkContextTest, PreconnectOne) { ...@@ -2613,6 +2655,39 @@ TEST_F(NetworkContextTest, PreconnectOne) {
connection_listener.WaitForAcceptedConnections(1u); connection_listener.WaitForAcceptedConnections(1u);
} }
TEST_F(NetworkContextTest, PreconnectHSTS) {
std::unique_ptr<NetworkContext> network_context =
CreateContextWithParams(CreateContextParams());
ConnectionListener connection_listener;
net::EmbeddedTestServer test_server(net::EmbeddedTestServer::TYPE_HTTPS);
test_server.SetConnectionListener(&connection_listener);
ASSERT_TRUE(test_server.Start());
const GURL server_http_url = GetHttpUrlFromHttps(test_server.base_url());
network_context->PreconnectSockets(1, server_http_url, net::LOAD_NORMAL,
true);
connection_listener.WaitForAcceptedConnections(1u);
int num_sockets = GetSocketCount(network_context.get());
EXPECT_EQ(num_sockets, 1);
int num_ssl_sockets = GetSSLSocketCount(network_context.get());
EXPECT_EQ(num_ssl_sockets, 0);
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
network_context->url_request_context()->transport_security_state()->AddHSTS(
server_http_url.host(), expiry, false);
network_context->PreconnectSockets(1, server_http_url, net::LOAD_NORMAL,
true);
connection_listener.WaitForAcceptedConnections(1u);
num_sockets = GetSocketCount(network_context.get());
EXPECT_EQ(num_sockets, 2);
num_ssl_sockets = GetSSLSocketCount(network_context.get());
EXPECT_EQ(num_ssl_sockets, 1);
}
TEST_F(NetworkContextTest, PreconnectZero) { TEST_F(NetworkContextTest, PreconnectZero) {
std::unique_ptr<NetworkContext> network_context = std::unique_ptr<NetworkContext> network_context =
CreateContextWithParams(CreateContextParams()); CreateContextWithParams(CreateContextParams());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment