Commit cc40af91 authored by sergeyu's avatar sergeyu Committed by Commit bot

Simplify data channel creation logic in WebRTC-based protocol

Previously WebrtcDataStreamAdapter was implementing MessageChannelFactory,
which means that it was always creating channels asynchronously and in OPEN
state. Refactored it to create channels in CONNECTING state. Significantly
simplifies channel setup code as it's always synchronous. It's also possible
to reject incoming channels before they are connected and handle channels
rejected by the peer before being fully connected.

Review-Url: https://codereview.chromium.org/2164163002
Cr-Commit-Position: refs/heads/master@{#407046}
parent 3a5d7f96
...@@ -42,7 +42,11 @@ void ChannelDispatcherBase::OnChannelReady( ...@@ -42,7 +42,11 @@ void ChannelDispatcherBase::OnChannelReady(
channel_factory_ = nullptr; channel_factory_ = nullptr;
message_pipe_ = std::move(message_pipe); message_pipe_ = std::move(message_pipe);
message_pipe_->Start(this); message_pipe_->Start(this);
}
void ChannelDispatcherBase::OnMessagePipeOpen() {
DCHECK(!is_connected_);
is_connected_ = true;
event_handler_->OnChannelInitialized(this); event_handler_->OnChannelInitialized(this);
} }
...@@ -52,6 +56,7 @@ void ChannelDispatcherBase::OnMessageReceived( ...@@ -52,6 +56,7 @@ void ChannelDispatcherBase::OnMessageReceived(
} }
void ChannelDispatcherBase::OnMessagePipeClosed() { void ChannelDispatcherBase::OnMessagePipeClosed() {
is_connected_ = false;
event_handler_->OnChannelClosed(this); event_handler_->OnChannelClosed(this);
} }
......
...@@ -57,7 +57,7 @@ class ChannelDispatcherBase : public MessagePipe::EventHandler { ...@@ -57,7 +57,7 @@ class ChannelDispatcherBase : public MessagePipe::EventHandler {
const std::string& channel_name() { return channel_name_; } const std::string& channel_name() { return channel_name_; }
// Returns true if the channel is currently connected. // Returns true if the channel is currently connected.
bool is_connected() { return message_pipe() != nullptr; } bool is_connected() { return is_connected_; }
protected: protected:
explicit ChannelDispatcherBase(const char* channel_name); explicit ChannelDispatcherBase(const char* channel_name);
...@@ -71,12 +71,14 @@ class ChannelDispatcherBase : public MessagePipe::EventHandler { ...@@ -71,12 +71,14 @@ class ChannelDispatcherBase : public MessagePipe::EventHandler {
void OnChannelReady(std::unique_ptr<MessagePipe> message_pipe); void OnChannelReady(std::unique_ptr<MessagePipe> message_pipe);
// MessagePipe::EventHandler interface. // MessagePipe::EventHandler interface.
void OnMessagePipeOpen() override;
void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) override; void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) override;
void OnMessagePipeClosed() override; void OnMessagePipeClosed() override;
std::string channel_name_; std::string channel_name_;
MessageChannelFactory* channel_factory_ = nullptr; MessageChannelFactory* channel_factory_ = nullptr;
EventHandler* event_handler_ = nullptr; EventHandler* event_handler_ = nullptr;
bool is_connected_ = false;
std::unique_ptr<MessagePipe> message_pipe_; std::unique_ptr<MessagePipe> message_pipe_;
......
...@@ -273,7 +273,9 @@ MessagePipeConnectionTester::~MessagePipeConnectionTester() {} ...@@ -273,7 +273,9 @@ MessagePipeConnectionTester::~MessagePipeConnectionTester() {}
void MessagePipeConnectionTester::RunAndCheckResults() { void MessagePipeConnectionTester::RunAndCheckResults() {
host_pipe_->Start(this); host_pipe_->Start(this);
}
void MessagePipeConnectionTester::OnMessagePipeOpen() {
for (int i = 0; i < message_count_; ++i) { for (int i = 0; i < message_count_; ++i) {
std::unique_ptr<VideoPacket> message(new VideoPacket()); std::unique_ptr<VideoPacket> message(new VideoPacket());
message->mutable_data()->resize(message_size_); message->mutable_data()->resize(message_size_);
......
...@@ -121,6 +121,7 @@ class MessagePipeConnectionTester : public MessagePipe::EventHandler { ...@@ -121,6 +121,7 @@ class MessagePipeConnectionTester : public MessagePipe::EventHandler {
protected: protected:
// MessagePipe::EventHandler interface. // MessagePipe::EventHandler interface.
void OnMessagePipeOpen() override;
void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) override; void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) override;
void OnMessagePipeClosed() override; void OnMessagePipeClosed() override;
......
...@@ -26,6 +26,9 @@ class MessagePipe { ...@@ -26,6 +26,9 @@ class MessagePipe {
public: public:
class EventHandler { class EventHandler {
public: public:
// Called when the channel is open.
virtual void OnMessagePipeOpen() = 0;
// Called when a message is received. // Called when a message is received.
virtual void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) = 0; virtual void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) = 0;
...@@ -38,8 +41,9 @@ class MessagePipe { ...@@ -38,8 +41,9 @@ class MessagePipe {
virtual ~MessagePipe() {} virtual ~MessagePipe() {}
// Starts the channel. |event_handler| will be called to notify when a message // Starts the channel. Must be called immediately after MessagePipe is
// is received or the pipe is closed. // created. |event_handler| will be notified when state of the pipe changes or
// when a message is received.
virtual void Start(EventHandler* event_handler) = 0; virtual void Start(EventHandler* event_handler) = 0;
// Sends a message. |done| is called when the message has been sent to the // Sends a message. |done| is called when the message has been sent to the
......
...@@ -42,6 +42,7 @@ void StreamMessagePipeAdapter::Start(EventHandler* event_handler) { ...@@ -42,6 +42,7 @@ void StreamMessagePipeAdapter::Start(EventHandler* event_handler) {
base::Unretained(event_handler)), base::Unretained(event_handler)),
base::Bind(&StreamMessagePipeAdapter::CloseOnError, base::Bind(&StreamMessagePipeAdapter::CloseOnError,
base::Unretained(this))); base::Unretained(this)));
event_handler->OnMessagePipeOpen();
} }
void StreamMessagePipeAdapter::Send(google::protobuf::MessageLite* message, void StreamMessagePipeAdapter::Send(google::protobuf::MessageLite* message,
......
...@@ -155,10 +155,11 @@ void WebrtcConnectionToClient::OnSessionStateChange(Session::State state) { ...@@ -155,10 +155,11 @@ void WebrtcConnectionToClient::OnSessionStateChange(Session::State state) {
} }
void WebrtcConnectionToClient::OnWebrtcTransportConnecting() { void WebrtcConnectionToClient::OnWebrtcTransportConnecting() {
// Create outgoing control channel by initializing |control_dispatcher_|. // Create outgoing control channel. |event_dispatcher_| is initialized later
// |event_dispatcher_| is initialized later because event channel is expected // because event channel is expected to be created by the client.
// to be created by the client. control_dispatcher_->Init(
control_dispatcher_->Init(transport_->outgoing_channel_factory(), this); transport_->CreateOutgoingChannel(control_dispatcher_->channel_name()),
this);
} }
void WebrtcConnectionToClient::OnWebrtcTransportConnected() { void WebrtcConnectionToClient::OnWebrtcTransportConnected() {
......
...@@ -104,7 +104,9 @@ void WebrtcConnectionToHost::OnSessionStateChange(Session::State state) { ...@@ -104,7 +104,9 @@ void WebrtcConnectionToHost::OnSessionStateChange(Session::State state) {
void WebrtcConnectionToHost::OnWebrtcTransportConnecting() { void WebrtcConnectionToHost::OnWebrtcTransportConnecting() {
event_dispatcher_.reset(new ClientEventDispatcher()); event_dispatcher_.reset(new ClientEventDispatcher());
event_dispatcher_->Init(transport_->outgoing_channel_factory(), this); event_dispatcher_->Init(
transport_->CreateOutgoingChannel(event_dispatcher_->channel_name()),
this);
} }
void WebrtcConnectionToHost::OnWebrtcTransportConnected() {} void WebrtcConnectionToHost::OnWebrtcTransportConnected() {}
......
...@@ -21,13 +21,14 @@ ...@@ -21,13 +21,14 @@
namespace remoting { namespace remoting {
namespace protocol { namespace protocol {
class WebrtcDataStreamAdapter::Channel : public MessagePipe, namespace {
public webrtc::DataChannelObserver {
public:
explicit Channel(WebrtcDataStreamAdapter* adapter);
~Channel() override;
void Start(rtc::scoped_refptr<webrtc::DataChannelInterface> channel); class WebrtcDataChannel : public MessagePipe,
public webrtc::DataChannelObserver {
public:
explicit WebrtcDataChannel(
rtc::scoped_refptr<webrtc::DataChannelInterface> channel);
~WebrtcDataChannel() override;
std::string name() { return channel_->label(); } std::string name() { return channel_->label(); }
...@@ -47,52 +48,38 @@ class WebrtcDataStreamAdapter::Channel : public MessagePipe, ...@@ -47,52 +48,38 @@ class WebrtcDataStreamAdapter::Channel : public MessagePipe,
void OnClosed(); void OnClosed();
// |adapter_| owns channels while they are being connected.
WebrtcDataStreamAdapter* adapter_;
rtc::scoped_refptr<webrtc::DataChannelInterface> channel_; rtc::scoped_refptr<webrtc::DataChannelInterface> channel_;
EventHandler* event_handler_ = nullptr; EventHandler* event_handler_ = nullptr;
State state_ = State::CONNECTING; State state_ = State::CONNECTING;
DISALLOW_COPY_AND_ASSIGN(Channel); DISALLOW_COPY_AND_ASSIGN(WebrtcDataChannel);
}; };
WebrtcDataStreamAdapter::Channel::Channel(WebrtcDataStreamAdapter* adapter) WebrtcDataChannel::WebrtcDataChannel(
: adapter_(adapter) {} rtc::scoped_refptr<webrtc::DataChannelInterface> channel)
: channel_(channel) {
channel_->RegisterObserver(this);
DCHECK_EQ(channel_->state(), webrtc::DataChannelInterface::kConnecting);
}
WebrtcDataStreamAdapter::Channel::~Channel() { WebrtcDataChannel::~WebrtcDataChannel() {
if (channel_) { if (channel_) {
channel_->UnregisterObserver(); channel_->UnregisterObserver();
channel_->Close(); channel_->Close();
} }
} }
void WebrtcDataStreamAdapter::Channel::Start( void WebrtcDataChannel::Start(EventHandler* event_handler) {
rtc::scoped_refptr<webrtc::DataChannelInterface> channel) {
DCHECK(!channel_);
channel_ = channel;
channel_->RegisterObserver(this);
if (channel_->state() == webrtc::DataChannelInterface::kOpen) {
OnConnected();
} else {
DCHECK_EQ(channel_->state(), webrtc::DataChannelInterface::kConnecting);
}
}
void WebrtcDataStreamAdapter::Channel::Start(EventHandler* event_handler) {
DCHECK(!event_handler_); DCHECK(!event_handler_);
DCHECK(event_handler); DCHECK(event_handler);
event_handler_ = event_handler; event_handler_ = event_handler;
} }
void WebrtcDataStreamAdapter::Channel::Send( void WebrtcDataChannel::Send(google::protobuf::MessageLite* message,
google::protobuf::MessageLite* message, const base::Closure& done) {
const base::Closure& done) {
DCHECK(state_ == State::OPEN); DCHECK(state_ == State::OPEN);
rtc::CopyOnWriteBuffer buffer; rtc::CopyOnWriteBuffer buffer;
...@@ -110,14 +97,19 @@ void WebrtcDataStreamAdapter::Channel::Send( ...@@ -110,14 +97,19 @@ void WebrtcDataStreamAdapter::Channel::Send(
done.Run(); done.Run();
} }
void WebrtcDataStreamAdapter::Channel::OnStateChange() { void WebrtcDataChannel::OnStateChange() {
switch (channel_->state()) { switch (channel_->state()) {
case webrtc::DataChannelInterface::kOpen: case webrtc::DataChannelInterface::kOpen:
OnConnected(); DCHECK(state_ == State::CONNECTING);
state_ = State::OPEN;
event_handler_->OnMessagePipeOpen();
break; break;
case webrtc::DataChannelInterface::kClosing: case webrtc::DataChannelInterface::kClosing:
OnClosed(); if (state_ != State::CLOSED) {
state_ = State::CLOSED;
event_handler_->OnMessagePipeClosed();
}
break; break;
case webrtc::DataChannelInterface::kConnecting: case webrtc::DataChannelInterface::kConnecting:
...@@ -126,8 +118,7 @@ void WebrtcDataStreamAdapter::Channel::OnStateChange() { ...@@ -126,8 +118,7 @@ void WebrtcDataStreamAdapter::Channel::OnStateChange() {
} }
} }
void WebrtcDataStreamAdapter::Channel::OnMessage( void WebrtcDataChannel::OnMessage(const webrtc::DataBuffer& rtc_buffer) {
const webrtc::DataBuffer& rtc_buffer) {
if (state_ != State::OPEN) { if (state_ != State::OPEN) {
LOG(ERROR) << "Dropping a message received when the channel is not open."; LOG(ERROR) << "Dropping a message received when the channel is not open.";
return; return;
...@@ -140,114 +131,24 @@ void WebrtcDataStreamAdapter::Channel::OnMessage( ...@@ -140,114 +131,24 @@ void WebrtcDataStreamAdapter::Channel::OnMessage(
event_handler_->OnMessageReceived(std::move(buffer)); event_handler_->OnMessageReceived(std::move(buffer));
} }
void WebrtcDataStreamAdapter::Channel::OnConnected() { } // namespace
DCHECK(state_ == State::CONNECTING);
state_ = State::OPEN;
WebrtcDataStreamAdapter* adapter = adapter_;
adapter_ = nullptr;
adapter->OnChannelConnected(this);
}
void WebrtcDataStreamAdapter::Channel::OnClosed() {
switch (state_) {
case State::CONNECTING:
state_ = State::CLOSED;
LOG(WARNING) << "Channel " << channel_->label()
<< " was closed before it's connected.";
adapter_->OnChannelError();
return;
case State::OPEN:
state_ = State::CLOSED;
event_handler_->OnMessagePipeClosed();
return;
case State::CLOSED:
break;
}
}
struct WebrtcDataStreamAdapter::PendingChannel {
PendingChannel(std::unique_ptr<Channel> channel,
const ChannelCreatedCallback& connected_callback)
: channel(std::move(channel)), connected_callback(connected_callback) {}
PendingChannel(PendingChannel&& other)
: channel(std::move(other.channel)),
connected_callback(std::move(other.connected_callback)) {}
PendingChannel& operator=(PendingChannel&& other) {
channel = std::move(other.channel);
connected_callback = std::move(other.connected_callback);
return *this;
}
std::unique_ptr<Channel> channel;
ChannelCreatedCallback connected_callback;
};
WebrtcDataStreamAdapter::WebrtcDataStreamAdapter( WebrtcDataStreamAdapter::WebrtcDataStreamAdapter(
const ErrorCallback& error_callback) rtc::scoped_refptr<webrtc::PeerConnectionInterface> peer_connection)
: error_callback_(error_callback) {} : peer_connection_(peer_connection) {}
WebrtcDataStreamAdapter::~WebrtcDataStreamAdapter() {}
WebrtcDataStreamAdapter::~WebrtcDataStreamAdapter() { std::unique_ptr<MessagePipe> WebrtcDataStreamAdapter::CreateOutgoingChannel(
DCHECK(pending_channels_.empty()); const std::string& name) {
}
void WebrtcDataStreamAdapter::Initialize(
rtc::scoped_refptr<webrtc::PeerConnectionInterface> peer_connection) {
peer_connection_ = peer_connection;
}
void WebrtcDataStreamAdapter::WrapIncomingDataChannel(
rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel,
const ChannelCreatedCallback& callback) {
AddPendingChannel(data_channel, callback);
}
void WebrtcDataStreamAdapter::CreateChannel(
const std::string& name,
const ChannelCreatedCallback& callback) {
webrtc::DataChannelInit config; webrtc::DataChannelInit config;
config.reliable = true; config.reliable = true;
AddPendingChannel(peer_connection_->CreateDataChannel(name, &config), return base::WrapUnique(new WebrtcDataChannel(
callback); peer_connection_->CreateDataChannel(name, &config)));
}
void WebrtcDataStreamAdapter::CancelChannelCreation(const std::string& name) {
auto it = pending_channels_.find(name);
DCHECK(it != pending_channels_.end());
pending_channels_.erase(it);
}
void WebrtcDataStreamAdapter::AddPendingChannel(
rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel,
const ChannelCreatedCallback& callback) {
DCHECK(peer_connection_);
DCHECK(pending_channels_.find(data_channel->label()) ==
pending_channels_.end());
Channel* channel = new Channel(this);
pending_channels_.insert(
std::make_pair(data_channel->label(),
PendingChannel(base::WrapUnique(channel), callback)));
channel->Start(data_channel);
}
void WebrtcDataStreamAdapter::OnChannelConnected(Channel* channel) {
auto it = pending_channels_.find(channel->name());
DCHECK(it != pending_channels_.end());
PendingChannel pending_channel = std::move(it->second);
pending_channels_.erase(it);
// Once the channel is connected its ownership is passed to the
// |connected_callback|.
pending_channel.connected_callback.Run(std::move(pending_channel.channel));
} }
void WebrtcDataStreamAdapter::OnChannelError() { std::unique_ptr<MessagePipe> WebrtcDataStreamAdapter::WrapIncomingDataChannel(
pending_channels_.clear(); rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel) {
error_callback_.Run(CHANNEL_CONNECTION_ERROR); return base::WrapUnique(new WebrtcDataChannel(data_channel));
} }
} // namespace protocol } // namespace protocol
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifndef REMOTING_PROTOCOL_WEBRTC_DATA_STREAM_ADAPTER_H_ #ifndef REMOTING_PROTOCOL_WEBRTC_DATA_STREAM_ADAPTER_H_
#define REMOTING_PROTOCOL_WEBRTC_DATA_STREAM_ADAPTER_H_ #define REMOTING_PROTOCOL_WEBRTC_DATA_STREAM_ADAPTER_H_
#include <memory>
#include <string> #include <string>
#include "base/callback.h" #include "base/callback.h"
...@@ -21,50 +22,25 @@ class PeerConnectionInterface; ...@@ -21,50 +22,25 @@ class PeerConnectionInterface;
namespace remoting { namespace remoting {
namespace protocol { namespace protocol {
// WebrtcDataStreamAdapter is a MessageChannelFactory that creates channels that // WebrtcDataStreamAdapter wraps MessagePipe for WebRTC data channels.
// send and receive messages over PeerConnection data channels. class WebrtcDataStreamAdapter {
class WebrtcDataStreamAdapter : public MessageChannelFactory {
public: public:
typedef base::Callback<void(ErrorCode)> ErrorCallback; typedef base::Callback<void(ErrorCode)> ErrorCallback;
explicit WebrtcDataStreamAdapter(const ErrorCallback& error_callback); explicit WebrtcDataStreamAdapter(
~WebrtcDataStreamAdapter() override;
// Initializes the adapter for |peer_connection|. If |outgoing| is set to true
// all channels will be created as outgoing. Otherwise CreateChannel() will
// wait for the other end to create connection.
void Initialize(
rtc::scoped_refptr<webrtc::PeerConnectionInterface> peer_connection); rtc::scoped_refptr<webrtc::PeerConnectionInterface> peer_connection);
~WebrtcDataStreamAdapter();
// Called by WebrtcTransport. // Creates outgoing data channel.
void WrapIncomingDataChannel( std::unique_ptr<MessagePipe> CreateOutgoingChannel(const std::string& name);
rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel,
const ChannelCreatedCallback& callback);
// MessageChannelFactory interface. // Creates incoming data channel.
void CreateChannel(const std::string& name, std::unique_ptr<MessagePipe> WrapIncomingDataChannel(
const ChannelCreatedCallback& callback) override; rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel);
void CancelChannelCreation(const std::string& name) override;
private: private:
class Channel;
friend class Channel;
struct PendingChannel;
void AddPendingChannel(
rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel,
const ChannelCreatedCallback& callback);
void OnChannelConnected(Channel* channel);
void OnChannelError();
ErrorCallback error_callback_;
rtc::scoped_refptr<webrtc::PeerConnectionInterface> peer_connection_; rtc::scoped_refptr<webrtc::PeerConnectionInterface> peer_connection_;
std::map<std::string, PendingChannel> pending_channels_;
DISALLOW_COPY_AND_ASSIGN(WebrtcDataStreamAdapter); DISALLOW_COPY_AND_ASSIGN(WebrtcDataStreamAdapter);
}; };
......
...@@ -152,6 +152,11 @@ WebrtcTransport::WebrtcTransport( ...@@ -152,6 +152,11 @@ WebrtcTransport::WebrtcTransport(
WebrtcTransport::~WebrtcTransport() {} WebrtcTransport::~WebrtcTransport() {}
std::unique_ptr<MessagePipe> WebrtcTransport::CreateOutgoingChannel(
const std::string& name) {
return data_stream_adapter_->CreateOutgoingChannel(name);
}
void WebrtcTransport::Start( void WebrtcTransport::Start(
Authenticator* authenticator, Authenticator* authenticator,
SendTransportInfoCallback send_transport_info_callback) { SendTransportInfoCallback send_transport_info_callback) {
...@@ -192,9 +197,7 @@ void WebrtcTransport::Start( ...@@ -192,9 +197,7 @@ void WebrtcTransport::Start(
peer_connection_ = peer_connection_factory_->CreatePeerConnection( peer_connection_ = peer_connection_factory_->CreatePeerConnection(
rtc_config, &constraints, std::move(port_allocator), nullptr, this); rtc_config, &constraints, std::move(port_allocator), nullptr, this);
data_stream_adapter_.reset(new WebrtcDataStreamAdapter( data_stream_adapter_.reset(new WebrtcDataStreamAdapter(peer_connection_));
base::Bind(&WebrtcTransport::Close, base::Unretained(this))));
data_stream_adapter_->Initialize(peer_connection_);
event_handler_->OnWebrtcTransportConnecting(); event_handler_->OnWebrtcTransportConnecting();
...@@ -434,10 +437,9 @@ void WebrtcTransport::OnRemoveStream( ...@@ -434,10 +437,9 @@ void WebrtcTransport::OnRemoveStream(
void WebrtcTransport::OnDataChannel( void WebrtcTransport::OnDataChannel(
rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel) { rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel) {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
data_stream_adapter_->WrapIncomingDataChannel( event_handler_->OnWebrtcTransportIncomingDataChannel(
data_channel, data_channel->label(),
base::Bind(&EventHandler::OnWebrtcTransportIncomingDataChannel, data_stream_adapter_->WrapIncomingDataChannel(data_channel));
base::Unretained(event_handler_), data_channel->label()));
} }
void WebrtcTransport::OnRenegotiationNeeded() { void WebrtcTransport::OnRenegotiationNeeded() {
......
...@@ -29,7 +29,6 @@ namespace remoting { ...@@ -29,7 +29,6 @@ namespace remoting {
namespace protocol { namespace protocol {
class TransportContext; class TransportContext;
class MessageChannelFactory;
class MessagePipe; class MessagePipe;
class WebrtcTransport : public Transport, class WebrtcTransport : public Transport,
...@@ -81,11 +80,10 @@ class WebrtcTransport : public Transport, ...@@ -81,11 +80,10 @@ class WebrtcTransport : public Transport,
return video_encoder_factory_; return video_encoder_factory_;
} }
// Factory for outgoing data channels. Must be used only after the transport // Creates outgoing data channel. The channel is created in CONNECTING state.
// is connected. // The caller must wait for OnMessagePipeOpen() notification before sending
MessageChannelFactory* outgoing_channel_factory() { // any messages.
return data_stream_adapter_.get(); std::unique_ptr<MessagePipe> CreateOutgoingChannel(const std::string& name);
}
// Transport interface. // Transport interface.
void Start(Authenticator* authenticator, void Start(Authenticator* authenticator,
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/url_request/url_request_context_getter.h" #include "net/url_request/url_request_context_getter.h"
#include "remoting/base/compound_buffer.h" #include "remoting/base/compound_buffer.h"
#include "remoting/protocol/connection_tester.h" #include "remoting/proto/event.pb.h"
#include "remoting/protocol/fake_authenticator.h" #include "remoting/protocol/fake_authenticator.h"
#include "remoting/protocol/message_channel_factory.h" #include "remoting/protocol/message_channel_factory.h"
#include "remoting/protocol/message_pipe.h" #include "remoting/protocol/message_pipe.h"
#include "remoting/protocol/message_serialization.h"
#include "remoting/protocol/network_settings.h" #include "remoting/protocol/network_settings.h"
#include "remoting/protocol/transport_context.h" #include "remoting/protocol/transport_context.h"
#include "remoting/signaling/fake_signal_strategy.h" #include "remoting/signaling/fake_signal_strategy.h"
...@@ -97,15 +98,32 @@ class TestMessagePipeEventHandler : public MessagePipe::EventHandler { ...@@ -97,15 +98,32 @@ class TestMessagePipeEventHandler : public MessagePipe::EventHandler {
TestMessagePipeEventHandler() {} TestMessagePipeEventHandler() {}
~TestMessagePipeEventHandler() override {} ~TestMessagePipeEventHandler() override {}
void set_open_callback(const base::Closure& callback) {
open_callback_ = callback;
}
void set_message_callback(const base::Closure& callback) {
message_callback_ = callback;
}
void set_closed_callback(const base::Closure& callback) { void set_closed_callback(const base::Closure& callback) {
closed_callback_ = callback; closed_callback_ = callback;
} }
bool is_open() { return is_open_; }
const std::list<std::unique_ptr<CompoundBuffer>>& received_messages() {
return received_messages_;
}
// MessagePipe::EventHandler interface. // MessagePipe::EventHandler interface.
void OnMessagePipeOpen() override {
is_open_ = true;
if (!open_callback_.is_null())
open_callback_.Run();
}
void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) override { void OnMessageReceived(std::unique_ptr<CompoundBuffer> message) override {
NOTREACHED(); received_messages_.push_back(std::move(message));
if (!message_callback_.is_null())
message_callback_.Run();
} }
void OnMessagePipeClosed() override { void OnMessagePipeClosed() override {
if (!closed_callback_.is_null()) { if (!closed_callback_.is_null()) {
closed_callback_.Run(); closed_callback_.Run();
...@@ -115,8 +133,13 @@ class TestMessagePipeEventHandler : public MessagePipe::EventHandler { ...@@ -115,8 +133,13 @@ class TestMessagePipeEventHandler : public MessagePipe::EventHandler {
} }
private: private:
bool is_open_ = false;
base::Closure open_callback_;
base::Closure message_callback_;
base::Closure closed_callback_; base::Closure closed_callback_;
std::list<std::unique_ptr<CompoundBuffer>> received_messages_;
DISALLOW_COPY_AND_ASSIGN(TestMessagePipeEventHandler); DISALLOW_COPY_AND_ASSIGN(TestMessagePipeEventHandler);
}; };
...@@ -220,22 +243,24 @@ class WebrtcTransportTest : public testing::Test { ...@@ -220,22 +243,24 @@ class WebrtcTransportTest : public testing::Test {
} }
void CreateHostDataStream() { void CreateHostDataStream() {
host_transport_->outgoing_channel_factory()->CreateChannel( host_message_pipe_ = host_transport_->CreateOutgoingChannel(kChannelName);
kChannelName, base::Bind(&WebrtcTransportTest::OnHostChannelCreated, host_message_pipe_->Start(&host_message_pipe_event_handler_);
base::Unretained(this))); host_message_pipe_event_handler_.set_open_callback(base::Bind(
&WebrtcTransportTest::OnHostChannelConnected, base::Unretained(this)));
} }
void OnIncomingChannel(const std::string& name, void OnIncomingChannel(const std::string& name,
std::unique_ptr<MessagePipe> pipe) { std::unique_ptr<MessagePipe> pipe) {
EXPECT_EQ(kChannelName, name); EXPECT_EQ(kChannelName, name);
client_message_pipe_ = std::move(pipe); client_message_pipe_ = std::move(pipe);
if (run_loop_ && host_message_pipe_) client_message_pipe_->Start(&client_message_pipe_event_handler_);
if (run_loop_ && host_message_pipe_event_handler_.is_open())
run_loop_->Quit(); run_loop_->Quit();
} }
void OnHostChannelCreated(std::unique_ptr<MessagePipe> pipe) { void OnHostChannelConnected() {
host_message_pipe_ = std::move(pipe); if (run_loop_ && client_message_pipe_event_handler_.is_open())
if (run_loop_ && client_message_pipe_)
run_loop_->Quit(); run_loop_->Quit();
} }
...@@ -283,6 +308,7 @@ class WebrtcTransportTest : public testing::Test { ...@@ -283,6 +308,7 @@ class WebrtcTransportTest : public testing::Test {
std::unique_ptr<FakeAuthenticator> client_authenticator_; std::unique_ptr<FakeAuthenticator> client_authenticator_;
std::unique_ptr<MessagePipe> client_message_pipe_; std::unique_ptr<MessagePipe> client_message_pipe_;
TestMessagePipeEventHandler client_message_pipe_event_handler_;
std::unique_ptr<MessagePipe> host_message_pipe_; std::unique_ptr<MessagePipe> host_message_pipe_;
TestMessagePipeEventHandler host_message_pipe_event_handler_; TestMessagePipeEventHandler host_message_pipe_event_handler_;
...@@ -324,12 +350,20 @@ TEST_F(WebrtcTransportTest, DataStream) { ...@@ -324,12 +350,20 @@ TEST_F(WebrtcTransportTest, DataStream) {
EXPECT_TRUE(client_message_pipe_); EXPECT_TRUE(client_message_pipe_);
EXPECT_TRUE(host_message_pipe_); EXPECT_TRUE(host_message_pipe_);
const int kMessageSize = 1024; TextEvent message;
const int kMessages = 100; message.set_text("Hello");
MessagePipeConnectionTester tester(host_message_pipe_.get(), host_message_pipe_->Send(&message, base::Closure());
client_message_pipe_.get(), kMessageSize,
kMessages); run_loop_.reset(new base::RunLoop());
tester.RunAndCheckResults(); client_message_pipe_event_handler_.set_message_callback(
base::Bind(&base::RunLoop::Quit, base::Unretained(run_loop_.get())));
run_loop_->Run();
ASSERT_EQ(1U, client_message_pipe_event_handler_.received_messages().size());
std::unique_ptr<TextEvent> received_message = ParseMessage<TextEvent>(
client_message_pipe_event_handler_.received_messages().front().get());
EXPECT_EQ(message.text(), received_message->text());
} }
// Verify that data streams can be created after connection has been initiated. // Verify that data streams can be created after connection has been initiated.
...@@ -366,7 +400,6 @@ TEST_F(WebrtcTransportTest, TerminateDataChannel) { ...@@ -366,7 +400,6 @@ TEST_F(WebrtcTransportTest, TerminateDataChannel) {
// Expect that the channel is closed on the host side once the client closes // Expect that the channel is closed on the host side once the client closes
// the channel. // the channel.
host_message_pipe_->Start(&host_message_pipe_event_handler_);
host_message_pipe_event_handler_.set_closed_callback(base::Bind( host_message_pipe_event_handler_.set_closed_callback(base::Bind(
&WebrtcTransportTest::OnHostChannelClosed, base::Unretained(this))); &WebrtcTransportTest::OnHostChannelClosed, base::Unretained(this)));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment