Commit 6852d7d9 authored by sergeyu@chromium.org's avatar sergeyu@chromium.org

Changed MessageReader so that it doesn't read from the socket if there are

other messages being processed. Added unittests for MessageReader.

BUG=None
TEST=Unittests

Review URL: http://codereview.chromium.org/6271004

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@72262 0039d316-1c4b-4281-b951-d872f2087c98
parent a4f4692c
......@@ -105,6 +105,8 @@ class CompoundBufferInputStream
explicit CompoundBufferInputStream(const CompoundBuffer* buffer);
virtual ~CompoundBufferInputStream();
int position() const { return position_; }
// google::protobuf::io::ZeroCopyInputStream interface.
virtual bool Next(const void** data, int* size);
virtual void BackUp(int count);
......
......@@ -40,17 +40,3 @@ message MouseEvent {
optional MouseButton button = 5;
optional bool button_down = 6;
}
// Defines an event message on the event channel.
message Event {
required int32 timestamp = 1; // Client timestamp for event
optional bool dummy = 2; // Is this a dummy event?
optional KeyEvent key = 3;
optional MouseEvent mouse = 4;
}
// Message sent in the event channel.
message EventMessage {
repeated Event event = 1;
}
......@@ -22,3 +22,12 @@ message ControlMessage {
optional BeginSessionRequest begin_session_request = 3;
optional BeginSessionResponse begin_session_response = 4;
}
// Defines an event message on the event channel.
message EventMessage {
required int32 timestamp = 1; // Client timestamp for event
optional bool dummy = 2; // Is this a dummy event?
optional KeyEvent key_event = 3;
optional MouseEvent mouse_event = 4;
}
......@@ -11,7 +11,6 @@
#include "remoting/protocol/client_stub.h"
#include "remoting/protocol/input_stub.h"
#include "remoting/protocol/message_reader.h"
#include "remoting/protocol/ref_counted_message.h"
#include "remoting/protocol/session.h"
namespace remoting {
......@@ -39,18 +38,18 @@ void ClientMessageDispatcher::Initialize(
}
void ClientMessageDispatcher::OnControlMessageReceived(
ControlMessage* message) {
scoped_refptr<RefCountedMessage<ControlMessage> > ref_msg =
new RefCountedMessage<ControlMessage>(message);
ControlMessage* message, Task* done_task) {
// TODO(sergeyu): Add message validation.
if (message->has_notify_resolution()) {
client_stub_->NotifyResolution(
&message->notify_resolution(), NewDeleteTask(ref_msg));
&message->notify_resolution(), done_task);
} else if (message->has_begin_session_response()) {
client_stub_->BeginSessionResponse(
&message->begin_session_response().login_status(),
NewDeleteTask(ref_msg));
&message->begin_session_response().login_status(), done_task);
} else {
NOTREACHED() << "Invalid control message received";
LOG(WARNING) << "Invalid control message received.";
done_task->Run();
delete done_task;
}
}
......
......@@ -40,7 +40,7 @@ class ClientMessageDispatcher {
void Initialize(protocol::Session* session, ClientStub* client_stub);
private:
void OnControlMessageReceived(ControlMessage* message);
void OnControlMessageReceived(ControlMessage* message, Task* done_task);
// MessageReader that runs on the control channel. It runs a loop
// that parses data on the channel and then calls the corresponding handler
......
......@@ -21,7 +21,7 @@ FakeSocket::FakeSocket()
FakeSocket::~FakeSocket() {
}
void FakeSocket::AppendInputData(char* data, int data_size) {
void FakeSocket::AppendInputData(const char* data, int data_size) {
input_data_.insert(input_data_.end(), data, data + data_size);
// Complete pending read if any.
if (read_pending_) {
......@@ -78,7 +78,7 @@ FakeUdpSocket::FakeUdpSocket()
FakeUdpSocket::~FakeUdpSocket() {
}
void FakeUdpSocket::AppendInputPacket(char* data, int data_size) {
void FakeUdpSocket::AppendInputPacket(const char* data, int data_size) {
input_packets_.push_back(std::string());
input_packets_.back().assign(data, data + data_size);
......
......@@ -27,10 +27,11 @@ class FakeSocket : public net::Socket {
FakeSocket();
virtual ~FakeSocket();
const std::string& written_data() { return written_data_; }
const std::string& written_data() const { return written_data_; }
void AppendInputData(char* data, int data_size);
int input_pos() { return input_pos_; }
void AppendInputData(const char* data, int data_size);
int input_pos() const { return input_pos_; }
bool read_pending() const { return read_pending_; }
// net::Socket interface.
virtual int Read(net::IOBuffer* buf, int buf_len,
......@@ -60,12 +61,12 @@ class FakeUdpSocket : public net::Socket {
FakeUdpSocket();
virtual ~FakeUdpSocket();
const std::vector<std::string>& written_packets() {
const std::vector<std::string>& written_packets() const {
return written_packets_;
}
void AppendInputPacket(char* data, int data_size);
int input_pos() { return input_pos_; }
void AppendInputPacket(const char* data, int data_size);
int input_pos() const { return input_pos_; }
// net::Socket interface.
virtual int Read(net::IOBuffer* buf, int buf_len,
......@@ -100,7 +101,7 @@ class FakeSession : public Session {
message_loop_ = message_loop;
}
bool is_closed() { return closed_; }
bool is_closed() const { return closed_; }
virtual void SetStateChangeCallback(StateChangeCallback* callback);
......
......@@ -11,7 +11,6 @@
#include "remoting/protocol/host_stub.h"
#include "remoting/protocol/input_stub.h"
#include "remoting/protocol/message_reader.h"
#include "remoting/protocol/ref_counted_message.h"
#include "remoting/protocol/session.h"
namespace remoting {
......@@ -47,34 +46,32 @@ void HostMessageDispatcher::Initialize(
NewCallback(this, &HostMessageDispatcher::OnControlMessageReceived));
}
void HostMessageDispatcher::OnControlMessageReceived(ControlMessage* message) {
scoped_refptr<RefCountedMessage<ControlMessage> > ref_msg =
new RefCountedMessage<ControlMessage>(message);
void HostMessageDispatcher::OnControlMessageReceived(
ControlMessage* message, Task* done_task) {
// TODO(sergeyu): Add message validation.
if (message->has_suggest_resolution()) {
host_stub_->SuggestResolution(
&message->suggest_resolution(), NewDeleteTask(ref_msg));
host_stub_->SuggestResolution(&message->suggest_resolution(), done_task);
} else if (message->has_begin_session_request()) {
host_stub_->BeginSessionRequest(
&message->begin_session_request().credentials(),
NewDeleteTask(ref_msg));
&message->begin_session_request().credentials(), done_task);
} else {
NOTREACHED() << "Invalid control message received";
LOG(WARNING) << "Invalid control message received.";
done_task->Run();
delete done_task;
}
}
void HostMessageDispatcher::OnEventMessageReceived(
EventMessage* message) {
scoped_refptr<RefCountedMessage<EventMessage> > ref_msg =
new RefCountedMessage<EventMessage>(message);
for (int i = 0; i < message->event_size(); ++i) {
if (message->event(i).has_key()) {
input_stub_->InjectKeyEvent(
&message->event(i).key(), NewDeleteTask(ref_msg));
}
if (message->event(i).has_mouse()) {
input_stub_->InjectMouseEvent(
&message->event(i).mouse(), NewDeleteTask(ref_msg));
}
EventMessage* message, Task* done_task) {
// TODO(sergeyu): Add message validation.
if (message->has_key_event()) {
input_stub_->InjectKeyEvent(&message->key_event(), done_task);
} else if (message->has_mouse_event()) {
input_stub_->InjectMouseEvent(&message->mouse_event(), done_task);
} else {
LOG(WARNING) << "Invalid event message received.";
done_task->Run();
delete done_task;
}
}
......
......@@ -11,12 +11,10 @@
#include "remoting/protocol/message_reader.h"
namespace remoting {
class EventMessage;
namespace protocol {
class ControlMessage;
class EventMessage;
class HostStub;
class InputStub;
class Session;
......@@ -45,11 +43,11 @@ class HostMessageDispatcher {
private:
// This method is called by |control_channel_reader_| when a control
// message is received.
void OnControlMessageReceived(ControlMessage* message);
void OnControlMessageReceived(ControlMessage* message, Task* done_task);
// This method is called by |event_channel_reader_| when a event
// message is received.
void OnEventMessageReceived(EventMessage* message);
void OnEventMessageReceived(EventMessage* message, Task* done_task);
// MessageReader that runs on the control channel. It runs a loop
// that parses data on the channel and then delegates the message to this
......
......@@ -9,6 +9,7 @@
#include "base/task.h"
#include "remoting/proto/event.pb.h"
#include "remoting/proto/internal.pb.h"
#include "remoting/protocol/buffered_socket_writer.h"
#include "remoting/protocol/util.h"
......@@ -27,19 +28,17 @@ InputSender::~InputSender() {
void InputSender::InjectKeyEvent(const KeyEvent* event, Task* done) {
EventMessage message;
Event* evt = message.add_event();
// TODO(hclam): Provide timestamp.
evt->set_timestamp(0);
evt->mutable_key()->CopyFrom(*event);
message.set_timestamp(0);
message.mutable_key_event()->CopyFrom(*event);
buffered_writer_->Write(SerializeAndFrameMessage(message), done);
}
void InputSender::InjectMouseEvent(const MouseEvent* event, Task* done) {
EventMessage message;
Event* evt = message.add_event();
// TODO(hclam): Provide timestamp.
evt->set_timestamp(0);
evt->mutable_mouse()->CopyFrom(*event);
message.set_timestamp(0);
message.mutable_mouse_event()->CopyFrom(*event);
buffered_writer_->Write(SerializeAndFrameMessage(message), done);
}
......
......@@ -25,7 +25,7 @@ void MessageDecoder::AddData(scoped_refptr<net::IOBuffer> data,
buffer_.Append(data, data_size);
}
bool MessageDecoder::GetNextMessage(CompoundBuffer* message_buffer) {
CompoundBuffer* MessageDecoder::GetNextMessage() {
// Determine the payload size. If we already know it then skip this part.
// We may not have enough data to determine the payload size so use a
// utility function to find out.
......@@ -39,14 +39,15 @@ bool MessageDecoder::GetNextMessage(CompoundBuffer* message_buffer) {
// If the next payload size is still not known or we don't have enough
// data for parsing then exit.
if (!next_payload_known_ || buffer_.total_bytes() < next_payload_)
return false;
return NULL;
CompoundBuffer* message_buffer = new CompoundBuffer();
message_buffer->CopyFrom(buffer_, 0, next_payload_);
message_buffer->Lock();
buffer_.CropFront(next_payload_);
next_payload_known_ = false;
return true;
return message_buffer;
}
bool MessageDecoder::GetPayloadSize(int* size) {
......
......@@ -35,10 +35,11 @@ class MessageDecoder {
// its bytes are consumed.
void AddData(scoped_refptr<net::IOBuffer> data, int data_size);
// Get next message from the stream and puts it in
// |message_buffer|. Returns false if there are no complete messages
// yet.
bool GetNextMessage(CompoundBuffer* message_buffer);
// Returns next message from the stream. Ownership of the result is
// passed to the caller. Returns NULL if there are no complete
// messages yet, otherwise returns a buffer that contains one
// message.
CompoundBuffer* GetNextMessage();
private:
// Retrieves the read payload size of the current protocol buffer via |size|.
......
......@@ -6,7 +6,9 @@
#include "base/scoped_ptr.h"
#include "base/stl_util-inl.h"
#include "base/string_number_conversions.h"
#include "remoting/proto/event.pb.h"
#include "remoting/proto/internal.pb.h"
#include "remoting/protocol/message_decoder.h"
#include "remoting/protocol/util.h"
#include "testing/gtest/include/gtest/gtest.h"
......@@ -29,16 +31,13 @@ static void PrepareData(uint8** buffer, int* size) {
// Contains all encoded messages.
std::string encoded_data;
EventMessage msg;
// Then append 10 update sequences to the data.
for (int i = 0; i < 10; ++i) {
Event* event = msg.add_event();
event->set_timestamp(i);
event->mutable_key()->set_keycode(kTestKey + i);
event->mutable_key()->set_pressed((i % 2) != 0);
EventMessage msg;
msg.set_timestamp(i);
msg.mutable_key_event()->set_keycode(kTestKey + i);
msg.mutable_key_event()->set_pressed((i % 2) != 0);
AppendMessage(msg, &encoded_data);
msg.Clear();
}
*size = encoded_data.length();
......@@ -62,25 +61,27 @@ void SimulateReadSequence(const int read_sequence[], int sequence_size) {
// Then feed the protocol decoder using the above generated data and the
// read pattern.
std::list<EventMessage*> message_list;
for (int i = 0; i < size;) {
for (int pos = 0; pos < size;) {
SCOPED_TRACE("Input position: " + base::IntToString(pos));
// First generate the amount to feed the decoder.
int read = std::min(size - i, read_sequence[i % sequence_size]);
int read = std::min(size - pos, read_sequence[pos % sequence_size]);
// And then prepare an IOBuffer for feeding it.
scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(read));
memcpy(buffer->data(), test_data + i, read);
memcpy(buffer->data(), test_data + pos, read);
decoder.AddData(buffer, read);
while (true) {
CompoundBuffer message;
if (!decoder.GetNextMessage(&message))
scoped_ptr<CompoundBuffer> message(decoder.GetNextMessage());
if (!message.get())
break;
EventMessage* event = new EventMessage();
CompoundBufferInputStream stream(&message);
CompoundBufferInputStream stream(message.get());
ASSERT_TRUE(event->ParseFromZeroCopyStream(&stream));
message_list.push_back(event);
}
i += read;
pos += read;
}
// Then verify the decoded messages.
......@@ -90,15 +91,16 @@ void SimulateReadSequence(const int read_sequence[], int sequence_size) {
for (std::list<EventMessage*>::iterator it =
message_list.begin();
it != message_list.end(); ++it) {
SCOPED_TRACE("Message " + base::IntToString(index));
EventMessage* message = *it;
// Partial update stream.
EXPECT_EQ(message->event_size(), 1);
EXPECT_TRUE(message->event(0).has_key());
EXPECT_TRUE(message->has_key_event());
// TODO(sergeyu): Don't use index here. Instead store the expected values
// in an array.
EXPECT_EQ(kTestKey + index, message->event(0).key().keycode());
EXPECT_EQ((index % 2) != 0, message->event(0).key().pressed());
EXPECT_EQ(kTestKey + index, message->key_event().keycode());
EXPECT_EQ((index % 2) != 0, message->key_event().pressed());
++index;
}
STLDeleteElements(&message_list);
......
......@@ -18,12 +18,16 @@ static const int kReadBufferSize = 4096;
MessageReader::MessageReader()
: socket_(NULL),
message_loop_(NULL),
read_pending_(false),
pending_messages_(0),
closed_(false),
ALLOW_THIS_IN_INITIALIZER_LIST(
read_callback_(this, &MessageReader::OnRead)) {
}
MessageReader::~MessageReader() {
CHECK_EQ(pending_messages_, 0);
}
void MessageReader::Init(net::Socket* socket,
......@@ -31,21 +35,27 @@ void MessageReader::Init(net::Socket* socket,
message_received_callback_.reset(callback);
DCHECK(socket);
socket_ = socket;
message_loop_ = MessageLoop::current();
DoRead();
}
void MessageReader::DoRead() {
while (!closed_) {
DCHECK(!read_pending_);
// Don't try to read again if there is another read pending or we
// have messages that we haven't finished processing yet.
while (!closed_ && !read_pending_ && pending_messages_ == 0) {
read_buffer_ = new net::IOBuffer(kReadBufferSize);
int result = socket_->Read(
read_buffer_, kReadBufferSize, &read_callback_);
HandleReadResult(result);
if (result < 0)
break;
}
}
void MessageReader::OnRead(int result) {
DCHECK(read_pending_);
read_pending_ = false;
if (!closed_) {
HandleReadResult(result);
DoRead();
......@@ -53,12 +63,17 @@ void MessageReader::OnRead(int result) {
}
void MessageReader::HandleReadResult(int result) {
if (closed_)
return;
if (result > 0) {
OnDataReceived(read_buffer_, result);
} else {
if (result == net::ERR_CONNECTION_CLOSED) {
closed_ = true;
} else if (result != net::ERR_IO_PENDING) {
} else if (result == net::ERR_IO_PENDING) {
read_pending_ = true;
} else {
LOG(ERROR) << "Read() returned error " << result;
}
}
......@@ -67,14 +82,42 @@ void MessageReader::HandleReadResult(int result) {
void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) {
message_decoder_.AddData(data, data_size);
// Get list of all new messages first, and then call the callback
// for all of them.
std::vector<CompoundBuffer*> new_messages;
while (true) {
CompoundBuffer buffer;
if (!message_decoder_.GetNextMessage(&buffer))
CompoundBuffer* buffer = message_decoder_.GetNextMessage();
if (!buffer)
break;
new_messages.push_back(buffer);
}
pending_messages_ += new_messages.size();
message_received_callback_->Run(&buffer);
for (std::vector<CompoundBuffer*>::iterator it = new_messages.begin();
it != new_messages.end(); ++it) {
message_received_callback_->Run(*it, NewRunnableMethod(
this, &MessageReader::OnMessageDone, *it));
}
}
void MessageReader::OnMessageDone(CompoundBuffer* message) {
delete message;
ProcessDoneEvent();
}
void MessageReader::ProcessDoneEvent() {
if (MessageLoop::current() != message_loop_) {
message_loop_->PostTask(FROM_HERE, NewRunnableMethod(
this, &MessageReader::ProcessDoneEvent));
return;
}
pending_messages_--;
DCHECK_GE(pending_messages_, 0);
DoRead(); // Start next read if neccessary.
}
} // namespace protocol
} // namespace remoting
......@@ -13,6 +13,8 @@
#include "remoting/base/compound_buffer.h"
#include "remoting/protocol/message_decoder.h"
class MessageLoop;
namespace net {
class IOBuffer;
class Socket;
......@@ -22,10 +24,23 @@ namespace remoting {
namespace protocol {
// MessageReader reads data from the socket asynchronously and calls
// callback for each message it receives
class MessageReader {
// callback for each message it receives. It stops calling the
// callback as soon as the socket is closed, so the socket should
// always be closed before the callback handler is destroyed.
//
// In order to throttle the stream, MessageReader doesn't try to read
// new data from the socket until all previously received messages are
// processed by the receiver (|done_task| is called for each message).
// It is still possible that the MessageReceivedCallback is called
// twice (so that there is more than one outstanding message),
// e.g. when we the sender sends multiple messages in one TCP packet.
class MessageReader : public base::RefCountedThreadSafe<MessageReader> {
public:
typedef Callback1<CompoundBuffer*>::Type MessageReceivedCallback;
// The callback is given ownership of the second argument
// (|done_task|). The buffer (first argument) is owned by
// MessageReader and is freed when the task specified by the second
// argument is called.
typedef Callback2<CompoundBuffer*, Task*>::Type MessageReceivedCallback;
MessageReader();
virtual ~MessageReader();
......@@ -39,9 +54,23 @@ class MessageReader {
void OnRead(int result);
void HandleReadResult(int result);
void OnDataReceived(net::IOBuffer* data, int data_size);
void OnMessageDone(CompoundBuffer* message);
void ProcessDoneEvent();
net::Socket* socket_;
// The network message loop this object runs on.
MessageLoop* message_loop_;
// Set to true, when we have a socket read pending, and expecting
// OnRead() to be called when new data is received.
bool read_pending_;
// Number of messages that we received, but haven't finished
// processing yet, i.e. |done_task| hasn't been called for these
// messages.
int pending_messages_;
bool closed_;
scoped_refptr<net::IOBuffer> read_buffer_;
net::CompletionCallbackImpl<MessageReader> read_callback_;
......@@ -52,33 +81,46 @@ class MessageReader {
scoped_ptr<MessageReceivedCallback> message_received_callback_;
};
// Version of MessageReader for protocol buffer messages, that parses
// each incoming message.
template <class T>
class ProtobufMessageReader {
public:
typedef typename Callback1<T*>::Type MessageReceivedCallback;
typedef typename Callback2<T*, Task*>::Type MessageReceivedCallback;
ProtobufMessageReader() { };
~ProtobufMessageReader() { };
void Init(net::Socket* socket, MessageReceivedCallback* callback) {
message_received_callback_.reset(callback);
message_reader_.Init(
message_reader_ = new MessageReader();
message_reader_->Init(
socket, NewCallback(this, &ProtobufMessageReader<T>::OnNewData));
}
private:
void OnNewData(CompoundBuffer* buffer) {
void OnNewData(CompoundBuffer* buffer, Task* done_task) {
T* message = new T();
CompoundBufferInputStream stream(buffer);
bool ret = message->ParseFromZeroCopyStream(&stream);
if (!ret) {
LOG(WARNING) << "Received message that is not a valid protocol buffer.";
delete message;
} else {
message_received_callback_->Run(message);
DCHECK_EQ(stream.position(), buffer->total_bytes());
message_received_callback_->Run(
message, NewRunnableFunction(
&ProtobufMessageReader<T>::OnDone, message, done_task));
}
}
MessageReader message_reader_;
static void OnDone(T* message, Task* done_task) {
delete message;
done_task->Run();
delete done_task;
}
scoped_refptr<MessageReader> message_reader_;
scoped_ptr<MessageReceivedCallback> message_received_callback_;
};
......
// Copyright (c) 2010 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <string>
#include "base/message_loop.h"
#include "net/socket/socket.h"
#include "remoting/protocol/fake_session.h"
#include "remoting/protocol/message_reader.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "third_party/libjingle/source/talk/base/byteorder.h"
using testing::_;
using testing::DoAll;
using testing::Mock;
using testing::SaveArg;
namespace remoting {
namespace protocol {
namespace {
const char kTestMessage1[] = "Message1";
const char kTestMessage2[] = "Message2";
ACTION(CallDoneTask) {
arg1->Run();
delete arg1;
}
}
class MockMessageReceivedCallback {
public:
MOCK_METHOD2(OnMessage, void(CompoundBuffer*, Task*));
};
class MessageReaderTest : public testing::Test {
protected:
virtual void SetUp() {
reader_ = new MessageReader();
}
void InitReader() {
reader_->Init(&socket_, NewCallback(
&callback_, &MockMessageReceivedCallback::OnMessage));
}
void AddMessage(const std::string& message) {
std::string data = std::string(4, ' ') + message;
talk_base::SetBE32(const_cast<char*>(data.data()), message.size());
socket_.AppendInputData(data.data(), data.size());
}
bool CompareResult(CompoundBuffer* buffer, const std::string& expected) {
std::string result(buffer->total_bytes(), ' ');
buffer->CopyTo(const_cast<char*>(result.data()), result.size());
return result == expected;
}
// MessageLoop must be first here, so that is is destroyed the last.
MessageLoop message_loop_;
scoped_refptr<MessageReader> reader_;
FakeSocket socket_;
MockMessageReceivedCallback callback_;
};
// Receive one message and process it with delay
TEST_F(MessageReaderTest, OneMessage_Delay) {
CompoundBuffer* buffer;
Task* done_task;
AddMessage(kTestMessage1);
EXPECT_CALL(callback_, OnMessage(_, _))
.Times(1)
.WillOnce(DoAll(SaveArg<0>(&buffer),
SaveArg<1>(&done_task)));
InitReader();
Mock::VerifyAndClearExpectations(&callback_);
Mock::VerifyAndClearExpectations(&socket_);
EXPECT_TRUE(CompareResult(buffer, kTestMessage1));
// Verify that the reader starts reading again only after we've
// finished processing the previous message.
EXPECT_FALSE(socket_.read_pending());
done_task->Run();
EXPECT_TRUE(socket_.read_pending());
}
// Receive one message and process it instantly.
TEST_F(MessageReaderTest, OneMessage_Instant) {
AddMessage(kTestMessage1);
EXPECT_CALL(callback_, OnMessage(_, _))
.Times(1)
.WillOnce(CallDoneTask());
InitReader();
EXPECT_TRUE(socket_.read_pending());
}
// Receive two messages in one packet.
TEST_F(MessageReaderTest, TwoMessages_Together) {
CompoundBuffer* buffer1;
Task* done_task1;
CompoundBuffer* buffer2;
Task* done_task2;
AddMessage(kTestMessage1);
AddMessage(kTestMessage2);
EXPECT_CALL(callback_, OnMessage(_, _))
.Times(2)
.WillOnce(DoAll(SaveArg<0>(&buffer1),
SaveArg<1>(&done_task1)))
.WillOnce(DoAll(SaveArg<0>(&buffer2),
SaveArg<1>(&done_task2)));
InitReader();
Mock::VerifyAndClearExpectations(&callback_);
Mock::VerifyAndClearExpectations(&socket_);
EXPECT_TRUE(CompareResult(buffer1, kTestMessage1));
EXPECT_TRUE(CompareResult(buffer2, kTestMessage2));
// Verify that the reader starts reading again only after we've
// finished processing the previous message.
EXPECT_FALSE(socket_.read_pending());
done_task1->Run();
EXPECT_FALSE(socket_.read_pending());
done_task2->Run();
EXPECT_TRUE(socket_.read_pending());
}
// Receive two messages in one packet, and process the first one
// instantly.
TEST_F(MessageReaderTest, TwoMessages_Instant) {
CompoundBuffer* buffer2;
Task* done_task2;
AddMessage(kTestMessage1);
AddMessage(kTestMessage2);
EXPECT_CALL(callback_, OnMessage(_, _))
.Times(2)
.WillOnce(CallDoneTask())
.WillOnce(DoAll(SaveArg<0>(&buffer2),
SaveArg<1>(&done_task2)));
InitReader();
Mock::VerifyAndClearExpectations(&callback_);
Mock::VerifyAndClearExpectations(&socket_);
EXPECT_TRUE(CompareResult(buffer2, kTestMessage2));
// Verify that the reader starts reading again only after we've
// finished processing the second message.
EXPECT_FALSE(socket_.read_pending());
done_task2->Run();
EXPECT_TRUE(socket_.read_pending());
}
// Receive two messages in one packet, and process both of them
// instantly.
TEST_F(MessageReaderTest, TwoMessages_Instant2) {
AddMessage(kTestMessage1);
AddMessage(kTestMessage2);
EXPECT_CALL(callback_, OnMessage(_, _))
.Times(2)
.WillOnce(CallDoneTask())
.WillOnce(CallDoneTask());
InitReader();
EXPECT_TRUE(socket_.read_pending());
}
// Receive two messages in separate packets.
TEST_F(MessageReaderTest, TwoMessages_Separately) {
CompoundBuffer* buffer;
Task* done_task;
AddMessage(kTestMessage1);
EXPECT_CALL(callback_, OnMessage(_, _))
.Times(1)
.WillOnce(DoAll(SaveArg<0>(&buffer),
SaveArg<1>(&done_task)));
InitReader();
Mock::VerifyAndClearExpectations(&callback_);
Mock::VerifyAndClearExpectations(&socket_);
EXPECT_TRUE(CompareResult(buffer, kTestMessage1));
// Verify that the reader starts reading again only after we've
// finished processing the previous message.
EXPECT_FALSE(socket_.read_pending());
done_task->Run();
EXPECT_TRUE(socket_.read_pending());
// Write another message and verify that we receive it.
EXPECT_CALL(callback_, OnMessage(_, _))
.Times(1)
.WillOnce(DoAll(SaveArg<0>(&buffer),
SaveArg<1>(&done_task)));
AddMessage(kTestMessage2);
EXPECT_TRUE(CompareResult(buffer, kTestMessage2));
// Verify that the reader starts reading again only after we've
// finished processing the previous message.
EXPECT_FALSE(socket_.read_pending());
done_task->Run();
EXPECT_TRUE(socket_.read_pending());
}
} // namespace protocol
} // namespace remoting
......@@ -25,8 +25,8 @@ void ProtobufVideoReader::Init(protocol::Session* session,
video_stub_ = video_stub;
}
void ProtobufVideoReader::OnNewData(VideoPacket* packet) {
video_stub_->ProcessVideoPacket(packet, new DeleteTask<VideoPacket>(packet));
void ProtobufVideoReader::OnNewData(VideoPacket* packet, Task* done_task) {
video_stub_->ProcessVideoPacket(packet, done_task);
}
} // namespace protocol
......
......@@ -23,7 +23,7 @@ class ProtobufVideoReader : public VideoReader {
virtual void Init(protocol::Session* session, VideoStub* video_stub);
private:
void OnNewData(VideoPacket* packet);
void OnNewData(VideoPacket* packet, Task* done_task);
VideoPacketFormat::Encoding encoding_;
......
// Copyright (c) 2010 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// This is a wrapper class to help ref-counting a protobuf message.
// This file should only be inclued on host_message_dispatcher.cc and
// client_message_dispatche.cc.
// A single protobuf can contain multiple messages that will be handled by
// different message handlers. We use this wrapper to ensure that the
// protobuf is only deleted after all the handlers have finished executing.
#ifndef REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_
#define REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_
#include "base/ref_counted.h"
#include "base/task.h"
namespace remoting {
namespace protocol {
template <typename T>
class RefCountedMessage : public base::RefCounted<RefCountedMessage<T> > {
public:
RefCountedMessage(T* message) : message_(message) { }
T* message() { return message_.get(); }
private:
scoped_ptr<T> message_;
};
// Dummy methods to destroy messages.
template <class T>
static void DeleteMessage(scoped_refptr<T> message) { }
template <class T>
static Task* NewDeleteTask(scoped_refptr<T> message) {
return NewRunnableFunction(&DeleteMessage<T>, message);
}
} // namespace protocol
} // namespace remoting
#endif // REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_
......@@ -486,6 +486,7 @@
'protocol/fake_session.h',
'protocol/jingle_session_unittest.cc',
'protocol/message_decoder_unittest.cc',
'protocol/message_reader_unittest.cc',
'protocol/mock_objects.h',
'protocol/rtp_video_reader_unittest.cc',
'protocol/rtp_video_writer_unittest.cc',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment