Commit c80a1994 authored by sergeyu@chromium.org's avatar sergeyu@chromium.org

Refactor ChannelAuthenticator so that it can be used with Authenticator.

BUG=None
TEST=None


Review URL: http://codereview.chromium.org/8527018

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@110051 0039d316-1c4b-4281-b951-d872f2087c98
parent b6cb7702
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
#include "crypto/hmac.h" #include "crypto/hmac.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/socket/ssl_client_socket.h" #include "net/socket/ssl_socket.h"
#include "net/socket/ssl_server_socket.h"
#include "net/socket/stream_socket.h" #include "net/socket/stream_socket.h"
namespace remoting { namespace remoting {
...@@ -46,8 +45,10 @@ bool GetAuthBytes(const std::string& shared_secret, ...@@ -46,8 +45,10 @@ bool GetAuthBytes(const std::string& shared_secret,
} // namespace } // namespace
HostChannelAuthenticator::HostChannelAuthenticator(net::SSLServerSocket* socket) HostChannelAuthenticator::HostChannelAuthenticator(
: socket_(socket), const std::string& shared_secret)
: shared_secret_(shared_secret),
socket_(NULL),
ALLOW_THIS_IN_INITIALIZER_LIST(auth_read_callback_( ALLOW_THIS_IN_INITIALIZER_LIST(auth_read_callback_(
this, &HostChannelAuthenticator::OnAuthBytesRead)) { this, &HostChannelAuthenticator::OnAuthBytesRead)) {
} }
...@@ -55,10 +56,11 @@ HostChannelAuthenticator::HostChannelAuthenticator(net::SSLServerSocket* socket) ...@@ -55,10 +56,11 @@ HostChannelAuthenticator::HostChannelAuthenticator(net::SSLServerSocket* socket)
HostChannelAuthenticator::~HostChannelAuthenticator() { HostChannelAuthenticator::~HostChannelAuthenticator() {
} }
void HostChannelAuthenticator::Authenticate(const std::string& shared_secret, void HostChannelAuthenticator::Authenticate(net::SSLSocket* socket,
const DoneCallback& done_callback) { const DoneCallback& done_callback) {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
socket_ = socket;
done_callback_ = done_callback; done_callback_ = done_callback;
unsigned char key_material[kAuthDigestLength]; unsigned char key_material[kAuthDigestLength];
...@@ -70,7 +72,7 @@ void HostChannelAuthenticator::Authenticate(const std::string& shared_secret, ...@@ -70,7 +72,7 @@ void HostChannelAuthenticator::Authenticate(const std::string& shared_secret,
return; return;
} }
if (!GetAuthBytes(shared_secret, if (!GetAuthBytes(shared_secret_,
std::string(key_material, key_material + kAuthDigestLength), std::string(key_material, key_material + kAuthDigestLength),
&auth_bytes_)) { &auth_bytes_)) {
done_callback.Run(FAILURE); done_callback.Run(FAILURE);
...@@ -139,8 +141,9 @@ bool HostChannelAuthenticator::VerifyAuthBytes( ...@@ -139,8 +141,9 @@ bool HostChannelAuthenticator::VerifyAuthBytes(
} }
ClientChannelAuthenticator::ClientChannelAuthenticator( ClientChannelAuthenticator::ClientChannelAuthenticator(
net::SSLClientSocket* socket) const std::string& shared_secret)
: socket_(socket), : shared_secret_(shared_secret),
socket_(NULL),
ALLOW_THIS_IN_INITIALIZER_LIST(auth_write_callback_( ALLOW_THIS_IN_INITIALIZER_LIST(auth_write_callback_(
this, &ClientChannelAuthenticator::OnAuthBytesWritten)) { this, &ClientChannelAuthenticator::OnAuthBytesWritten)) {
} }
...@@ -149,10 +152,11 @@ ClientChannelAuthenticator::~ClientChannelAuthenticator() { ...@@ -149,10 +152,11 @@ ClientChannelAuthenticator::~ClientChannelAuthenticator() {
} }
void ClientChannelAuthenticator::Authenticate( void ClientChannelAuthenticator::Authenticate(
const std::string& shared_secret, net::SSLSocket* socket,
const DoneCallback& done_callback) { const DoneCallback& done_callback) {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
socket_ = socket;
done_callback_ = done_callback; done_callback_ = done_callback;
unsigned char key_material[kAuthDigestLength]; unsigned char key_material[kAuthDigestLength];
...@@ -165,7 +169,7 @@ void ClientChannelAuthenticator::Authenticate( ...@@ -165,7 +169,7 @@ void ClientChannelAuthenticator::Authenticate(
} }
std::string auth_bytes; std::string auth_bytes;
if (!GetAuthBytes(shared_secret, if (!GetAuthBytes(shared_secret_,
std::string(key_material, key_material + kAuthDigestLength), std::string(key_material, key_material + kAuthDigestLength),
&auth_bytes)) { &auth_bytes)) {
done_callback.Run(FAILURE); done_callback.Run(FAILURE);
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
namespace net { namespace net {
class DrainableIOBuffer; class DrainableIOBuffer;
class GrowableIOBuffer; class GrowableIOBuffer;
class SSLClientSocket; class SSLSocket;
class SSLServerSocket;
} // namespace net } // namespace net
namespace remoting { namespace remoting {
...@@ -38,7 +37,7 @@ class ChannelAuthenticator : public base::NonThreadSafe { ...@@ -38,7 +37,7 @@ class ChannelAuthenticator : public base::NonThreadSafe {
// when authentication is finished. Caller retains ownership of // when authentication is finished. Caller retains ownership of
// |socket|. |shared_secret| is a shared secret that we use to // |socket|. |shared_secret| is a shared secret that we use to
// authenticate the channel. // authenticate the channel.
virtual void Authenticate(const std::string& shared_secret, virtual void Authenticate(net::SSLSocket* socket,
const DoneCallback& done_callback) = 0; const DoneCallback& done_callback) = 0;
private: private:
...@@ -47,11 +46,11 @@ class ChannelAuthenticator : public base::NonThreadSafe { ...@@ -47,11 +46,11 @@ class ChannelAuthenticator : public base::NonThreadSafe {
class HostChannelAuthenticator : public ChannelAuthenticator { class HostChannelAuthenticator : public ChannelAuthenticator {
public: public:
HostChannelAuthenticator(net::SSLServerSocket* socket); HostChannelAuthenticator(const std::string& shared_secret);
virtual ~HostChannelAuthenticator(); virtual ~HostChannelAuthenticator();
// ChannelAuthenticator overrides. // ChannelAuthenticator overrides.
virtual void Authenticate(const std::string& shared_secret, virtual void Authenticate(net::SSLSocket* socket,
const DoneCallback& done_callback) OVERRIDE; const DoneCallback& done_callback) OVERRIDE;
private: private:
...@@ -60,8 +59,9 @@ class HostChannelAuthenticator : public ChannelAuthenticator { ...@@ -60,8 +59,9 @@ class HostChannelAuthenticator : public ChannelAuthenticator {
bool HandleAuthBytesRead(int result); bool HandleAuthBytesRead(int result);
bool VerifyAuthBytes(const std::string& received_auth_bytes); bool VerifyAuthBytes(const std::string& received_auth_bytes);
std::string shared_secret_;
std::string auth_bytes_; std::string auth_bytes_;
net::SSLServerSocket* socket_; net::SSLSocket* socket_;
DoneCallback done_callback_; DoneCallback done_callback_;
scoped_refptr<net::GrowableIOBuffer> auth_read_buf_; scoped_refptr<net::GrowableIOBuffer> auth_read_buf_;
...@@ -73,11 +73,11 @@ class HostChannelAuthenticator : public ChannelAuthenticator { ...@@ -73,11 +73,11 @@ class HostChannelAuthenticator : public ChannelAuthenticator {
class ClientChannelAuthenticator : public ChannelAuthenticator { class ClientChannelAuthenticator : public ChannelAuthenticator {
public: public:
ClientChannelAuthenticator(net::SSLClientSocket* socket); ClientChannelAuthenticator(const std::string& shared_secret);
virtual ~ClientChannelAuthenticator(); virtual ~ClientChannelAuthenticator();
// ChannelAuthenticator overrides. // ChannelAuthenticator overrides.
virtual void Authenticate(const std::string& shared_secret, virtual void Authenticate(net::SSLSocket* socket,
const DoneCallback& done_callback); const DoneCallback& done_callback);
private: private:
...@@ -85,12 +85,14 @@ class ClientChannelAuthenticator : public ChannelAuthenticator { ...@@ -85,12 +85,14 @@ class ClientChannelAuthenticator : public ChannelAuthenticator {
void OnAuthBytesWritten(int result); void OnAuthBytesWritten(int result);
bool HandleAuthBytesWritten(int result); bool HandleAuthBytesWritten(int result);
net::SSLClientSocket* socket_; std::string shared_secret_;
net::SSLSocket* socket_;
DoneCallback done_callback_; DoneCallback done_callback_;
scoped_refptr<net::DrainableIOBuffer> auth_write_buf_; scoped_refptr<net::DrainableIOBuffer> auth_write_buf_;
net::OldCompletionCallbackImpl<ClientChannelAuthenticator> auth_write_callback_; net::OldCompletionCallbackImpl<ClientChannelAuthenticator>
auth_write_callback_;
DISALLOW_COPY_AND_ASSIGN(ClientChannelAuthenticator); DISALLOW_COPY_AND_ASSIGN(ClientChannelAuthenticator);
}; };
......
...@@ -239,7 +239,7 @@ class JingleSessionTest : public testing::Test { ...@@ -239,7 +239,7 @@ class JingleSessionTest : public testing::Test {
EXPECT_CALL(host_connection_callback_, EXPECT_CALL(host_connection_callback_,
OnStateChange(Session::CONNECTED_CHANNELS)) OnStateChange(Session::CONNECTED_CHANNELS))
.Times(AtMost(1)); .Times(AtMost(1));
// Expect that the connection will be closed eventually. // Expect that the connection will fail.
EXPECT_CALL(host_connection_callback_, EXPECT_CALL(host_connection_callback_,
OnStateChange(Session::FAILED)) OnStateChange(Session::FAILED))
.Times(1) .Times(1)
......
...@@ -72,8 +72,6 @@ JingleStreamConnector::JingleStreamConnector( ...@@ -72,8 +72,6 @@ JingleStreamConnector::JingleStreamConnector(
initiator_(false), initiator_(false),
local_private_key_(NULL), local_private_key_(NULL),
raw_channel_(NULL), raw_channel_(NULL),
ssl_client_socket_(NULL),
ssl_server_socket_(NULL),
ALLOW_THIS_IN_INITIALIZER_LIST(tcp_connect_callback_( ALLOW_THIS_IN_INITIALIZER_LIST(tcp_connect_callback_(
this, &JingleStreamConnector::OnTCPConnect)), this, &JingleStreamConnector::OnTCPConnect)),
ALLOW_THIS_IN_INITIALIZER_LIST(ssl_connect_callback_( ALLOW_THIS_IN_INITIALIZER_LIST(ssl_connect_callback_(
...@@ -122,8 +120,8 @@ bool JingleStreamConnector::EstablishTCPConnection(net::Socket* socket) { ...@@ -122,8 +120,8 @@ bool JingleStreamConnector::EstablishTCPConnection(net::Socket* socket) {
adapter->SetReceiveBufferSize(kTcpReceiveBufferSize); adapter->SetReceiveBufferSize(kTcpReceiveBufferSize);
adapter->SetSendBufferSize(kTcpSendBufferSize); adapter->SetSendBufferSize(kTcpSendBufferSize);
socket_.reset(adapter); tcp_socket_.reset(adapter);
int result = socket_->Connect(&tcp_connect_callback_); int result = tcp_socket_->Connect(&tcp_connect_callback_);
if (result == net::ERR_IO_PENDING) { if (result == net::ERR_IO_PENDING) {
return true; return true;
} else if (result == net::OK) { } else if (result == net::OK) {
...@@ -135,18 +133,18 @@ bool JingleStreamConnector::EstablishTCPConnection(net::Socket* socket) { ...@@ -135,18 +133,18 @@ bool JingleStreamConnector::EstablishTCPConnection(net::Socket* socket) {
} }
bool JingleStreamConnector::EstablishSSLConnection() { bool JingleStreamConnector::EstablishSSLConnection() {
DCHECK(socket_->IsConnected()); DCHECK(tcp_socket_->IsConnected());
int result; int result;
if (initiator_) { if (initiator_) {
cert_verifier_.reset(new net::CertVerifier()); cert_verifier_.reset(new net::CertVerifier());
// Create client SSL socket. // Create client SSL socket.
ssl_client_socket_ = CreateSSLClientSocket( net::SSLClientSocket* socket = CreateSSLClientSocket(
socket_.release(), remote_cert_, cert_verifier_.get()); tcp_socket_.release(), remote_cert_, cert_verifier_.get());
socket_.reset(ssl_client_socket_); socket_.reset(socket);
result = ssl_client_socket_->Connect(&ssl_connect_callback_); result = socket->Connect(&ssl_connect_callback_);
} else { } else {
scoped_refptr<net::X509Certificate> cert = scoped_refptr<net::X509Certificate> cert =
net::X509Certificate::CreateFromBytes( net::X509Certificate::CreateFromBytes(
...@@ -158,11 +156,11 @@ bool JingleStreamConnector::EstablishSSLConnection() { ...@@ -158,11 +156,11 @@ bool JingleStreamConnector::EstablishSSLConnection() {
// Create server SSL socket. // Create server SSL socket.
net::SSLConfig ssl_config; net::SSLConfig ssl_config;
ssl_server_socket_ = net::CreateSSLServerSocket( net::SSLServerSocket* socket = net::CreateSSLServerSocket(
socket_.release(), cert, local_private_key_, ssl_config); tcp_socket_.release(), cert, local_private_key_, ssl_config);
socket_.reset(ssl_server_socket_); socket_.reset(socket);
result = ssl_server_socket_->Handshake(&ssl_connect_callback_); result = socket->Handshake(&ssl_connect_callback_);
} }
if (result == net::ERR_IO_PENDING) { if (result == net::ERR_IO_PENDING) {
...@@ -205,15 +203,14 @@ void JingleStreamConnector::OnSSLConnect(int result) { ...@@ -205,15 +203,14 @@ void JingleStreamConnector::OnSSLConnect(int result) {
void JingleStreamConnector::AuthenticateChannel() { void JingleStreamConnector::AuthenticateChannel() {
if (initiator_) { if (initiator_) {
authenticator_.reset(new ClientChannelAuthenticator(ssl_client_socket_)); authenticator_.reset(
new ClientChannelAuthenticator(session_->shared_secret()));
} else { } else {
authenticator_.reset(new HostChannelAuthenticator(ssl_server_socket_)); authenticator_.reset(
new HostChannelAuthenticator(session_->shared_secret()));
} }
authenticator_->Authenticate(socket_.get(), base::Bind(
authenticator_->Authenticate( &JingleStreamConnector::OnAuthenticationDone, base::Unretained(this)));
session_->shared_secret(),
base::Bind(&JingleStreamConnector::OnAuthenticationDone,
base::Unretained(this)));
} }
void JingleStreamConnector::OnAuthenticationDone( void JingleStreamConnector::OnAuthenticationDone(
......
...@@ -20,8 +20,7 @@ class TransportChannel; ...@@ -20,8 +20,7 @@ class TransportChannel;
namespace net { namespace net {
class CertVerifier; class CertVerifier;
class StreamSocket; class StreamSocket;
class SSLClientSocket; class SSLSocket;
class SSLServerSocket;
} // namespace net } // namespace net
namespace remoting { namespace remoting {
...@@ -74,11 +73,8 @@ class JingleStreamConnector : public JingleChannelConnector { ...@@ -74,11 +73,8 @@ class JingleStreamConnector : public JingleChannelConnector {
crypto::RSAPrivateKey* local_private_key_; crypto::RSAPrivateKey* local_private_key_;
cricket::TransportChannel* raw_channel_; cricket::TransportChannel* raw_channel_;
scoped_ptr<net::StreamSocket> socket_; scoped_ptr<net::StreamSocket> tcp_socket_;
scoped_ptr<net::SSLSocket> socket_;
// TODO(wez): Ugly up-casts needed so we can fetch SSL keying material.
net::SSLClientSocket* ssl_client_socket_;
net::SSLServerSocket* ssl_server_socket_;
// Used to verify the certificate received in SSLClientSocket. // Used to verify the certificate received in SSLClientSocket.
scoped_ptr<net::CertVerifier> cert_verifier_; scoped_ptr<net::CertVerifier> cert_verifier_;
......
...@@ -236,11 +236,10 @@ void PepperStreamChannel::OnSSLConnect(int result) { ...@@ -236,11 +236,10 @@ void PepperStreamChannel::OnSSLConnect(int result) {
void PepperStreamChannel::AuthenticateChannel() { void PepperStreamChannel::AuthenticateChannel() {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
authenticator_.reset(new ClientChannelAuthenticator(ssl_client_socket_)); authenticator_.reset(
authenticator_->Authenticate( new ClientChannelAuthenticator(session_->shared_secret()));
session_->shared_secret(), authenticator_->Authenticate(ssl_client_socket_, base::Bind(
base::Bind(&PepperStreamChannel::OnAuthenticationDone, &PepperStreamChannel::OnAuthenticationDone, base::Unretained(this)));
base::Unretained(this)));
} }
void PepperStreamChannel::OnAuthenticationDone( void PepperStreamChannel::OnAuthenticationDone(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment