Commit 57651a87 authored by sergeyu@chromium.org's avatar sergeyu@chromium.org

Separate Authenticator and Session unittests.

Previously JingleSession unit tests were using real authenticators. Here
I changed them to always use FakeAuthenticator and added new tests for channel 
authentication in v1_authenticator_unittest.cc . Also, to make new tests pass,
fixed session-terminate message handling in JingleSession to return correct 
error code.

BUG=105214


Review URL: http://codereview.chromium.org/8743023

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@114178 0039d316-1c4b-4281-b951-d872f2087c98
parent f7e3fb85
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/protocol/connection_tester.h"
#include "base/bind.h"
#include "base/message_loop.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace remoting {
namespace protocol {
StreamConnectionTester::StreamConnectionTester(net::StreamSocket* client_socket,
net::StreamSocket* host_socket,
int message_size,
int message_count)
: message_loop_(MessageLoop::current()),
host_socket_(host_socket),
client_socket_(client_socket),
message_size_(message_size),
message_count_(message_count),
test_data_size_(message_size * message_count),
done_(false),
write_errors_(0),
read_errors_(0) {
}
StreamConnectionTester::~StreamConnectionTester() {
}
void StreamConnectionTester::Start() {
InitBuffers();
DoRead();
DoWrite();
}
void StreamConnectionTester::CheckResults() {
EXPECT_EQ(0, write_errors_);
EXPECT_EQ(0, read_errors_);
ASSERT_EQ(test_data_size_, input_buffer_->offset());
output_buffer_->SetOffset(0);
ASSERT_EQ(test_data_size_, output_buffer_->size());
EXPECT_EQ(0, memcmp(output_buffer_->data(),
input_buffer_->StartOfBuffer(), test_data_size_));
}
void StreamConnectionTester::Done() {
done_ = true;
message_loop_->PostTask(FROM_HERE, MessageLoop::QuitClosure());
}
void StreamConnectionTester::InitBuffers() {
output_buffer_ = new net::DrainableIOBuffer(
new net::IOBuffer(test_data_size_), test_data_size_);
input_buffer_ = new net::GrowableIOBuffer();
}
void StreamConnectionTester::DoWrite() {
int result = 1;
while (result > 0) {
if (output_buffer_->BytesRemaining() == 0)
break;
int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
message_size_);
result = client_socket_->Write(
output_buffer_, bytes_to_write,
base::Bind(&StreamConnectionTester::OnWritten, base::Unretained(this)));
HandleWriteResult(result);
}
}
void StreamConnectionTester::OnWritten(int result) {
HandleWriteResult(result);
DoWrite();
}
void StreamConnectionTester::HandleWriteResult(int result) {
if (result <= 0 && result != net::ERR_IO_PENDING) {
LOG(ERROR) << "Received error " << result << " when trying to write";
write_errors_++;
Done();
} else if (result > 0) {
output_buffer_->DidConsume(result);
}
}
void StreamConnectionTester::DoRead() {
int result = 1;
while (result > 0) {
input_buffer_->SetCapacity(input_buffer_->offset() + message_size_);
result = host_socket_->Read(
input_buffer_, message_size_,
base::Bind(&StreamConnectionTester::OnRead, base::Unretained(this)));
HandleReadResult(result);
};
}
void StreamConnectionTester::OnRead(int result) {
HandleReadResult(result);
if (!done_)
DoRead(); // Don't try to read again when we are done reading.
}
void StreamConnectionTester::HandleReadResult(int result) {
if (result <= 0 && result != net::ERR_IO_PENDING) {
LOG(ERROR) << "Received error " << result << " when trying to read";
read_errors_++;
Done();
} else if (result > 0) {
// Allocate memory for the next read.
input_buffer_->set_offset(input_buffer_->offset() + result);
if (input_buffer_->offset() == test_data_size_)
Done();
}
}
DatagramConnectionTester::DatagramConnectionTester(net::Socket* client_socket,
net::Socket* host_socket,
int message_size,
int message_count,
int delay_ms)
: message_loop_(MessageLoop::current()),
host_socket_(host_socket),
client_socket_(client_socket),
message_size_(message_size),
message_count_(message_count),
delay_ms_(delay_ms),
done_(false),
write_errors_(0),
read_errors_(0),
packets_sent_(0),
packets_received_(0),
bad_packets_received_(0) {
sent_packets_.resize(message_count_);
}
DatagramConnectionTester::~DatagramConnectionTester() {
}
void DatagramConnectionTester::Start() {
DoRead();
DoWrite();
}
void DatagramConnectionTester::CheckResults() {
EXPECT_EQ(0, write_errors_);
EXPECT_EQ(0, read_errors_);
EXPECT_EQ(0, bad_packets_received_);
// Verify that we've received at least one packet.
EXPECT_GT(packets_received_, 0);
LOG(INFO) << "Received " << packets_received_ << " packets out of "
<< message_count_;
}
void DatagramConnectionTester::Done() {
done_ = true;
message_loop_->PostTask(FROM_HERE, MessageLoop::QuitClosure());
}
void DatagramConnectionTester::DoWrite() {
if (packets_sent_ >= message_count_) {
Done();
return;
}
scoped_refptr<net::IOBuffer> packet(new net::IOBuffer(message_size_));
memset(packet->data(), 123, message_size_);
sent_packets_[packets_sent_] = packet;
// Put index of this packet in the beginning of the packet body.
memcpy(packet->data(), &packets_sent_, sizeof(packets_sent_));
int result = client_socket_->Write(
packet, message_size_,
base::Bind(&DatagramConnectionTester::OnWritten, base::Unretained(this)));
HandleWriteResult(result);
}
void DatagramConnectionTester::OnWritten(int result) {
HandleWriteResult(result);
}
void DatagramConnectionTester::HandleWriteResult(int result) {
if (result <= 0 && result != net::ERR_IO_PENDING) {
LOG(ERROR) << "Received error " << result << " when trying to write";
write_errors_++;
Done();
} else if (result > 0) {
EXPECT_EQ(message_size_, result);
packets_sent_++;
message_loop_->PostDelayedTask(FROM_HERE, base::Bind(
&DatagramConnectionTester::DoWrite, base::Unretained(this)), delay_ms_);
}
}
void DatagramConnectionTester::DoRead() {
int result = 1;
while (result > 0) {
int kReadSize = message_size_ * 2;
read_buffer_ = new net::IOBuffer(kReadSize);
result = host_socket_->Read(
read_buffer_, kReadSize,
base::Bind(&DatagramConnectionTester::OnRead, base::Unretained(this)));
HandleReadResult(result);
};
}
void DatagramConnectionTester::OnRead(int result) {
HandleReadResult(result);
DoRead();
}
void DatagramConnectionTester::HandleReadResult(int result) {
if (result <= 0 && result != net::ERR_IO_PENDING) {
// Error will be received after the socket is closed.
LOG(ERROR) << "Received error " << result << " when trying to read";
read_errors_++;
Done();
} else if (result > 0) {
packets_received_++;
if (message_size_ != result) {
// Invalid packet size;
bad_packets_received_++;
} else {
// Validate packet body.
int packet_id;
memcpy(&packet_id, read_buffer_->data(), sizeof(packet_id));
if (packet_id < 0 || packet_id >= message_count_) {
bad_packets_received_++;
} else {
if (memcmp(read_buffer_->data(), sent_packets_[packet_id]->data(),
message_size_) != 0)
bad_packets_received_++;
}
}
}
}
} // namespace protocol
} // namespace remoting
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef REMOTING_PROTOCOL_CONNECTION_TESTER_H_
#define REMOTING_PROTOCOL_CONNECTION_TESTER_H_
#include <vector>
#include "base/memory/ref_counted.h"
class MessageLoop;
namespace net {
class DrainableIOBuffer;
class GrowableIOBuffer;
class IOBuffer;
class Socket;
class StreamSocket;
} // namespace net
namespace remoting {
namespace protocol {
// This class is used by unit tests to verify that a connection
// between two sockets works properly, i.e. data is delivered from one
// end to the other.
class StreamConnectionTester {
public:
StreamConnectionTester(net::StreamSocket* client_socket,
net::StreamSocket* host_socket,
int message_size,
int message_count);
~StreamConnectionTester();
void Start();
void CheckResults();
protected:
void Done();
void InitBuffers();
void DoWrite();
void OnWritten(int result);
void HandleWriteResult(int result);
void DoRead();
void OnRead(int result);
void HandleReadResult(int result);
private:
MessageLoop* message_loop_;
net::StreamSocket* host_socket_;
net::StreamSocket* client_socket_;
int message_size_;
int message_count_;
int test_data_size_;
bool done_;
scoped_refptr<net::DrainableIOBuffer> output_buffer_;
scoped_refptr<net::GrowableIOBuffer> input_buffer_;
int write_errors_;
int read_errors_;
};
class DatagramConnectionTester {
public:
DatagramConnectionTester(net::Socket* client_socket,
net::Socket* host_socket,
int message_size,
int message_count,
int delay_ms);
~DatagramConnectionTester() ;
void Start();
void CheckResults();
private:
void Done();
void DoWrite();
void OnWritten(int result);
void HandleWriteResult(int result);
void DoRead();
void OnRead(int result);
void HandleReadResult(int result);
MessageLoop* message_loop_;
net::Socket* host_socket_;
net::Socket* client_socket_;
int message_size_;
int message_count_;
int delay_ms_;
bool done_;
std::vector<scoped_refptr<net::IOBuffer> > sent_packets_;
scoped_refptr<net::IOBuffer> read_buffer_;
int write_errors_;
int read_errors_;
int packets_sent_;
int packets_received_;
int bad_packets_received_;
};
} // namespace protocol
} // namespace remoting
#endif // REMOTING_PROTOCOL_CONNECTION_TESTER_H_
...@@ -199,6 +199,7 @@ void ConnectionToHost::OnSessionStateChange( ...@@ -199,6 +199,7 @@ void ConnectionToHost::OnSessionStateChange(
CloseOnError(INCOMPATIBLE_PROTOCOL); CloseOnError(INCOMPATIBLE_PROTOCOL);
break; break;
case Session::CHANNEL_CONNECTION_ERROR: case Session::CHANNEL_CONNECTION_ERROR:
case Session::UNKNOWN_ERROR:
CloseOnError(NETWORK_FAILURE); CloseOnError(NETWORK_FAILURE);
break; break;
case Session::OK: case Session::OK:
......
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/protocol/fake_authenticator.h"
#include "base/message_loop.h"
#include "base/string_number_conversions.h"
#include "net/socket/stream_socket.h"
#include "remoting/base/constants.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/libjingle/source/talk/xmllite/xmlelement.h"
namespace remoting {
namespace protocol {
FakeChannelAuthenticator::FakeChannelAuthenticator(bool accept, bool async)
: accept_(accept),
async_(async),
ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
}
FakeChannelAuthenticator::~FakeChannelAuthenticator() {
}
void FakeChannelAuthenticator::SecureAndAuthenticate(
net::StreamSocket* socket, const DoneCallback& done_callback) {
net::Error error;
if (accept_) {
error = net::OK;
} else {
error = net::ERR_FAILED;
delete socket;
socket = NULL;
}
if (async_) {
MessageLoop::current()->PostTask(FROM_HERE, base::Bind(
&FakeChannelAuthenticator::CallCallback, weak_factory_.GetWeakPtr(),
done_callback, error, socket));
} else {
done_callback.Run(error, socket);
}
}
void FakeChannelAuthenticator::CallCallback(
const DoneCallback& done_callback,
net::Error error,
net::StreamSocket* socket) {
done_callback.Run(error, socket);
}
FakeAuthenticator::FakeAuthenticator(
Type type, int round_trips, Action action, bool async)
: type_(type),
round_trips_(round_trips),
action_(action),
async_(async),
messages_(0) {
}
FakeAuthenticator::~FakeAuthenticator() {
}
Authenticator::State FakeAuthenticator::state() const{
EXPECT_LE(messages_, round_trips_ * 2);
if (messages_ >= round_trips_ * 2) {
if (action_ == REJECT) {
return REJECTED;
} else {
return ACCEPTED;
}
}
// Don't send the last message if this is a host that wants to
// reject a connection.
if (messages_ == round_trips_ * 2 - 1 &&
type_ == HOST && action_ == REJECT) {
return REJECTED;
}
// We are not done yet. process next message.
if ((messages_ % 2 == 0 && type_ == CLIENT) ||
(messages_ % 2 == 1 && type_ == HOST)) {
return MESSAGE_READY;
} else {
return WAITING_MESSAGE;
}
}
void FakeAuthenticator::ProcessMessage(const buzz::XmlElement* message) {
EXPECT_EQ(WAITING_MESSAGE, state());
std::string id =
message->TextNamed(buzz::QName(kChromotingXmlNamespace, "id"));
EXPECT_EQ(id, base::IntToString(messages_));
++messages_;
}
buzz::XmlElement* FakeAuthenticator::GetNextMessage() {
EXPECT_EQ(MESSAGE_READY, state());
buzz::XmlElement* result = new buzz::XmlElement(
buzz::QName(kChromotingXmlNamespace, "authentication"));
buzz::XmlElement* id = new buzz::XmlElement(
buzz::QName(kChromotingXmlNamespace, "id"));
id->AddText(base::IntToString(messages_));
result->AddElement(id);
++messages_;
return result;
}
ChannelAuthenticator*
FakeAuthenticator::CreateChannelAuthenticator() const {
EXPECT_EQ(ACCEPTED, state());
return new FakeChannelAuthenticator(action_ != REJECT_CHANNEL, async_);
}
FakeHostAuthenticatorFactory::FakeHostAuthenticatorFactory(
int round_trips, FakeAuthenticator::Action action, bool async)
: round_trips_(round_trips),
action_(action), async_(async) {
}
FakeHostAuthenticatorFactory::~FakeHostAuthenticatorFactory() {
}
Authenticator* FakeHostAuthenticatorFactory::CreateAuthenticator(
const std::string& remote_jid,
const buzz::XmlElement* first_message) {
return new FakeAuthenticator(FakeAuthenticator::HOST, round_trips_,
action_, async_);
}
} // namespace protocol
} // namespace remoting
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef REMOTING_PROTOCOL_FAKE_AUTHENTICATOR_H_
#define REMOTING_PROTOCOL_FAKE_AUTHENTICATOR_H_
#include "base/memory/weak_ptr.h"
#include "remoting/protocol/authenticator.h"
#include "remoting/protocol/channel_authenticator.h"
namespace remoting {
namespace protocol {
class FakeChannelAuthenticator : public ChannelAuthenticator {
public:
FakeChannelAuthenticator(bool accept, bool async);
virtual ~FakeChannelAuthenticator();
// ChannelAuthenticator interface.
virtual void SecureAndAuthenticate(
net::StreamSocket* socket, const DoneCallback& done_callback) OVERRIDE;
private:
void CallCallback(
const DoneCallback& done_callback,
net::Error error,
net::StreamSocket* socket);
bool accept_;
bool async_;
base::WeakPtrFactory<FakeChannelAuthenticator> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(FakeChannelAuthenticator);
};
class FakeAuthenticator : public Authenticator {
public:
enum Type {
HOST,
CLIENT,
};
enum Action {
ACCEPT,
REJECT,
REJECT_CHANNEL
};
FakeAuthenticator(Type type, int round_trips, Action action, bool async);
virtual ~FakeAuthenticator();
// Authenticator interface.
virtual State state() const OVERRIDE;
virtual void ProcessMessage(const buzz::XmlElement* message) OVERRIDE;
virtual buzz::XmlElement* GetNextMessage() OVERRIDE;
virtual ChannelAuthenticator* CreateChannelAuthenticator() const OVERRIDE;
protected:
Type type_;
int round_trips_;
Action action_;
bool async_;
// Total number of messages that have been processed.
int messages_;
DISALLOW_COPY_AND_ASSIGN(FakeAuthenticator);
};
class FakeHostAuthenticatorFactory : public AuthenticatorFactory {
public:
FakeHostAuthenticatorFactory(
int round_trips, FakeAuthenticator::Action action, bool async);
virtual ~FakeHostAuthenticatorFactory();
// AuthenticatorFactory interface.
virtual Authenticator* CreateAuthenticator(
const std::string& remote_jid,
const buzz::XmlElement* first_message) OVERRIDE;
private:
int round_trips_;
FakeAuthenticator::Action action_;
bool async_;
DISALLOW_COPY_AND_ASSIGN(FakeHostAuthenticatorFactory);
};
} // namespace protocol
} // namespace remoting
#endif // REMOTING_PROTOCOL_FAKE_AUTHENTICATOR_H_
...@@ -4,9 +4,12 @@ ...@@ -4,9 +4,12 @@
#include "remoting/protocol/fake_session.h" #include "remoting/protocol/fake_session.h"
#include "base/bind.h"
#include "base/message_loop.h" #include "base/message_loop.h"
#include "net/base/address_list.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/base/net_util.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace remoting { namespace remoting {
...@@ -16,17 +19,19 @@ const char kTestJid[] = "host1@gmail.com/chromoting123"; ...@@ -16,17 +19,19 @@ const char kTestJid[] = "host1@gmail.com/chromoting123";
FakeSocket::FakeSocket() FakeSocket::FakeSocket()
: read_pending_(false), : read_pending_(false),
read_buffer_size_(0),
input_pos_(0), input_pos_(0),
message_loop_(MessageLoop::current()) { message_loop_(MessageLoop::current()),
ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
} }
FakeSocket::~FakeSocket() { FakeSocket::~FakeSocket() {
EXPECT_EQ(message_loop_, MessageLoop::current()); EXPECT_EQ(message_loop_, MessageLoop::current());
} }
void FakeSocket::AppendInputData(const char* data, int data_size) { void FakeSocket::AppendInputData(const std::vector<char>& data) {
EXPECT_EQ(message_loop_, MessageLoop::current()); EXPECT_EQ(message_loop_, MessageLoop::current());
input_data_.insert(input_data_.end(), data, data + data_size); input_data_.insert(input_data_.end(), data.begin(), data.end());
// Complete pending read if any. // Complete pending read if any.
if (read_pending_) { if (read_pending_) {
read_pending_ = false; read_pending_ = false;
...@@ -36,11 +41,17 @@ void FakeSocket::AppendInputData(const char* data, int data_size) { ...@@ -36,11 +41,17 @@ void FakeSocket::AppendInputData(const char* data, int data_size) {
memcpy(read_buffer_->data(), memcpy(read_buffer_->data(),
&(*input_data_.begin()) + input_pos_, result); &(*input_data_.begin()) + input_pos_, result);
input_pos_ += result; input_pos_ += result;
read_callback_.Run(result);
read_buffer_ = NULL; read_buffer_ = NULL;
read_callback_.Run(result);
} }
} }
void FakeSocket::PairWith(FakeSocket* peer_socket) {
EXPECT_EQ(message_loop_, MessageLoop::current());
peer_socket_ = peer_socket->weak_factory_.GetWeakPtr();
peer_socket->peer_socket_ = weak_factory_.GetWeakPtr();
}
int FakeSocket::Read(net::IOBuffer* buf, int buf_len, int FakeSocket::Read(net::IOBuffer* buf, int buf_len,
const net::CompletionCallback& callback) { const net::CompletionCallback& callback) {
EXPECT_EQ(message_loop_, MessageLoop::current()); EXPECT_EQ(message_loop_, MessageLoop::current());
...@@ -64,6 +75,13 @@ int FakeSocket::Write(net::IOBuffer* buf, int buf_len, ...@@ -64,6 +75,13 @@ int FakeSocket::Write(net::IOBuffer* buf, int buf_len,
EXPECT_EQ(message_loop_, MessageLoop::current()); EXPECT_EQ(message_loop_, MessageLoop::current());
written_data_.insert(written_data_.end(), written_data_.insert(written_data_.end(),
buf->data(), buf->data() + buf_len); buf->data(), buf->data() + buf_len);
if (peer_socket_) {
message_loop_->PostTask(FROM_HERE, base::Bind(
&FakeSocket::AppendInputData, peer_socket_,
std::vector<char>(buf->data(), buf->data() + buf_len)));
}
return buf_len; return buf_len;
} }
...@@ -82,7 +100,7 @@ int FakeSocket::Connect(const net::CompletionCallback& callback) { ...@@ -82,7 +100,7 @@ int FakeSocket::Connect(const net::CompletionCallback& callback) {
} }
void FakeSocket::Disconnect() { void FakeSocket::Disconnect() {
NOTIMPLEMENTED(); peer_socket_.reset();
} }
bool FakeSocket::IsConnected() const { bool FakeSocket::IsConnected() const {
...@@ -95,10 +113,11 @@ bool FakeSocket::IsConnectedAndIdle() const { ...@@ -95,10 +113,11 @@ bool FakeSocket::IsConnectedAndIdle() const {
return false; return false;
} }
int FakeSocket::GetPeerAddress( int FakeSocket::GetPeerAddress(net::AddressList* address) const {
net::AddressList* address) const { net::IPAddressNumber ip;
NOTIMPLEMENTED(); ip.resize(net::kIPv4AddressSize);
return net::ERR_FAILED; *address = net::AddressList::CreateFromIPAddress(ip, 0);
return net::OK;
} }
int FakeSocket::GetLocalAddress( int FakeSocket::GetLocalAddress(
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <vector> #include <vector>
#include "base/memory/scoped_ptr.h" #include "base/memory/scoped_ptr.h"
#include "base/memory/weak_ptr.h"
#include "net/base/completion_callback.h" #include "net/base/completion_callback.h"
#include "net/socket/socket.h" #include "net/socket/socket.h"
#include "net/socket/stream_socket.h" #include "net/socket/stream_socket.h"
...@@ -27,6 +28,11 @@ extern const char kTestJid[]; ...@@ -27,6 +28,11 @@ extern const char kTestJid[];
// Read() reads data from another buffer that can be set with AppendInputData(). // Read() reads data from another buffer that can be set with AppendInputData().
// Pending reads are supported, so if there is a pending read AppendInputData() // Pending reads are supported, so if there is a pending read AppendInputData()
// calls the read callback. // calls the read callback.
//
// Two fake sockets can be connected to each other using the
// PairWith() method, e.g.: a->PairWith(b). After this all data
// written to |a| can be read from |b| and vica versa. Two connected
// sockets |a| and |b| must be created and used on the same thread.
class FakeSocket : public net::StreamSocket { class FakeSocket : public net::StreamSocket {
public: public:
FakeSocket(); FakeSocket();
...@@ -34,7 +40,8 @@ class FakeSocket : public net::StreamSocket { ...@@ -34,7 +40,8 @@ class FakeSocket : public net::StreamSocket {
const std::string& written_data() const { return written_data_; } const std::string& written_data() const { return written_data_; }
void AppendInputData(const char* data, int data_size); void AppendInputData(const std::vector<char>& data);
void PairWith(FakeSocket* peer_socket);
int input_pos() const { return input_pos_; } int input_pos() const { return input_pos_; }
bool read_pending() const { return read_pending_; } bool read_pending() const { return read_pending_; }
...@@ -47,7 +54,7 @@ class FakeSocket : public net::StreamSocket { ...@@ -47,7 +54,7 @@ class FakeSocket : public net::StreamSocket {
virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE;
// net::StreamSocket implementation. // net::StreamSocket interface.
virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual int Connect(const net::CompletionCallback& callback) OVERRIDE;
virtual void Disconnect() OVERRIDE; virtual void Disconnect() OVERRIDE;
virtual bool IsConnected() const OVERRIDE; virtual bool IsConnected() const OVERRIDE;
...@@ -67,6 +74,7 @@ class FakeSocket : public net::StreamSocket { ...@@ -67,6 +74,7 @@ class FakeSocket : public net::StreamSocket {
scoped_refptr<net::IOBuffer> read_buffer_; scoped_refptr<net::IOBuffer> read_buffer_;
int read_buffer_size_; int read_buffer_size_;
net::CompletionCallback read_callback_; net::CompletionCallback read_callback_;
base::WeakPtr<FakeSocket> peer_socket_;
std::string written_data_; std::string written_data_;
std::string input_data_; std::string input_data_;
...@@ -75,6 +83,7 @@ class FakeSocket : public net::StreamSocket { ...@@ -75,6 +83,7 @@ class FakeSocket : public net::StreamSocket {
net::BoundNetLog net_log_; net::BoundNetLog net_log_;
MessageLoop* message_loop_; MessageLoop* message_loop_;
base::WeakPtrFactory<FakeSocket> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(FakeSocket); DISALLOW_COPY_AND_ASSIGN(FakeSocket);
}; };
......
...@@ -45,6 +45,8 @@ JingleSession::JingleSession( ...@@ -45,6 +45,8 @@ JingleSession::JingleSession(
jid_ = cricket_session_->remote_name(); jid_ = cricket_session_->remote_name();
cricket_session_->SignalState.connect(this, &JingleSession::OnSessionState); cricket_session_->SignalState.connect(this, &JingleSession::OnSessionState);
cricket_session_->SignalError.connect(this, &JingleSession::OnSessionError); cricket_session_->SignalError.connect(this, &JingleSession::OnSessionError);
cricket_session_->SignalReceivedTerminateReason.connect(
this, &JingleSession::OnTerminateReason);
} }
JingleSession::~JingleSession() { JingleSession::~JingleSession() {
...@@ -244,6 +246,11 @@ void JingleSession::OnSessionError( ...@@ -244,6 +246,11 @@ void JingleSession::OnSessionError(
} }
} }
void JingleSession::OnTerminateReason(cricket::Session* session,
const std::string& reason) {
terminate_reason_ = reason;
}
void JingleSession::OnInitiate() { void JingleSession::OnInitiate() {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
jid_ = cricket_session_->remote_name(); jid_ = cricket_session_->remote_name();
...@@ -332,7 +339,16 @@ void JingleSession::OnAccept() { ...@@ -332,7 +339,16 @@ void JingleSession::OnAccept() {
void JingleSession::OnTerminate() { void JingleSession::OnTerminate() {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
CloseInternal(net::ERR_CONNECTION_ABORTED, OK);
if (terminate_reason_ == "success") {
CloseInternal(net::ERR_CONNECTION_ABORTED, OK);
} else if (terminate_reason_ == "decline") {
CloseInternal(net::ERR_CONNECTION_ABORTED, AUTHENTICATION_FAILED);
} else if (terminate_reason_ == "incompatible-protocol") {
CloseInternal(net::ERR_CONNECTION_ABORTED, INCOMPATIBLE_PROTOCOL);
} else {
CloseInternal(net::ERR_CONNECTION_ABORTED, UNKNOWN_ERROR);
}
} }
void JingleSession::AcceptConnection() { void JingleSession::AcceptConnection() {
......
...@@ -80,6 +80,8 @@ class JingleSession : public protocol::Session, ...@@ -80,6 +80,8 @@ class JingleSession : public protocol::Session,
// Used for Session.SignalError sigslot. // Used for Session.SignalError sigslot.
void OnSessionError(cricket::BaseSession* session, void OnSessionError(cricket::BaseSession* session,
cricket::BaseSession::Error error); cricket::BaseSession::Error error);
// Used for Session.SignalReceivedTerminateReason sigslot.
void OnTerminateReason(cricket::Session* session, const std::string& reason);
void OnInitiate(); void OnInitiate();
void OnAccept(); void OnAccept();
...@@ -135,6 +137,11 @@ class JingleSession : public protocol::Session, ...@@ -135,6 +137,11 @@ class JingleSession : public protocol::Session,
// Channels that are currently being connected. // Channels that are currently being connected.
ChannelConnectorsMap channel_connectors_; ChannelConnectorsMap channel_connectors_;
// Termination reason. Needs to be stored because
// SignalReceivedTerminateReason handler is not allowed to destroy
// the object.
std::string terminate_reason_;
ScopedRunnableMethodFactory<JingleSession> task_factory_; ScopedRunnableMethodFactory<JingleSession> task_factory_;
DISALLOW_COPY_AND_ASSIGN(JingleSession); DISALLOW_COPY_AND_ASSIGN(JingleSession);
......
...@@ -68,7 +68,7 @@ class MessageReaderTest : public testing::Test { ...@@ -68,7 +68,7 @@ class MessageReaderTest : public testing::Test {
std::string data = std::string(4, ' ') + message; std::string data = std::string(4, ' ') + message;
talk_base::SetBE32(const_cast<char*>(data.data()), message.size()); talk_base::SetBE32(const_cast<char*>(data.data()), message.size());
socket_.AppendInputData(data.data(), data.size()); socket_.AppendInputData(std::vector<char>(data.begin(), data.end()));
} }
bool CompareResult(CompoundBuffer* buffer, const std::string& expected) { bool CompareResult(CompoundBuffer* buffer, const std::string& expected) {
......
...@@ -60,6 +60,7 @@ class Session : public base::NonThreadSafe { ...@@ -60,6 +60,7 @@ class Session : public base::NonThreadSafe {
INCOMPATIBLE_PROTOCOL, INCOMPATIBLE_PROTOCOL,
AUTHENTICATION_FAILED, AUTHENTICATION_FAILED,
CHANNEL_CONNECTION_ERROR, CHANNEL_CONNECTION_ERROR,
UNKNOWN_ERROR,
}; };
// State change callbacks are called after session state has // State change callbacks are called after session state has
......
...@@ -2,24 +2,47 @@ ...@@ -2,24 +2,47 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include "remoting/protocol/v1_authenticator.h"
#include "base/bind.h"
#include "base/file_path.h" #include "base/file_path.h"
#include "base/file_util.h" #include "base/file_util.h"
#include "base/message_loop.h"
#include "base/path_service.h" #include "base/path_service.h"
#include "crypto/rsa_private_key.h" #include "crypto/rsa_private_key.h"
#include "remoting/protocol/v1_authenticator.h" #include "net/base/net_errors.h"
#include "remoting/protocol/authenticator.h"
#include "remoting/protocol/channel_authenticator.h"
#include "remoting/protocol/connection_tester.h"
#include "remoting/protocol/fake_session.h"
#include "remoting/protocol/v1_client_channel_authenticator.h"
#include "remoting/protocol/v1_host_channel_authenticator.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
#include "third_party/libjingle/source/talk/xmllite/xmlelement.h" #include "third_party/libjingle/source/talk/xmllite/xmlelement.h"
using testing::_;
using testing::DeleteArg;
using testing::SaveArg;
namespace remoting { namespace remoting {
namespace protocol { namespace protocol {
namespace { namespace {
const char kHostJid[] = "host1@gmail.com/123";
const int kMessageSize = 100;
const int kMessages = 1;
const char kClientJid[] = "host2@gmail.com/321"; const char kClientJid[] = "host2@gmail.com/321";
const char kTestSharedSecret[] = "1234-1234-5678"; const char kTestSharedSecret[] = "1234-1234-5678";
const char kTestSharedSecretBad[] = "0000-0000-0001"; const char kTestSharedSecretBad[] = "0000-0000-0001";
class MockChannelDoneCallback {
public:
MOCK_METHOD2(OnDone, void(net::Error error, net::StreamSocket* socket));
};
} // namespace } // namespace
class V1AuthenticatorTest : public testing::Test { class V1AuthenticatorTest : public testing::Test {
...@@ -30,8 +53,7 @@ class V1AuthenticatorTest : public testing::Test { ...@@ -30,8 +53,7 @@ class V1AuthenticatorTest : public testing::Test {
} }
protected: protected:
void InitAuthenticators(const std::string& client_secret, virtual void SetUp() OVERRIDE {
const std::string& host_secret) {
FilePath certs_dir; FilePath certs_dir;
PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir); PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir);
certs_dir = certs_dir.AppendASCII("net"); certs_dir = certs_dir.AppendASCII("net");
...@@ -40,8 +62,7 @@ class V1AuthenticatorTest : public testing::Test { ...@@ -40,8 +62,7 @@ class V1AuthenticatorTest : public testing::Test {
certs_dir = certs_dir.AppendASCII("certificates"); certs_dir = certs_dir.AppendASCII("certificates");
FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
std::string cert_der; ASSERT_TRUE(file_util::ReadFileToString(cert_path, &host_cert_));
ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der));
FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
std::string key_string; std::string key_string;
...@@ -52,9 +73,12 @@ class V1AuthenticatorTest : public testing::Test { ...@@ -52,9 +73,12 @@ class V1AuthenticatorTest : public testing::Test {
key_string.length())); key_string.length()));
private_key_.reset( private_key_.reset(
crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
}
void InitAuthenticators(const std::string& client_secret,
const std::string& host_secret) {
host_.reset(new V1HostAuthenticator( host_.reset(new V1HostAuthenticator(
cert_der, private_key_.get(), host_secret, kClientJid)); host_cert_, private_key_.get(), host_secret, kClientJid));
client_.reset(new V1ClientAuthenticator(kClientJid, client_secret)); client_.reset(new V1ClientAuthenticator(kClientJid, client_secret));
} }
...@@ -91,9 +115,53 @@ class V1AuthenticatorTest : public testing::Test { ...@@ -91,9 +115,53 @@ class V1AuthenticatorTest : public testing::Test {
host_->state() != Authenticator::REJECTED); host_->state() != Authenticator::REJECTED);
} }
void RunChannelAuth(bool expected_fail) {
client_fake_socket_.reset(new FakeSocket());
host_fake_socket_.reset(new FakeSocket());
client_fake_socket_->PairWith(host_fake_socket_.get());
client_auth_->SecureAndAuthenticate(
client_fake_socket_.release(),
base::Bind(&MockChannelDoneCallback::OnDone,
base::Unretained(&client_callback_)));
host_auth_->SecureAndAuthenticate(
host_fake_socket_.release(),
base::Bind(&MockChannelDoneCallback::OnDone,
base::Unretained(&host_callback_)));
net::StreamSocket* client_socket = NULL;
net::StreamSocket* host_socket = NULL;
EXPECT_CALL(client_callback_, OnDone(net::OK, _))
.WillOnce(SaveArg<1>(&client_socket));
if (expected_fail) {
EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED, NULL));
} else {
EXPECT_CALL(host_callback_, OnDone(net::OK, _))
.WillOnce(SaveArg<1>(&host_socket));
}
message_loop_.RunAllPending();
client_socket_.reset(client_socket);
host_socket_.reset(host_socket);
}
MessageLoop message_loop_;
scoped_ptr<crypto::RSAPrivateKey> private_key_; scoped_ptr<crypto::RSAPrivateKey> private_key_;
std::string host_cert_;
scoped_ptr<V1HostAuthenticator> host_; scoped_ptr<V1HostAuthenticator> host_;
scoped_ptr<V1ClientAuthenticator> client_; scoped_ptr<V1ClientAuthenticator> client_;
scoped_ptr<FakeSocket> client_fake_socket_;
scoped_ptr<FakeSocket> host_fake_socket_;
scoped_ptr<ChannelAuthenticator> client_auth_;
scoped_ptr<ChannelAuthenticator> host_auth_;
MockChannelDoneCallback client_callback_;
MockChannelDoneCallback host_callback_;
scoped_ptr<net::StreamSocket> client_socket_;
scoped_ptr<net::StreamSocket> host_socket_;
DISALLOW_COPY_AND_ASSIGN(V1AuthenticatorTest); DISALLOW_COPY_AND_ASSIGN(V1AuthenticatorTest);
}; };
...@@ -106,8 +174,23 @@ TEST_F(V1AuthenticatorTest, SuccessfulAuth) { ...@@ -106,8 +174,23 @@ TEST_F(V1AuthenticatorTest, SuccessfulAuth) {
} }
ASSERT_EQ(Authenticator::ACCEPTED, host_->state()); ASSERT_EQ(Authenticator::ACCEPTED, host_->state());
ASSERT_EQ(Authenticator::ACCEPTED, client_->state()); ASSERT_EQ(Authenticator::ACCEPTED, client_->state());
client_auth_.reset(client_->CreateChannelAuthenticator());
host_auth_.reset(host_->CreateChannelAuthenticator());
RunChannelAuth(false);
EXPECT_TRUE(client_socket_.get() != NULL);
EXPECT_TRUE(host_socket_.get() != NULL);
StreamConnectionTester tester(host_socket_.get(), client_socket_.get(),
kMessageSize, kMessages);
tester.Start();
message_loop_.Run();
tester.CheckResults();
} }
// Verify that connection is rejected when secrets don't match.
TEST_F(V1AuthenticatorTest, InvalidSecret) { TEST_F(V1AuthenticatorTest, InvalidSecret) {
{ {
SCOPED_TRACE("RunAuthExchange"); SCOPED_TRACE("RunAuthExchange");
...@@ -117,5 +200,17 @@ TEST_F(V1AuthenticatorTest, InvalidSecret) { ...@@ -117,5 +200,17 @@ TEST_F(V1AuthenticatorTest, InvalidSecret) {
ASSERT_EQ(Authenticator::REJECTED, host_->state()); ASSERT_EQ(Authenticator::REJECTED, host_->state());
} }
// Verify that channels cannot be using invalid shared secret.
TEST_F(V1AuthenticatorTest, InvalidChannelSecret) {
client_auth_.reset(new V1ClientChannelAuthenticator(
host_cert_, kTestSharedSecretBad));
host_auth_.reset(new V1HostChannelAuthenticator(
host_cert_, private_key_.get(),kTestSharedSecret));
RunChannelAuth(true);
EXPECT_TRUE(host_socket_.get() == NULL);
}
} // namespace protocol } // namespace protocol
} // namespace remoting } // namespace remoting
...@@ -25,7 +25,7 @@ namespace remoting { ...@@ -25,7 +25,7 @@ namespace remoting {
namespace protocol { namespace protocol {
class V1ClientChannelAuthenticator : public ChannelAuthenticator, class V1ClientChannelAuthenticator : public ChannelAuthenticator,
public base::NonThreadSafe { public base::NonThreadSafe {
public: public:
V1ClientChannelAuthenticator(const std::string& host_cert, V1ClientChannelAuthenticator(const std::string& host_cert,
const std::string& shared_secret); const std::string& shared_secret);
......
...@@ -25,13 +25,13 @@ namespace remoting { ...@@ -25,13 +25,13 @@ namespace remoting {
namespace protocol { namespace protocol {
class V1HostChannelAuthenticator : public ChannelAuthenticator, class V1HostChannelAuthenticator : public ChannelAuthenticator,
public base::NonThreadSafe { public base::NonThreadSafe {
public: public:
// Caller retains ownership of |local_private_key|. It must exist // Caller retains ownership of |local_private_key|. It must exist
// while this object exists. // while this object exists.
V1HostChannelAuthenticator(const std::string& local_cert, V1HostChannelAuthenticator(const std::string& local_cert,
crypto::RSAPrivateKey* local_private_key, crypto::RSAPrivateKey* local_private_key,
const std::string& shared_secret); const std::string& shared_secret);
virtual ~V1HostChannelAuthenticator(); virtual ~V1HostChannelAuthenticator();
// ChannelAuthenticator interface. // ChannelAuthenticator interface.
......
...@@ -926,7 +926,11 @@ ...@@ -926,7 +926,11 @@
'jingle_glue/jingle_thread_unittest.cc', 'jingle_glue/jingle_thread_unittest.cc',
'jingle_glue/mock_objects.cc', 'jingle_glue/mock_objects.cc',
'jingle_glue/mock_objects.h', 'jingle_glue/mock_objects.h',
'protocol/connection_tester.cc',
'protocol/connection_tester.h',
'protocol/connection_to_client_unittest.cc', 'protocol/connection_to_client_unittest.cc',
'protocol/fake_authenticator.cc',
'protocol/fake_authenticator.h',
'protocol/fake_session.cc', 'protocol/fake_session.cc',
'protocol/fake_session.h', 'protocol/fake_session.h',
'protocol/jingle_messages_unittest.cc', 'protocol/jingle_messages_unittest.cc',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment