Commit cdc002cd authored by Kartik Hegde's avatar Kartik Hegde Committed by Commit Bot

network_diagnostics: Add TLS upgrade parameter

Some network diagnostics need to simply validate a successful TCP
connection. Add a construction parameter to make the TLS upgrade
optional.

BUG=b/172994051
TEST=(1) unit_tests --gtest_filter=TlsProberWithFakeNetworkContextTest.*
(2) unit_tests --gtest_filter=TlsProberWithRealNetworkContextTest.*

Change-Id: I51a0f40ea5727cec8c84975fcacca64d245d2609
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2530460
Commit-Queue: Kartik Hegde <khegde@chromium.org>
Reviewed-by: default avatarSteven Bennetts <stevenjb@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Cr-Commit-Position: refs/heads/master@{#827644}
parent 285191ce
...@@ -83,13 +83,15 @@ void FakeNetworkContext::SetTCPConnectCode( ...@@ -83,13 +83,15 @@ void FakeNetworkContext::SetTCPConnectCode(
base::Optional<net::Error>& tcp_connect_code) { base::Optional<net::Error>& tcp_connect_code) {
if (tcp_connect_code.has_value()) { if (tcp_connect_code.has_value()) {
tcp_connect_code_ = tcp_connect_code.value(); tcp_connect_code_ = tcp_connect_code.value();
fake_tcp_connected_socket_ = std::make_unique<FakeTCPConnectedSocket>();
} }
} }
void FakeNetworkContext::SetTLSUpgradeCode( void FakeNetworkContext::SetTLSUpgradeCode(
base::Optional<net::Error>& tls_upgrade_code) { base::Optional<net::Error>& tls_upgrade_code) {
if (tls_upgrade_code.has_value()) { if (tls_upgrade_code.has_value()) {
fake_tcp_connected_socket_ = std::make_unique<FakeTCPConnectedSocket>(); DCHECK(fake_tcp_connected_socket_);
fake_tcp_connected_socket_->set_tls_upgrade_code(tls_upgrade_code.value()); fake_tcp_connected_socket_->set_tls_upgrade_code(tls_upgrade_code.value());
} }
} }
......
...@@ -25,13 +25,6 @@ HostResolver::ResolutionResult::ResolutionResult( ...@@ -25,13 +25,6 @@ HostResolver::ResolutionResult::ResolutionResult(
HostResolver::ResolutionResult::~ResolutionResult() = default; HostResolver::ResolutionResult::~ResolutionResult() = default;
HostResolver::HostResolver(const GURL& url,
network::mojom::NetworkContext* network_context,
OnResolutionComplete callback)
: HostResolver(net::HostPortPair::FromURL(url),
network_context,
std::move(callback)) {}
HostResolver::HostResolver(const net::HostPortPair& host_port_pair, HostResolver::HostResolver(const net::HostPortPair& host_port_pair,
network::mojom::NetworkContext* network_context, network::mojom::NetworkContext* network_context,
OnResolutionComplete callback) OnResolutionComplete callback)
......
...@@ -33,12 +33,6 @@ class HostResolver : public network::ResolveHostClientBase { ...@@ -33,12 +33,6 @@ class HostResolver : public network::ResolveHostClientBase {
}; };
using OnResolutionComplete = base::OnceCallback<void(ResolutionResult&)>; using OnResolutionComplete = base::OnceCallback<void(ResolutionResult&)>;
// Performs the DNS resolution of a specified |url|. Note that |callback|
// will not be called until construction is complete.
HostResolver(const GURL& url,
network::mojom::NetworkContext* network_context,
OnResolutionComplete callback);
// Performs the DNS resolution of a specified |host_port_pair|. Note that // Performs the DNS resolution of a specified |host_port_pair|. Note that
// |callback| will not be called until construction is complete. // |callback| will not be called until construction is complete.
HostResolver(const net::HostPortPair& host_port_pair, HostResolver(const net::HostPortPair& host_port_pair,
......
...@@ -38,7 +38,6 @@ class HostResolverTest : public ::testing::Test { ...@@ -38,7 +38,6 @@ class HostResolverTest : public ::testing::Test {
protected: protected:
const net::HostPortPair kFakeHostPortPair = const net::HostPortPair kFakeHostPortPair =
net::HostPortPair::FromString("fake_stun_server.com:80"); net::HostPortPair::FromString("fake_stun_server.com:80");
const GURL kFakeUrl{"https://www.FAKE_HOST_NAME.com:1234/"};
const net::IPEndPoint kFakeIPAddress{ const net::IPEndPoint kFakeIPAddress{
net::IPEndPoint(net::IPAddress::IPv4Localhost(), /*port=*/1234)}; net::IPEndPoint(net::IPAddress::IPv4Localhost(), /*port=*/1234)};
std::unique_ptr<HostResolver> host_resolver_; std::unique_ptr<HostResolver> host_resolver_;
...@@ -48,39 +47,7 @@ class HostResolverTest : public ::testing::Test { ...@@ -48,39 +47,7 @@ class HostResolverTest : public ::testing::Test {
FakeNetworkContext fake_network_context_; FakeNetworkContext fake_network_context_;
}; };
TEST_F(HostResolverTest, TestSuccessfulResolutionWithUrl) { TEST_F(HostResolverTest, TestSuccessfulResolution) {
auto address_list = net::AddressList(kFakeIPAddress);
auto fake_dns_result = std::make_unique<FakeHostResolver::DnsResult>(
net::OK, net::ResolveErrorInfo(net::OK), address_list);
InitializeNetworkContext(std::move(fake_dns_result));
HostResolver::ResolutionResult resolution_result{
net::ERR_FAILED, net::ResolveErrorInfo(net::OK), base::nullopt};
base::RunLoop run_loop;
host_resolver_ = std::make_unique<HostResolver>(
kFakeUrl, fake_network_context(),
base::BindOnce(
[](HostResolver::ResolutionResult* resolution_result,
base::OnceClosure quit_closure,
HostResolver::ResolutionResult& res_result) {
resolution_result->result = res_result.result;
resolution_result->resolve_error_info =
res_result.resolve_error_info;
resolution_result->resolved_addresses =
res_result.resolved_addresses;
std::move(quit_closure).Run();
},
&resolution_result, run_loop.QuitClosure()));
run_loop.Run();
EXPECT_EQ(resolution_result.result, net::OK);
EXPECT_EQ(resolution_result.resolve_error_info,
net::ResolveErrorInfo(net::OK));
EXPECT_EQ(resolution_result.resolved_addresses.value().size(), 1);
EXPECT_EQ(resolution_result.resolved_addresses.value().front(),
address_list.front());
}
TEST_F(HostResolverTest, TestSuccessfulResolutionWithHostPortPair) {
auto address_list = net::AddressList(kFakeIPAddress); auto address_list = net::AddressList(kFakeIPAddress);
auto fake_dns_result = std::make_unique<FakeHostResolver::DnsResult>( auto fake_dns_result = std::make_unique<FakeHostResolver::DnsResult>(
net::OK, net::ResolveErrorInfo(net::OK), address_list); net::OK, net::ResolveErrorInfo(net::OK), address_list);
......
...@@ -111,7 +111,9 @@ void HttpsFirewallRoutine::ProbeNextUrl() { ...@@ -111,7 +111,9 @@ void HttpsFirewallRoutine::ProbeNextUrl() {
void HttpsFirewallRoutine::AttemptProbe(const GURL& url) { void HttpsFirewallRoutine::AttemptProbe(const GURL& url) {
// Store the instance of TlsProber. // Store the instance of TlsProber.
tls_prober_ = tls_prober_getter_callback_.Run( tls_prober_ = tls_prober_getter_callback_.Run(
base::BindRepeating(&HttpsFirewallRoutine::GetNetworkContext), url, base::BindRepeating(&HttpsFirewallRoutine::GetNetworkContext),
net::HostPortPair::FromURL(url),
/*negotiate_tls=*/true,
base::BindOnce(&HttpsFirewallRoutine::OnProbeComplete, weak_ptr(), url)); base::BindOnce(&HttpsFirewallRoutine::OnProbeComplete, weak_ptr(), url));
} }
...@@ -150,9 +152,11 @@ network::mojom::NetworkContext* HttpsFirewallRoutine::GetNetworkContext() { ...@@ -150,9 +152,11 @@ network::mojom::NetworkContext* HttpsFirewallRoutine::GetNetworkContext() {
std::unique_ptr<TlsProber> HttpsFirewallRoutine::CreateAndExecuteTlsProber( std::unique_ptr<TlsProber> HttpsFirewallRoutine::CreateAndExecuteTlsProber(
TlsProber::NetworkContextGetter network_context_getter, TlsProber::NetworkContextGetter network_context_getter,
const GURL& url, net::HostPortPair host_port_pair,
bool negotiate_tls,
TlsProber::TlsProbeCompleteCallback callback) { TlsProber::TlsProbeCompleteCallback callback) {
return std::make_unique<TlsProber>(std::move(network_context_getter), url, return std::make_unique<TlsProber>(std::move(network_context_getter),
host_port_pair, negotiate_tls,
std::move(callback)); std::move(callback));
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "chrome/browser/chromeos/net/network_diagnostics/network_diagnostics_routine.h" #include "chrome/browser/chromeos/net/network_diagnostics/network_diagnostics_routine.h"
#include "chrome/browser/chromeos/net/network_diagnostics/tls_prober.h" #include "chrome/browser/chromeos/net/network_diagnostics/tls_prober.h"
#include "url/gurl.h" #include "net/base/host_port_pair.h"
namespace network { namespace network {
namespace mojom { namespace mojom {
...@@ -34,7 +34,8 @@ class HttpsFirewallRoutine : public NetworkDiagnosticsRoutine { ...@@ -34,7 +34,8 @@ class HttpsFirewallRoutine : public NetworkDiagnosticsRoutine {
using TlsProberGetterCallback = using TlsProberGetterCallback =
base::RepeatingCallback<std::unique_ptr<TlsProber>( base::RepeatingCallback<std::unique_ptr<TlsProber>(
TlsProber::NetworkContextGetter network_context_getter, TlsProber::NetworkContextGetter network_context_getter,
const GURL& url, net::HostPortPair host_port_pair,
bool negotiate_tls,
TlsProber::TlsProbeCompleteCallback callback)>; TlsProber::TlsProbeCompleteCallback callback)>;
HttpsFirewallRoutine(); HttpsFirewallRoutine();
...@@ -74,7 +75,8 @@ class HttpsFirewallRoutine : public NetworkDiagnosticsRoutine { ...@@ -74,7 +75,8 @@ class HttpsFirewallRoutine : public NetworkDiagnosticsRoutine {
// Creates an instance of TlsProber. // Creates an instance of TlsProber.
static std::unique_ptr<TlsProber> CreateAndExecuteTlsProber( static std::unique_ptr<TlsProber> CreateAndExecuteTlsProber(
TlsProber::NetworkContextGetter network_context_getter, TlsProber::NetworkContextGetter network_context_getter,
const GURL& url, net::HostPortPair host_port_pair,
bool negotiate_tls,
TlsProber::TlsProbeCompleteCallback callback); TlsProber::TlsProbeCompleteCallback callback);
// Returns the weak pointer to |this|. // Returns the weak pointer to |this|.
......
...@@ -117,7 +117,8 @@ class HttpsFirewallRoutineTest : public ::testing::Test { ...@@ -117,7 +117,8 @@ class HttpsFirewallRoutineTest : public ::testing::Test {
std::unique_ptr<TlsProber> CreateAndExecuteTlsProber( std::unique_ptr<TlsProber> CreateAndExecuteTlsProber(
TlsProber::NetworkContextGetter network_context_getter, TlsProber::NetworkContextGetter network_context_getter,
const GURL& url, net::HostPortPair host_port_pair,
bool negotiate_tls,
TlsProber::TlsProbeCompleteCallback callback) { TlsProber::TlsProbeCompleteCallback callback) {
DCHECK(fake_probe_results_.size() > 0); DCHECK(fake_probe_results_.size() > 0);
......
...@@ -55,26 +55,29 @@ net::NetworkTrafficAnnotationTag GetTrafficAnnotationTag() { ...@@ -55,26 +55,29 @@ net::NetworkTrafficAnnotationTag GetTrafficAnnotationTag() {
} // namespace } // namespace
TlsProber::TlsProber(NetworkContextGetter network_context_getter, TlsProber::TlsProber(NetworkContextGetter network_context_getter,
const GURL& url, net::HostPortPair host_port_pair,
bool negotiate_tls,
TlsProbeCompleteCallback callback) TlsProbeCompleteCallback callback)
: network_context_getter_(std::move(network_context_getter)), : network_context_getter_(std::move(network_context_getter)),
url_(url), host_port_pair_(host_port_pair),
negotiate_tls_(negotiate_tls),
callback_(std::move(callback)) { callback_(std::move(callback)) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
DCHECK(callback_); DCHECK(callback_);
DCHECK(url_.is_valid()); DCHECK(!host_port_pair_.IsEmpty());
network::mojom::NetworkContext* network_context = network::mojom::NetworkContext* network_context =
network_context_getter_.Run(); network_context_getter_.Run();
DCHECK(network_context); DCHECK(network_context);
host_resolver_ = std::make_unique<HostResolver>( host_resolver_ = std::make_unique<HostResolver>(
net::HostPortPair::FromURL(url_), network_context, host_port_pair, network_context,
base::BindOnce(&TlsProber::OnHostResolutionComplete, base::BindOnce(&TlsProber::OnHostResolutionComplete,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
} }
TlsProber::TlsProber() = default; TlsProber::TlsProber()
: network_context_getter_(base::NullCallback()), negotiate_tls_(false) {}
TlsProber::~TlsProber() = default; TlsProber::~TlsProber() = default;
...@@ -126,6 +129,11 @@ void TlsProber::OnConnectComplete( ...@@ -126,6 +129,11 @@ void TlsProber::OnConnectComplete(
OnDone(result, ProbeExitEnum::kTcpConnectionFailure); OnDone(result, ProbeExitEnum::kTcpConnectionFailure);
return; return;
} }
if (!negotiate_tls_) {
OnDone(result, ProbeExitEnum::kSuccess);
return;
}
DCHECK(peer_addr.has_value()); DCHECK(peer_addr.has_value());
auto pending_receiver = auto pending_receiver =
...@@ -138,7 +146,7 @@ void TlsProber::OnConnectComplete( ...@@ -138,7 +146,7 @@ void TlsProber::OnConnectComplete(
tls_client_socket_remote_.set_disconnect_handler( tls_client_socket_remote_.set_disconnect_handler(
base::BindOnce(&TlsProber::OnDisconnect, weak_factory_.GetWeakPtr())); base::BindOnce(&TlsProber::OnDisconnect, weak_factory_.GetWeakPtr()));
tcp_connected_socket_remote_->UpgradeToTLS( tcp_connected_socket_remote_->UpgradeToTLS(
net::HostPortPair::FromURL(url_), host_port_pair_,
/*options=*/nullptr, /*options=*/nullptr,
net::MutableNetworkTrafficAnnotationTag(GetTrafficAnnotationTag()), net::MutableNetworkTrafficAnnotationTag(GetTrafficAnnotationTag()),
std::move(pending_receiver), std::move(pending_receiver),
......
...@@ -14,17 +14,17 @@ ...@@ -14,17 +14,17 @@
#include "chrome/browser/chromeos/net/network_diagnostics/host_resolver.h" #include "chrome/browser/chromeos/net/network_diagnostics/host_resolver.h"
#include "mojo/public/cpp/bindings/remote.h" #include "mojo/public/cpp/bindings/remote.h"
#include "net/base/completion_once_callback.h" #include "net/base/completion_once_callback.h"
#include "net/base/host_port_pair.h"
#include "services/network/public/mojom/network_context.mojom.h" #include "services/network/public/mojom/network_context.mojom.h"
#include "services/network/public/mojom/tcp_socket.mojom.h" #include "services/network/public/mojom/tcp_socket.mojom.h"
#include "services/network/public/mojom/tls_socket.mojom.h" #include "services/network/public/mojom/tls_socket.mojom.h"
#include "url/gurl.h"
namespace chromeos { namespace chromeos {
namespace network_diagnostics { namespace network_diagnostics {
// Uses a TLS socket to determine whether a secure socket connection to a host // Uses either a TCP or TLS socket to determine whether a socket connection to a
// can be established. No read or write functionality is exposed by this socket. // host can be established. No read or write functionality is exposed by this
// Used by network diagnostics routines. // socket. Used by network diagnostics routines.
class TlsProber { class TlsProber {
public: public:
// Lists the ways a prober may end. The callback passed into the prober's // Lists the ways a prober may end. The callback passed into the prober's
...@@ -43,14 +43,16 @@ class TlsProber { ...@@ -43,14 +43,16 @@ class TlsProber {
using TlsProbeCompleteCallback = using TlsProbeCompleteCallback =
base::OnceCallback<void(int result, ProbeExitEnum probe_exit_enum)>; base::OnceCallback<void(int result, ProbeExitEnum probe_exit_enum)>;
// Establishes a TCP connection (with underlying TLS support) to |url|. Note // Establishes a TCP connection to |host_port_pair|. If |negotiate_tls| is
// that the constructor will not invoke |callback|, which is passed into // true, the underlying TCP socket upgrades to include TLS support. Note that
// TlsProber during construction. This ensures the TlsProber instance is // the constructor will not invoke |callback|, which is passed into TlsProber
// constructed before |callback| is invoked. The TlsProber must be created on // during construction. This ensures the TlsProber instance is constructed
// the UI thread and will invoke |callback| on the UI thread. // before |callback| is invoked. The TlsProber must be created on the UI
// thread and will invoke |callback| on the UI thread.
// |network_context_getter| will be invoked on the UI thread. // |network_context_getter| will be invoked on the UI thread.
TlsProber(NetworkContextGetter network_context_getter, TlsProber(NetworkContextGetter network_context_getter,
const GURL& url, net::HostPortPair host_port_pair,
bool negotiate_tls,
TlsProbeCompleteCallback callback); TlsProbeCompleteCallback callback);
TlsProber(const TlsProber&) = delete; TlsProber(const TlsProber&) = delete;
TlsProber& operator=(const TlsProber&) = delete; TlsProber& operator=(const TlsProber&) = delete;
...@@ -97,9 +99,11 @@ class TlsProber { ...@@ -97,9 +99,11 @@ class TlsProber {
void OnDone(int result, ProbeExitEnum probe_exit_enum); void OnDone(int result, ProbeExitEnum probe_exit_enum);
// Gets the active profile-specific network context. // Gets the active profile-specific network context.
NetworkContextGetter network_context_getter_; const NetworkContextGetter network_context_getter_;
// URL containing the hostname and port. // Contains the hostname and port.
GURL url_; const net::HostPortPair host_port_pair_;
// Indicates whether TLS support must be added to the underlying socket.
const bool negotiate_tls_;
// Host resolver used for DNS lookup. // Host resolver used for DNS lookup.
std::unique_ptr<HostResolver> host_resolver_; std::unique_ptr<HostResolver> host_resolver_;
// Holds socket if socket was connected via TCP. // Holds socket if socket was connected via TCP.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment