Commit 9bbf3290 authored by davidben's avatar davidben Committed by Commit bot

Pass in a non-null CertVerifier into SSLClientSocket.

SslHmacChannelAuthenticator passes in a null one which crashes but the
IsAllowedBadCert check, as well as inconsistent ability to use X509Certificate
in the sandbox masks the issue most of the time.

This also fixes FakeStreamSocket to propogate disconnects to the peer, which is
needed to add a test for this case. (If SSLClientSocket doesn't like a
certificate, it just ceremoniously disconnects the connection right after the
handshake.) This test crashed before this CL outside the sandbox. (Inside the
sandbox, it's possible that it worked on some platforms due to the sandbox
breaking net::X509Certificate. I didn't do a survey.)

BUG=none

Review URL: https://codereview.chromium.org/1080593003

Cr-Commit-Position: refs/heads/master@{#326886}
parent 9ba360de
...@@ -2444,6 +2444,8 @@ SSLClientSocketNSS::SSLClientSocketNSS( ...@@ -2444,6 +2444,8 @@ SSLClientSocketNSS::SSLClientSocketNSS(
transport_security_state_(context.transport_security_state), transport_security_state_(context.transport_security_state),
policy_enforcer_(context.cert_policy_enforcer), policy_enforcer_(context.cert_policy_enforcer),
valid_thread_id_(base::kInvalidThreadId) { valid_thread_id_(base::kInvalidThreadId) {
DCHECK(cert_verifier_);
EnterFunction(""); EnterFunction("");
InitCore(); InitCore();
LeaveFunction(""); LeaveFunction("");
......
...@@ -389,6 +389,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( ...@@ -389,6 +389,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
policy_enforcer_(context.cert_policy_enforcer), policy_enforcer_(context.cert_policy_enforcer),
net_log_(transport_->socket()->NetLog()), net_log_(transport_->socket()->NetLog()),
weak_factory_(this) { weak_factory_(this) {
DCHECK(cert_verifier_);
} }
SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "remoting/protocol/fake_stream_socket.h" #include "remoting/protocol/fake_stream_socket.h"
#include "base/bind.h" #include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/thread_task_runner_handle.h" #include "base/thread_task_runner_handle.h"
#include "net/base/address_list.h" #include "net/base/address_list.h"
...@@ -45,9 +46,17 @@ void FakeStreamSocket::AppendInputData(const std::string& data) { ...@@ -45,9 +46,17 @@ void FakeStreamSocket::AppendInputData(const std::string& data) {
input_pos_ += result; input_pos_ += result;
read_buffer_ = nullptr; read_buffer_ = nullptr;
net::CompletionCallback callback = read_callback_; base::ResetAndReturn(&read_callback_).Run(result);
read_callback_.Reset(); }
callback.Run(result); }
void FakeStreamSocket::AppendReadError(int error) {
EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
// Complete pending read if any.
if (!read_callback_.is_null()) {
base::ResetAndReturn(&read_callback_).Run(error);
} else {
next_read_error_ = error;
} }
} }
...@@ -65,18 +74,16 @@ int FakeStreamSocket::Read(net::IOBuffer* buf, int buf_len, ...@@ -65,18 +74,16 @@ int FakeStreamSocket::Read(net::IOBuffer* buf, int buf_len,
const net::CompletionCallback& callback) { const net::CompletionCallback& callback) {
EXPECT_TRUE(task_runner_->BelongsToCurrentThread()); EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
if (next_read_error_ != net::OK) {
int r = next_read_error_;
next_read_error_ = net::OK;
return r;
}
if (input_pos_ < static_cast<int>(input_data_.size())) { if (input_pos_ < static_cast<int>(input_data_.size())) {
int result = std::min(buf_len, int result = std::min(buf_len,
static_cast<int>(input_data_.size()) - input_pos_); static_cast<int>(input_data_.size()) - input_pos_);
memcpy(buf->data(), &(*input_data_.begin()) + input_pos_, result); memcpy(buf->data(), &(*input_data_.begin()) + input_pos_, result);
input_pos_ += result; input_pos_ += result;
return result; return result;
} else if (next_read_error_ != net::OK) {
int r = next_read_error_;
next_read_error_ = net::OK;
return r;
} else { } else {
read_buffer_ = buf; read_buffer_ = buf;
read_buffer_size_ = buf_len; read_buffer_size_ = buf_len;
...@@ -159,6 +166,14 @@ int FakeStreamSocket::Connect(const net::CompletionCallback& callback) { ...@@ -159,6 +166,14 @@ int FakeStreamSocket::Connect(const net::CompletionCallback& callback) {
void FakeStreamSocket::Disconnect() { void FakeStreamSocket::Disconnect() {
EXPECT_TRUE(task_runner_->BelongsToCurrentThread()); EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
if (peer_socket_.get()) {
task_runner_->PostTask(
FROM_HERE,
base::Bind(&FakeStreamSocket::AppendReadError,
peer_socket_,
net::ERR_CONNECTION_CLOSED));
}
peer_socket_.reset(); peer_socket_.reset();
} }
......
...@@ -45,14 +45,17 @@ class FakeStreamSocket : public net::StreamSocket { ...@@ -45,14 +45,17 @@ class FakeStreamSocket : public net::StreamSocket {
// Enables asynchronous Write(). // Enables asynchronous Write().
void set_async_write(bool async_write) { async_write_ = async_write; } void set_async_write(bool async_write) { async_write_ = async_write; }
// Set error codes for the next Read() and Write() calls. Once returned the // Set error codes for the next Write() call. Once returned the
// values are automatically reset to net::OK . // value is automatically reset to net::OK .
void set_next_read_error(int error) { next_read_error_ = error; }
void set_next_write_error(int error) { next_write_error_ = error; } void set_next_write_error(int error) { next_write_error_ = error; }
// Appends |data| to the read buffer. // Appends |data| to the read buffer.
void AppendInputData(const std::string& data); void AppendInputData(const std::string& data);
// Causes Read() to fail with |error| once the read buffer is exhausted. If
// there is a currently pending Read, it is interrupted.
void AppendReadError(int error);
// Pairs the socket with |peer_socket|. Deleting either of the paired sockets // Pairs the socket with |peer_socket|. Deleting either of the paired sockets
// unpairs them. // unpairs them.
void PairWith(FakeStreamSocket* peer_socket); void PairWith(FakeStreamSocket* peer_socket);
......
...@@ -279,7 +279,7 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) { ...@@ -279,7 +279,7 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) {
// Read() returns error. // Read() returns error.
TEST_F(MessageReaderTest, ReadError) { TEST_F(MessageReaderTest, ReadError) {
socket_.set_next_read_error(net::ERR_FAILED); socket_.AppendReadError(net::ERR_FAILED);
// Add a message. It should never be read after the error above. // Add a message. It should never be read after the error above.
AddMessage(kTestMessage1); AddMessage(kTestMessage1);
......
...@@ -6,10 +6,13 @@ ...@@ -6,10 +6,13 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/bind_helpers.h" #include "base/bind_helpers.h"
#include "base/logging.h"
#include "crypto/secure_util.h" #include "crypto/secure_util.h"
#include "net/base/host_port_pair.h" #include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/cert/cert_status_flags.h"
#include "net/cert/cert_verifier.h"
#include "net/cert/x509_certificate.h" #include "net/cert/x509_certificate.h"
#include "net/http/transport_security_state.h" #include "net/http/transport_security_state.h"
#include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_factory.h"
...@@ -24,6 +27,34 @@ ...@@ -24,6 +27,34 @@
namespace remoting { namespace remoting {
namespace protocol { namespace protocol {
namespace {
// A CertVerifier which rejects every certificate.
class FailingCertVerifier : public net::CertVerifier {
public:
FailingCertVerifier() {}
~FailingCertVerifier() override {}
int Verify(net::X509Certificate* cert,
const std::string& hostname,
int flags,
net::CRLSet* crl_set,
net::CertVerifyResult* verify_result,
const net::CompletionCallback& callback,
RequestHandle* out_req,
const net::BoundNetLog& net_log) override {
verify_result->verified_cert = cert;
verify_result->cert_status = net::CERT_STATUS_INVALID;
return net::ERR_CERT_INVALID;
}
void CancelRequest(RequestHandle req) override {
NOTIMPLEMENTED();
}
};
} // namespace
// static // static
scoped_ptr<SslHmacChannelAuthenticator> scoped_ptr<SslHmacChannelAuthenticator>
SslHmacChannelAuthenticator::CreateForClient( SslHmacChannelAuthenticator::CreateForClient(
...@@ -95,6 +126,7 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate( ...@@ -95,6 +126,7 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate(
#endif #endif
} else { } else {
transport_security_state_.reset(new net::TransportSecurityState); transport_security_state_.reset(new net::TransportSecurityState);
cert_verifier_.reset(new FailingCertVerifier);
net::SSLConfig::CertAndStatus cert_and_status; net::SSLConfig::CertAndStatus cert_and_status;
cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID; cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
...@@ -112,6 +144,7 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate( ...@@ -112,6 +144,7 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate(
net::HostPortPair host_and_port(kSslFakeHostName, 0); net::HostPortPair host_and_port(kSslFakeHostName, 0);
net::SSLClientSocketContext context; net::SSLClientSocketContext context;
context.transport_security_state = transport_security_state_.get(); context.transport_security_state = transport_security_state_.get();
context.cert_verifier = cert_verifier_.get();
scoped_ptr<net::ClientSocketHandle> socket_handle( scoped_ptr<net::ClientSocketHandle> socket_handle(
new net::ClientSocketHandle); new net::ClientSocketHandle);
socket_handle->SetSocket(socket.Pass()); socket_handle->SetSocket(socket.Pass());
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "remoting/protocol/channel_authenticator.h" #include "remoting/protocol/channel_authenticator.h"
namespace net { namespace net {
class CertVerifier;
class DrainableIOBuffer; class DrainableIOBuffer;
class GrowableIOBuffer; class GrowableIOBuffer;
class SSLSocket; class SSLSocket;
...@@ -89,6 +90,7 @@ class SslHmacChannelAuthenticator : public ChannelAuthenticator, ...@@ -89,6 +90,7 @@ class SslHmacChannelAuthenticator : public ChannelAuthenticator,
// Used in the CLIENT mode only. // Used in the CLIENT mode only.
std::string remote_cert_; std::string remote_cert_;
scoped_ptr<net::TransportSecurityState> transport_security_state_; scoped_ptr<net::TransportSecurityState> transport_security_state_;
scoped_ptr<net::CertVerifier> cert_verifier_;
scoped_ptr<net::SSLSocket> socket_; scoped_ptr<net::SSLSocket> socket_;
DoneCallback done_callback_; DoneCallback done_callback_;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "crypto/rsa_private_key.h" #include "crypto/rsa_private_key.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/base/test_data_directory.h" #include "net/base/test_data_directory.h"
#include "net/test/cert_test_util.h"
#include "remoting/base/rsa_key_pair.h" #include "remoting/base/rsa_key_pair.h"
#include "remoting/protocol/connection_tester.h" #include "remoting/protocol/connection_tester.h"
#include "remoting/protocol/fake_session.h" #include "remoting/protocol/fake_session.h"
...@@ -68,7 +69,7 @@ class SslHmacChannelAuthenticatorTest : public testing::Test { ...@@ -68,7 +69,7 @@ class SslHmacChannelAuthenticatorTest : public testing::Test {
ASSERT_TRUE(key_pair_.get()); ASSERT_TRUE(key_pair_.get());
} }
void RunChannelAuth(bool expected_fail) { void RunChannelAuth(int expected_client_error, int expected_host_error) {
client_fake_socket_.reset(new FakeStreamSocket()); client_fake_socket_.reset(new FakeStreamSocket());
host_fake_socket_.reset(new FakeStreamSocket()); host_fake_socket_.reset(new FakeStreamSocket());
client_fake_socket_->PairWith(host_fake_socket_.get()); client_fake_socket_->PairWith(host_fake_socket_.get());
...@@ -87,14 +88,18 @@ class SslHmacChannelAuthenticatorTest : public testing::Test { ...@@ -87,14 +88,18 @@ class SslHmacChannelAuthenticatorTest : public testing::Test {
// callback. // callback.
int callback_counter = 2; int callback_counter = 2;
if (expected_fail) { if (expected_client_error != net::OK) {
EXPECT_CALL(client_callback_, OnDone(net::ERR_FAILED, nullptr)) EXPECT_CALL(client_callback_, OnDone(expected_client_error, nullptr))
.WillOnce(QuitThreadOnCounter(&callback_counter));
EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED, nullptr))
.WillOnce(QuitThreadOnCounter(&callback_counter)); .WillOnce(QuitThreadOnCounter(&callback_counter));
} else { } else {
EXPECT_CALL(client_callback_, OnDone(net::OK, NotNull())) EXPECT_CALL(client_callback_, OnDone(net::OK, NotNull()))
.WillOnce(QuitThreadOnCounter(&callback_counter)); .WillOnce(QuitThreadOnCounter(&callback_counter));
}
if (expected_host_error != net::OK) {
EXPECT_CALL(host_callback_, OnDone(expected_host_error, nullptr))
.WillOnce(QuitThreadOnCounter(&callback_counter));
} else {
EXPECT_CALL(host_callback_, OnDone(net::OK, NotNull())) EXPECT_CALL(host_callback_, OnDone(net::OK, NotNull()))
.WillOnce(QuitThreadOnCounter(&callback_counter)); .WillOnce(QuitThreadOnCounter(&callback_counter));
} }
...@@ -149,7 +154,7 @@ TEST_F(SslHmacChannelAuthenticatorTest, SuccessfulAuth) { ...@@ -149,7 +154,7 @@ TEST_F(SslHmacChannelAuthenticatorTest, SuccessfulAuth) {
host_auth_ = SslHmacChannelAuthenticator::CreateForHost( host_auth_ = SslHmacChannelAuthenticator::CreateForHost(
host_cert_, key_pair_, kTestSharedSecret); host_cert_, key_pair_, kTestSharedSecret);
RunChannelAuth(false); RunChannelAuth(net::OK, net::OK);
ASSERT_TRUE(client_socket_.get() != nullptr); ASSERT_TRUE(client_socket_.get() != nullptr);
ASSERT_TRUE(host_socket_.get() != nullptr); ASSERT_TRUE(host_socket_.get() != nullptr);
...@@ -169,7 +174,26 @@ TEST_F(SslHmacChannelAuthenticatorTest, InvalidChannelSecret) { ...@@ -169,7 +174,26 @@ TEST_F(SslHmacChannelAuthenticatorTest, InvalidChannelSecret) {
host_auth_ = SslHmacChannelAuthenticator::CreateForHost( host_auth_ = SslHmacChannelAuthenticator::CreateForHost(
host_cert_, key_pair_, kTestSharedSecret); host_cert_, key_pair_, kTestSharedSecret);
RunChannelAuth(true); RunChannelAuth(net::ERR_FAILED, net::ERR_FAILED);
ASSERT_TRUE(host_socket_.get() == nullptr);
}
// Verify that channels cannot be using invalid certificate.
TEST_F(SslHmacChannelAuthenticatorTest, InvalidCertificate) {
// Import a second certificate for the client to expect.
scoped_refptr<net::X509Certificate> host_cert2(
net::ImportCertFromFile(net::GetTestCertsDirectory(), "ok_cert.pem"));
std::string host_cert2_der;
ASSERT_TRUE(net::X509Certificate::GetDEREncoded(host_cert2->os_cert_handle(),
&host_cert2_der));
client_auth_ = SslHmacChannelAuthenticator::CreateForClient(
host_cert2_der, kTestSharedSecret);
host_auth_ = SslHmacChannelAuthenticator::CreateForHost(
host_cert_, key_pair_, kTestSharedSecret);
RunChannelAuth(net::ERR_CERT_INVALID, net::ERR_CONNECTION_CLOSED);
ASSERT_TRUE(host_socket_.get() == nullptr); ASSERT_TRUE(host_socket_.get() == nullptr);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment