Commit 166710e0 authored by Matt Menke's avatar Matt Menke Committed by Commit Bot

Make SocksConnectJob pass NetworkIsolationKey to the HostResolver.

Also pass the NIK to SocksConnectJob through its parameters.

This only affects the DNS resolution of the destination's hostname, not
the SOCKS4 proxy's hostname, which continues to not use the NIK.

Bug: 997049
Change-Id: If0cc74bb28d82f4f6929ccb12f15691e22c40589
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1898389
Commit-Queue: Matt Menke <mmenke@chromium.org>
Reviewed-by: default avatarEric Orth <ericorth@chromium.org>
Cr-Commit-Position: refs/heads/master@{#712889}
parent 8215f1b5
...@@ -399,10 +399,9 @@ void MockHostResolverBase::DetachRequest(size_t id) { ...@@ -399,10 +399,9 @@ void MockHostResolverBase::DetachRequest(size_t id) {
requests_.erase(it); requests_.erase(it);
} }
MockHostResolverBase::RequestImpl* MockHostResolverBase::request(size_t id) { const std::string& MockHostResolverBase::request_host(size_t id) {
RequestMap::iterator request = requests_.find(id); DCHECK(request(id));
DCHECK(request != requests_.end()); return request(id)->request_host().host();
return (*request).second;
} }
RequestPriority MockHostResolverBase::request_priority(size_t id) { RequestPriority MockHostResolverBase::request_priority(size_t id) {
...@@ -464,6 +463,12 @@ void MockHostResolverBase::TriggerMdnsListeners( ...@@ -464,6 +463,12 @@ void MockHostResolverBase::TriggerMdnsListeners(
} }
} }
MockHostResolverBase::RequestImpl* MockHostResolverBase::request(size_t id) {
RequestMap::iterator request = requests_.find(id);
DCHECK(request != requests_.end());
return (*request).second;
}
// start id from 1 to distinguish from NULL RequestHandle // start id from 1 to distinguish from NULL RequestHandle
MockHostResolverBase::MockHostResolverBase(bool use_caching, MockHostResolverBase::MockHostResolverBase(bool use_caching,
int cache_invalidation_num) int cache_invalidation_num)
......
...@@ -156,8 +156,8 @@ class MockHostResolverBase ...@@ -156,8 +156,8 @@ class MockHostResolverBase
// Detach cancelled request. // Detach cancelled request.
void DetachRequest(size_t id); void DetachRequest(size_t id);
// Returns the request with the given id. // Returns the hostname of the request with the given id.
RequestImpl* request(size_t id); const std::string& request_host(size_t id);
// Returns the priority of the request with the given id. // Returns the priority of the request with the given id.
RequestPriority request_priority(size_t id); RequestPriority request_priority(size_t id);
...@@ -229,6 +229,9 @@ class MockHostResolverBase ...@@ -229,6 +229,9 @@ class MockHostResolverBase
typedef std::map<size_t, RequestImpl*> RequestMap; typedef std::map<size_t, RequestImpl*> RequestMap;
// Returns the request with the given id.
RequestImpl* request(size_t id);
// If > 0, |cache_invalidation_num| is the number of times a cached entry can // If > 0, |cache_invalidation_num| is the number of times a cached entry can
// be read before it invalidates itself. Useful to force cache expiration // be read before it invalidates itself. Useful to force cache expiration
// scenarios. // scenarios.
......
...@@ -146,7 +146,7 @@ std::unique_ptr<ConnectJob> ConnectJob::CreateConnectJob( ...@@ -146,7 +146,7 @@ std::unique_ptr<ConnectJob> ConnectJob::CreateConnectJob(
socks_params = base::MakeRefCounted<SOCKSSocketParams>( socks_params = base::MakeRefCounted<SOCKSSocketParams>(
std::move(proxy_tcp_params), std::move(proxy_tcp_params),
proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5, endpoint, proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5, endpoint,
*proxy_annotation_tag); network_isolation_key, *proxy_annotation_tag);
} }
} }
......
...@@ -62,6 +62,7 @@ static_assert(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, ...@@ -62,6 +62,7 @@ static_assert(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
SOCKSClientSocket::SOCKSClientSocket( SOCKSClientSocket::SOCKSClientSocket(
std::unique_ptr<StreamSocket> transport_socket, std::unique_ptr<StreamSocket> transport_socket,
const HostPortPair& destination, const HostPortPair& destination,
const NetworkIsolationKey& network_isolation_key,
RequestPriority priority, RequestPriority priority,
HostResolver* host_resolver, HostResolver* host_resolver,
bool disable_secure_dns, bool disable_secure_dns,
...@@ -75,6 +76,7 @@ SOCKSClientSocket::SOCKSClientSocket( ...@@ -75,6 +76,7 @@ SOCKSClientSocket::SOCKSClientSocket(
host_resolver_(host_resolver), host_resolver_(host_resolver),
disable_secure_dns_(disable_secure_dns), disable_secure_dns_(disable_secure_dns),
destination_(destination), destination_(destination),
network_isolation_key_(network_isolation_key),
priority_(priority), priority_(priority),
net_log_(transport_socket_->NetLog()), net_log_(transport_socket_->NetLog()),
traffic_annotation_(traffic_annotation) {} traffic_annotation_(traffic_annotation) {}
...@@ -309,8 +311,8 @@ int SOCKSClientSocket::DoResolveHost() { ...@@ -309,8 +311,8 @@ int SOCKSClientSocket::DoResolveHost() {
parameters.initial_priority = priority_; parameters.initial_priority = priority_;
if (disable_secure_dns_) if (disable_secure_dns_)
parameters.secure_dns_mode_override = DnsConfig::SecureDnsMode::OFF; parameters.secure_dns_mode_override = DnsConfig::SecureDnsMode::OFF;
resolve_host_request_ = resolve_host_request_ = host_resolver_->CreateRequest(
host_resolver_->CreateRequest(destination_, net_log_, parameters); destination_, network_isolation_key_, net_log_, parameters);
return resolve_host_request_->Start( return resolve_host_request_->Start(
base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
......
...@@ -31,8 +31,10 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { ...@@ -31,8 +31,10 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
public: public:
// |destination| contains the hostname and port to which the socket above will // |destination| contains the hostname and port to which the socket above will
// communicate to via the socks layer. For testing the referrer is optional. // communicate to via the socks layer. For testing the referrer is optional.
// |network_isolation_key| is used for host resolution.
SOCKSClientSocket(std::unique_ptr<StreamSocket> transport_socket, SOCKSClientSocket(std::unique_ptr<StreamSocket> transport_socket,
const HostPortPair& destination, const HostPortPair& destination,
const NetworkIsolationKey& network_isolation_key,
RequestPriority priority, RequestPriority priority,
HostResolver* host_resolver, HostResolver* host_resolver,
bool disable_secure_dns, bool disable_secure_dns,
...@@ -140,6 +142,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { ...@@ -140,6 +142,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
bool disable_secure_dns_; bool disable_secure_dns_;
std::unique_ptr<HostResolver::ResolveHostRequest> resolve_host_request_; std::unique_ptr<HostResolver::ResolveHostRequest> resolve_host_request_;
const HostPortPair destination_; const HostPortPair destination_;
const NetworkIsolationKey network_isolation_key_;
RequestPriority priority_; RequestPriority priority_;
NetLogWithSource net_log_; NetLogWithSource net_log_;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "base/logging.h" #include "base/logging.h"
#include "net/base/address_list.h" #include "net/base/address_list.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/base/network_isolation_key.h"
#include "net/base/test_completion_callback.h" #include "net/base/test_completion_callback.h"
#include "net/dns/host_resolver.h" #include "net/dns/host_resolver.h"
#include "net/dns/mock_host_resolver.h" #include "net/dns/mock_host_resolver.h"
...@@ -56,7 +57,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { ...@@ -56,7 +57,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
net::SOCKSClientSocket socket( net::SOCKSClientSocket socket(
std::move(fuzzed_socket), net::HostPortPair("foo", 80), std::move(fuzzed_socket), net::HostPortPair("foo", 80),
net::DEFAULT_PRIORITY, &mock_host_resolver, net::NetworkIsolationKey(), net::DEFAULT_PRIORITY, &mock_host_resolver,
false /* disable_secure_dns */, TRAFFIC_ANNOTATION_FOR_TESTS); false /* disable_secure_dns */, TRAFFIC_ANNOTATION_FOR_TESTS);
int result = socket.Connect(callback.callback()); int result = socket.Connect(callback.callback());
callback.GetResult(result); callback.GetResult(result);
......
...@@ -94,8 +94,8 @@ std::unique_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( ...@@ -94,8 +94,8 @@ std::unique_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
// non-owning pointer to it. // non-owning pointer to it.
tcp_sock_ = socket.get(); tcp_sock_ = socket.get();
return std::make_unique<SOCKSClientSocket>( return std::make_unique<SOCKSClientSocket>(
std::move(socket), HostPortPair(hostname, port), DEFAULT_PRIORITY, std::move(socket), HostPortPair(hostname, port), NetworkIsolationKey(),
host_resolver, false /* disable_secure_dns */, DEFAULT_PRIORITY, host_resolver, false /* disable_secure_dns */,
TRAFFIC_ANNOTATION_FOR_TESTS); TRAFFIC_ANNOTATION_FOR_TESTS);
} }
...@@ -435,8 +435,9 @@ TEST_F(SOCKSClientSocketTest, Tag) { ...@@ -435,8 +435,9 @@ TEST_F(SOCKSClientSocketTest, Tag) {
// non-owning pointer to it. // non-owning pointer to it.
MockHostResolver host_resolver; MockHostResolver host_resolver;
SOCKSClientSocket socket(std::unique_ptr<StreamSocket>(tagging_sock), SOCKSClientSocket socket(std::unique_ptr<StreamSocket>(tagging_sock),
HostPortPair("localhost", 80), DEFAULT_PRIORITY, HostPortPair("localhost", 80), NetworkIsolationKey(),
&host_resolver, false /* disable_secure_dns */, DEFAULT_PRIORITY, &host_resolver,
false /* disable_secure_dns */,
TRAFFIC_ANNOTATION_FOR_TESTS); TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_EQ(tagging_sock->tag(), SocketTag()); EXPECT_EQ(tagging_sock->tag(), SocketTag());
...@@ -454,8 +455,8 @@ TEST_F(SOCKSClientSocketTest, SetDisableSecureDns) { ...@@ -454,8 +455,8 @@ TEST_F(SOCKSClientSocketTest, SetDisableSecureDns) {
MockHostResolver host_resolver; MockHostResolver host_resolver;
SOCKSClientSocket socket( SOCKSClientSocket socket(
std::make_unique<MockTCPClientSocket>(address_list_, &log, &data), std::make_unique<MockTCPClientSocket>(address_list_, &log, &data),
HostPortPair("localhost", 80), DEFAULT_PRIORITY, &host_resolver, HostPortPair("localhost", 80), NetworkIsolationKey(), DEFAULT_PRIORITY,
disable_secure_dns, TRAFFIC_ANNOTATION_FOR_TESTS); &host_resolver, disable_secure_dns, TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_EQ(ERR_IO_PENDING, socket.Connect(callback_.callback())); EXPECT_EQ(ERR_IO_PENDING, socket.Connect(callback_.callback()));
EXPECT_EQ(disable_secure_dns, EXPECT_EQ(disable_secure_dns,
......
...@@ -26,10 +26,12 @@ SOCKSSocketParams::SOCKSSocketParams( ...@@ -26,10 +26,12 @@ SOCKSSocketParams::SOCKSSocketParams(
scoped_refptr<TransportSocketParams> proxy_server_params, scoped_refptr<TransportSocketParams> proxy_server_params,
bool socks_v5, bool socks_v5,
const HostPortPair& host_port_pair, const HostPortPair& host_port_pair,
const NetworkIsolationKey& network_isolation_key,
const NetworkTrafficAnnotationTag& traffic_annotation) const NetworkTrafficAnnotationTag& traffic_annotation)
: transport_params_(std::move(proxy_server_params)), : transport_params_(std::move(proxy_server_params)),
destination_(host_port_pair), destination_(host_port_pair),
socks_v5_(socks_v5), socks_v5_(socks_v5),
network_isolation_key_(network_isolation_key),
traffic_annotation_(traffic_annotation) {} traffic_annotation_(traffic_annotation) {}
SOCKSSocketParams::~SOCKSSocketParams() = default; SOCKSSocketParams::~SOCKSSocketParams() = default;
...@@ -165,7 +167,7 @@ int SOCKSConnectJob::DoSOCKSConnect() { ...@@ -165,7 +167,7 @@ int SOCKSConnectJob::DoSOCKSConnect() {
} else { } else {
socket_.reset(new SOCKSClientSocket( socket_.reset(new SOCKSClientSocket(
transport_connect_job_->PassSocket(), socks_params_->destination(), transport_connect_job_->PassSocket(), socks_params_->destination(),
priority(), host_resolver(), socks_params_->network_isolation_key(), priority(), host_resolver(),
socks_params_->transport_params()->disable_secure_dns(), socks_params_->transport_params()->disable_secure_dns(),
socks_params_->traffic_annotation())); socks_params_->traffic_annotation()));
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "net/base/completion_once_callback.h" #include "net/base/completion_once_callback.h"
#include "net/base/host_port_pair.h" #include "net/base/host_port_pair.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/base/network_isolation_key.h"
#include "net/base/request_priority.h" #include "net/base/request_priority.h"
#include "net/socket/connect_job.h" #include "net/socket/connect_job.h"
#include "net/traffic_annotation/network_traffic_annotation.h" #include "net/traffic_annotation/network_traffic_annotation.h"
...@@ -30,6 +31,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams ...@@ -30,6 +31,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams
SOCKSSocketParams(scoped_refptr<TransportSocketParams> proxy_server_params, SOCKSSocketParams(scoped_refptr<TransportSocketParams> proxy_server_params,
bool socks_v5, bool socks_v5,
const HostPortPair& host_port_pair, const HostPortPair& host_port_pair,
const NetworkIsolationKey& network_isolation_key,
const NetworkTrafficAnnotationTag& traffic_annotation); const NetworkTrafficAnnotationTag& traffic_annotation);
const scoped_refptr<TransportSocketParams>& transport_params() const { const scoped_refptr<TransportSocketParams>& transport_params() const {
...@@ -37,6 +39,9 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams ...@@ -37,6 +39,9 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams
} }
const HostPortPair& destination() const { return destination_; } const HostPortPair& destination() const { return destination_; }
bool is_socks_v5() const { return socks_v5_; } bool is_socks_v5() const { return socks_v5_; }
const NetworkIsolationKey& network_isolation_key() {
return network_isolation_key_;
}
const NetworkTrafficAnnotationTag traffic_annotation() { const NetworkTrafficAnnotationTag traffic_annotation() {
return traffic_annotation_; return traffic_annotation_;
...@@ -51,6 +56,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams ...@@ -51,6 +56,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams
// This is the HTTP destination. // This is the HTTP destination.
const HostPortPair destination_; const HostPortPair destination_;
const bool socks_v5_; const bool socks_v5_;
const NetworkIsolationKey network_isolation_key_;
NetworkTrafficAnnotationTag traffic_annotation_; NetworkTrafficAnnotationTag traffic_annotation_;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "net/base/load_timing_info.h" #include "net/base/load_timing_info.h"
#include "net/base/load_timing_info_test_util.h" #include "net/base/load_timing_info_test_util.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/base/network_isolation_key.h"
#include "net/dns/mock_host_resolver.h" #include "net/dns/mock_host_resolver.h"
#include "net/log/net_log.h" #include "net/log/net_log.h"
#include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_factory.h"
...@@ -75,7 +76,7 @@ class SOCKSConnectJobTest : public testing::Test, public WithTaskEnvironment { ...@@ -75,7 +76,7 @@ class SOCKSConnectJobTest : public testing::Test, public WithTaskEnvironment {
socks_version == SOCKSVersion::V4 socks_version == SOCKSVersion::V4
? HostPortPair(kSOCKS4TestHost, kSOCKS4TestPort) ? HostPortPair(kSOCKS4TestHost, kSOCKS4TestPort)
: HostPortPair(kSOCKS5TestHost, kSOCKS5TestPort), : HostPortPair(kSOCKS5TestHost, kSOCKS5TestPort),
TRAFFIC_ANNOTATION_FOR_TESTS); NetworkIsolationKey(), TRAFFIC_ANNOTATION_FOR_TESTS);
} }
protected: protected:
......
...@@ -95,6 +95,7 @@ class SSLConnectJobTest : public WithTaskEnvironment, public testing::Test { ...@@ -95,6 +95,7 @@ class SSLConnectJobTest : public WithTaskEnvironment, public testing::Test {
new SOCKSSocketParams(proxy_transport_socket_params_, new SOCKSSocketParams(proxy_transport_socket_params_,
true, true,
HostPortPair("sockshost", 443), HostPortPair("sockshost", 443),
NetworkIsolationKey(),
TRAFFIC_ANNOTATION_FOR_TESTS)), TRAFFIC_ANNOTATION_FOR_TESTS)),
http_proxy_socket_params_( http_proxy_socket_params_(
new HttpProxySocketParams(proxy_transport_socket_params_, new HttpProxySocketParams(proxy_transport_socket_params_,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment