Commit 612af1b6 authored by Lily Houghton's avatar Lily Houghton Committed by Commit Bot

Move predictor HSTS redirection to NetworkService

This CL moves the HSTS redirection call in the predictor's preconnect
method into the NetworkService (creating an internal analogue to
GetHSTSRedirectOnIOThread).

Bug: 821027
Cq-Include-Trybots: luci.chromium.try:linux_mojo
Change-Id: I5070de29aa219d2c038634c584b15c2a714a128f
Reviewed-on: https://chromium-review.googlesource.com/1152082Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Commit-Queue: Lily Houghton <lilyhoughton@chromium.org>
Cr-Commit-Position: refs/heads/master@{#579510}
parent 5124b6d3
......@@ -740,14 +740,11 @@ enum SubresourceValue {
};
void Predictor::PreconnectUrlOnIOThread(
const GURL& original_url,
const GURL& url,
const GURL& site_for_cookies,
UrlInfo::ResolutionMotivation motivation,
bool allow_credentials,
int count) {
// Skip the HSTS redirect.
GURL url = GetHSTSRedirectOnIOThread(original_url);
// TODO(csharrison): The observer should only be notified after the null check
// for the URLRequestContextGetter. The predictor tests should be fixed to
// allow for this, as they currently expect a callback with no getter.
......
......@@ -384,87 +384,6 @@ class TestPredictorObserver : public PredictorObserver {
std::vector<GURL> preconnected_urls_;
};
// Tests that preconnects apply the HSTS list.
TEST_F(PredictorTest, HSTSRedirect) {
const GURL kHttpUrl("http://example.com");
const GURL kHttpsUrl("https://example.com");
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
net::TransportSecurityState state;
state.AddHSTS(kHttpUrl.host(), expiry, false);
Predictor predictor(true);
TestPredictorObserver observer;
predictor.SetObserver(&observer);
predictor.SetTransportSecurityState(&state);
predictor.PreconnectUrl(kHttpUrl, GURL(), UrlInfo::OMNIBOX_MOTIVATED, true,
2);
ASSERT_EQ(1u, observer.preconnected_urls_.size());
EXPECT_EQ(kHttpsUrl, observer.preconnected_urls_[0]);
predictor.Shutdown();
}
// Tests that preconnecting a URL on the HSTS list preconnects the subresources
// for the SSL version.
TEST_F(PredictorTest, HSTSRedirectSubresources) {
const GURL kHttpUrl("http://example.com");
const GURL kHttpsUrl("https://example.com");
const GURL kSubresourceUrl("https://images.example.com");
const double kUseRate = 23.4;
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
net::TransportSecurityState state;
state.AddHSTS(kHttpUrl.host(), expiry, false);
SimplePredictor predictor(true);
TestPredictorObserver observer;
predictor.SetObserver(&observer);
predictor.SetTransportSecurityState(&state);
std::unique_ptr<base::ListValue> referral_list(NewEmptySerializationList());
AddToSerializedList(
kHttpsUrl, kSubresourceUrl, kUseRate, referral_list.get());
predictor.DeserializeReferrers(*referral_list.get());
predictor.PreconnectUrlAndSubresources(kHttpUrl, GURL());
ASSERT_EQ(2u, observer.preconnected_urls_.size());
EXPECT_EQ(kHttpsUrl, observer.preconnected_urls_[0]);
EXPECT_EQ(kSubresourceUrl, observer.preconnected_urls_[1]);
predictor.Shutdown();
}
TEST_F(PredictorTest, HSTSRedirectLearnedSubresource) {
const GURL kHttpUrl("http://example.com");
const GURL kHttpsUrl("https://example.com");
const GURL kSubresourceUrl("https://images.example.com");
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
net::TransportSecurityState state;
state.AddHSTS(kHttpUrl.host(), expiry, false);
SimplePredictor predictor(true);
TestPredictorObserver observer;
predictor.SetObserver(&observer);
predictor.SetTransportSecurityState(&state);
// Note that the predictor would also learn the HSTS redirect from kHttpUrl to
// kHttpsUrl during the navigation.
predictor.LearnFromNavigation(kHttpUrl, kSubresourceUrl);
predictor.PreconnectUrlAndSubresources(kHttpUrl, GURL());
ASSERT_EQ(2u, observer.preconnected_urls_.size());
EXPECT_EQ(kHttpsUrl, observer.preconnected_urls_[0]);
EXPECT_EQ(kSubresourceUrl, observer.preconnected_urls_[1]);
predictor.Shutdown();
}
TEST_F(PredictorTest, NoProxyService) {
// Don't actually try to resolve names.
Predictor::set_max_parallel_resolves(0);
......
......@@ -801,9 +801,11 @@ void NetworkContext::SetFailingHttpTransactionForTesting(
}
void NetworkContext::PreconnectSockets(uint32_t num_streams,
const GURL& url,
const GURL& original_url,
int32_t load_flags,
bool privacy_mode_enabled) {
GURL url = GetHSTSRedirect(original_url);
// |PreconnectSockets| may receive arguments from the renderer, which is not
// guaranteed to validate them.
if (num_streams == 0)
......@@ -1223,4 +1225,20 @@ URLRequestContextOwner NetworkContext::MakeURLRequestContext() {
return result;
}
GURL NetworkContext::GetHSTSRedirect(const GURL& original_url) {
// TODO(lilyhoughton) This needs to be gotten rid of once explicit
// construction with a URLRequestContext is no longer supported.
if (!url_request_context_->transport_security_state() ||
!original_url.SchemeIs("http") ||
!url_request_context_->transport_security_state()->ShouldUpgradeToSSL(
original_url.host())) {
return original_url;
}
url::Replacements<char> replacements;
const char kNewScheme[] = "https";
replacements.SetScheme(kNewScheme, url::Component(0, strlen(kNewScheme)));
return original_url.ReplaceComponents(replacements);
}
} // namespace network
......@@ -241,6 +241,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext
URLRequestContextOwner MakeURLRequestContext();
GURL GetHSTSRedirect(const GURL& original_url);
NetworkService* const network_service_;
std::unique_ptr<ResourceScheduler> resource_scheduler_;
......
......@@ -59,6 +59,7 @@
#include "net/proxy_resolution/proxy_config.h"
#include "net/proxy_resolution/proxy_info.h"
#include "net/proxy_resolution/proxy_resolution_service.h"
#include "net/socket/ssl_client_socket_pool.h"
#include "net/socket/transport_client_socket_pool.h"
#include "net/ssl/channel_id_service.h"
#include "net/ssl/channel_id_store.h"
......@@ -209,6 +210,39 @@ class NetworkContextTest : public testing::Test,
return value;
}
// Looks up a value with the given name from the NetworkContext's
// SSLSocketPool info dictionary.
int GetSSLSocketPoolInfo(NetworkContext* context, base::StringPiece name) {
int value;
context->url_request_context()
->http_transaction_factory()
->GetSession()
->GetSSLSocketPool(
net::HttpNetworkSession::SocketPoolType::NORMAL_SOCKET_POOL)
->GetInfoAsValue("", "", false)
->GetInteger(name, &value);
return value;
}
int GetSocketCount(NetworkContext* network_context) {
return GetSocketPoolInfo(network_context, "idle_socket_count") +
GetSocketPoolInfo(network_context, "connecting_socket_count") +
GetSocketPoolInfo(network_context, "handed_out_socket_count");
}
int GetSSLSocketCount(NetworkContext* network_context) {
return GetSSLSocketPoolInfo(network_context, "idle_socket_count") +
GetSSLSocketPoolInfo(network_context, "connecting_socket_count") +
GetSSLSocketPoolInfo(network_context, "handed_out_socket_count");
}
GURL GetHttpUrlFromHttps(const GURL& https_url) {
url::Replacements<char> replacements;
const char http[] = "http";
replacements.SetScheme(http, url::Component(0, strlen(http)));
return https_url.ReplaceComponents(replacements);
}
protected:
base::test::ScopedTaskEnvironment scoped_task_environment_;
std::unique_ptr<NetworkService> network_service_;
......@@ -2499,7 +2533,9 @@ class ConnectionListener
: public net::test_server::EmbeddedTestServerConnectionListener {
public:
ConnectionListener()
: task_runner_(base::ThreadTaskRunnerHandle::Get()),
: total_sockets_seen_(0),
total_sockets_waited_for_(0),
task_runner_(base::ThreadTaskRunnerHandle::Get()),
num_accepted_connections_needed_(0),
num_accepted_connections_loop_(nullptr) {}
......@@ -2513,6 +2549,7 @@ class ConnectionListener
EXPECT_TRUE(sockets_.find(socket) == sockets_.end());
sockets_[socket] = SOCKET_ACCEPTED;
total_sockets_seen_++;
CheckAccepted();
}
......@@ -2529,7 +2566,7 @@ class ConnectionListener
base::RunLoop run_loop;
{
base::AutoLock lock(lock_);
EXPECT_GE(num_connections, sockets_.size());
EXPECT_GE(num_connections, sockets_.size() - total_sockets_waited_for_);
num_accepted_connections_loop_ = &run_loop;
num_accepted_connections_needed_ = num_connections;
CheckAccepted();
......@@ -2541,7 +2578,8 @@ class ConnectionListener
// Grab the mutex again and make sure that the number of accepted sockets is
// indeed |num_connections|.
base::AutoLock lock(lock_);
EXPECT_EQ(num_connections, sockets_.size());
total_sockets_waited_for_ += num_connections;
EXPECT_EQ(total_sockets_seen_, total_sockets_waited_for_);
}
// Helper function to stop the waiting for sockets to be accepted for
......@@ -2555,7 +2593,8 @@ class ConnectionListener
DCHECK(num_accepted_connections_loop_ ||
num_accepted_connections_needed_ == 0);
if (!num_accepted_connections_loop_ ||
num_accepted_connections_needed_ != sockets_.size()) {
num_accepted_connections_needed_ !=
sockets_.size() - total_sockets_waited_for_) {
return;
}
......@@ -2580,6 +2619,9 @@ class ConnectionListener
return address.port();
}
int total_sockets_seen_;
int total_sockets_waited_for_;
enum SocketStatus { SOCKET_ACCEPTED, SOCKET_READ_FROM };
scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
......@@ -2613,6 +2655,39 @@ TEST_F(NetworkContextTest, PreconnectOne) {
connection_listener.WaitForAcceptedConnections(1u);
}
TEST_F(NetworkContextTest, PreconnectHSTS) {
std::unique_ptr<NetworkContext> network_context =
CreateContextWithParams(CreateContextParams());
ConnectionListener connection_listener;
net::EmbeddedTestServer test_server(net::EmbeddedTestServer::TYPE_HTTPS);
test_server.SetConnectionListener(&connection_listener);
ASSERT_TRUE(test_server.Start());
const GURL server_http_url = GetHttpUrlFromHttps(test_server.base_url());
network_context->PreconnectSockets(1, server_http_url, net::LOAD_NORMAL,
true);
connection_listener.WaitForAcceptedConnections(1u);
int num_sockets = GetSocketCount(network_context.get());
EXPECT_EQ(num_sockets, 1);
int num_ssl_sockets = GetSSLSocketCount(network_context.get());
EXPECT_EQ(num_ssl_sockets, 0);
const base::Time expiry =
base::Time::Now() + base::TimeDelta::FromSeconds(1000);
network_context->url_request_context()->transport_security_state()->AddHSTS(
server_http_url.host(), expiry, false);
network_context->PreconnectSockets(1, server_http_url, net::LOAD_NORMAL,
true);
connection_listener.WaitForAcceptedConnections(1u);
num_sockets = GetSocketCount(network_context.get());
EXPECT_EQ(num_sockets, 2);
num_ssl_sockets = GetSSLSocketCount(network_context.get());
EXPECT_EQ(num_ssl_sockets, 1);
}
TEST_F(NetworkContextTest, PreconnectZero) {
std::unique_ptr<NetworkContext> network_context =
CreateContextWithParams(CreateContextParams());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment