Commit dfcc8927 authored by sergeyu@chromium.org's avatar sergeyu@chromium.org

Implement ChannelMultiplexer.

ChannelMultiplexer allows multiple logical channels to share a 
single underlying transport channel.

BUG=137135

Review URL: https://chromiumcodereview.appspot.com/10830046

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@150484 0039d316-1c4b-4281-b951-d872f2087c98
parent d7f7f753
......@@ -4,6 +4,7 @@
#include "remoting/host/server_log_entry.h"
#include "base/logging.h"
#include "base/sys_info.h"
#include "remoting/base/constants.h"
#include "remoting/protocol/session.h"
......
......@@ -15,6 +15,7 @@
'control.proto',
'event.proto',
'internal.proto',
'mux.proto',
'video.proto',
],
'variables': {
......
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// Protocol for the mux channel that multiplexes multiple channels.
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package remoting.protocol;
message MultiplexPacket {
// Channel ID. Each peer choses this value when it sends first packet to
// the other peer. It unique identified channel this packet belongs to.
// Channel ID is direction-specific, i.e. each channel has two IDs
// assigned to it: one for receiving and one for sending.
optional int32 channel_id = 1;
// Channel name. The name is used to identify channels before channel ID
// is assigned in the first message. This value must be included only
// in the first packet for a given channel. All other packets must be
// identified using channel ID.
optional string channel_name = 2;
optional bytes data = 3;
}
......@@ -55,7 +55,9 @@ bool BufferedSocketWriterBase::Write(
buffer_size_ += data->size();
DoWrite();
return true;
// DoWrite() may trigger OnWriteError() to be called.
return !closed_;
}
void BufferedSocketWriterBase::DoWrite() {
......
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef REMOTING_PROTOCOL_CHANNEL_FACTORY_H_
#define REMOTING_PROTOCOL_CHANNEL_FACTORY_H_
#include "base/callback.h"
#include "base/memory/scoped_ptr.h"
#include "base/threading/non_thread_safe.h"
namespace net {
class Socket;
class StreamSocket;
} // namespace net
namespace remoting {
namespace protocol {
class ChannelFactory : public base::NonThreadSafe {
public:
// TODO(sergeyu): Specify connection error code when channel
// connection fails.
typedef base::Callback<void(scoped_ptr<net::StreamSocket>)>
StreamChannelCallback;
typedef base::Callback<void(scoped_ptr<net::Socket>)>
DatagramChannelCallback;
ChannelFactory() {}
// Creates new channels for this connection. The specified callback is called
// when then new channel is created and connected. The callback is called with
// NULL if connection failed for any reason. Callback may be called
// synchronously, before the call returns. All channels must be destroyed
// before the factory is destroyed and CancelChannelCreation() must be called
// to cancel creation of channels for which the |callback| hasn't been called
// yet.
virtual void CreateStreamChannel(
const std::string& name, const StreamChannelCallback& callback) = 0;
virtual void CreateDatagramChannel(
const std::string& name, const DatagramChannelCallback& callback) = 0;
// Cancels a pending CreateStreamChannel() or CreateDatagramChannel()
// operation for the named channel. If the channel creation already
// completed then canceling it has no effect. When shutting down
// this method must be called for each channel pending creation.
virtual void CancelChannelCreation(const std::string& name) = 0;
protected:
virtual ~ChannelFactory() {}
private:
DISALLOW_COPY_AND_ASSIGN(ChannelFactory);
};
} // namespace protocol
} // namespace remoting
#endif // REMOTING_PROTOCOL_CHANNEL_FACTORY_H_
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/protocol/channel_multiplexer.h"
#include <string.h>
#include "base/bind.h"
#include "base/callback.h"
#include "base/location.h"
#include "base/stl_util.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
#include "remoting/protocol/util.h"
namespace remoting {
namespace protocol {
namespace {
const int kChannelIdUnknown = -1;
const int kMaxPacketSize = 1024;
class PendingPacket {
public:
PendingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task)
: packet(packet.Pass()),
done_task(done_task),
pos(0U) {
}
~PendingPacket() {
done_task.Run();
}
bool is_empty() { return pos >= packet->data().size(); }
int Read(char* buffer, size_t size) {
size = std::min(size, packet->data().size() - pos);
memcpy(buffer, packet->data().data() + pos, size);
pos += size;
return size;
}
private:
scoped_ptr<MultiplexPacket> packet;
base::Closure done_task;
size_t pos;
DISALLOW_COPY_AND_ASSIGN(PendingPacket);
};
} // namespace
const char ChannelMultiplexer::kMuxChannelName[] = "mux";
struct ChannelMultiplexer::PendingChannel {
PendingChannel(const std::string& name,
const StreamChannelCallback& callback)
: name(name), callback(callback) {
}
std::string name;
StreamChannelCallback callback;
};
class ChannelMultiplexer::MuxChannel {
public:
MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
int send_id);
~MuxChannel();
const std::string& name() { return name_; }
int receive_id() { return receive_id_; }
void set_receive_id(int id) { receive_id_ = id; }
// Called by ChannelMultiplexer.
scoped_ptr<net::StreamSocket> CreateSocket();
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task);
void OnWriteFailed();
// Called by MuxSocket.
void OnSocketDestroyed();
bool DoWrite(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task);
int DoRead(net::IOBuffer* buffer, int buffer_len);
private:
ChannelMultiplexer* multiplexer_;
std::string name_;
int send_id_;
bool id_sent_;
int receive_id_;
MuxSocket* socket_;
std::list<PendingPacket*> pending_packets_;
DISALLOW_COPY_AND_ASSIGN(MuxChannel);
};
class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
public base::NonThreadSafe,
public base::SupportsWeakPtr<MuxSocket> {
public:
MuxSocket(MuxChannel* channel);
~MuxSocket();
void OnWriteComplete();
void OnWriteFailed();
void OnPacketReceived();
// net::StreamSocket interface.
virtual int Read(net::IOBuffer* buffer, int buffer_len,
const net::CompletionCallback& callback) OVERRIDE;
virtual int Write(net::IOBuffer* buffer, int buffer_len,
const net::CompletionCallback& callback) OVERRIDE;
virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
NOTIMPLEMENTED();
return false;
}
virtual bool SetSendBufferSize(int32 size) OVERRIDE {
NOTIMPLEMENTED();
return false;
}
virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
NOTIMPLEMENTED();
return net::ERR_FAILED;
}
virtual void Disconnect() OVERRIDE {
NOTIMPLEMENTED();
}
virtual bool IsConnected() const OVERRIDE {
NOTIMPLEMENTED();
return true;
}
virtual bool IsConnectedAndIdle() const OVERRIDE {
NOTIMPLEMENTED();
return false;
}
virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
NOTIMPLEMENTED();
return net::ERR_FAILED;
}
virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
NOTIMPLEMENTED();
return net::ERR_FAILED;
}
virtual const net::BoundNetLog& NetLog() const OVERRIDE {
NOTIMPLEMENTED();
return net_log_;
}
virtual void SetSubresourceSpeculation() OVERRIDE {
NOTIMPLEMENTED();
}
virtual void SetOmniboxSpeculation() OVERRIDE {
NOTIMPLEMENTED();
}
virtual bool WasEverUsed() const OVERRIDE {
return true;
}
virtual bool UsingTCPFastOpen() const OVERRIDE {
return false;
}
virtual int64 NumBytesRead() const OVERRIDE {
NOTIMPLEMENTED();
return 0;
}
virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE {
NOTIMPLEMENTED();
return base::TimeDelta();
}
virtual bool WasNpnNegotiated() const OVERRIDE {
return false;
}
virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
return net::kProtoUnknown;
}
virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
NOTIMPLEMENTED();
return false;
}
private:
MuxChannel* channel_;
net::CompletionCallback read_callback_;
scoped_refptr<net::IOBuffer> read_buffer_;
int read_buffer_size_;
bool write_pending_;
int write_result_;
net::CompletionCallback write_callback_;
net::BoundNetLog net_log_;
DISALLOW_COPY_AND_ASSIGN(MuxSocket);
};
ChannelMultiplexer::MuxChannel::MuxChannel(
ChannelMultiplexer* multiplexer,
const std::string& name,
int send_id)
: multiplexer_(multiplexer),
name_(name),
send_id_(send_id),
id_sent_(false),
receive_id_(kChannelIdUnknown),
socket_(NULL) {
}
ChannelMultiplexer::MuxChannel::~MuxChannel() {
// Socket must be destroyed before the channel.
DCHECK(!socket_);
STLDeleteElements(&pending_packets_);
}
scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
DCHECK(!socket_); // Can't create more than one socket per channel.
scoped_ptr<MuxSocket> result(new MuxSocket(this));
socket_ = result.get();
return result.PassAs<net::StreamSocket>();
}
void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task) {
DCHECK_EQ(packet->channel_id(), receive_id_);
if (packet->data().size() > 0) {
pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
if (socket_) {
// Notify the socket that we have more data.
socket_->OnPacketReceived();
}
}
}
void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
if (socket_)
socket_->OnWriteFailed();
}
void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
DCHECK(socket_);
socket_ = NULL;
}
bool ChannelMultiplexer::MuxChannel::DoWrite(
scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task) {
packet->set_channel_id(send_id_);
if (!id_sent_) {
packet->set_channel_name(name_);
id_sent_ = true;
}
return multiplexer_->DoWrite(packet.Pass(), done_task);
}
int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
int buffer_len) {
int pos = 0;
while (buffer_len > 0 && !pending_packets_.empty()) {
DCHECK(!pending_packets_.front()->is_empty());
int result = pending_packets_.front()->Read(
buffer->data() + pos, buffer_len);
DCHECK_LE(result, buffer_len);
pos += result;
buffer_len -= pos;
if (pending_packets_.front()->is_empty()) {
delete pending_packets_.front();
pending_packets_.erase(pending_packets_.begin());
}
}
return pos;
}
ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
: channel_(channel),
read_buffer_size_(0),
write_pending_(false),
write_result_(0) {
}
ChannelMultiplexer::MuxSocket::~MuxSocket() {
channel_->OnSocketDestroyed();
}
int ChannelMultiplexer::MuxSocket::Read(
net::IOBuffer* buffer, int buffer_len,
const net::CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK(read_callback_.is_null());
int result = channel_->DoRead(buffer, buffer_len);
if (result == 0) {
read_buffer_ = buffer;
read_buffer_size_ = buffer_len;
read_callback_ = callback;
return net::ERR_IO_PENDING;
}
return result;
}
int ChannelMultiplexer::MuxSocket::Write(
net::IOBuffer* buffer, int buffer_len,
const net::CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
size_t size = std::min(kMaxPacketSize, buffer_len);
packet->mutable_data()->assign(buffer->data(), size);
write_pending_ = true;
bool result = channel_->DoWrite(packet.Pass(), base::Bind(
&ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
if (!result) {
// Cannot complete the write, e.g. if the connection has been terminated.
return net::ERR_FAILED;
}
// OnWriteComplete() might be called above synchronously.
if (write_pending_) {
DCHECK(write_callback_.is_null());
write_callback_ = callback;
write_result_ = size;
return net::ERR_IO_PENDING;
}
return size;
}
void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
write_pending_ = false;
if (!write_callback_.is_null()) {
net::CompletionCallback cb;
std::swap(cb, write_callback_);
cb.Run(write_result_);
}
}
void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
if (!write_callback_.is_null()) {
net::CompletionCallback cb;
std::swap(cb, write_callback_);
cb.Run(net::ERR_FAILED);
}
}
void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
if (!read_callback_.is_null()) {
int result = channel_->DoRead(read_buffer_, read_buffer_size_);
read_buffer_ = NULL;
DCHECK_GT(result, 0);
net::CompletionCallback cb;
std::swap(cb, read_callback_);
cb.Run(result);
}
}
ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
const std::string& base_channel_name)
: base_channel_factory_(factory),
base_channel_name_(base_channel_name),
next_channel_id_(0),
destroyed_flag_(NULL) {
factory->CreateStreamChannel(
base_channel_name,
base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
base::Unretained(this)));
}
ChannelMultiplexer::~ChannelMultiplexer() {
DCHECK(pending_channels_.empty());
STLDeleteValues(&channels_);
// Cancel creation of the base channel if it hasn't finished.
if (base_channel_factory_)
base_channel_factory_->CancelChannelCreation(base_channel_name_);
if (destroyed_flag_)
*destroyed_flag_ = true;
}
void ChannelMultiplexer::CreateStreamChannel(
const std::string& name,
const StreamChannelCallback& callback) {
if (base_channel_.get()) {
// Already have |base_channel_|. Create new multiplexed channel
// synchronously.
callback.Run(GetOrCreateChannel(name)->CreateSocket());
} else if (!base_channel_.get() && !base_channel_factory_) {
// Fail synchronously if we failed to create |base_channel_|.
callback.Run(scoped_ptr<net::StreamSocket>());
} else {
// Still waiting for the |base_channel_|.
pending_channels_.push_back(PendingChannel(name, callback));
}
}
void ChannelMultiplexer::CreateDatagramChannel(
const std::string& name,
const DatagramChannelCallback& callback) {
NOTIMPLEMENTED();
callback.Run(scoped_ptr<net::Socket>());
}
void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
it != pending_channels_.end(); ++it) {
if (it->name == name) {
pending_channels_.erase(it);
return;
}
}
}
void ChannelMultiplexer::OnBaseChannelReady(
scoped_ptr<net::StreamSocket> socket) {
base_channel_factory_ = NULL;
base_channel_ = socket.Pass();
if (!base_channel_.get()) {
// Notify all callers that we can't create any channels.
for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
it != pending_channels_.end(); ++it) {
it->callback.Run(scoped_ptr<net::StreamSocket>());
}
pending_channels_.clear();
return;
}
// Initialize reader and writer.
reader_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnIncomingPacket,
base::Unretained(this)));
writer_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnWriteFailed,
base::Unretained(this)));
// Now create all pending channels.
for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
it != pending_channels_.end(); ++it) {
it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket());
}
pending_channels_.clear();
}
ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
const std::string& name) {
// Check if we already have a channel with the requested name.
std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
if (it != channels_.end())
return it->second;
// Create a new channel if we haven't found existing one.
MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
++next_channel_id_;
channels_[channel->name()] = channel;
return channel;
}
void ChannelMultiplexer::OnWriteFailed(int error) {
bool destroyed = false;
destroyed_flag_ = &destroyed;
for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
it != channels_.end(); ++it) {
it->second->OnWriteFailed();
if (destroyed)
return;
}
destroyed_flag_ = NULL;
}
void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task) {
if (!packet->has_channel_id()) {
LOG(ERROR) << "Received packet without channel_id.";
done_task.Run();
return;
}
int receive_id = packet->channel_id();
MuxChannel* channel = NULL;
std::map<int, MuxChannel*>::iterator it =
channels_by_receive_id_.find(receive_id);
if (it != channels_by_receive_id_.end()) {
channel = it->second;
} else {
// This is a new |channel_id| we haven't seen before. Look it up by name.
if (!packet->has_channel_name()) {
LOG(ERROR) << "Received packet with unknown channel_id and "
"without channel_name.";
done_task.Run();
return;
}
channel = GetOrCreateChannel(packet->channel_name());
channel->set_receive_id(receive_id);
channels_by_receive_id_[receive_id] = channel;
}
channel->OnIncomingPacket(packet.Pass(), done_task);
}
bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task) {
return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
}
} // namespace protocol
} // namespace remoting
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
#define REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
#include "remoting/proto/mux.pb.h"
#include "remoting/protocol/buffered_socket_writer.h"
#include "remoting/protocol/channel_factory.h"
#include "remoting/protocol/message_reader.h"
namespace remoting {
namespace protocol {
class ChannelMultiplexer : public ChannelFactory {
public:
static const char kMuxChannelName[];
// |factory| is used to create the channel upon which to multiplex.
ChannelMultiplexer(ChannelFactory* factory,
const std::string& base_channel_name);
virtual ~ChannelMultiplexer();
// ChannelFactory interface.
virtual void CreateStreamChannel(
const std::string& name,
const StreamChannelCallback& callback) OVERRIDE;
virtual void CreateDatagramChannel(
const std::string& name,
const DatagramChannelCallback& callback) OVERRIDE;
virtual void CancelChannelCreation(const std::string& name) OVERRIDE;
private:
struct PendingChannel;
class MuxChannel;
class MuxSocket;
friend class MuxChannel;
// Callback for |base_channel_| creation.
void OnBaseChannelReady(scoped_ptr<net::StreamSocket> socket);
// Helper method used to create channels.
MuxChannel* GetOrCreateChannel(const std::string& name);
// Callbacks for |writer_| and |reader_|.
void OnWriteFailed(int error);
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task);
// Called by MuxChannel.
bool DoWrite(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task);
// Factory used to create |base_channel_|. Set to NULL once creation is
// finished or failed.
ChannelFactory* base_channel_factory_;
// Name of the underlying channel.
std::string base_channel_name_;
// The channel over which to multiplex.
scoped_ptr<net::StreamSocket> base_channel_;
// List of requested channels while we are waiting for |base_channel_|.
std::list<PendingChannel> pending_channels_;
int next_channel_id_;
std::map<std::string, MuxChannel*> channels_;
// Channels are added to |channels_by_receive_id_| only after we receive
// receive_id from the remote peer.
std::map<int, MuxChannel*> channels_by_receive_id_;
BufferedSocketWriter writer_;
ProtobufMessageReader<MultiplexPacket> reader_;
// Flag used by OnWriteFailed() to detect when the multiplexer is destroyed.
bool* destroyed_flag_;
DISALLOW_COPY_AND_ASSIGN(ChannelMultiplexer);
};
} // namespace protocol
} // namespace remoting
#endif // REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/protocol/channel_multiplexer.h"
#include "base/bind.h"
#include "base/message_loop.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
#include "net/socket/stream_socket.h"
#include "remoting/base/constants.h"
#include "remoting/protocol/connection_tester.h"
#include "remoting/protocol/fake_session.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::AtMost;
using testing::InvokeWithoutArgs;
namespace remoting {
namespace protocol {
namespace {
const int kMessageSize = 1024;
const int kMessages = 100;
const char kMuxChannelName[] = "mux";
void QuitCurrentThread() {
MessageLoop::current()->PostTask(FROM_HERE, MessageLoop::QuitClosure());
}
class MockSocketCallback {
public:
MOCK_METHOD1(OnDone, void(int result));
};
} // namespace
class ChannelMultiplexerTest : public testing::Test {
public:
void DeleteAll() {
host_socket1_.reset();
host_socket2_.reset();
client_socket1_.reset();
client_socket2_.reset();
host_mux_.reset();
client_mux_.reset();
}
protected:
virtual void SetUp() OVERRIDE {
// Create pair of multiplexers and connect them to each other.
host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName));
client_mux_.reset(new ChannelMultiplexer(&client_session_,
kMuxChannelName));
FakeSocket* host_socket =
host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
FakeSocket* client_socket =
client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
host_socket->PairWith(client_socket);
// Make writes asynchronous in one direction.
host_socket->set_async_write(true);
}
void CreateChannel(const std::string& name,
scoped_ptr<net::StreamSocket>* host_socket,
scoped_ptr<net::StreamSocket>* client_socket) {
int counter = 2;
host_mux_->CreateStreamChannel(name, base::Bind(
&ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
host_socket, &counter));
client_mux_->CreateStreamChannel(name, base::Bind(
&ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
client_socket, &counter));
message_loop_.Run();
EXPECT_TRUE(host_socket->get());
EXPECT_TRUE(client_socket->get());
}
void OnChannelConnected(
scoped_ptr<net::StreamSocket>* storage,
int* counter,
scoped_ptr<net::StreamSocket> socket) {
*storage = socket.Pass();
--(*counter);
EXPECT_GE(*counter, 0);
if (*counter == 0)
QuitCurrentThread();
}
scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) {
scoped_refptr<net::IOBufferWithSize> result =
new net::IOBufferWithSize(size);
for (int i = 0; i< size; ++i) {
result->data()[i] = rand() % 256;
}
return result;
}
MessageLoop message_loop_;
FakeSession host_session_;
FakeSession client_session_;
scoped_ptr<ChannelMultiplexer> host_mux_;
scoped_ptr<ChannelMultiplexer> client_mux_;
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
};
TEST_F(ChannelMultiplexerTest, OneChannel) {
scoped_ptr<net::StreamSocket> host_socket;
scoped_ptr<net::StreamSocket> client_socket;
ASSERT_NO_FATAL_FAILURE(CreateChannel("test", &host_socket, &client_socket));
StreamConnectionTester tester(host_socket.get(), client_socket.get(),
kMessageSize, kMessages);
tester.Start();
message_loop_.Run();
tester.CheckResults();
}
TEST_F(ChannelMultiplexerTest, TwoChannels) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
kMessageSize, kMessages);
tester1.Start();
tester2.Start();
while (!tester1.done() || !tester2.done()) {
message_loop_.Run();
}
tester1.CheckResults();
tester2.CheckResults();
}
// Four channels, two in each direction
TEST_F(ChannelMultiplexerTest, FourChannels) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
scoped_ptr<net::StreamSocket> host_socket3;
scoped_ptr<net::StreamSocket> client_socket3;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test3", &host_socket3, &client_socket3));
scoped_ptr<net::StreamSocket> host_socket4;
scoped_ptr<net::StreamSocket> client_socket4;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch4", &host_socket4, &client_socket4));
StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(),
kMessageSize, kMessages);
StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(),
kMessageSize, kMessages);
tester1.Start();
tester2.Start();
tester3.Start();
tester4.Start();
while (!tester1.done() || !tester2.done() ||
!tester3.done() || !tester4.done()) {
message_loop_.Run();
}
tester1.CheckResults();
tester2.CheckResults();
tester3.CheckResults();
tester4.CheckResults();
}
TEST_F(ChannelMultiplexerTest, SyncFail) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
set_async_write(false);
scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
MockSocketCallback cb1;
MockSocketCallback cb2;
EXPECT_CALL(cb1, OnDone(_))
.Times(0);
EXPECT_CALL(cb2, OnDone(_))
.Times(0);
EXPECT_EQ(net::ERR_FAILED, host_socket1_->Write(buf, buf->size(), base::Bind(
&MockSocketCallback::OnDone, base::Unretained(&cb1))));
EXPECT_EQ(net::ERR_FAILED, host_socket2_->Write(buf, buf->size(), base::Bind(
&MockSocketCallback::OnDone, base::Unretained(&cb2))));
message_loop_.RunAllPending();
}
TEST_F(ChannelMultiplexerTest, AsyncFail) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
set_async_write(true);
scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
MockSocketCallback cb1;
MockSocketCallback cb2;
EXPECT_CALL(cb1, OnDone(net::ERR_FAILED));
EXPECT_CALL(cb2, OnDone(net::ERR_FAILED));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket1_->Write(buf, buf->size(), base::Bind(
&MockSocketCallback::OnDone, base::Unretained(&cb1))));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket2_->Write(buf, buf->size(), base::Bind(
&MockSocketCallback::OnDone, base::Unretained(&cb2))));
message_loop_.RunAllPending();
}
TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test", &host_socket1_, &client_socket1_));
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch2", &host_socket2_, &client_socket2_));
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
set_async_write(true);
scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
MockSocketCallback cb1;
MockSocketCallback cb2;
EXPECT_CALL(cb1, OnDone(net::ERR_FAILED))
.Times(AtMost(1))
.WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
EXPECT_CALL(cb2, OnDone(net::ERR_FAILED))
.Times(AtMost(1))
.WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket1_->Write(buf, buf->size(), base::Bind(
&MockSocketCallback::OnDone, base::Unretained(&cb1))));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket2_->Write(buf, buf->size(), base::Bind(
&MockSocketCallback::OnDone, base::Unretained(&cb2))));
message_loop_.RunAllPending();
// Check that the sockets were destroyed.
EXPECT_FALSE(host_mux_.get());
}
} // namespace protocol
} // namespace remoting
......@@ -34,6 +34,7 @@ class StreamConnectionTester {
~StreamConnectionTester();
void Start();
bool done() { return done_; }
void CheckResults();
protected:
......
......@@ -7,16 +7,12 @@
#include <string>
#include "base/callback.h"
#include "base/threading/non_thread_safe.h"
#include "remoting/protocol/buffered_socket_writer.h"
#include "remoting/protocol/channel_factory.h"
#include "remoting/protocol/errors.h"
#include "remoting/protocol/session_config.h"
namespace net {
class IPEndPoint;
class Socket;
class StreamSocket;
} // namespace net
namespace remoting {
......@@ -27,7 +23,7 @@ struct TransportRoute;
// Generic interface for Chromotocol connection used by both client and host.
// Provides access to the connection channels, but doesn't depend on the
// protocol used for each channel.
class Session : public base::NonThreadSafe {
class Session : public ChannelFactory {
public:
enum State {
// Created, but not connecting yet.
......@@ -74,12 +70,6 @@ class Session : public base::NonThreadSafe {
bool ready) {}
};
// TODO(sergeyu): Specify connection error code when channel
// connection fails.
typedef base::Callback<void(scoped_ptr<net::StreamSocket>)>
StreamChannelCallback;
typedef base::Callback<void(scoped_ptr<net::Socket>)>
DatagramChannelCallback;
Session() {}
virtual ~Session() {}
......@@ -91,23 +81,6 @@ class Session : public base::NonThreadSafe {
// Returns error code for a failed session.
virtual ErrorCode error() = 0;
// Creates new channels for this connection. The specified callback
// is called when then new channel is created and connected. The
// callback is called with NULL if connection failed for any reason.
// All channels must be destroyed before the session is
// destroyed. Can be called only when in CONNECTING, CONNECTED or
// AUTHENTICATED states.
virtual void CreateStreamChannel(
const std::string& name, const StreamChannelCallback& callback) = 0;
virtual void CreateDatagramChannel(
const std::string& name, const DatagramChannelCallback& callback) = 0;
// Cancels a pending CreateStreamChannel() or CreateDatagramChannel()
// operation for the named channel. If the channel creation already
// completed then cancelling it has no effect. When shutting down
// this method must be called for each channel pending creation.
virtual void CancelChannelCreation(const std::string& name) = 0;
// JID of the other side.
virtual const std::string& jid() = 0;
......
......@@ -1616,6 +1616,8 @@
'protocol/channel_authenticator.h',
'protocol/channel_dispatcher_base.cc',
'protocol/channel_dispatcher_base.h',
'protocol/channel_multiplexer.cc',
'protocol/channel_multiplexer.h',
'protocol/client_control_dispatcher.cc',
'protocol/client_control_dispatcher.h',
'protocol/client_event_dispatcher.cc',
......@@ -1623,11 +1625,11 @@
'protocol/client_stub.h',
'protocol/clipboard_echo_filter.cc',
'protocol/clipboard_echo_filter.h',
'protocol/clipboard_filter.h',
'protocol/clipboard_filter.cc',
'protocol/clipboard_filter.h',
'protocol/clipboard_stub.h',
'protocol/clipboard_thread_proxy.cc',
'protocol/clipboard_thread_proxy.h',
'protocol/clipboard_stub.h',
'protocol/connection_to_client.cc',
'protocol/connection_to_client.h',
'protocol/connection_to_host.cc',
......@@ -1802,6 +1804,7 @@
'protocol/authenticator_test_base.cc',
'protocol/authenticator_test_base.h',
'protocol/buffered_socket_writer_unittest.cc',
'protocol/channel_multiplexer_unittest.cc',
'protocol/clipboard_echo_filter_unittest.cc',
'protocol/connection_tester.cc',
'protocol/connection_tester.h',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment