Commit b8ab3858 authored by David Benjamin's avatar David Benjamin Committed by Commit Bot

Switch SSLServerSocket to CRYPTO_BUFFER.

This also allows us to support asynchronous client certificate
verification, but this CL leaves it as a TODO for now.

Bug: 706445
Change-Id: I792eb91a854bb15a67317d7ea4d04a80ba5ca4da
Reviewed-on: https://chromium-review.googlesource.com/586431Reviewed-by: default avatarSteven Valdez <svaldez@chromium.org>
Reviewed-by: default avatarMatt Mueller <mattm@chromium.org>
Commit-Queue: David Benjamin <davidben@chromium.org>
Cr-Commit-Position: refs/heads/master@{#491886}
parent dac798e2
...@@ -103,6 +103,11 @@ NET_EXPORT bssl::UniquePtr<CRYPTO_BUFFER> CreateCryptoBuffer( ...@@ -103,6 +103,11 @@ NET_EXPORT bssl::UniquePtr<CRYPTO_BUFFER> CreateCryptoBuffer(
NET_EXPORT bssl::UniquePtr<CRYPTO_BUFFER> CreateCryptoBuffer( NET_EXPORT bssl::UniquePtr<CRYPTO_BUFFER> CreateCryptoBuffer(
const char* invalid_data); const char* invalid_data);
// Creates a new X509Certificate from the chain in |buffers|, which must have at
// least one element.
scoped_refptr<X509Certificate> CreateX509CertificateFromBuffers(
STACK_OF(CRYPTO_BUFFER) * buffers);
} // namespace x509_util } // namespace x509_util
} // namespace net } // namespace net
......
...@@ -8,12 +8,14 @@ ...@@ -8,12 +8,14 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <vector>
#include "base/lazy_instance.h" #include "base/lazy_instance.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
#include "base/strings/string_util.h" #include "base/strings/string_util.h"
#include "build/build_config.h"
#include "crypto/ec_private_key.h" #include "crypto/ec_private_key.h"
#include "crypto/openssl_util.h" #include "crypto/openssl_util.h"
#include "crypto/rsa_private_key.h" #include "crypto/rsa_private_key.h"
...@@ -26,6 +28,7 @@ ...@@ -26,6 +28,7 @@
#include "third_party/boringssl/src/include/openssl/digest.h" #include "third_party/boringssl/src/include/openssl/digest.h"
#include "third_party/boringssl/src/include/openssl/mem.h" #include "third_party/boringssl/src/include/openssl/mem.h"
#include "third_party/boringssl/src/include/openssl/pool.h" #include "third_party/boringssl/src/include/openssl/pool.h"
#include "third_party/boringssl/src/include/openssl/stack.h"
namespace net { namespace net {
...@@ -392,6 +395,33 @@ bssl::UniquePtr<CRYPTO_BUFFER> CreateCryptoBuffer( ...@@ -392,6 +395,33 @@ bssl::UniquePtr<CRYPTO_BUFFER> CreateCryptoBuffer(
data.size(), GetBufferPool())); data.size(), GetBufferPool()));
} }
scoped_refptr<X509Certificate> CreateX509CertificateFromBuffers(
STACK_OF(CRYPTO_BUFFER) * buffers) {
if (sk_CRYPTO_BUFFER_num(buffers) == 0) {
NOTREACHED();
return nullptr;
}
#if BUILDFLAG(USE_BYTE_CERTS)
std::vector<CRYPTO_BUFFER*> intermediate_chain;
for (size_t i = 1; i < sk_CRYPTO_BUFFER_num(buffers); ++i)
intermediate_chain.push_back(sk_CRYPTO_BUFFER_value(buffers, i));
return X509Certificate::CreateFromHandle(sk_CRYPTO_BUFFER_value(buffers, 0),
intermediate_chain);
#else
// Convert the certificate chains to a platform certificate handle.
std::vector<base::StringPiece> der_chain;
der_chain.reserve(sk_CRYPTO_BUFFER_num(buffers));
for (size_t i = 0; i < sk_CRYPTO_BUFFER_num(buffers); ++i) {
const CRYPTO_BUFFER* cert = sk_CRYPTO_BUFFER_value(buffers, i);
der_chain.push_back(base::StringPiece(
reinterpret_cast<const char*>(CRYPTO_BUFFER_data(cert)),
CRYPTO_BUFFER_len(cert)));
}
return X509Certificate::CreateFromDERCertChain(der_chain);
#endif
}
} // namespace x509_util } // namespace x509_util
} // namespace net } // namespace net
...@@ -198,44 +198,6 @@ int GetBufferSize(const char* field_trial) { ...@@ -198,44 +198,6 @@ int GetBufferSize(const char* field_trial) {
return buffer_size; return buffer_size;
} }
scoped_refptr<X509Certificate> OSChainFromBuffers(STACK_OF(CRYPTO_BUFFER) *
openssl_chain) {
if (sk_CRYPTO_BUFFER_num(openssl_chain) == 0) {
NOTREACHED();
return nullptr;
}
#if BUILDFLAG(USE_BYTE_CERTS)
std::vector<CRYPTO_BUFFER*> intermediate_chain;
for (size_t i = 1; i < sk_CRYPTO_BUFFER_num(openssl_chain); ++i)
intermediate_chain.push_back(sk_CRYPTO_BUFFER_value(openssl_chain, i));
return X509Certificate::CreateFromHandle(
sk_CRYPTO_BUFFER_value(openssl_chain, 0), intermediate_chain);
#else
// Convert the certificate chains to a platform certificate handle.
std::vector<base::StringPiece> der_chain;
der_chain.reserve(sk_CRYPTO_BUFFER_num(openssl_chain));
for (size_t i = 0; i < sk_CRYPTO_BUFFER_num(openssl_chain); ++i) {
const CRYPTO_BUFFER* cert = sk_CRYPTO_BUFFER_value(openssl_chain, i);
base::StringPiece der;
der_chain.push_back(base::StringPiece(
reinterpret_cast<const char*>(CRYPTO_BUFFER_data(cert)),
CRYPTO_BUFFER_len(cert)));
}
return X509Certificate::CreateFromDERCertChain(der_chain);
#endif
}
#if !defined(OS_IOS) && !BUILDFLAG(USE_BYTE_CERTS)
bssl::UniquePtr<CRYPTO_BUFFER> OSCertHandleToBuffer(
X509Certificate::OSCertHandle os_handle) {
std::string der_encoded;
if (!X509Certificate::GetDEREncoded(os_handle, &der_encoded))
return nullptr;
return x509_util::CreateCryptoBuffer(der_encoded);
}
#endif
std::unique_ptr<base::Value> NetLogSSLAlertCallback( std::unique_ptr<base::Value> NetLogSSLAlertCallback(
const void* bytes, const void* bytes,
size_t len, size_t len,
...@@ -1244,7 +1206,8 @@ int SSLClientSocketImpl::DoChannelIDLookupComplete(int result) { ...@@ -1244,7 +1206,8 @@ int SSLClientSocketImpl::DoChannelIDLookupComplete(int result) {
int SSLClientSocketImpl::DoVerifyCert(int result) { int SSLClientSocketImpl::DoVerifyCert(int result) {
DCHECK(start_cert_verification_time_.is_null()); DCHECK(start_cert_verification_time_.is_null());
server_cert_ = OSChainFromBuffers(SSL_get0_peer_certificates(ssl_.get())); server_cert_ = x509_util::CreateX509CertificateFromBuffers(
SSL_get0_peer_certificates(ssl_.get()));
// OpenSSL decoded the certificate, but the platform certificate // OpenSSL decoded the certificate, but the platform certificate
// implementation could not. This is treated as a fatal SSL-level protocol // implementation could not. This is treated as a fatal SSL-level protocol
...@@ -1659,44 +1622,11 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) { ...@@ -1659,44 +1622,11 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) {
return -1; return -1;
} }
#if BUILDFLAG(USE_BYTE_CERTS) if (!SetSSLChainAndKey(ssl_.get(), ssl_config_.client_cert.get(), nullptr,
std::vector<CRYPTO_BUFFER*> chain_raw; &SSLContext::kPrivateKeyMethod)) {
chain_raw.push_back(ssl_config_.client_cert->os_cert_handle());
for (X509Certificate::OSCertHandle cert :
ssl_config_.client_cert->GetIntermediateCertificates()) {
chain_raw.push_back(cert);
}
#else
std::vector<bssl::UniquePtr<CRYPTO_BUFFER>> chain;
std::vector<CRYPTO_BUFFER*> chain_raw;
bssl::UniquePtr<CRYPTO_BUFFER> buf =
OSCertHandleToBuffer(ssl_config_.client_cert->os_cert_handle());
if (!buf) {
LOG(WARNING) << "Failed to import certificate";
OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT); OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT);
return -1; return -1;
} }
chain_raw.push_back(buf.get());
chain.push_back(std::move(buf));
for (X509Certificate::OSCertHandle cert :
ssl_config_.client_cert->GetIntermediateCertificates()) {
bssl::UniquePtr<CRYPTO_BUFFER> buf = OSCertHandleToBuffer(cert);
if (!buf) {
LOG(WARNING) << "Failed to import intermediate";
OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT);
return -1;
}
chain_raw.push_back(buf.get());
chain.push_back(std::move(buf));
}
#endif
if (!SSL_set_chain_and_key(ssl_.get(), chain_raw.data(), chain_raw.size(),
nullptr, &SSLContext::kPrivateKeyMethod)) {
LOG(WARNING) << "Failed to set client certificate";
return -1;
}
std::vector<SSLPrivateKey::Hash> digest_prefs = std::vector<SSLPrivateKey::Hash> digest_prefs =
ssl_config_.client_private_key->GetDigestPreferences(); ssl_config_.client_private_key->GetDigestPreferences();
...@@ -1726,8 +1656,11 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) { ...@@ -1726,8 +1656,11 @@ int SSLClientSocketImpl::ClientCertRequestCallback(SSL* ssl) {
SSL_set_private_key_digest_prefs(ssl_.get(), digests.data(), SSL_set_private_key_digest_prefs(ssl_.get(), digests.data(),
digests.size()); digests.size());
net_log_.AddEvent(NetLogEventType::SSL_CLIENT_CERT_PROVIDED, net_log_.AddEvent(
NetLog::IntCallback("cert_count", chain_raw.size())); NetLogEventType::SSL_CLIENT_CERT_PROVIDED,
NetLog::IntCallback(
"cert_count",
1 + ssl_config_.client_cert->GetIntermediateCertificates().size()));
return 1; return 1;
} }
#endif // defined(OS_IOS) #endif // defined(OS_IOS)
......
...@@ -23,46 +23,19 @@ ...@@ -23,46 +23,19 @@
#include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_connection_status_flags.h"
#include "net/ssl/ssl_info.h" #include "net/ssl/ssl_info.h"
#include "third_party/boringssl/src/include/openssl/err.h" #include "third_party/boringssl/src/include/openssl/err.h"
#include "third_party/boringssl/src/include/openssl/pool.h"
#include "third_party/boringssl/src/include/openssl/ssl.h" #include "third_party/boringssl/src/include/openssl/ssl.h"
#include "third_party/boringssl/src/include/openssl/x509.h"
#define GotoState(s) next_handshake_state_ = s #define GotoState(s) next_handshake_state_ = s
namespace net { namespace net {
namespace { class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
public SocketBIOAdapter::Delegate {
// Creates an X509Certificate out of the concatenation of |cert|, if non-null,
// with |chain|.
scoped_refptr<X509Certificate> CreateX509Certificate(X509* cert,
STACK_OF(X509) * chain) {
std::vector<base::StringPiece> der_chain;
base::StringPiece der_cert;
scoped_refptr<X509Certificate> client_cert;
if (cert) {
if (!x509_util::GetDER(cert, &der_cert))
return nullptr;
der_chain.push_back(der_cert);
}
for (size_t i = 0; i < sk_X509_num(chain); ++i) {
X509* x = sk_X509_value(chain, i);
if (!x509_util::GetDER(x, &der_cert))
return nullptr;
der_chain.push_back(der_cert);
}
return X509Certificate::CreateFromDERCertChain(der_chain);
}
class SSLServerSocketImpl : public SSLServerSocket,
public SocketBIOAdapter::Delegate {
public: public:
// See comments on CreateSSLServerSocket for details of how these SocketImpl(SSLServerContextImpl* context,
// parameters are used. std::unique_ptr<StreamSocket> socket);
SSLServerSocketImpl(std::unique_ptr<StreamSocket> socket, ~SocketImpl() override;
bssl::UniquePtr<SSL> ssl);
~SSLServerSocketImpl() override;
// SSLServerSocket interface. // SSLServerSocket interface.
int Handshake(const CompletionCallback& callback) override; int Handshake(const CompletionCallback& callback) override;
...@@ -102,7 +75,8 @@ class SSLServerSocketImpl : public SSLServerSocket, ...@@ -102,7 +75,8 @@ class SSLServerSocketImpl : public SSLServerSocket,
void ClearConnectionAttempts() override {} void ClearConnectionAttempts() override {}
void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}
int64_t GetTotalReceivedBytes() const override; int64_t GetTotalReceivedBytes() const override;
static int CertVerifyCallback(X509_STORE_CTX* store_ctx, void* arg); static ssl_verify_result_t CertVerifyCallback(SSL* ssl, uint8_t* out_alert);
ssl_verify_result_t CertVerifyCallbackImpl(uint8_t* out_alert);
// SocketBIOAdapter::Delegate implementation. // SocketBIOAdapter::Delegate implementation.
void OnReadReady() override; void OnReadReady() override;
...@@ -128,6 +102,8 @@ class SSLServerSocketImpl : public SSLServerSocket, ...@@ -128,6 +102,8 @@ class SSLServerSocketImpl : public SSLServerSocket,
int Init(); int Init();
void ExtractClientCert(); void ExtractClientCert();
SSLServerContextImpl* context_;
NetLogWithSource net_log_; NetLogWithSource net_log_;
CompletionCallback user_handshake_callback_; CompletionCallback user_handshake_callback_;
...@@ -155,20 +131,23 @@ class SSLServerSocketImpl : public SSLServerSocket, ...@@ -155,20 +131,23 @@ class SSLServerSocketImpl : public SSLServerSocket,
State next_handshake_state_; State next_handshake_state_;
bool completed_handshake_; bool completed_handshake_;
DISALLOW_COPY_AND_ASSIGN(SSLServerSocketImpl); DISALLOW_COPY_AND_ASSIGN(SocketImpl);
}; };
SSLServerSocketImpl::SSLServerSocketImpl( SSLServerContextImpl::SocketImpl::SocketImpl(
std::unique_ptr<StreamSocket> transport_socket, SSLServerContextImpl* context,
bssl::UniquePtr<SSL> ssl) std::unique_ptr<StreamSocket> transport_socket)
: user_read_buf_len_(0), : context_(context),
user_read_buf_len_(0),
user_write_buf_len_(0), user_write_buf_len_(0),
ssl_(std::move(ssl)),
transport_socket_(std::move(transport_socket)), transport_socket_(std::move(transport_socket)),
next_handshake_state_(STATE_NONE), next_handshake_state_(STATE_NONE),
completed_handshake_(false) {} completed_handshake_(false) {
ssl_.reset(SSL_new(context_->ssl_ctx_.get()));
SSL_set_app_data(ssl_.get(), this);
}
SSLServerSocketImpl::~SSLServerSocketImpl() { SSLServerContextImpl::SocketImpl::~SocketImpl() {
if (ssl_) { if (ssl_) {
// Calling SSL_shutdown prevents the session from being marked as // Calling SSL_shutdown prevents the session from being marked as
// unresumable. // unresumable.
...@@ -177,7 +156,8 @@ SSLServerSocketImpl::~SSLServerSocketImpl() { ...@@ -177,7 +156,8 @@ SSLServerSocketImpl::~SSLServerSocketImpl() {
} }
} }
int SSLServerSocketImpl::Handshake(const CompletionCallback& callback) { int SSLServerContextImpl::SocketImpl::Handshake(
const CompletionCallback& callback) {
net_log_.BeginEvent(NetLogEventType::SSL_SERVER_HANDSHAKE); net_log_.BeginEvent(NetLogEventType::SSL_SERVER_HANDSHAKE);
// Set up new ssl object. // Set up new ssl object.
...@@ -204,11 +184,12 @@ int SSLServerSocketImpl::Handshake(const CompletionCallback& callback) { ...@@ -204,11 +184,12 @@ int SSLServerSocketImpl::Handshake(const CompletionCallback& callback) {
return rv > OK ? OK : rv; return rv > OK ? OK : rv;
} }
int SSLServerSocketImpl::ExportKeyingMaterial(const base::StringPiece& label, int SSLServerContextImpl::SocketImpl::ExportKeyingMaterial(
bool has_context, const base::StringPiece& label,
const base::StringPiece& context, bool has_context,
unsigned char* out, const base::StringPiece& context,
unsigned int outlen) { unsigned char* out,
unsigned int outlen) {
if (!IsConnected()) if (!IsConnected())
return ERR_SOCKET_NOT_CONNECTED; return ERR_SOCKET_NOT_CONNECTED;
...@@ -228,9 +209,9 @@ int SSLServerSocketImpl::ExportKeyingMaterial(const base::StringPiece& label, ...@@ -228,9 +209,9 @@ int SSLServerSocketImpl::ExportKeyingMaterial(const base::StringPiece& label,
return OK; return OK;
} }
int SSLServerSocketImpl::Read(IOBuffer* buf, int SSLServerContextImpl::SocketImpl::Read(IOBuffer* buf,
int buf_len, int buf_len,
const CompletionCallback& callback) { const CompletionCallback& callback) {
DCHECK(user_read_callback_.is_null()); DCHECK(user_read_callback_.is_null());
DCHECK(user_handshake_callback_.is_null()); DCHECK(user_handshake_callback_.is_null());
DCHECK(!user_read_buf_); DCHECK(!user_read_buf_);
...@@ -253,9 +234,10 @@ int SSLServerSocketImpl::Read(IOBuffer* buf, ...@@ -253,9 +234,10 @@ int SSLServerSocketImpl::Read(IOBuffer* buf,
return rv; return rv;
} }
int SSLServerSocketImpl::Write(IOBuffer* buf, int SSLServerContextImpl::SocketImpl::Write(
int buf_len, IOBuffer* buf,
const CompletionCallback& callback) { int buf_len,
const CompletionCallback& callback) {
DCHECK(user_write_callback_.is_null()); DCHECK(user_write_callback_.is_null());
DCHECK(!user_write_buf_); DCHECK(!user_write_buf_);
DCHECK(!callback.is_null()); DCHECK(!callback.is_null());
...@@ -274,72 +256,75 @@ int SSLServerSocketImpl::Write(IOBuffer* buf, ...@@ -274,72 +256,75 @@ int SSLServerSocketImpl::Write(IOBuffer* buf,
return rv; return rv;
} }
int SSLServerSocketImpl::SetReceiveBufferSize(int32_t size) { int SSLServerContextImpl::SocketImpl::SetReceiveBufferSize(int32_t size) {
return transport_socket_->SetReceiveBufferSize(size); return transport_socket_->SetReceiveBufferSize(size);
} }
int SSLServerSocketImpl::SetSendBufferSize(int32_t size) { int SSLServerContextImpl::SocketImpl::SetSendBufferSize(int32_t size) {
return transport_socket_->SetSendBufferSize(size); return transport_socket_->SetSendBufferSize(size);
} }
int SSLServerSocketImpl::Connect(const CompletionCallback& callback) { int SSLServerContextImpl::SocketImpl::Connect(
const CompletionCallback& callback) {
NOTIMPLEMENTED(); NOTIMPLEMENTED();
return ERR_NOT_IMPLEMENTED; return ERR_NOT_IMPLEMENTED;
} }
void SSLServerSocketImpl::Disconnect() { void SSLServerContextImpl::SocketImpl::Disconnect() {
transport_socket_->Disconnect(); transport_socket_->Disconnect();
} }
bool SSLServerSocketImpl::IsConnected() const { bool SSLServerContextImpl::SocketImpl::IsConnected() const {
// TODO(wtc): Find out if we should check transport_socket_->IsConnected() // TODO(wtc): Find out if we should check transport_socket_->IsConnected()
// as well. // as well.
return completed_handshake_; return completed_handshake_;
} }
bool SSLServerSocketImpl::IsConnectedAndIdle() const { bool SSLServerContextImpl::SocketImpl::IsConnectedAndIdle() const {
return completed_handshake_ && transport_socket_->IsConnectedAndIdle(); return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
} }
int SSLServerSocketImpl::GetPeerAddress(IPEndPoint* address) const { int SSLServerContextImpl::SocketImpl::GetPeerAddress(
IPEndPoint* address) const {
if (!IsConnected()) if (!IsConnected())
return ERR_SOCKET_NOT_CONNECTED; return ERR_SOCKET_NOT_CONNECTED;
return transport_socket_->GetPeerAddress(address); return transport_socket_->GetPeerAddress(address);
} }
int SSLServerSocketImpl::GetLocalAddress(IPEndPoint* address) const { int SSLServerContextImpl::SocketImpl::GetLocalAddress(
IPEndPoint* address) const {
if (!IsConnected()) if (!IsConnected())
return ERR_SOCKET_NOT_CONNECTED; return ERR_SOCKET_NOT_CONNECTED;
return transport_socket_->GetLocalAddress(address); return transport_socket_->GetLocalAddress(address);
} }
const NetLogWithSource& SSLServerSocketImpl::NetLog() const { const NetLogWithSource& SSLServerContextImpl::SocketImpl::NetLog() const {
return net_log_; return net_log_;
} }
void SSLServerSocketImpl::SetSubresourceSpeculation() { void SSLServerContextImpl::SocketImpl::SetSubresourceSpeculation() {
transport_socket_->SetSubresourceSpeculation(); transport_socket_->SetSubresourceSpeculation();
} }
void SSLServerSocketImpl::SetOmniboxSpeculation() { void SSLServerContextImpl::SocketImpl::SetOmniboxSpeculation() {
transport_socket_->SetOmniboxSpeculation(); transport_socket_->SetOmniboxSpeculation();
} }
bool SSLServerSocketImpl::WasEverUsed() const { bool SSLServerContextImpl::SocketImpl::WasEverUsed() const {
return transport_socket_->WasEverUsed(); return transport_socket_->WasEverUsed();
} }
bool SSLServerSocketImpl::WasAlpnNegotiated() const { bool SSLServerContextImpl::SocketImpl::WasAlpnNegotiated() const {
NOTIMPLEMENTED(); NOTIMPLEMENTED();
return false; return false;
} }
NextProto SSLServerSocketImpl::GetNegotiatedProtocol() const { NextProto SSLServerContextImpl::SocketImpl::GetNegotiatedProtocol() const {
// ALPN is not supported by this class. // ALPN is not supported by this class.
return kProtoUnknown; return kProtoUnknown;
} }
bool SSLServerSocketImpl::GetSSLInfo(SSLInfo* ssl_info) { bool SSLServerContextImpl::SocketImpl::GetSSLInfo(SSLInfo* ssl_info) {
ssl_info->Reset(); ssl_info->Reset();
if (!completed_handshake_) if (!completed_handshake_)
return false; return false;
...@@ -363,15 +348,16 @@ bool SSLServerSocketImpl::GetSSLInfo(SSLInfo* ssl_info) { ...@@ -363,15 +348,16 @@ bool SSLServerSocketImpl::GetSSLInfo(SSLInfo* ssl_info) {
return true; return true;
} }
void SSLServerSocketImpl::GetConnectionAttempts(ConnectionAttempts* out) const { void SSLServerContextImpl::SocketImpl::GetConnectionAttempts(
ConnectionAttempts* out) const {
out->clear(); out->clear();
} }
int64_t SSLServerSocketImpl::GetTotalReceivedBytes() const { int64_t SSLServerContextImpl::SocketImpl::GetTotalReceivedBytes() const {
return transport_socket_->GetTotalReceivedBytes(); return transport_socket_->GetTotalReceivedBytes();
} }
void SSLServerSocketImpl::OnReadReady() { void SSLServerContextImpl::SocketImpl::OnReadReady() {
if (next_handshake_state_ == STATE_HANDSHAKE) { if (next_handshake_state_ == STATE_HANDSHAKE) {
// In handshake phase. The parameter to OnHandshakeIOComplete is unused. // In handshake phase. The parameter to OnHandshakeIOComplete is unused.
OnHandshakeIOComplete(OK); OnHandshakeIOComplete(OK);
...@@ -388,7 +374,7 @@ void SSLServerSocketImpl::OnReadReady() { ...@@ -388,7 +374,7 @@ void SSLServerSocketImpl::OnReadReady() {
DoReadCallback(rv); DoReadCallback(rv);
} }
void SSLServerSocketImpl::OnWriteReady() { void SSLServerContextImpl::SocketImpl::OnWriteReady() {
if (next_handshake_state_ == STATE_HANDSHAKE) { if (next_handshake_state_ == STATE_HANDSHAKE) {
// In handshake phase. The parameter to OnHandshakeIOComplete is unused. // In handshake phase. The parameter to OnHandshakeIOComplete is unused.
OnHandshakeIOComplete(OK); OnHandshakeIOComplete(OK);
...@@ -405,7 +391,7 @@ void SSLServerSocketImpl::OnWriteReady() { ...@@ -405,7 +391,7 @@ void SSLServerSocketImpl::OnWriteReady() {
DoWriteCallback(rv); DoWriteCallback(rv);
} }
void SSLServerSocketImpl::OnHandshakeIOComplete(int result) { void SSLServerContextImpl::SocketImpl::OnHandshakeIOComplete(int result) {
int rv = DoHandshakeLoop(result); int rv = DoHandshakeLoop(result);
if (rv == ERR_IO_PENDING) if (rv == ERR_IO_PENDING)
return; return;
...@@ -415,8 +401,7 @@ void SSLServerSocketImpl::OnHandshakeIOComplete(int result) { ...@@ -415,8 +401,7 @@ void SSLServerSocketImpl::OnHandshakeIOComplete(int result) {
DoHandshakeCallback(rv); DoHandshakeCallback(rv);
} }
int SSLServerContextImpl::SocketImpl::DoPayloadRead() {
int SSLServerSocketImpl::DoPayloadRead() {
DCHECK(completed_handshake_); DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_handshake_state_); DCHECK_EQ(STATE_NONE, next_handshake_state_);
DCHECK(user_read_buf_); DCHECK(user_read_buf_);
...@@ -438,7 +423,7 @@ int SSLServerSocketImpl::DoPayloadRead() { ...@@ -438,7 +423,7 @@ int SSLServerSocketImpl::DoPayloadRead() {
return net_error; return net_error;
} }
int SSLServerSocketImpl::DoPayloadWrite() { int SSLServerContextImpl::SocketImpl::DoPayloadWrite() {
DCHECK(completed_handshake_); DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_handshake_state_); DCHECK_EQ(STATE_NONE, next_handshake_state_);
DCHECK(user_write_buf_); DCHECK(user_write_buf_);
...@@ -459,7 +444,7 @@ int SSLServerSocketImpl::DoPayloadWrite() { ...@@ -459,7 +444,7 @@ int SSLServerSocketImpl::DoPayloadWrite() {
return net_error; return net_error;
} }
int SSLServerSocketImpl::DoHandshakeLoop(int last_io_result) { int SSLServerContextImpl::SocketImpl::DoHandshakeLoop(int last_io_result) {
int rv = last_io_result; int rv = last_io_result;
do { do {
// Default to STATE_NONE for next state. // Default to STATE_NONE for next state.
...@@ -483,21 +468,17 @@ int SSLServerSocketImpl::DoHandshakeLoop(int last_io_result) { ...@@ -483,21 +468,17 @@ int SSLServerSocketImpl::DoHandshakeLoop(int last_io_result) {
return rv; return rv;
} }
int SSLServerSocketImpl::DoHandshake() { int SSLServerContextImpl::SocketImpl::DoHandshake() {
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
int net_error = OK; int net_error = OK;
int rv = SSL_do_handshake(ssl_.get()); int rv = SSL_do_handshake(ssl_.get());
if (rv == 1) { if (rv == 1) {
completed_handshake_ = true; completed_handshake_ = true;
// The results of SSL_get_peer_certificate() must be explicitly freed. STACK_OF(CRYPTO_BUFFER)* certs = SSL_get0_peer_certificates(ssl_.get());
bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl_.get())); if (certs) {
if (cert) { client_cert_ = x509_util::CreateX509CertificateFromBuffers(certs);
// The caller does not take ownership of SSL_get_peer_cert_chain's if (!client_cert_)
// results.
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_.get());
client_cert_ = CreateX509Certificate(cert.get(), chain);
if (!client_cert_.get())
return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT; return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT;
} }
} else { } else {
...@@ -527,12 +508,12 @@ int SSLServerSocketImpl::DoHandshake() { ...@@ -527,12 +508,12 @@ int SSLServerSocketImpl::DoHandshake() {
return net_error; return net_error;
} }
void SSLServerSocketImpl::DoHandshakeCallback(int rv) { void SSLServerContextImpl::SocketImpl::DoHandshakeCallback(int rv) {
DCHECK_NE(rv, ERR_IO_PENDING); DCHECK_NE(rv, ERR_IO_PENDING);
base::ResetAndReturn(&user_handshake_callback_).Run(rv > OK ? OK : rv); base::ResetAndReturn(&user_handshake_callback_).Run(rv > OK ? OK : rv);
} }
void SSLServerSocketImpl::DoReadCallback(int rv) { void SSLServerContextImpl::SocketImpl::DoReadCallback(int rv) {
DCHECK(rv != ERR_IO_PENDING); DCHECK(rv != ERR_IO_PENDING);
DCHECK(!user_read_callback_.is_null()); DCHECK(!user_read_callback_.is_null());
...@@ -541,7 +522,7 @@ void SSLServerSocketImpl::DoReadCallback(int rv) { ...@@ -541,7 +522,7 @@ void SSLServerSocketImpl::DoReadCallback(int rv) {
base::ResetAndReturn(&user_read_callback_).Run(rv); base::ResetAndReturn(&user_read_callback_).Run(rv);
} }
void SSLServerSocketImpl::DoWriteCallback(int rv) { void SSLServerContextImpl::SocketImpl::DoWriteCallback(int rv) {
DCHECK(rv != ERR_IO_PENDING); DCHECK(rv != ERR_IO_PENDING);
DCHECK(!user_write_callback_.is_null()); DCHECK(!user_write_callback_.is_null());
...@@ -550,7 +531,7 @@ void SSLServerSocketImpl::DoWriteCallback(int rv) { ...@@ -550,7 +531,7 @@ void SSLServerSocketImpl::DoWriteCallback(int rv) {
base::ResetAndReturn(&user_write_callback_).Run(rv); base::ResetAndReturn(&user_write_callback_).Run(rv);
} }
int SSLServerSocketImpl::Init() { int SSLServerContextImpl::SocketImpl::Init() {
static const int kBufferSize = 17 * 1024; static const int kBufferSize = 17 * 1024;
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
...@@ -558,6 +539,14 @@ int SSLServerSocketImpl::Init() { ...@@ -558,6 +539,14 @@ int SSLServerSocketImpl::Init() {
if (!ssl_) if (!ssl_)
return ERR_UNEXPECTED; return ERR_UNEXPECTED;
// Set certificate and private key.
DCHECK(context_->cert_->os_cert_handle());
DCHECK(context_->key_->key());
if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(),
context_->key_->key(), nullptr)) {
return ERR_UNEXPECTED;
}
transport_adapter_.reset(new SocketBIOAdapter( transport_adapter_.reset(new SocketBIOAdapter(
transport_socket_.get(), kBufferSize, kBufferSize, this)); transport_socket_.get(), kBufferSize, kBufferSize, this));
BIO* transport_bio = transport_adapter_->bio(); BIO* transport_bio = transport_adapter_->bio();
...@@ -572,37 +561,43 @@ int SSLServerSocketImpl::Init() { ...@@ -572,37 +561,43 @@ int SSLServerSocketImpl::Init() {
} }
// static // static
int SSLServerSocketImpl::CertVerifyCallback(X509_STORE_CTX* store_ctx, ssl_verify_result_t SSLServerContextImpl::SocketImpl::CertVerifyCallback(
void* arg) { SSL* ssl,
ClientCertVerifier* verifier = reinterpret_cast<ClientCertVerifier*>(arg); uint8_t* out_alert) {
SocketImpl* socket = reinterpret_cast<SocketImpl*>(SSL_get_app_data(ssl));
return socket->CertVerifyCallbackImpl(out_alert);
}
ssl_verify_result_t SSLServerContextImpl::SocketImpl::CertVerifyCallbackImpl(
uint8_t* out_alert) {
ClientCertVerifier* verifier =
context_->ssl_server_config_.client_cert_verifier;
// If a verifier was not supplied, all certificates are accepted. // If a verifier was not supplied, all certificates are accepted.
if (!verifier) if (!verifier)
return 1; return ssl_verify_ok;
STACK_OF(X509)* chain = store_ctx->untrusted;
scoped_refptr<X509Certificate> client_cert( scoped_refptr<X509Certificate> client_cert =
CreateX509Certificate(nullptr, chain)); x509_util::CreateX509CertificateFromBuffers(
if (!client_cert.get()) { SSL_get0_peer_certificates(ssl_.get()));
X509_STORE_CTX_set_error(store_ctx, X509_V_ERR_CERT_REJECTED); if (!client_cert) {
return 0; *out_alert = SSL_AD_BAD_CERTIFICATE;
return ssl_verify_invalid;
} }
// Asynchronous completion of Verify is currently not supported.
// http://crbug.com/347402 // TODO(davidben): Support asynchronous verifiers. http://crbug.com/347402
// The API for Verify supports the parts needed for async completion
// but is currently expected to complete synchronously.
std::unique_ptr<ClientCertVerifier::Request> ignore_async; std::unique_ptr<ClientCertVerifier::Request> ignore_async;
int res = int res =
verifier->Verify(client_cert.get(), CompletionCallback(), &ignore_async); verifier->Verify(client_cert.get(), CompletionCallback(), &ignore_async);
DCHECK_NE(res, ERR_IO_PENDING); DCHECK_NE(res, ERR_IO_PENDING);
if (res != OK) { if (res != OK) {
X509_STORE_CTX_set_error(store_ctx, X509_V_ERR_CERT_REJECTED); // TODO(davidben): Map from certificate verification failure to alert.
return 0; *out_alert = SSL_AD_CERTIFICATE_UNKNOWN;
return ssl_verify_invalid;
} }
return 1; return ssl_verify_ok;
} }
} // namespace
std::unique_ptr<SSLServerContext> CreateSSLServerContext( std::unique_ptr<SSLServerContext> CreateSSLServerContext(
X509Certificate* certificate, X509Certificate* certificate,
const crypto::RSAPrivateKey& key, const crypto::RSAPrivateKey& key,
...@@ -620,7 +615,7 @@ SSLServerContextImpl::SSLServerContextImpl( ...@@ -620,7 +615,7 @@ SSLServerContextImpl::SSLServerContextImpl(
key_(key.Copy()) { key_(key.Copy()) {
CHECK(key_); CHECK(key_);
crypto::EnsureOpenSSLInit(); crypto::EnsureOpenSSLInit();
ssl_ctx_.reset(SSL_CTX_new(TLS_method())); ssl_ctx_.reset(SSL_CTX_new(TLS_with_buffers_method()));
SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_SERVER); SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_SERVER);
uint8_t session_ctx_id = 0; uint8_t session_ctx_id = 0;
SSL_CTX_set_session_id_context(ssl_ctx_.get(), &session_ctx_id, SSL_CTX_set_session_id_context(ssl_ctx_.get(), &session_ctx_id,
...@@ -635,38 +630,13 @@ SSLServerContextImpl::SSLServerContextImpl( ...@@ -635,38 +630,13 @@ SSLServerContextImpl::SSLServerContextImpl(
// Fall-through // Fall-through
case SSLServerConfig::ClientCertType::OPTIONAL_CLIENT_CERT: case SSLServerConfig::ClientCertType::OPTIONAL_CLIENT_CERT:
verify_mode |= SSL_VERIFY_PEER; verify_mode |= SSL_VERIFY_PEER;
SSL_CTX_set_verify(ssl_ctx_.get(), verify_mode, nullptr); SSL_CTX_set_custom_verify(ssl_ctx_.get(), verify_mode,
SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), SocketImpl::CertVerifyCallback);
SSLServerSocketImpl::CertVerifyCallback,
ssl_server_config_.client_cert_verifier);
break; break;
case SSLServerConfig::ClientCertType::NO_CLIENT_CERT: case SSLServerConfig::ClientCertType::NO_CLIENT_CERT:
break; break;
} }
// Set certificate and private key.
DCHECK(cert_->os_cert_handle());
DCHECK(key_->key());
#if BUILDFLAG(USE_BYTE_CERTS)
// On success, SSL_CTX_set_chain_and_key acquires a reference to
// |cert_->os_cert_handle()| and |key_->key()|.
CRYPTO_BUFFER* cert_buffers[] = {cert_->os_cert_handle()};
CHECK(SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_buffers,
arraysize(cert_buffers), key_->key(),
nullptr /* privkey_method */));
#elif defined(USE_OPENSSL_CERTS)
CHECK(SSL_CTX_use_certificate(ssl_ctx_.get(), cert_->os_cert_handle()));
CHECK(SSL_CTX_use_PrivateKey(ssl_ctx_.get(), key_->key()));
#else
std::string der_string;
CHECK(X509Certificate::GetDEREncoded(cert_->os_cert_handle(), &der_string));
CHECK(SSL_CTX_use_certificate_ASN1(
ssl_ctx_.get(), der_string.length(),
reinterpret_cast<const unsigned char*>(der_string.data())));
// On success, SSL_CTX_use_PrivateKey acquires a reference to |key_->key()|.
CHECK(SSL_CTX_use_PrivateKey(ssl_ctx_.get(), key_->key()));
#endif // USE_OPENSSL_CERTS && !USE_BYTE_CERTS
DCHECK_LT(SSL3_VERSION, ssl_server_config_.version_min); DCHECK_LT(SSL3_VERSION, ssl_server_config_.version_min);
DCHECK_LT(SSL3_VERSION, ssl_server_config_.version_max); DCHECK_LT(SSL3_VERSION, ssl_server_config_.version_max);
CHECK(SSL_CTX_set_min_proto_version(ssl_ctx_.get(), CHECK(SSL_CTX_set_min_proto_version(ssl_ctx_.get(),
...@@ -713,16 +683,12 @@ SSLServerContextImpl::SSLServerContextImpl( ...@@ -713,16 +683,12 @@ SSLServerContextImpl::SSLServerContextImpl(
if (ssl_server_config_.client_cert_type != if (ssl_server_config_.client_cert_type !=
SSLServerConfig::ClientCertType::NO_CLIENT_CERT && SSLServerConfig::ClientCertType::NO_CLIENT_CERT &&
!ssl_server_config_.cert_authorities_.empty()) { !ssl_server_config_.cert_authorities_.empty()) {
bssl::UniquePtr<STACK_OF(X509_NAME)> stack(sk_X509_NAME_new_null()); bssl::UniquePtr<STACK_OF(CRYPTO_BUFFER)> stack(sk_CRYPTO_BUFFER_new_null());
for (const auto& authority : ssl_server_config_.cert_authorities_) { for (const auto& authority : ssl_server_config_.cert_authorities_) {
const uint8_t* name = reinterpret_cast<const uint8_t*>(authority.c_str()); sk_CRYPTO_BUFFER_push(stack.get(),
const uint8_t* name_start = name; x509_util::CreateCryptoBuffer(authority).release());
bssl::UniquePtr<X509_NAME> subj(
d2i_X509_NAME(nullptr, &name, authority.length()));
CHECK(subj && name == name_start + authority.length());
sk_X509_NAME_push(stack.get(), subj.release());
} }
SSL_CTX_set_client_CA_list(ssl_ctx_.get(), stack.release()); SSL_CTX_set0_client_CAs(ssl_ctx_.get(), stack.release());
} }
} }
...@@ -730,9 +696,8 @@ SSLServerContextImpl::~SSLServerContextImpl() {} ...@@ -730,9 +696,8 @@ SSLServerContextImpl::~SSLServerContextImpl() {}
std::unique_ptr<SSLServerSocket> SSLServerContextImpl::CreateSSLServerSocket( std::unique_ptr<SSLServerSocket> SSLServerContextImpl::CreateSSLServerSocket(
std::unique_ptr<StreamSocket> socket) { std::unique_ptr<StreamSocket> socket) {
bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx_.get()));
return std::unique_ptr<SSLServerSocket>( return std::unique_ptr<SSLServerSocket>(
new SSLServerSocketImpl(std::move(socket), std::move(ssl))); new SocketImpl(this, std::move(socket)));
} }
} // namespace net } // namespace net
...@@ -29,6 +29,8 @@ class SSLServerContextImpl : public SSLServerContext { ...@@ -29,6 +29,8 @@ class SSLServerContextImpl : public SSLServerContext {
std::unique_ptr<StreamSocket> socket) override; std::unique_ptr<StreamSocket> socket) override;
private: private:
class SocketImpl;
bssl::UniquePtr<SSL_CTX> ssl_ctx_; bssl::UniquePtr<SSL_CTX> ssl_ctx_;
// Options for the SSL socket. // Options for the SSL socket.
......
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
#include "base/location.h" #include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/values.h" #include "base/values.h"
#include "build/build_config.h"
#include "crypto/openssl_util.h" #include "crypto/openssl_util.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/cert/x509_util.h"
#include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_connection_status_flags.h"
#include "third_party/boringssl/src/include/openssl/err.h" #include "third_party/boringssl/src/include/openssl/err.h"
#include "third_party/boringssl/src/include/openssl/ssl.h" #include "third_party/boringssl/src/include/openssl/ssl.h"
...@@ -138,6 +140,16 @@ std::unique_ptr<base::Value> NetLogOpenSSLErrorCallback( ...@@ -138,6 +140,16 @@ std::unique_ptr<base::Value> NetLogOpenSSLErrorCallback(
return std::move(dict); return std::move(dict);
} }
#if !BUILDFLAG(USE_BYTE_CERTS)
bssl::UniquePtr<CRYPTO_BUFFER> OSCertHandleToBuffer(
X509Certificate::OSCertHandle os_handle) {
std::string der_encoded;
if (!X509Certificate::GetDEREncoded(os_handle, &der_encoded))
return nullptr;
return x509_util::CreateCryptoBuffer(der_encoded);
}
#endif
} // namespace } // namespace
void OpenSSLPutNetError(const tracked_objects::Location& location, int err) { void OpenSSLPutNetError(const tracked_objects::Location& location, int err) {
...@@ -224,4 +236,48 @@ int GetNetSSLVersion(SSL* ssl) { ...@@ -224,4 +236,48 @@ int GetNetSSLVersion(SSL* ssl) {
} }
} }
bool SetSSLChainAndKey(SSL* ssl,
X509Certificate* cert,
EVP_PKEY* pkey,
const SSL_PRIVATE_KEY_METHOD* custom_key) {
#if BUILDFLAG(USE_BYTE_CERTS)
std::vector<CRYPTO_BUFFER*> chain_raw;
chain_raw.push_back(cert->os_cert_handle());
for (X509Certificate::OSCertHandle handle :
cert->GetIntermediateCertificates()) {
chain_raw.push_back(handle);
}
#else
std::vector<bssl::UniquePtr<CRYPTO_BUFFER>> chain;
std::vector<CRYPTO_BUFFER*> chain_raw;
bssl::UniquePtr<CRYPTO_BUFFER> buf =
OSCertHandleToBuffer(cert->os_cert_handle());
if (!buf) {
LOG(WARNING) << "Failed to import certificate";
return false;
}
chain_raw.push_back(buf.get());
chain.push_back(std::move(buf));
for (X509Certificate::OSCertHandle handle :
cert->GetIntermediateCertificates()) {
bssl::UniquePtr<CRYPTO_BUFFER> buf = OSCertHandleToBuffer(handle);
if (!buf) {
LOG(WARNING) << "Failed to import intermediate";
return false;
}
chain_raw.push_back(buf.get());
chain.push_back(std::move(buf));
}
#endif
if (!SSL_set_chain_and_key(ssl, chain_raw.data(), chain_raw.size(), pkey,
custom_key)) {
LOG(WARNING) << "Failed to set client certificate";
return false;
}
return true;
}
} // namespace net } // namespace net
...@@ -78,6 +78,13 @@ NetLogParametersCallback CreateNetLogOpenSSLErrorCallback( ...@@ -78,6 +78,13 @@ NetLogParametersCallback CreateNetLogOpenSSLErrorCallback(
// this SSL connection. // this SSL connection.
int GetNetSSLVersion(SSL* ssl); int GetNetSSLVersion(SSL* ssl);
// Configures |ssl| to send the specified certificate and either |pkey| or
// |custom_key|. This is a wrapper over |SSL_set_chain_and_key|.
bool SetSSLChainAndKey(SSL* ssl,
X509Certificate* cert,
EVP_PKEY* pkey,
const SSL_PRIVATE_KEY_METHOD* custom_key);
} // namespace net } // namespace net
#endif // NET_SSL_OPENSSL_SSL_UTIL_H_ #endif // NET_SSL_OPENSSL_SSL_UTIL_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment