Commit 665861e0 authored by Ryan Ki Sing Chung's avatar Ryan Ki Sing Chung Committed by Commit Bot

Allow using SSLPrivateKey as the interface for SSLServerSocket's key.

ECDHE ciphers would be required since decrypt is not available.

Bug: 794762
Change-Id: Ie855f1026460d1f2ffa62dcdec6870c5525cdd86
Reviewed-on: https://chromium-review.googlesource.com/826324
Commit-Queue: Ryan Chung <ryanchung@chromium.org>
Reviewed-by: default avatarDavid Benjamin <davidben@chromium.org>
Cr-Commit-Position: refs/heads/master@{#524480}
parent fc3004ac
......@@ -30,6 +30,7 @@ class RSAPrivateKey;
namespace net {
struct SSLServerConfig;
class SSLPrivateKey;
class X509Certificate;
// A server socket that uses SSL as the transport layer.
......@@ -70,6 +71,11 @@ NET_EXPORT std::unique_ptr<SSLServerContext> CreateSSLServerContext(
const crypto::RSAPrivateKey& key,
const SSLServerConfig& ssl_config);
NET_EXPORT std::unique_ptr<SSLServerContext> CreateSSLServerContext(
X509Certificate* certificate,
scoped_refptr<SSLPrivateKey> key,
const SSLServerConfig& ssl_config);
} // namespace net
#endif // NET_SOCKET_SSL_SERVER_SOCKET_H_
......@@ -7,7 +7,9 @@
#include <utility>
#include "base/callback_helpers.h"
#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/memory/weak_ptr.h"
#include "base/strings/string_util.h"
#include "crypto/openssl_util.h"
#include "crypto/rsa_private_key.h"
......@@ -30,6 +32,34 @@
namespace net {
namespace {
// This constant can be any non-negative/non-zero value (eg: it does not
// overlap with any value of the net::Error range, including net::OK).
const int kNoPendingResult = 1;
class SocketDataIndex {
public:
static SocketDataIndex* GetInstance();
SocketDataIndex() {
ssl_socket_data_index_ = SSL_get_ex_new_index(0, 0, 0, 0, 0);
}
// This is the index used with SSL_get_ex_data to retrieve the owner
// SSLServerSocketImpl object from an SSL instance.
int ssl_socket_data_index_;
};
base::LazyInstance<SocketDataIndex>::Leaky g_ssl_socket_data_index_ =
LAZY_INSTANCE_INITIALIZER;
// static
SocketDataIndex* SocketDataIndex::GetInstance() {
return g_ssl_socket_data_index_.Pointer();
}
} // namespace
class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
public SocketBIOAdapter::Delegate {
public:
......@@ -81,6 +111,36 @@ class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
static ssl_verify_result_t CertVerifyCallback(SSL* ssl, uint8_t* out_alert);
ssl_verify_result_t CertVerifyCallbackImpl(uint8_t* out_alert);
static const SSL_PRIVATE_KEY_METHOD kPrivateKeyMethod;
static ssl_private_key_result_t PrivateKeySignCallback(SSL* ssl,
uint8_t* out,
size_t* out_len,
size_t max_out,
uint16_t algorithm,
const uint8_t* in,
size_t in_len);
static ssl_private_key_result_t PrivateKeyDecryptCallback(SSL* ssl,
uint8_t* out,
size_t* out_len,
size_t max_out,
const uint8_t* in,
size_t in_len);
static ssl_private_key_result_t PrivateKeyCompleteCallback(SSL* ssl,
uint8_t* out,
size_t* out_len,
size_t max_out);
ssl_private_key_result_t PrivateKeySignCallback(uint8_t* out,
size_t* out_len,
size_t max_out,
uint16_t algorithm,
const uint8_t* in,
size_t in_len);
ssl_private_key_result_t PrivateKeyCompleteCallback(uint8_t* out,
size_t* out_len,
size_t max_out);
void OnPrivateKeyComplete(Error error, const std::vector<uint8_t>& signature);
// SocketBIOAdapter::Delegate implementation.
void OnReadReady() override;
void OnWriteReady() override;
......@@ -113,6 +173,10 @@ class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
CompletionCallback user_read_callback_;
CompletionCallback user_write_callback_;
// SSLPrivateKey signature.
int signature_result_;
std::vector<uint8_t> signature_;
// Used by Read function.
scoped_refptr<IOBuffer> user_read_buf_;
int user_read_buf_len_;
......@@ -134,6 +198,8 @@ class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
State next_handshake_state_;
bool completed_handshake_;
base::WeakPtrFactory<SocketImpl> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(SocketImpl);
};
......@@ -141,11 +207,13 @@ SSLServerContextImpl::SocketImpl::SocketImpl(
SSLServerContextImpl* context,
std::unique_ptr<StreamSocket> transport_socket)
: context_(context),
signature_result_(kNoPendingResult),
user_read_buf_len_(0),
user_write_buf_len_(0),
transport_socket_(std::move(transport_socket)),
next_handshake_state_(STATE_NONE),
completed_handshake_(false) {
completed_handshake_(false),
weak_factory_(this) {
ssl_.reset(SSL_new(context_->ssl_ctx_.get()));
SSL_set_app_data(ssl_.get(), this);
}
......@@ -159,6 +227,108 @@ SSLServerContextImpl::SocketImpl::~SocketImpl() {
}
}
// static
const SSL_PRIVATE_KEY_METHOD
SSLServerContextImpl::SocketImpl::kPrivateKeyMethod = {
&SSLServerContextImpl::SocketImpl::PrivateKeySignCallback,
&SSLServerContextImpl::SocketImpl::PrivateKeyDecryptCallback,
&SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback,
};
// static
ssl_private_key_result_t
SSLServerContextImpl::SocketImpl::PrivateKeySignCallback(SSL* ssl,
uint8_t* out,
size_t* out_len,
size_t max_out,
uint16_t algorithm,
const uint8_t* in,
size_t in_len) {
DCHECK(ssl);
SSLServerContextImpl::SocketImpl* socket =
static_cast<SSLServerContextImpl::SocketImpl*>(SSL_get_ex_data(
ssl, SocketDataIndex::GetInstance()->ssl_socket_data_index_));
DCHECK(socket);
return socket->PrivateKeySignCallback(out, out_len, max_out, algorithm, in,
in_len);
}
// static
ssl_private_key_result_t
SSLServerContextImpl::SocketImpl::PrivateKeyDecryptCallback(SSL* ssl,
uint8_t* out,
size_t* out_len,
size_t max_out,
const uint8_t* in,
size_t in_len) {
// Decrypt is not supported.
return ssl_private_key_failure;
}
// static
ssl_private_key_result_t
SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback(SSL* ssl,
uint8_t* out,
size_t* out_len,
size_t max_out) {
DCHECK(ssl);
SSLServerContextImpl::SocketImpl* socket =
static_cast<SSLServerContextImpl::SocketImpl*>(SSL_get_ex_data(
ssl, SocketDataIndex::GetInstance()->ssl_socket_data_index_));
DCHECK(socket);
return socket->PrivateKeyCompleteCallback(out, out_len, max_out);
}
ssl_private_key_result_t
SSLServerContextImpl::SocketImpl::PrivateKeySignCallback(uint8_t* out,
size_t* out_len,
size_t max_out,
uint16_t algorithm,
const uint8_t* in,
size_t in_len) {
DCHECK(context_);
DCHECK(context_->private_key_);
signature_result_ = ERR_IO_PENDING;
context_->private_key_->Sign(
algorithm, base::make_span(in, in_len),
base::BindRepeating(
&SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete,
weak_factory_.GetWeakPtr()));
return ssl_private_key_retry;
}
ssl_private_key_result_t
SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback(uint8_t* out,
size_t* out_len,
size_t max_out) {
if (signature_result_ == ERR_IO_PENDING)
return ssl_private_key_retry;
if (signature_result_ != OK) {
OpenSSLPutNetError(FROM_HERE, signature_result_);
return ssl_private_key_failure;
}
if (signature_.size() > max_out) {
OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED);
return ssl_private_key_failure;
}
memcpy(out, signature_.data(), signature_.size());
*out_len = signature_.size();
signature_.clear();
return ssl_private_key_success;
}
void SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete(
Error error,
const std::vector<uint8_t>& signature) {
DCHECK_EQ(ERR_IO_PENDING, signature_result_);
DCHECK(signature_.empty());
signature_result_ = error;
if (signature_result_ == OK)
signature_ = signature;
DoHandshakeLoop(ERR_IO_PENDING);
}
int SSLServerContextImpl::SocketImpl::Handshake(
const CompletionCallback& callback) {
net_log_.BeginEvent(NetLogEventType::SSL_SERVER_HANDSHAKE);
......@@ -480,7 +650,6 @@ int SSLServerContextImpl::SocketImpl::DoHandshake() {
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
int net_error = OK;
int rv = SSL_do_handshake(ssl_.get());
if (rv == 1) {
completed_handshake_ = true;
STACK_OF(CRYPTO_BUFFER)* certs = SSL_get0_peer_certificates(ssl_.get());
......@@ -491,6 +660,13 @@ int SSLServerContextImpl::SocketImpl::DoHandshake() {
}
} else {
int ssl_error = SSL_get_error(ssl_.get(), rv);
if (ssl_error == SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) {
DCHECK(context_->private_key_);
GotoState(STATE_HANDSHAKE);
return ERR_IO_PENDING;
}
OpenSSLErrorInfo error_info;
net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
......@@ -544,15 +720,30 @@ int SSLServerContextImpl::SocketImpl::Init() {
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
if (!ssl_)
if (!ssl_ ||
!SSL_set_ex_data(ssl_.get(),
SocketDataIndex::GetInstance()->ssl_socket_data_index_,
this))
return ERR_UNEXPECTED;
// Set certificate and private key.
DCHECK(context_->cert_->cert_buffer());
DCHECK(context_->key_->key());
if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(),
context_->key_->key(), nullptr)) {
return ERR_UNEXPECTED;
if (context_->key_) {
DCHECK(context_->cert_->cert_buffer());
DCHECK(context_->key_->key());
if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(),
context_->key_->key(), nullptr)) {
return ERR_UNEXPECTED;
}
} else {
DCHECK(context_->private_key_);
if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(), nullptr,
&kPrivateKeyMethod)) {
return ERR_UNEXPECTED;
}
std::vector<uint16_t> preferences =
context_->private_key_->GetAlgorithmPreferences();
SSL_set_signing_algorithm_prefs(ssl_.get(), preferences.data(),
preferences.size());
}
transport_adapter_.reset(new SocketBIOAdapter(
......@@ -610,8 +801,26 @@ std::unique_ptr<SSLServerContext> CreateSSLServerContext(
X509Certificate* certificate,
const crypto::RSAPrivateKey& key,
const SSLServerConfig& ssl_server_config) {
return std::unique_ptr<SSLServerContext>(
new SSLServerContextImpl(certificate, key, ssl_server_config));
return std::make_unique<SSLServerContextImpl>(certificate, key,
ssl_server_config);
}
std::unique_ptr<SSLServerContext> CreateSSLServerContext(
X509Certificate* certificate,
scoped_refptr<SSLPrivateKey> key,
const SSLServerConfig& ssl_config) {
return std::make_unique<SSLServerContextImpl>(certificate, key, ssl_config);
}
SSLServerContextImpl::SSLServerContextImpl(
X509Certificate* certificate,
scoped_refptr<net::SSLPrivateKey> key,
const SSLServerConfig& ssl_server_config)
: ssl_server_config_(ssl_server_config),
cert_(certificate),
private_key_(key) {
CHECK(private_key_);
Init();
}
SSLServerContextImpl::SSLServerContextImpl(
......@@ -622,6 +831,10 @@ SSLServerContextImpl::SSLServerContextImpl(
cert_(certificate),
key_(key.Copy()) {
CHECK(key_);
Init();
}
void SSLServerContextImpl::Init() {
crypto::EnsureOpenSSLInit();
ssl_ctx_.reset(SSL_CTX_new(TLS_with_buffers_method()));
SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_SERVER);
......@@ -674,7 +887,8 @@ SSLServerContextImpl::SSLServerContextImpl(
// as the handshake hash.
std::string command("DEFAULT:!SHA256:!SHA384:!AESGCM+AES256:!aPSK");
if (ssl_server_config_.require_ecdhe)
// SSLPrivateKey only supports ECDHE-based ciphers because it lacks decrypt.
if (ssl_server_config_.require_ecdhe || (!key_ && private_key_))
command.append(":!kRSA");
// Remove any disabled ciphers.
......@@ -704,8 +918,7 @@ SSLServerContextImpl::~SSLServerContextImpl() = default;
std::unique_ptr<SSLServerSocket> SSLServerContextImpl::CreateSSLServerSocket(
std::unique_ptr<StreamSocket> socket) {
return std::unique_ptr<SSLServerSocket>(
new SocketImpl(this, std::move(socket)));
return std::make_unique<SocketImpl>(this, std::move(socket));
}
} // namespace net
......@@ -23,6 +23,9 @@ class SSLServerContextImpl : public SSLServerContext {
SSLServerContextImpl(X509Certificate* certificate,
const crypto::RSAPrivateKey& key,
const SSLServerConfig& ssl_server_config);
SSLServerContextImpl(X509Certificate* certificate,
scoped_refptr<SSLPrivateKey> key,
const SSLServerConfig& ssl_server_config);
~SSLServerContextImpl() override;
std::unique_ptr<SSLServerSocket> CreateSSLServerSocket(
......@@ -31,6 +34,8 @@ class SSLServerContextImpl : public SSLServerContext {
private:
class SocketImpl;
void Init();
bssl::UniquePtr<SSL_CTX> ssl_ctx_;
// Options for the SSL socket.
......@@ -40,7 +45,9 @@ class SSLServerContextImpl : public SSLServerContext {
scoped_refptr<X509Certificate> cert_;
// Private key used by the server.
// Only one representation should be set at any time.
std::unique_ptr<crypto::RSAPrivateKey> key_;
const scoped_refptr<SSLPrivateKey> private_key_;
};
} // namespace net
......
......@@ -373,6 +373,13 @@ class SSLServerSocketTest : public PlatformTest {
server_private_key_ = ReadTestKey("unittest.key.bin");
ASSERT_TRUE(server_private_key_);
std::unique_ptr<crypto::RSAPrivateKey> key =
ReadTestKey("unittest.key.bin");
ASSERT_TRUE(key);
EVP_PKEY_up_ref(key->key());
server_ssl_private_key_ =
WrapOpenSSLPrivateKey(bssl::UniquePtr<EVP_PKEY>(key->key()));
client_ssl_config_.false_start_enabled = false;
client_ssl_config_.channel_id_enabled = false;
......@@ -392,6 +399,16 @@ class SSLServerSocketTest : public PlatformTest {
server_cert_.get(), *server_private_key_, server_ssl_config_);
}
void CreateContextSSLPrivateKey() {
client_socket_.reset();
server_socket_.reset();
channel_1_.reset();
channel_2_.reset();
server_context_.reset();
server_context_ = CreateSSLServerContext(
server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
}
void CreateSockets() {
client_socket_.reset();
server_socket_.reset();
......@@ -489,6 +506,7 @@ class SSLServerSocketTest : public PlatformTest {
std::unique_ptr<MockCTPolicyEnforcer> ct_policy_enforcer_;
std::unique_ptr<SSLServerContext> server_context_;
std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
scoped_refptr<X509Certificate> server_cert_;
};
......@@ -1098,4 +1116,80 @@ TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
}
// This test executes Connect() on SSLClientSocket and Handshake() on
// SSLServerSocket to make sure handshaking between the two sockets is
// completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
ASSERT_NO_FATAL_FAILURE(CreateSockets());
TestCompletionCallback handshake_callback;
int server_ret = server_socket_->Handshake(handshake_callback.callback());
TestCompletionCallback connect_callback;
int client_ret = client_socket_->Connect(connect_callback.callback());
client_ret = connect_callback.GetResult(client_ret);
server_ret = handshake_callback.GetResult(server_ret);
ASSERT_THAT(client_ret, IsOk());
ASSERT_THAT(server_ret, IsOk());
// Make sure the cert status is expected.
SSLInfo ssl_info;
ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
// The default cipher suite should be ECDHE and an AEAD.
uint16_t cipher_suite =
SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
const char* key_exchange;
const char* cipher;
const char* mac;
bool is_aead;
bool is_tls13;
SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
cipher_suite);
EXPECT_TRUE(is_aead);
ASSERT_FALSE(is_tls13);
EXPECT_STREQ("ECDHE_RSA", key_exchange);
}
// Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
// server key.
TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
// Disable all ECDHE suites on the client side.
uint16_t kEcdheCiphers[] = {
0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA
0xc009, // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
0xc00a, // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
0xc011, // ECDHE_RSA_WITH_RC4_128_SHA
0xc013, // ECDHE_RSA_WITH_AES_128_CBC_SHA
0xc014, // ECDHE_RSA_WITH_AES_256_CBC_SHA
0xc02b, // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
0xc02f, // ECDHE_RSA_WITH_AES_128_GCM_SHA256
0xcca8, // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
0xcca9, // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
};
client_ssl_config_.disabled_cipher_suites.assign(
kEcdheCiphers, kEcdheCiphers + arraysize(kEcdheCiphers));
// TLS 1.3 always works with SSLPrivateKey.
client_ssl_config_.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
ASSERT_NO_FATAL_FAILURE(CreateSockets());
TestCompletionCallback connect_callback;
int client_ret = client_socket_->Connect(connect_callback.callback());
TestCompletionCallback handshake_callback;
int server_ret = server_socket_->Handshake(handshake_callback.callback());
client_ret = connect_callback.GetResult(client_ret);
server_ret = handshake_callback.GetResult(server_ret);
ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
}
} // namespace net
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment