Commit cffd7f90 authored by wtc@chromium.org's avatar wtc@chromium.org

Change the lifespan of SSlConnectJobMessengers so that they are created

only when needed, and deleted as soon as they are no longer necessary.

Add methods to SSLClientSocketPool that are passed to the SSLConnectJob
and SSLConnectJobMessenger as callbacks. These allow the SSLConnectJob
to tell the SSLClientSocketPool to create a messenger for the job when
appropriate, and the SSLConnectJobMessenger to tell the
SSLCLientSocketPool to remove a messenger when appropriate. An
SSLConnectJob will now only create an SSLConnectJobMessenger if its
socket's session is not already in the session cache. The messenger
will then ask to be removed when there are no remaining pending or
connecting sockets in the messenger.

Written by Mackenzie Shelley <mshelley@chromium.org>
Original review URL: https://codereview.chromium.org/384873002/

R=rsleevi@chromium.org
TBR=mek@chromium.org
BUG=398967

Review URL: https://codereview.chromium.org/476313004

Cr-Commit-Position: refs/heads/master@{#291192}
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@291192 0039d316-1c4b-4281-b951-d872f2087c98
parent be21aaa4
...@@ -60,6 +60,7 @@ class MockSSLClientSocket : public net::SSLClientSocket { ...@@ -60,6 +60,7 @@ class MockSSLClientSocket : public net::SSLClientSocket {
unsigned char*, unsigned char*,
unsigned int)); unsigned int));
MOCK_METHOD1(GetTLSUniqueChannelBinding, int(std::string*)); MOCK_METHOD1(GetTLSUniqueChannelBinding, int(std::string*));
MOCK_CONST_METHOD0(GetSessionCacheKey, std::string());
MOCK_CONST_METHOD0(InSessionCache, bool()); MOCK_CONST_METHOD0(InSessionCache, bool());
MOCK_METHOD1(SetHandshakeCompletionCallback, void(const base::Closure&)); MOCK_METHOD1(SetHandshakeCompletionCallback, void(const base::Closure&));
MOCK_METHOD1(GetSSLCertRequestInfo, void(net::SSLCertRequestInfo*)); MOCK_METHOD1(GetSSLCertRequestInfo, void(net::SSLCertRequestInfo*));
......
...@@ -764,6 +764,11 @@ const BoundNetLog& MockClientSocket::NetLog() const { ...@@ -764,6 +764,11 @@ const BoundNetLog& MockClientSocket::NetLog() const {
return net_log_; return net_log_;
} }
std::string MockClientSocket::GetSessionCacheKey() const {
NOTIMPLEMENTED();
return std::string();
}
bool MockClientSocket::InSessionCache() const { bool MockClientSocket::InSessionCache() const {
NOTIMPLEMENTED(); NOTIMPLEMENTED();
return false; return false;
...@@ -1322,6 +1327,7 @@ MockSSLClientSocket::MockSSLClientSocket( ...@@ -1322,6 +1327,7 @@ MockSSLClientSocket::MockSSLClientSocket(
// tests. // tests.
transport_socket->socket()->NetLog()), transport_socket->socket()->NetLog()),
transport_(transport_socket.Pass()), transport_(transport_socket.Pass()),
host_port_pair_(host_port_pair),
data_(data), data_(data),
is_npn_state_set_(false), is_npn_state_set_(false),
new_npn_value_(false), new_npn_value_(false),
...@@ -1389,6 +1395,12 @@ bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { ...@@ -1389,6 +1395,12 @@ bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
return true; return true;
} }
std::string MockSSLClientSocket::GetSessionCacheKey() const {
// For the purposes of these tests, |host_and_port| will serve as the
// cache key.
return host_port_pair_.ToString();
}
bool MockSSLClientSocket::InSessionCache() const { bool MockSSLClientSocket::InSessionCache() const {
return data_->is_in_session_cache; return data_->is_in_session_cache;
} }
......
...@@ -703,6 +703,7 @@ class MockClientSocket : public SSLClientSocket { ...@@ -703,6 +703,7 @@ class MockClientSocket : public SSLClientSocket {
virtual void SetOmniboxSpeculation() OVERRIDE {} virtual void SetOmniboxSpeculation() OVERRIDE {}
// SSLClientSocket implementation. // SSLClientSocket implementation.
virtual std::string GetSessionCacheKey() const OVERRIDE;
virtual bool InSessionCache() const OVERRIDE; virtual bool InSessionCache() const OVERRIDE;
virtual void SetHandshakeCompletionCallback(const base::Closure& cb) OVERRIDE; virtual void SetHandshakeCompletionCallback(const base::Closure& cb) OVERRIDE;
virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info)
...@@ -964,6 +965,7 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { ...@@ -964,6 +965,7 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
// SSLClientSocket implementation. // SSLClientSocket implementation.
virtual std::string GetSessionCacheKey() const OVERRIDE;
virtual bool InSessionCache() const OVERRIDE; virtual bool InSessionCache() const OVERRIDE;
virtual void SetHandshakeCompletionCallback(const base::Closure& cb) OVERRIDE; virtual void SetHandshakeCompletionCallback(const base::Closure& cb) OVERRIDE;
virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info)
...@@ -1003,6 +1005,7 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { ...@@ -1003,6 +1005,7 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
int DoSSLConnectComplete(int result); int DoSSLConnectComplete(int result);
scoped_ptr<ClientSocketHandle> transport_; scoped_ptr<ClientSocketHandle> transport_;
HostPortPair host_port_pair_;
SSLSocketDataProvider* data_; SSLSocketDataProvider* data_;
bool is_npn_state_set_; bool is_npn_state_set_;
bool new_npn_value_; bool new_npn_value_;
......
...@@ -89,16 +89,6 @@ NextProto SSLClientSocket::GetNegotiatedProtocol() const { ...@@ -89,16 +89,6 @@ NextProto SSLClientSocket::GetNegotiatedProtocol() const {
return protocol_negotiated_; return protocol_negotiated_;
} }
// static
std::string SSLClientSocket::CreateSessionCacheKey(
const HostPortPair& host_and_port,
const std::string& ssl_session_cache_shard) {
std::string result = host_and_port.ToString();
result.append("/");
result.append(ssl_session_cache_shard);
return result;
}
bool SSLClientSocket::IgnoreCertError(int error, int load_flags) { bool SSLClientSocket::IgnoreCertError(int error, int load_flags) {
if (error == OK || load_flags & LOAD_IGNORE_ALL_CERT_ERRORS) if (error == OK || load_flags & LOAD_IGNORE_ALL_CERT_ERRORS)
return true; return true;
......
...@@ -83,15 +83,8 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { ...@@ -83,15 +83,8 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
virtual bool WasNpnNegotiated() const OVERRIDE; virtual bool WasNpnNegotiated() const OVERRIDE;
virtual NextProto GetNegotiatedProtocol() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
// Formats a unique key for the SSL session cache. This method // Computes a unique key string for the SSL session cache.
// is necessary so that all classes create cache keys in a consistent virtual std::string GetSessionCacheKey() const = 0;
// manner.
// TODO(mshelley) This method will be deleted in an upcoming CL when
// it will no longer be necessary to generate a cache key outside of
// an SSLClientSocket.
static std::string CreateSessionCacheKey(
const HostPortPair& host_and_port,
const std::string& ssl_session_cache_shard);
// Returns true if there is a cache entry in the SSL session cache // Returns true if there is a cache entry in the SSL session cache
// for the cache key of the SSL socket. // for the cache key of the SSL socket.
......
...@@ -2834,6 +2834,11 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { ...@@ -2834,6 +2834,11 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) {
return true; return true;
} }
std::string SSLClientSocketNSS::GetSessionCacheKey() const {
NOTIMPLEMENTED();
return std::string();
}
bool SSLClientSocketNSS::InSessionCache() const { bool SSLClientSocketNSS::InSessionCache() const {
// For now, always return true so that SSLConnectJobs are never held back. // For now, always return true so that SSLConnectJobs are never held back.
return true; return true;
......
...@@ -68,6 +68,7 @@ class SSLClientSocketNSS : public SSLClientSocket { ...@@ -68,6 +68,7 @@ class SSLClientSocketNSS : public SSLClientSocket {
virtual ~SSLClientSocketNSS(); virtual ~SSLClientSocketNSS();
// SSLClientSocket implementation. // SSLClientSocket implementation.
virtual std::string GetSessionCacheKey() const OVERRIDE;
virtual bool InSessionCache() const OVERRIDE; virtual bool InSessionCache() const OVERRIDE;
virtual void SetHandshakeCompletionCallback( virtual void SetHandshakeCompletionCallback(
const base::Closure& callback) OVERRIDE; const base::Closure& callback) OVERRIDE;
......
...@@ -369,6 +369,13 @@ SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { ...@@ -369,6 +369,13 @@ SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() {
Disconnect(); Disconnect();
} }
std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const {
std::string result = host_and_port_.ToString();
result.append("/");
result.append(ssl_session_cache_shard_);
return result;
}
bool SSLClientSocketOpenSSL::InSessionCache() const { bool SSLClientSocketOpenSSL::InSessionCache() const {
SSLContext* context = SSLContext::GetInstance(); SSLContext* context = SSLContext::GetInstance();
std::string cache_key = GetSessionCacheKey(); std::string cache_key = GetSessionCacheKey();
...@@ -840,10 +847,6 @@ void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { ...@@ -840,10 +847,6 @@ void SSLClientSocketOpenSSL::DoWriteCallback(int rv) {
base::ResetAndReturn(&user_write_callback_).Run(rv); base::ResetAndReturn(&user_write_callback_).Run(rv);
} }
std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const {
return CreateSessionCacheKey(host_and_port_, ssl_session_cache_shard_);
}
void SSLClientSocketOpenSSL::OnHandshakeCompletion() { void SSLClientSocketOpenSSL::OnHandshakeCompletion() {
if (!handshake_completion_callback_.is_null()) if (!handshake_completion_callback_.is_null())
base::ResetAndReturn(&handshake_completion_callback_).Run(); base::ResetAndReturn(&handshake_completion_callback_).Run();
......
...@@ -57,6 +57,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { ...@@ -57,6 +57,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
} }
// SSLClientSocket implementation. // SSLClientSocket implementation.
virtual std::string GetSessionCacheKey() const OVERRIDE;
virtual bool InSessionCache() const OVERRIDE; virtual bool InSessionCache() const OVERRIDE;
virtual void SetHandshakeCompletionCallback( virtual void SetHandshakeCompletionCallback(
const base::Closure& callback) OVERRIDE; const base::Closure& callback) OVERRIDE;
...@@ -110,8 +111,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { ...@@ -110,8 +111,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
void DoReadCallback(int result); void DoReadCallback(int result);
void DoWriteCallback(int result); void DoWriteCallback(int result);
// Compute a unique key string for the SSL session cache.
std::string GetSessionCacheKey() const;
void OnHandshakeCompletion(); void OnHandshakeCompletion();
bool DoTransportIO(); bool DoTransportIO();
......
...@@ -104,7 +104,10 @@ SSLConnectJobMessenger::SocketAndCallback::SocketAndCallback( ...@@ -104,7 +104,10 @@ SSLConnectJobMessenger::SocketAndCallback::SocketAndCallback(
SSLConnectJobMessenger::SocketAndCallback::~SocketAndCallback() { SSLConnectJobMessenger::SocketAndCallback::~SocketAndCallback() {
} }
SSLConnectJobMessenger::SSLConnectJobMessenger() : weak_factory_(this) { SSLConnectJobMessenger::SSLConnectJobMessenger(
const base::Closure& messenger_finished_callback)
: messenger_finished_callback_(messenger_finished_callback),
weak_factory_(this) {
} }
SSLConnectJobMessenger::~SSLConnectJobMessenger() { SSLConnectJobMessenger::~SSLConnectJobMessenger() {
...@@ -125,9 +128,8 @@ void SSLConnectJobMessenger::RemovePendingSocket(SSLClientSocket* ssl_socket) { ...@@ -125,9 +128,8 @@ void SSLConnectJobMessenger::RemovePendingSocket(SSLClientSocket* ssl_socket) {
} }
bool SSLConnectJobMessenger::CanProceed(SSLClientSocket* ssl_socket) { bool SSLConnectJobMessenger::CanProceed(SSLClientSocket* ssl_socket) {
// If the session is in the session cache, or there are no connecting // If there are no connecting sockets, allow the connection to proceed.
// sockets, allow the connection to proceed. return connecting_sockets_.empty();
return ssl_socket->InSessionCache() || connecting_sockets_.empty();
} }
void SSLConnectJobMessenger::MonitorConnectionResult( void SSLConnectJobMessenger::MonitorConnectionResult(
...@@ -149,6 +151,8 @@ void SSLConnectJobMessenger::OnSSLHandshakeCompleted() { ...@@ -149,6 +151,8 @@ void SSLConnectJobMessenger::OnSSLHandshakeCompleted() {
connecting_sockets_.clear(); connecting_sockets_.clear();
SSLPendingSocketsAndCallbacks temp_list; SSLPendingSocketsAndCallbacks temp_list;
temp_list.swap(pending_sockets_and_callbacks_); temp_list.swap(pending_sockets_and_callbacks_);
base::Closure messenger_finished_callback = messenger_finished_callback_;
messenger_finished_callback.Run();
RunAllCallbacks(temp_list); RunAllCallbacks(temp_list);
} }
...@@ -175,7 +179,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, ...@@ -175,7 +179,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name,
ClientSocketFactory* client_socket_factory, ClientSocketFactory* client_socket_factory,
HostResolver* host_resolver, HostResolver* host_resolver,
const SSLClientSocketContext& context, const SSLClientSocketContext& context,
SSLConnectJobMessenger* messenger, const GetMessengerCallback& get_messenger_callback,
Delegate* delegate, Delegate* delegate,
NetLog* net_log) NetLog* net_log)
: ConnectJob(group_name, : ConnectJob(group_name,
...@@ -198,7 +202,8 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, ...@@ -198,7 +202,8 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name,
: context.ssl_session_cache_shard)), : context.ssl_session_cache_shard)),
io_callback_( io_callback_(
base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))), base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))),
messenger_(messenger), messenger_(NULL),
get_messenger_callback_(get_messenger_callback),
weak_factory_(this) { weak_factory_(this) {
} }
...@@ -402,23 +407,31 @@ int SSLConnectJob::DoCreateSSLSocket() { ...@@ -402,23 +407,31 @@ int SSLConnectJob::DoCreateSSLSocket() {
params_->host_and_port(), params_->host_and_port(),
params_->ssl_config(), params_->ssl_config(),
context_); context_);
if (!ssl_socket_->InSessionCache())
messenger_ = get_messenger_callback_.Run(ssl_socket_->GetSessionCacheKey());
return OK; return OK;
} }
int SSLConnectJob::DoCheckForResume() { int SSLConnectJob::DoCheckForResume() {
next_state_ = STATE_SSL_CONNECT; next_state_ = STATE_SSL_CONNECT;
if (!messenger_) if (!messenger_)
return OK; return OK;
// TODO(mshelley): Remove duplicate InSessionCache() calls.
if (messenger_->CanProceed(ssl_socket_.get())) { if (messenger_->CanProceed(ssl_socket_.get())) {
if (!ssl_socket_->InSessionCache()) messenger_->MonitorConnectionResult(ssl_socket_.get());
messenger_->MonitorConnectionResult(ssl_socket_.get()); // The SSLConnectJob no longer needs access to the messenger after this
// point.
messenger_ = NULL;
return OK; return OK;
} }
messenger_->AddPendingSocket(ssl_socket_.get(), messenger_->AddPendingSocket(ssl_socket_.get(),
base::Bind(&SSLConnectJob::ResumeSSLConnection, base::Bind(&SSLConnectJob::ResumeSSLConnection,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
return ERR_IO_PENDING; return ERR_IO_PENDING;
} }
...@@ -556,6 +569,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { ...@@ -556,6 +569,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
void SSLConnectJob::ResumeSSLConnection() { void SSLConnectJob::ResumeSSLConnection() {
DCHECK_EQ(next_state_, STATE_SSL_CONNECT); DCHECK_EQ(next_state_, STATE_SSL_CONNECT);
messenger_ = NULL;
OnIOComplete(OK); OnIOComplete(OK);
} }
...@@ -585,7 +599,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( ...@@ -585,7 +599,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory(
ClientSocketFactory* client_socket_factory, ClientSocketFactory* client_socket_factory,
HostResolver* host_resolver, HostResolver* host_resolver,
const SSLClientSocketContext& context, const SSLClientSocketContext& context,
bool enable_ssl_connect_job_waiting, const SSLConnectJob::GetMessengerCallback& get_messenger_callback,
NetLog* net_log) NetLog* net_log)
: transport_pool_(transport_pool), : transport_pool_(transport_pool),
socks_pool_(socks_pool), socks_pool_(socks_pool),
...@@ -593,9 +607,8 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( ...@@ -593,9 +607,8 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory(
client_socket_factory_(client_socket_factory), client_socket_factory_(client_socket_factory),
host_resolver_(host_resolver), host_resolver_(host_resolver),
context_(context), context_(context),
enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting), get_messenger_callback_(get_messenger_callback),
net_log_(net_log), net_log_(net_log) {
messenger_map_(new MessengerMap) {
base::TimeDelta max_transport_timeout = base::TimeDelta(); base::TimeDelta max_transport_timeout = base::TimeDelta();
base::TimeDelta pool_timeout; base::TimeDelta pool_timeout;
if (transport_pool_) if (transport_pool_)
...@@ -615,7 +628,6 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( ...@@ -615,7 +628,6 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory(
} }
SSLClientSocketPool::SSLConnectJobFactory::~SSLConnectJobFactory() { SSLClientSocketPool::SSLConnectJobFactory::~SSLConnectJobFactory() {
STLDeleteValues(messenger_map_.get());
} }
SSLClientSocketPool::SSLClientSocketPool( SSLClientSocketPool::SSLClientSocketPool(
...@@ -655,9 +667,12 @@ SSLClientSocketPool::SSLClientSocketPool( ...@@ -655,9 +667,12 @@ SSLClientSocketPool::SSLClientSocketPool(
transport_security_state, transport_security_state,
cert_transparency_verifier, cert_transparency_verifier,
ssl_session_cache_shard), ssl_session_cache_shard),
enable_ssl_connect_job_waiting, base::Bind(
&SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger,
base::Unretained(this)),
net_log)), net_log)),
ssl_config_service_(ssl_config_service) { ssl_config_service_(ssl_config_service),
enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting) {
if (ssl_config_service_.get()) if (ssl_config_service_.get())
ssl_config_service_->AddObserver(this); ssl_config_service_->AddObserver(this);
if (transport_pool_) if (transport_pool_)
...@@ -669,28 +684,16 @@ SSLClientSocketPool::SSLClientSocketPool( ...@@ -669,28 +684,16 @@ SSLClientSocketPool::SSLClientSocketPool(
} }
SSLClientSocketPool::~SSLClientSocketPool() { SSLClientSocketPool::~SSLClientSocketPool() {
STLDeleteContainerPairSecondPointers(messenger_map_.begin(),
messenger_map_.end());
if (ssl_config_service_.get()) if (ssl_config_service_.get())
ssl_config_service_->RemoveObserver(this); ssl_config_service_->RemoveObserver(this);
} }
scoped_ptr<ConnectJob> scoped_ptr<ConnectJob> SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
const std::string& group_name, const std::string& group_name,
const PoolBase::Request& request, const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const { ConnectJob::Delegate* delegate) const {
SSLConnectJobMessenger* messenger = NULL;
if (enable_ssl_connect_job_waiting_) {
std::string cache_key = SSLClientSocket::CreateSessionCacheKey(
request.params()->host_and_port(), context_.ssl_session_cache_shard);
MessengerMap::const_iterator it = messenger_map_->find(cache_key);
if (it == messenger_map_->end()) {
std::pair<MessengerMap::iterator, bool> iter = messenger_map_->insert(
MessengerMap::value_type(cache_key, new SSLConnectJobMessenger()));
it = iter.first;
}
messenger = it->second;
}
return scoped_ptr<ConnectJob>(new SSLConnectJob(group_name, return scoped_ptr<ConnectJob>(new SSLConnectJob(group_name,
request.priority(), request.priority(),
request.params(), request.params(),
...@@ -701,13 +704,13 @@ SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( ...@@ -701,13 +704,13 @@ SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
client_socket_factory_, client_socket_factory_,
host_resolver_, host_resolver_,
context_, context_,
messenger, get_messenger_callback_,
delegate, delegate,
net_log_)); net_log_));
} }
base::TimeDelta base::TimeDelta SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout()
SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() const { const {
return timeout_; return timeout_;
} }
...@@ -822,6 +825,32 @@ bool SSLClientSocketPool::CloseOneIdleConnection() { ...@@ -822,6 +825,32 @@ bool SSLClientSocketPool::CloseOneIdleConnection() {
return base_.CloseOneIdleConnectionInHigherLayeredPool(); return base_.CloseOneIdleConnectionInHigherLayeredPool();
} }
SSLConnectJobMessenger* SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger(
const std::string& cache_key) {
if (!enable_ssl_connect_job_waiting_)
return NULL;
MessengerMap::const_iterator it = messenger_map_.find(cache_key);
if (it == messenger_map_.end()) {
std::pair<MessengerMap::iterator, bool> iter =
messenger_map_.insert(MessengerMap::value_type(
cache_key,
new SSLConnectJobMessenger(
base::Bind(&SSLClientSocketPool::DeleteSSLConnectJobMessenger,
base::Unretained(this),
cache_key))));
it = iter.first;
}
return it->second;
}
void SSLClientSocketPool::DeleteSSLConnectJobMessenger(
const std::string& cache_key) {
MessengerMap::iterator it = messenger_map_.find(cache_key);
CHECK(it != messenger_map_.end());
delete it->second;
messenger_map_.erase(it);
}
void SSLClientSocketPool::OnSSLConfigChanged() { void SSLClientSocketPool::OnSSLConfigChanged() {
FlushWithError(ERR_NETWORK_CHANGED); FlushWithError(ERR_NETWORK_CHANGED);
} }
......
...@@ -115,7 +115,11 @@ class SSLConnectJobMessenger { ...@@ -115,7 +115,11 @@ class SSLConnectJobMessenger {
typedef std::vector<SocketAndCallback> SSLPendingSocketsAndCallbacks; typedef std::vector<SocketAndCallback> SSLPendingSocketsAndCallbacks;
SSLConnectJobMessenger(); // |messenger_finished_callback| is run when a connection monitored by the
// SSLConnectJobMessenger has completed and we are finished with the
// SSLConnectJobMessenger.
explicit SSLConnectJobMessenger(
const base::Closure& messenger_finished_callback);
~SSLConnectJobMessenger(); ~SSLConnectJobMessenger();
// Removes |socket| from the set of sockets being monitored. This // Removes |socket| from the set of sockets being monitored. This
...@@ -151,18 +155,30 @@ class SSLConnectJobMessenger { ...@@ -151,18 +155,30 @@ class SSLConnectJobMessenger {
void RunAllCallbacks( void RunAllCallbacks(
const SSLPendingSocketsAndCallbacks& pending_socket_and_callbacks); const SSLPendingSocketsAndCallbacks& pending_socket_and_callbacks);
base::WeakPtrFactory<SSLConnectJobMessenger> weak_factory_;
SSLPendingSocketsAndCallbacks pending_sockets_and_callbacks_; SSLPendingSocketsAndCallbacks pending_sockets_and_callbacks_;
// Note: this field is a vector to allow for future design changes. Currently, // Note: this field is a vector to allow for future design changes. Currently,
// this vector should only ever have one entry. // this vector should only ever have one entry.
std::vector<SSLClientSocket*> connecting_sockets_; std::vector<SSLClientSocket*> connecting_sockets_;
base::Closure messenger_finished_callback_;
base::WeakPtrFactory<SSLConnectJobMessenger> weak_factory_;
}; };
// SSLConnectJob handles the SSL handshake after setting up the underlying // SSLConnectJob handles the SSL handshake after setting up the underlying
// connection as specified in the params. // connection as specified in the params.
class SSLConnectJob : public ConnectJob { class SSLConnectJob : public ConnectJob {
public: public:
// Callback to allow the SSLConnectJob to obtain an SSLConnectJobMessenger to
// coordinate connecting. The SSLConnectJob will supply a unique identifer
// (ex: the SSL session cache key), with the expectation that the same
// Messenger will be returned for all such ConnectJobs.
//
// Note: It will only be called for situations where the SSL session cache
// does not already have a candidate session to resume.
typedef base::Callback<SSLConnectJobMessenger*(const std::string&)>
GetMessengerCallback;
// Note: the SSLConnectJob does not own |messenger| so it must outlive the // Note: the SSLConnectJob does not own |messenger| so it must outlive the
// job. // job.
SSLConnectJob(const std::string& group_name, SSLConnectJob(const std::string& group_name,
...@@ -175,7 +191,7 @@ class SSLConnectJob : public ConnectJob { ...@@ -175,7 +191,7 @@ class SSLConnectJob : public ConnectJob {
ClientSocketFactory* client_socket_factory, ClientSocketFactory* client_socket_factory,
HostResolver* host_resolver, HostResolver* host_resolver,
const SSLClientSocketContext& context, const SSLClientSocketContext& context,
SSLConnectJobMessenger* messenger, const GetMessengerCallback& get_messenger_callback,
Delegate* delegate, Delegate* delegate,
NetLog* net_log); NetLog* net_log);
virtual ~SSLConnectJob(); virtual ~SSLConnectJob();
...@@ -245,6 +261,8 @@ class SSLConnectJob : public ConnectJob { ...@@ -245,6 +261,8 @@ class SSLConnectJob : public ConnectJob {
SSLConnectJobMessenger* messenger_; SSLConnectJobMessenger* messenger_;
HttpResponseInfo error_response_info_; HttpResponseInfo error_response_info_;
GetMessengerCallback get_messenger_callback_;
base::WeakPtrFactory<SSLConnectJob> weak_factory_; base::WeakPtrFactory<SSLConnectJob> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(SSLConnectJob); DISALLOW_COPY_AND_ASSIGN(SSLConnectJob);
...@@ -330,8 +348,16 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ...@@ -330,8 +348,16 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
// HigherLayeredPool implementation. // HigherLayeredPool implementation.
virtual bool CloseOneIdleConnection() OVERRIDE; virtual bool CloseOneIdleConnection() OVERRIDE;
// Gets the SSLConnectJobMessenger for the given ssl session |cache_key|. If
// none exits, it creates one and stores it in |messenger_map_|.
SSLConnectJobMessenger* GetOrCreateSSLConnectJobMessenger(
const std::string& cache_key);
void DeleteSSLConnectJobMessenger(const std::string& cache_key);
private: private:
typedef ClientSocketPoolBase<SSLSocketParams> PoolBase; typedef ClientSocketPoolBase<SSLSocketParams> PoolBase;
// Maps SSLConnectJob cache keys to SSLConnectJobMessenger objects.
typedef std::map<std::string, SSLConnectJobMessenger*> MessengerMap;
// SSLConfigService::Observer implementation. // SSLConfigService::Observer implementation.
...@@ -341,14 +367,15 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ...@@ -341,14 +367,15 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
class SSLConnectJobFactory : public PoolBase::ConnectJobFactory { class SSLConnectJobFactory : public PoolBase::ConnectJobFactory {
public: public:
SSLConnectJobFactory(TransportClientSocketPool* transport_pool, SSLConnectJobFactory(
SOCKSClientSocketPool* socks_pool, TransportClientSocketPool* transport_pool,
HttpProxyClientSocketPool* http_proxy_pool, SOCKSClientSocketPool* socks_pool,
ClientSocketFactory* client_socket_factory, HttpProxyClientSocketPool* http_proxy_pool,
HostResolver* host_resolver, ClientSocketFactory* client_socket_factory,
const SSLClientSocketContext& context, HostResolver* host_resolver,
bool enable_ssl_connect_job_waiting, const SSLClientSocketContext& context,
NetLog* net_log); const SSLConnectJob::GetMessengerCallback& get_messenger_callback,
NetLog* net_log);
virtual ~SSLConnectJobFactory(); virtual ~SSLConnectJobFactory();
...@@ -361,9 +388,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ...@@ -361,9 +388,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
private: private:
// Maps SSLConnectJob cache keys to SSLConnectJobMessenger objects.
typedef std::map<std::string, SSLConnectJobMessenger*> MessengerMap;
TransportClientSocketPool* const transport_pool_; TransportClientSocketPool* const transport_pool_;
SOCKSClientSocketPool* const socks_pool_; SOCKSClientSocketPool* const socks_pool_;
HttpProxyClientSocketPool* const http_proxy_pool_; HttpProxyClientSocketPool* const http_proxy_pool_;
...@@ -371,13 +395,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ...@@ -371,13 +395,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
HostResolver* const host_resolver_; HostResolver* const host_resolver_;
const SSLClientSocketContext context_; const SSLClientSocketContext context_;
base::TimeDelta timeout_; base::TimeDelta timeout_;
bool enable_ssl_connect_job_waiting_; SSLConnectJob::GetMessengerCallback get_messenger_callback_;
NetLog* net_log_; NetLog* net_log_;
// |messenger_map_| is currently a pointer so that an element can be
// added to it inside of the const method NewConnectJob. In the future,
// elements will be added in a different method.
// TODO(mshelley) Change this to a non-pointer.
scoped_ptr<MessengerMap> messenger_map_;
DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory); DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory);
}; };
...@@ -387,6 +406,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ...@@ -387,6 +406,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
HttpProxyClientSocketPool* const http_proxy_pool_; HttpProxyClientSocketPool* const http_proxy_pool_;
PoolBase base_; PoolBase base_;
const scoped_refptr<SSLConfigService> ssl_config_service_; const scoped_refptr<SSLConfigService> ssl_config_service_;
MessengerMap messenger_map_;
bool enable_ssl_connect_job_waiting_;
DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool);
}; };
......
...@@ -392,7 +392,7 @@ int TransportConnectJob::ConnectInternal() { ...@@ -392,7 +392,7 @@ int TransportConnectJob::ConnectInternal() {
} }
scoped_ptr<ConnectJob> scoped_ptr<ConnectJob>
TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob(
const std::string& group_name, const std::string& group_name,
const PoolBase::Request& request, const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const { ConnectJob::Delegate* delegate) const {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment