Commit f31a5ce1 authored by kmarshall's avatar kmarshall Committed by Commit bot

Create new class "CastTransport", which encapsulates the message read and write event loops.

They currently target a revised stub interface for CastSocket, but not a full implementation.
That change is witheld in a separate Git client so this CL can remain at a manageable size.

Thanks!
BUG=396345
R=mfoltz@chromium.org
CC=wez@chromium.org

Review URL: https://codereview.chromium.org/555283002

Cr-Commit-Position: refs/heads/master@{#296993}
parent b6fc7567
......@@ -91,6 +91,8 @@ source_set("browser") {
"api/cast_channel/cast_message_util.h",
"api/cast_channel/cast_socket.cc",
"api/cast_channel/cast_socket.h",
"api/cast_channel/cast_transport.cc",
"api/cast_channel/cast_transport.h",
"api/cast_channel/logger.cc",
"api/cast_channel/logger.h",
"api/cast_channel/logger_util.cc",
......
......@@ -131,27 +131,14 @@ CastChannelAPI::GetFactoryInstance() {
return g_factory.Pointer();
}
scoped_ptr<CastSocket> CastChannelAPI::CreateCastSocket(
const std::string& extension_id, const net::IPEndPoint& ip_endpoint,
ChannelAuthType channel_auth, const base::TimeDelta& timeout) {
if (socket_for_test_.get()) {
return socket_for_test_.Pass();
} else {
return scoped_ptr<CastSocket>(
new CastSocket(extension_id,
ip_endpoint,
channel_auth,
this,
ExtensionsBrowserClient::Get()->GetNetLog(),
timeout,
logger_));
}
}
void CastChannelAPI::SetSocketForTest(scoped_ptr<CastSocket> socket_for_test) {
socket_for_test_ = socket_for_test.Pass();
}
scoped_ptr<cast_channel::CastSocket> CastChannelAPI::GetSocketForTest() {
return socket_for_test_.Pass();
}
void CastChannelAPI::OnError(const CastSocket* socket,
cast_channel::ChannelError error_state,
const cast_channel::LastErrors& last_errors) {
......@@ -367,13 +354,19 @@ bool CastChannelOpenFunction::Prepare() {
void CastChannelOpenFunction::AsyncWorkStart() {
DCHECK(api_);
DCHECK(ip_endpoint_.get());
scoped_ptr<CastSocket> socket = api_->CreateCastSocket(
extension_->id(),
*ip_endpoint_,
channel_auth_,
base::TimeDelta::FromMilliseconds(connect_info_->timeout.get()
? *connect_info_->timeout
: kDefaultConnectTimeoutMillis));
scoped_ptr<CastSocket> socket = api_->GetSocketForTest();
if (!socket.get()) {
socket.reset(new CastSocket(
extension_->id(),
*ip_endpoint_,
channel_auth_,
api_,
ExtensionsBrowserClient::Get()->GetNetLog(),
base::TimeDelta::FromMilliseconds(connect_info_->timeout.get()
? *connect_info_->timeout
: kDefaultConnectTimeoutMillis),
api_->GetLogger()));
}
new_channel_id_ = AddSocket(socket.release());
CastSocket* new_socket = GetSocket(new_channel_id_);
api_->GetLogger()->LogNewSocketEvent(*new_socket);
......@@ -424,11 +417,15 @@ bool CastChannelSendFunction::Prepare() {
}
void CastChannelSendFunction::AsyncWorkStart() {
CastSocket* socket = GetSocketOrCompleteWithError(
params_->channel.channel_id);
if (socket)
socket->SendMessage(params_->message,
base::Bind(&CastChannelSendFunction::OnSend, this));
CastSocket* socket = GetSocket(params_->channel.channel_id);
if (!socket) {
SetResultFromError(params_->channel.channel_id,
cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID);
AsyncWorkCompleted();
return;
}
socket->SendMessage(params_->message,
base::Bind(&CastChannelSendFunction::OnSend, this));
}
void CastChannelSendFunction::OnSend(int result) {
......@@ -455,10 +452,14 @@ bool CastChannelCloseFunction::Prepare() {
}
void CastChannelCloseFunction::AsyncWorkStart() {
CastSocket* socket = GetSocketOrCompleteWithError(
params_->channel.channel_id);
if (socket)
CastSocket* socket = GetSocket(params_->channel.channel_id);
if (!socket) {
SetResultFromError(params_->channel.channel_id,
cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID);
AsyncWorkCompleted();
} else {
socket->Close(base::Bind(&CastChannelCloseFunction::OnClose, this));
}
}
void CastChannelCloseFunction::OnClose(int result) {
......
......@@ -48,17 +48,10 @@ class CastChannelAPI : public BrowserContextKeyedAPI,
// BrowserContextKeyedAPI implementation.
static BrowserContextKeyedAPIFactory<CastChannelAPI>* GetFactoryInstance();
// Returns a new CastSocket that connects to |ip_endpoint| with authentication
// |channel_auth| and is to be owned by |extension_id|.
scoped_ptr<cast_channel::CastSocket> CreateCastSocket(
const std::string& extension_id,
const net::IPEndPoint& ip_endpoint,
cast_channel::ChannelAuthType channel_auth,
const base::TimeDelta& timeout);
// Returns a pointer to the Logger member variable.
// TODO(imcheng): Consider whether it is possible for this class to own the
// CastSockets and make this class the sole owner of Logger. Alternatively,
// CastSockets and make this class the sole owner of Logger.
// Alternatively,
// consider making Logger not ref-counted by passing a weak
// reference of Logger to the CastSockets instead.
scoped_refptr<cast_channel::Logger> GetLogger();
......@@ -67,6 +60,10 @@ class CastChannelAPI : public BrowserContextKeyedAPI,
// testing.
void SetSocketForTest(scoped_ptr<cast_channel::CastSocket> socket_for_test);
// Returns a test CastSocket instance, if it is defined.
// Otherwise returns a scoped_ptr with a NULL ptr value.
scoped_ptr<cast_channel::CastSocket> GetSocketForTest();
private:
friend class BrowserContextKeyedAPIFactory<CastChannelAPI>;
friend class ::CastChannelAPITest;
......
......@@ -62,6 +62,17 @@ bool MessageInfoToCastMessage(const MessageInfo& message,
return message_proto->IsInitialized();
}
bool IsCastMessageValid(const CastMessage& message_proto) {
if (message_proto.namespace_().empty() || message_proto.source_id().empty() ||
message_proto.destination_id().empty()) {
return false;
}
return (message_proto.payload_type() == CastMessage_PayloadType_STRING &&
message_proto.has_payload_utf8()) ||
(message_proto.payload_type() == CastMessage_PayloadType_BINARY &&
message_proto.has_payload_binary());
}
bool CastMessageToMessageInfo(const CastMessage& message_proto,
MessageInfo* message) {
DCHECK(message);
......
......@@ -19,6 +19,9 @@ struct MessageInfo;
bool MessageInfoToCastMessage(const MessageInfo& message,
CastMessage* message_proto);
// Checks if the contents of |message_proto| are semantically valid.
bool IsCastMessageValid(const CastMessage& message_proto);
// Fills |message| from |message_proto| and returns true on success.
bool CastMessageToMessageInfo(const CastMessage& message_proto,
MessageInfo* message);
......
......@@ -209,9 +209,8 @@ void CastSocket::Connect(const net::CompletionCallback& callback) {
void CastSocket::PostTaskToStartConnectLoop(int result) {
DCHECK(CalledOnValidThread());
DCHECK(connect_loop_callback_.IsCancelled());
connect_loop_callback_.Reset(base::Bind(&CastSocket::DoConnectLoop,
base::Unretained(this),
result));
connect_loop_callback_.Reset(
base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this), result));
base::MessageLoop::current()->PostTask(FROM_HERE,
connect_loop_callback_.callback());
}
......@@ -647,8 +646,8 @@ int CastSocket::DoWriteError(int result) {
void CastSocket::PostTaskToStartReadLoop() {
DCHECK(CalledOnValidThread());
DCHECK(read_loop_callback_.IsCancelled());
read_loop_callback_.Reset(base::Bind(&CastSocket::StartReadLoop,
base::Unretained(this)));
read_loop_callback_.Reset(
base::Bind(&CastSocket::StartReadLoop, base::Unretained(this)));
base::MessageLoop::current()->PostTask(FROM_HERE,
read_loop_callback_.callback());
}
......@@ -862,7 +861,8 @@ void CastSocket::SetWriteState(proto::WriteState write_state) {
}
CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback)
: callback(callback) { }
: callback(callback) {
}
bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) {
DCHECK(!io_buffer.get());
......@@ -876,7 +876,9 @@ bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) {
return true;
}
CastSocket::WriteRequest::~WriteRequest() { }
CastSocket::WriteRequest::~WriteRequest() {
}
} // namespace cast_channel
} // namespace core_api
} // namespace extensions
......
......@@ -47,11 +47,12 @@ class MessageFramer;
//
// NOTE: Not called "CastChannel" to reduce confusion with the generated API
// code.
// TODO(kmarshall): Inherit from CastSocket and rename to CastSocketImpl.
class CastSocket : public ApiResource {
public:
// Object to be informed of incoming messages and errors. The CastSocket that
// owns the delegate must not be deleted by it, only by the ApiResourceManager
// or in the callback to Close().
// Object to be informed of incoming messages and errors. The CastSocket
// that owns the delegate must not be deleted by it, only by the
// ApiResourceManager or in the callback to Close().
class Delegate {
public:
// An error occurred on the channel. |last_errors| contains the last errors
......
......@@ -80,8 +80,8 @@ class MockCastSocketDelegate : public CastSocket::Delegate {
void(const CastSocket* socket,
ChannelError error,
const LastErrors& last_errors));
MOCK_METHOD2(OnMessage, void(const CastSocket* socket,
const MessageInfo& message));
MOCK_METHOD2(OnMessage,
void(const CastSocket* socket, const MessageInfo& message));
};
class MockTCPSocket : public net::TCPClientSocket {
......@@ -204,8 +204,7 @@ class TestCastSocket : public CastSocket {
return msg.length() - MessageFramer::MessageHeader::header_size();
}
virtual ~TestCastSocket() {
}
virtual ~TestCastSocket() {}
// Helpers to set mock results for various operations.
void SetupTcp1Connect(net::IoMode mode, int result) {
......
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "extensions/browser/api/cast_channel/cast_transport.h"
#include <string>
#include "base/bind.h"
#include "base/format_macros.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/stringprintf.h"
#include "extensions/browser/api/cast_channel/cast_framer.h"
#include "extensions/browser/api/cast_channel/cast_message_util.h"
#include "extensions/browser/api/cast_channel/logger.h"
#include "extensions/browser/api/cast_channel/logger_util.h"
#include "extensions/common/api/cast_channel/cast_channel.pb.h"
#include "net/base/net_errors.h"
#define VLOG_WITH_CONNECTION(level) \
VLOG(level) << "[" << socket_->ip_endpoint().ToString() \
<< ", auth=" << socket_->channel_auth() << "] "
namespace extensions {
namespace core_api {
namespace cast_channel {
CastTransport::CastTransport(CastSocketInterface* socket,
Delegate* read_delegate,
scoped_refptr<Logger> logger)
: socket_(socket),
read_delegate_(read_delegate),
write_state_(WRITE_STATE_NONE),
read_state_(READ_STATE_NONE),
logger_(logger) {
DCHECK(socket);
DCHECK(read_delegate);
// Buffer is reused across messages to minimize unnecessary buffer
// [re]allocations.
read_buffer_ = new net::GrowableIOBuffer();
read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size());
framer_.reset(new MessageFramer(read_buffer_));
}
CastTransport::~CastTransport() {
DCHECK(thread_checker_.CalledOnValidThread());
FlushWriteQueue();
}
// static
proto::ReadState CastTransport::ReadStateToProto(
CastTransport::ReadState state) {
switch (state) {
case CastTransport::READ_STATE_NONE:
return proto::READ_STATE_NONE;
case CastTransport::READ_STATE_READ:
return proto::READ_STATE_READ;
case CastTransport::READ_STATE_READ_COMPLETE:
return proto::READ_STATE_READ_COMPLETE;
case CastTransport::READ_STATE_DO_CALLBACK:
return proto::READ_STATE_DO_CALLBACK;
case CastTransport::READ_STATE_ERROR:
return proto::READ_STATE_ERROR;
default:
NOTREACHED();
return proto::READ_STATE_NONE;
}
}
// static
proto::WriteState CastTransport::WriteStateToProto(
CastTransport::WriteState state) {
switch (state) {
case CastTransport::WRITE_STATE_NONE:
return proto::WRITE_STATE_NONE;
case CastTransport::WRITE_STATE_WRITE:
return proto::WRITE_STATE_WRITE;
case CastTransport::WRITE_STATE_WRITE_COMPLETE:
return proto::WRITE_STATE_WRITE_COMPLETE;
case CastTransport::WRITE_STATE_DO_CALLBACK:
return proto::WRITE_STATE_DO_CALLBACK;
case CastTransport::WRITE_STATE_ERROR:
return proto::WRITE_STATE_ERROR;
default:
NOTREACHED();
return proto::WRITE_STATE_NONE;
}
}
// static
proto::ErrorState CastTransport::ErrorStateToProto(ChannelError state) {
switch (state) {
case CHANNEL_ERROR_NONE:
return proto::CHANNEL_ERROR_NONE;
case CHANNEL_ERROR_CHANNEL_NOT_OPEN:
return proto::CHANNEL_ERROR_CHANNEL_NOT_OPEN;
case CHANNEL_ERROR_AUTHENTICATION_ERROR:
return proto::CHANNEL_ERROR_AUTHENTICATION_ERROR;
case CHANNEL_ERROR_CONNECT_ERROR:
return proto::CHANNEL_ERROR_CONNECT_ERROR;
case CHANNEL_ERROR_SOCKET_ERROR:
return proto::CHANNEL_ERROR_SOCKET_ERROR;
case CHANNEL_ERROR_TRANSPORT_ERROR:
return proto::CHANNEL_ERROR_TRANSPORT_ERROR;
case CHANNEL_ERROR_INVALID_MESSAGE:
return proto::CHANNEL_ERROR_INVALID_MESSAGE;
case CHANNEL_ERROR_INVALID_CHANNEL_ID:
return proto::CHANNEL_ERROR_INVALID_CHANNEL_ID;
case CHANNEL_ERROR_CONNECT_TIMEOUT:
return proto::CHANNEL_ERROR_CONNECT_TIMEOUT;
case CHANNEL_ERROR_UNKNOWN:
return proto::CHANNEL_ERROR_UNKNOWN;
default:
NOTREACHED();
return proto::CHANNEL_ERROR_NONE;
}
}
void CastTransport::FlushWriteQueue() {
for (; !write_queue_.empty(); write_queue_.pop()) {
net::CompletionCallback& callback = write_queue_.front().callback;
callback.Run(net::ERR_FAILED);
callback.Reset();
}
}
void CastTransport::SendMessage(const CastMessage& message,
const net::CompletionCallback& callback) {
DCHECK(thread_checker_.CalledOnValidThread());
std::string serialized_message;
if (!MessageFramer::Serialize(message, &serialized_message)) {
logger_->LogSocketEventForMessage(socket_->id(),
proto::SEND_MESSAGE_FAILED,
message.namespace_(),
"Error when serializing message.");
callback.Run(net::ERR_FAILED);
return;
}
WriteRequest write_request(
message.namespace_(), serialized_message, callback);
write_queue_.push(write_request);
logger_->LogSocketEventForMessage(
socket_->id(),
proto::MESSAGE_ENQUEUED,
message.namespace_(),
base::StringPrintf("Queue size: %" PRIuS, write_queue_.size()));
if (write_state_ == WRITE_STATE_NONE) {
SetWriteState(WRITE_STATE_WRITE);
OnWriteResult(net::OK);
}
}
CastTransport::WriteRequest::WriteRequest(
const std::string& namespace_,
const std::string& payload,
const net::CompletionCallback& callback)
: message_namespace(namespace_), callback(callback) {
VLOG(2) << "WriteRequest size: " << payload.size();
io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(payload),
payload.size());
}
CastTransport::WriteRequest::~WriteRequest() {
}
void CastTransport::SetReadState(ReadState read_state) {
if (read_state_ != read_state) {
read_state_ = read_state;
logger_->LogSocketReadState(socket_->id(), ReadStateToProto(read_state_));
}
}
void CastTransport::SetWriteState(WriteState write_state) {
if (write_state_ != write_state) {
write_state_ = write_state;
logger_->LogSocketWriteState(socket_->id(),
WriteStateToProto(write_state_));
}
}
void CastTransport::SetErrorState(ChannelError error_state) {
if (error_state_ != error_state) {
error_state_ = error_state;
logger_->LogSocketErrorState(socket_->id(),
ErrorStateToProto(error_state_));
}
}
void CastTransport::OnWriteResult(int result) {
DCHECK(thread_checker_.CalledOnValidThread());
VLOG_WITH_CONNECTION(1) << "OnWriteResult queue size: "
<< write_queue_.size();
if (write_queue_.empty()) {
SetWriteState(WRITE_STATE_NONE);
return;
}
// Network operations can either finish synchronously or asynchronously.
// This method executes the state machine transitions in a loop so that
// write state transitions happen even when network operations finish
// synchronously.
int rv = result;
do {
WriteState state = write_state_;
write_state_ = WRITE_STATE_NONE;
switch (state) {
case WRITE_STATE_WRITE:
rv = DoWrite();
break;
case WRITE_STATE_WRITE_COMPLETE:
rv = DoWriteComplete(rv);
break;
case WRITE_STATE_DO_CALLBACK:
rv = DoWriteCallback();
break;
case WRITE_STATE_ERROR:
rv = DoWriteError(rv);
break;
default:
NOTREACHED() << "BUG in write flow. Unknown state: " << state;
break;
}
} while (!write_queue_.empty() && rv != net::ERR_IO_PENDING &&
write_state_ != WRITE_STATE_NONE);
// No state change occurred in do-while loop above. This means state has
// transitioned to NONE.
if (write_state_ == WRITE_STATE_NONE) {
logger_->LogSocketWriteState(socket_->id(),
WriteStateToProto(write_state_));
}
// If write loop is done because the queue is empty then set write
// state to NONE
if (write_queue_.empty()) {
SetWriteState(WRITE_STATE_NONE);
}
// Write loop is done - if the result is ERR_FAILED then close with error.
if (rv == net::ERR_FAILED) {
DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
socket_->CloseWithError(error_state_);
FlushWriteQueue();
}
}
int CastTransport::DoWrite() {
DCHECK(!write_queue_.empty());
WriteRequest& request = write_queue_.front();
VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
<< request.io_buffer->size() << " bytes_written "
<< request.io_buffer->BytesConsumed();
SetWriteState(WRITE_STATE_WRITE_COMPLETE);
int rv = socket_->Write(
request.io_buffer.get(),
request.io_buffer->BytesRemaining(),
base::Bind(&CastTransport::OnWriteResult, base::Unretained(this)));
logger_->LogSocketEventWithRv(socket_->id(), proto::SOCKET_WRITE, rv);
return rv;
}
int CastTransport::DoWriteComplete(int result) {
VLOG_WITH_CONNECTION(2) << "DoWriteComplete result=" << result;
DCHECK(!write_queue_.empty());
if (result <= 0) { // NOTE that 0 also indicates an error
SetErrorState(CHANNEL_ERROR_TRANSPORT_ERROR);
SetWriteState(WRITE_STATE_ERROR);
return result == 0 ? net::ERR_FAILED : result;
}
// Some bytes were successfully written
WriteRequest& request = write_queue_.front();
scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
io_buffer->DidConsume(result);
if (io_buffer->BytesRemaining() == 0) { // Message fully sent
SetWriteState(WRITE_STATE_DO_CALLBACK);
} else {
SetWriteState(WRITE_STATE_WRITE);
}
return net::OK;
}
int CastTransport::DoWriteCallback() {
VLOG_WITH_CONNECTION(2) << "DoWriteCallback";
DCHECK(!write_queue_.empty());
SetWriteState(WRITE_STATE_WRITE);
WriteRequest& request = write_queue_.front();
int bytes_consumed = request.io_buffer->BytesConsumed();
logger_->LogSocketEventForMessage(
socket_->id(),
proto::MESSAGE_WRITTEN,
request.message_namespace,
base::StringPrintf("Bytes: %d", bytes_consumed));
request.callback.Run(net::OK);
write_queue_.pop();
return net::OK;
}
int CastTransport::DoWriteError(int result) {
VLOG_WITH_CONNECTION(2) << "DoWriteError result=" << result;
DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
DCHECK_LT(result, 0);
return net::ERR_FAILED;
}
void CastTransport::StartReadLoop() {
DCHECK(thread_checker_.CalledOnValidThread());
// Read loop would have already been started if read state is not NONE
if (read_state_ == READ_STATE_NONE) {
SetReadState(READ_STATE_READ);
OnReadResult(net::OK);
}
}
void CastTransport::OnReadResult(int result) {
DCHECK(thread_checker_.CalledOnValidThread());
// Network operations can either finish synchronously or asynchronously.
// This method executes the state machine transitions in a loop so that
// write state transitions happen even when network operations finish
// synchronously.
int rv = result;
do {
ReadState state = read_state_;
read_state_ = READ_STATE_NONE;
switch (state) {
case READ_STATE_READ:
rv = DoRead();
break;
case READ_STATE_READ_COMPLETE:
rv = DoReadComplete(rv);
break;
case READ_STATE_DO_CALLBACK:
rv = DoReadCallback();
break;
case READ_STATE_ERROR:
rv = DoReadError(rv);
DCHECK_EQ(read_state_, READ_STATE_NONE);
break;
default:
NOTREACHED() << "BUG in read flow. Unknown state: " << state;
break;
}
} while (rv != net::ERR_IO_PENDING && read_state_ != READ_STATE_NONE);
// No state change occurred in do-while loop above. This means state has
// transitioned to NONE.
if (read_state_ == READ_STATE_NONE) {
logger_->LogSocketReadState(socket_->id(), ReadStateToProto(read_state_));
}
if (rv == net::ERR_FAILED) {
DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
socket_->CloseWithError(error_state_);
FlushWriteQueue();
read_delegate_->OnError(
socket_, error_state_, logger_->GetLastErrors(socket_->id()));
}
}
int CastTransport::DoRead() {
VLOG_WITH_CONNECTION(2) << "DoRead";
SetReadState(READ_STATE_READ_COMPLETE);
// Determine how many bytes need to be read.
size_t num_bytes_to_read = framer_->BytesRequested();
// Read up to num_bytes_to_read into |current_read_buffer_|.
int rv = socket_->Read(
read_buffer_.get(),
base::checked_cast<uint32>(num_bytes_to_read),
base::Bind(&CastTransport::OnReadResult, base::Unretained(this)));
return rv;
}
int CastTransport::DoReadComplete(int result) {
VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
if (result <= 0) {
SetErrorState(CHANNEL_ERROR_TRANSPORT_ERROR);
SetReadState(READ_STATE_ERROR);
return result == 0 ? net::ERR_FAILED : result;
}
size_t message_size;
DCHECK(current_message_.get() == NULL);
current_message_ = framer_->Ingest(result, &message_size, &error_state_);
if (current_message_.get()) {
DCHECK_EQ(error_state_, CHANNEL_ERROR_NONE);
DCHECK_GT(message_size, static_cast<size_t>(0));
logger_->LogSocketEventForMessage(
socket_->id(),
proto::MESSAGE_READ,
current_message_->namespace_(),
base::StringPrintf("Message size: %u",
static_cast<uint32>(message_size)));
SetReadState(READ_STATE_DO_CALLBACK);
} else if (error_state_ != CHANNEL_ERROR_NONE) {
DCHECK(current_message_.get() == NULL);
SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
SetReadState(READ_STATE_ERROR);
} else {
DCHECK(current_message_.get() == NULL);
SetReadState(READ_STATE_READ);
}
return net::OK;
}
int CastTransport::DoReadCallback() {
VLOG_WITH_CONNECTION(2) << "DoReadCallback";
SetReadState(READ_STATE_READ);
if (!IsCastMessageValid(*current_message_)) {
SetReadState(READ_STATE_ERROR);
SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
return net::ERR_INVALID_RESPONSE;
}
logger_->LogSocketEventForMessage(socket_->id(),
proto::NOTIFY_ON_MESSAGE,
current_message_->namespace_(),
std::string());
read_delegate_->OnMessage(socket_, *current_message_);
current_message_.reset();
return net::OK;
}
int CastTransport::DoReadError(int result) {
VLOG_WITH_CONNECTION(2) << "DoReadError";
DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
DCHECK_LE(result, 0);
return net::ERR_FAILED;
}
} // namespace cast_channel
} // namespace core_api
} // namespace extensions
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef EXTENSIONS_BROWSER_API_CAST_CHANNEL_CAST_TRANSPORT_H_
#define EXTENSIONS_BROWSER_API_CAST_CHANNEL_CAST_TRANSPORT_H_
#include <queue>
#include <string>
#include "base/memory/ref_counted.h"
#include "base/threading/thread_checker.h"
#include "extensions/browser/api/cast_channel/logger.h"
#include "extensions/common/api/cast_channel.h"
#include "net/base/completion_callback.h"
namespace net {
class DrainableIOBuffer;
class IPEndPoint;
class IOBuffer;
class DrainableIOBuffer;
class GrowableIOBuffer;
} // namespace net
namespace extensions {
namespace core_api {
namespace cast_channel {
class CastMessage;
struct LastErrors;
class Logger;
class MessageFramer;
// TODO(kmarshall): Migrate CastSocket to new interface.
// Redirect references to CastSocket in logger.h to this interface once
// the interface is promoted to cast_socket.h.
class CastSocketInterface {
public:
CastSocketInterface() {}
virtual ~CastSocketInterface() {}
// Writes at least one, and up to |size| bytes to the socket.
// Returns net::ERR_IO_PENDING if the operation will complete
// asynchronously, in which case |callback| will be invoked
// on completion.
// Asynchronous writes are cancleled if the CastSocket is deleted.
// All values <= zero indicate an error.
virtual int Write(net::IOBuffer* buffer,
size_t size,
const net::CompletionCallback& callback) = 0;
// Reads at least one, and up to |size| bytes from the socket.
// Returns net::ERR_IO_PENDING if the operation will complete
// asynchronously, in which case |callback| will be invoked
// on completion.
// All values <= zero indicate an error.
virtual int Read(net::IOBuffer* buf,
int buf_len,
const net::CompletionCallback& callback) = 0;
virtual void CloseWithError(ChannelError error) = 0;
virtual const net::IPEndPoint& ip_endpoint() const = 0;
virtual ChannelAuthType channel_auth() const = 0;
virtual int id() const = 0;
};
// Manager class for reading and writing messages to/from a CastSocket.
class CastTransport {
public:
// Object to be informed of incoming messages and errors.
class Delegate {
public:
// An error occurred on the channel. |last_errors| contains the last errors
// logged for the channel from the implementation.
virtual void OnError(const CastSocketInterface* socket,
ChannelError error_state,
const LastErrors& last_errors) = 0;
// A message was received on the channel.
virtual void OnMessage(const CastSocketInterface* socket,
const CastMessage& message) = 0;
protected:
virtual ~Delegate() {}
};
// Adds a CastMessage read/write layer to a socket.
// Message read events are propagated to the owner via |read_delegate|.
// The CastTransport object should be deleted prior to the
// underlying socket being deleted.
CastTransport(CastSocketInterface* socket,
Delegate* read_delegate,
scoped_refptr<Logger> logger);
virtual ~CastTransport();
// Sends a CastMessage to |socket_|.
// |message|: The message to send.
// |callback|: Callback to be invoked when the write operation has finished.
void SendMessage(const CastMessage& message,
const net::CompletionCallback& callback);
// Starts reading messages from |socket_|.
void StartReadLoop();
private:
// Internal write states.
enum WriteState {
WRITE_STATE_NONE,
WRITE_STATE_WRITE,
WRITE_STATE_WRITE_COMPLETE,
WRITE_STATE_DO_CALLBACK,
WRITE_STATE_ERROR,
};
// Internal read states.
enum ReadState {
READ_STATE_NONE,
READ_STATE_READ,
READ_STATE_READ_COMPLETE,
READ_STATE_DO_CALLBACK,
READ_STATE_ERROR,
};
// Holds a message to be written to the socket. |callback| is invoked when the
// message is fully written or an error occurrs.
struct WriteRequest {
explicit WriteRequest(const std::string& namespace_,
const std::string& payload,
const net::CompletionCallback& callback);
~WriteRequest();
// Namespace of the serialized message.
std::string message_namespace;
// Write completion callback, invoked when the operation has completed or
// failed.
net::CompletionCallback callback;
// Buffer with outgoing data.
scoped_refptr<net::DrainableIOBuffer> io_buffer;
};
static proto::ReadState ReadStateToProto(CastTransport::ReadState state);
static proto::WriteState WriteStateToProto(CastTransport::WriteState state);
static proto::ErrorState ErrorStateToProto(ChannelError state);
// Terminates all in-flight write callbacks with error code ERR_FAILED.
void FlushWriteQueue();
// Main method that performs write flow state transitions.
void OnWriteResult(int result);
// Each of the below Do* method is executed in the corresponding
// write state. For example when write state is WRITE_STATE_WRITE_COMPLETE
// DowriteComplete is called, and so on.
int DoWrite();
int DoWriteComplete(int result);
int DoWriteCallback();
int DoWriteError(int result);
// Main method that performs write flow state transitions.
void OnReadResult(int result);
// Each of the below Do* method is executed in the corresponding
// write state. For example when write state is READ_STATE_READ_COMPLETE
// DoReadComplete is called, and so on.
int DoRead();
int DoReadComplete(int result);
int DoReadCallback();
int DoReadError(int result);
void SetReadState(ReadState read_state);
void SetWriteState(WriteState write_state);
void SetErrorState(ChannelError error_state);
// Queue of pending writes. The message at the front of the queue is the one
// being written.
std::queue<WriteRequest> write_queue_;
// Buffer used for read operations. Reused for every read.
scoped_refptr<net::GrowableIOBuffer> read_buffer_;
// Constructs and parses the wire representation of message frames.
scoped_ptr<MessageFramer> framer_;
// Last message received on the socket.
scoped_ptr<CastMessage> current_message_;
// Socket used for I/O operations.
CastSocketInterface* const socket_;
// Methods for communicating message receipt and error status to client code.
Delegate* const read_delegate_;
// Write flow state machine state.
WriteState write_state_;
// Read flow state machine state.
ReadState read_state_;
// Most recent error that occurred during read or write operation, if any.
ChannelError error_state_;
scoped_refptr<Logger> logger_;
base::ThreadChecker thread_checker_;
DISALLOW_COPY_AND_ASSIGN(CastTransport);
};
} // namespace cast_channel
} // namespace core_api
} // namespace extensions
#endif // EXTENSIONS_BROWSER_API_CAST_CHANNEL_CAST_TRANSPORT_H_
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "extensions/browser/api/cast_channel/cast_transport.h"
#include <stddef.h>
#include <queue>
#include "base/test/simple_test_tick_clock.h"
#include "extensions/browser/api/cast_channel/cast_framer.h"
#include "extensions/browser/api/cast_channel/cast_transport.h"
#include "extensions/browser/api/cast_channel/logger.h"
#include "extensions/browser/api/cast_channel/logger_util.h"
#include "extensions/common/api/cast_channel/cast_channel.pb.h"
#include "net/base/capturing_net_log.h"
#include "net/base/completion_callback.h"
#include "net/base/net_errors.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::DoAll;
using testing::InSequence;
using testing::Invoke;
using testing::NotNull;
using testing::Return;
using testing::WithArg;
namespace extensions {
namespace core_api {
namespace cast_channel {
namespace {
// Mockable placeholder for write completion events.
class CompleteHandler {
public:
CompleteHandler() {}
MOCK_METHOD1(Complete, void(int result));
private:
DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
};
// Creates a CastMessage proto with the bare minimum required fields set.
CastMessage CreateCastMessage() {
CastMessage output;
output.set_protocol_version(CastMessage::CASTV2_1_0);
output.set_namespace_("x");
output.set_source_id("source");
output.set_destination_id("destination");
output.set_payload_type(CastMessage::STRING);
output.set_payload_utf8("payload");
return output;
}
// FIFO queue of completion callbacks. Outstanding write operations are
// Push()ed into the queue. Callback completion is simulated by invoking
// Pop() in the same order as Push().
class CompletionQueue {
public:
CompletionQueue() {}
~CompletionQueue() { CHECK_EQ(0u, cb_queue_.size()); }
// Enqueues a pending completion callback.
void Push(const net::CompletionCallback& cb) { cb_queue_.push(cb); }
// Runs the next callback and removes it from the queue.
void Pop(int rv) {
CHECK_GT(cb_queue_.size(), 0u);
cb_queue_.front().Run(rv);
cb_queue_.pop();
}
private:
std::queue<net::CompletionCallback> cb_queue_;
DISALLOW_COPY_AND_ASSIGN(CompletionQueue);
};
// GMock action that reads data from an IOBuffer and writes it to a string
// variable.
//
// buf_idx (template parameter 0): 0-based index of the net::IOBuffer
// in the function mock arg list.
// size_idx (template parameter 1): 0-based index of the byte count arg.
// str: pointer to the string which will receive data from the buffer.
ACTION_TEMPLATE(ReadBufferToString,
HAS_2_TEMPLATE_PARAMS(int, buf_idx, int, size_idx),
AND_1_VALUE_PARAMS(str)) {
str->assign(testing::get<buf_idx>(args)->data(),
testing::get<size_idx>(args));
}
// GMock action that writes data from a string to an IOBuffer.
//
// buf_idx (template parameter 0): 0-based index of the IOBuffer arg.
// str: the string containing data to be written to the IOBuffer.
ACTION_TEMPLATE(FillBufferFromString,
HAS_1_TEMPLATE_PARAMS(int, buf_idx),
AND_1_VALUE_PARAMS(str)) {
memcpy(testing::get<buf_idx>(args)->data(), str.data(), str.size());
}
// GMock action that enqueues a write completion callback in a queue.
//
// buf_idx (template parameter 0): 0-based index of the CompletionCallback.
// completion_queue: a pointer to the CompletionQueue.
ACTION_TEMPLATE(EnqueueCallback,
HAS_1_TEMPLATE_PARAMS(int, cb_idx),
AND_1_VALUE_PARAMS(completion_queue)) {
completion_queue->Push(testing::get<cb_idx>(args));
}
// Checks if two proto messages are the same.
// From
// third_party/cacheinvalidation/overrides/google/cacheinvalidation/deps/gmock.h
MATCHER_P(EqualsProto, message, "") {
std::string expected_serialized, actual_serialized;
message.SerializeToString(&expected_serialized);
arg.SerializeToString(&actual_serialized);
return expected_serialized == actual_serialized;
}
} // namespace
class MockCastTransportDelegate : public CastTransport::Delegate {
public:
MOCK_METHOD3(OnError,
void(const CastSocketInterface* socket,
ChannelError error,
const LastErrors& last_errors));
MOCK_METHOD2(OnMessage,
void(const CastSocketInterface* socket,
const CastMessage& message));
};
class MockCastSocket : public CastSocketInterface {
public:
MockCastSocket() {
net::IPAddressNumber number;
number.push_back(192);
number.push_back(0);
number.push_back(0);
number.push_back(1);
ip_ = net::IPEndPoint(number, 8009);
}
virtual ~MockCastSocket() {}
// The IP endpoint for the destination of the channel.
virtual const net::IPEndPoint& ip_endpoint() const OVERRIDE { return ip_; }
// The authentication level requested for the channel.
virtual ChannelAuthType channel_auth() const OVERRIDE {
return CHANNEL_AUTH_TYPE_SSL_VERIFIED;
}
virtual int id() const OVERRIDE { return 1; }
MOCK_METHOD3(Write,
int(net::IOBuffer* buffer,
size_t size,
const net::CompletionCallback& callback));
MOCK_METHOD3(Read,
int(net::IOBuffer* buf,
int buf_len,
const net::CompletionCallback& callback));
MOCK_METHOD1(CloseWithError, void(ChannelError error));
protected:
virtual void CloseInternal() {}
private:
net::IPEndPoint ip_;
net::CapturingNetLog capturing_net_log_;
};
class CastTransportTest : public testing::Test {
public:
CastTransportTest()
: logger_(new Logger(
scoped_ptr<base::TickClock>(new base::SimpleTestTickClock),
base::TimeTicks())) {
transport_.reset(new CastTransport(&mock_socket_, &delegate_, logger_));
}
virtual ~CastTransportTest() {}
protected:
MockCastTransportDelegate delegate_;
MockCastSocket mock_socket_;
scoped_refptr<Logger> logger_;
scoped_ptr<CastTransport> transport_;
};
// ----------------------------------------------------------------------------
// Asynchronous write tests
TEST_F(CastTransportTest, TestFullWriteAsync) {
CompletionQueue socket_cbs;
CompleteHandler write_handler;
std::string output;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size(), _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)));
EXPECT_CALL(write_handler, Complete(net::OK));
transport_->SendMessage(
message,
base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
socket_cbs.Pop(serialized_message.size());
EXPECT_EQ(serialized_message, output);
}
TEST_F(CastTransportTest, TestPartialWritesAsync) {
InSequence seq;
CompletionQueue socket_cbs;
CompleteHandler write_handler;
std::string output;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Only one byte is written.
EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size(), _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)));
// Remainder of bytes are written.
EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size() - 1, _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)));
transport_->SendMessage(
message,
base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
EXPECT_EQ(serialized_message, output);
socket_cbs.Pop(1);
EXPECT_CALL(write_handler, Complete(net::OK));
socket_cbs.Pop(serialized_message.size() - 1);
EXPECT_EQ(serialized_message.substr(1, serialized_message.size() - 1),
output);
}
TEST_F(CastTransportTest, TestWriteFailureAsync) {
CompletionQueue socket_cbs;
CompleteHandler write_handler;
CastMessage message = CreateCastMessage();
EXPECT_CALL(mock_socket_, Write(NotNull(), _, _)).WillOnce(
DoAll(EnqueueCallback<2>(&socket_cbs), Return(net::ERR_IO_PENDING)));
EXPECT_CALL(write_handler, Complete(net::ERR_FAILED));
transport_->SendMessage(
message,
base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
socket_cbs.Pop(net::ERR_CONNECTION_RESET);
}
// ----------------------------------------------------------------------------
// Synchronous write tests
TEST_F(CastTransportTest, TestFullWriteSync) {
CompleteHandler write_handler;
std::string output;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size(), _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
Return(serialized_message.size())));
EXPECT_CALL(write_handler, Complete(net::OK));
transport_->SendMessage(
message,
base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
EXPECT_EQ(serialized_message, output);
}
TEST_F(CastTransportTest, TestPartialWritesSync) {
InSequence seq;
CompleteHandler write_handler;
std::string output;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Only one byte is written.
EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size(), _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output), Return(1)));
// Remainder of bytes are written.
EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size() - 1, _))
.WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
Return(serialized_message.size() - 1)));
EXPECT_CALL(write_handler, Complete(net::OK));
transport_->SendMessage(
message,
base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
EXPECT_EQ(serialized_message.substr(1, serialized_message.size() - 1),
output);
}
TEST_F(CastTransportTest, TestWriteFailureSync) {
CompleteHandler write_handler;
CastMessage message = CreateCastMessage();
EXPECT_CALL(mock_socket_, Write(NotNull(), _, _))
.WillOnce(Return(net::ERR_CONNECTION_RESET));
EXPECT_CALL(write_handler, Complete(net::ERR_FAILED));
transport_->SendMessage(
message,
base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
}
// ----------------------------------------------------------------------------
// Asynchronous read tests
TEST_F(CastTransportTest, TestFullReadAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
// Read bytes [4, n].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size())),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_, OnMessage(&mock_socket_, EqualsProto(message)));
transport_->StartReadLoop();
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(Return(net::ERR_IO_PENDING));
socket_cbs.Pop(serialized_message.size() -
MessageFramer::MessageHeader::header_size());
}
TEST_F(CastTransportTest, TestPartialReadAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
// Read bytes [4, n-1].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
// Read final byte.
EXPECT_CALL(mock_socket_, Read(NotNull(), 1, _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
serialized_message.size() - 1, 1)),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_, OnMessage(&mock_socket_, EqualsProto(message)));
transport_->StartReadLoop();
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(Return(net::ERR_IO_PENDING));
socket_cbs.Pop(serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1);
socket_cbs.Pop(1);
}
TEST_F(CastTransportTest, TestReadErrorInHeaderAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_,
OnError(&mock_socket_, CHANNEL_ERROR_TRANSPORT_ERROR, _));
EXPECT_CALL(mock_socket_, CloseWithError(CHANNEL_ERROR_TRANSPORT_ERROR));
transport_->StartReadLoop();
// Header read failure.
socket_cbs.Pop(net::ERR_CONNECTION_RESET);
}
TEST_F(CastTransportTest, TestReadErrorInBodyAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
// Read bytes [4, n-1].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_,
OnError(&mock_socket_, CHANNEL_ERROR_TRANSPORT_ERROR, _));
transport_->StartReadLoop();
// Header read is OK.
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
EXPECT_CALL(mock_socket_, CloseWithError(CHANNEL_ERROR_TRANSPORT_ERROR));
// Body read fails.
socket_cbs.Pop(net::ERR_CONNECTION_RESET);
}
TEST_F(CastTransportTest, TestReadCorruptedMessageAsync) {
CompletionQueue socket_cbs;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Corrupt the serialized message body(set it to X's).
for (size_t i = MessageFramer::MessageHeader::header_size();
i < serialized_message.size();
++i) {
serialized_message[i] = 'x';
}
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
// Read bytes [4, n].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
EnqueueCallback<2>(&socket_cbs),
Return(net::ERR_IO_PENDING)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_,
OnError(&mock_socket_, CHANNEL_ERROR_INVALID_MESSAGE, _));
transport_->StartReadLoop();
socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
EXPECT_CALL(mock_socket_, CloseWithError(CHANNEL_ERROR_INVALID_MESSAGE));
socket_cbs.Pop(serialized_message.size() -
MessageFramer::MessageHeader::header_size());
}
// ----------------------------------------------------------------------------
// Synchronous read tests
TEST_F(CastTransportTest, TestFullReadSync) {
InSequence s;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
Return(MessageFramer::MessageHeader::header_size())))
.RetiresOnSaturation();
// Read bytes [4, n].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size())),
Return(serialized_message.size() -
MessageFramer::MessageHeader::header_size())))
.RetiresOnSaturation();
EXPECT_CALL(delegate_, OnMessage(&mock_socket_, EqualsProto(message)));
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(Return(net::ERR_IO_PENDING));
transport_->StartReadLoop();
}
TEST_F(CastTransportTest, TestPartialReadSync) {
InSequence s;
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
Return(MessageFramer::MessageHeader::header_size())))
.RetiresOnSaturation();
// Read bytes [4, n-1].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
Return(serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)))
.RetiresOnSaturation();
// Read final byte.
EXPECT_CALL(mock_socket_, Read(NotNull(), 1, _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
serialized_message.size() - 1, 1)),
Return(1)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_, OnMessage(&mock_socket_, EqualsProto(message)));
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(Return(net::ERR_IO_PENDING));
transport_->StartReadLoop();
}
TEST_F(CastTransportTest, TestReadErrorInHeaderSync) {
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
Return(net::ERR_CONNECTION_RESET)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_,
OnError(&mock_socket_, CHANNEL_ERROR_TRANSPORT_ERROR, _));
EXPECT_CALL(mock_socket_, CloseWithError(CHANNEL_ERROR_TRANSPORT_ERROR));
transport_->StartReadLoop();
}
TEST_F(CastTransportTest, TestReadErrorInBodySync) {
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
Return(MessageFramer::MessageHeader::header_size())))
.RetiresOnSaturation();
// Read bytes [4, n-1].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
Return(net::ERR_CONNECTION_RESET)))
.RetiresOnSaturation();
EXPECT_CALL(delegate_,
OnError(&mock_socket_, CHANNEL_ERROR_TRANSPORT_ERROR, _));
EXPECT_CALL(mock_socket_, CloseWithError(CHANNEL_ERROR_TRANSPORT_ERROR));
transport_->StartReadLoop();
}
TEST_F(CastTransportTest, TestReadCorruptedMessageSync) {
CastMessage message = CreateCastMessage();
std::string serialized_message;
EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
// Corrupt the serialized message body(set it to X's).
for (size_t i = MessageFramer::MessageHeader::header_size();
i < serialized_message.size();
++i) {
serialized_message[i] = 'x';
}
// Read bytes [0, 3].
EXPECT_CALL(mock_socket_,
Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
Return(MessageFramer::MessageHeader::header_size())))
.RetiresOnSaturation();
// Read bytes [4, n].
EXPECT_CALL(mock_socket_,
Read(NotNull(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size(),
_))
.WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
MessageFramer::MessageHeader::header_size(),
serialized_message.size() -
MessageFramer::MessageHeader::header_size() - 1)),
Return(serialized_message.size() -
MessageFramer::MessageHeader::header_size())))
.RetiresOnSaturation();
EXPECT_CALL(delegate_,
OnError(&mock_socket_, CHANNEL_ERROR_INVALID_MESSAGE, _));
EXPECT_CALL(mock_socket_, CloseWithError(CHANNEL_ERROR_INVALID_MESSAGE));
transport_->StartReadLoop();
}
} // namespace cast_channel
} // namespace core_api
} // namespace extensions
......@@ -361,6 +361,10 @@
'browser/api/cast_channel/cast_message_util.h',
'browser/api/cast_channel/cast_socket.cc',
'browser/api/cast_channel/cast_socket.h',
'browser/api/cast_channel/cast_framer.cc',
'browser/api/cast_channel/cast_framer.h',
'browser/api/cast_channel/cast_transport.h',
'browser/api/cast_channel/cast_transport.cc',
'browser/api/cast_channel/logger.cc',
'browser/api/cast_channel/logger.h',
'browser/api/cast_channel/logger_util.cc',
......@@ -1072,6 +1076,7 @@
'browser/api/cast_channel/cast_channel_api_unittest.cc',
'browser/api/cast_channel/cast_framer_unittest.cc',
'browser/api/cast_channel/cast_socket_unittest.cc',
'browser/api/cast_channel/cast_transport_unittest.cc',
'browser/api/cast_channel/logger_unittest.cc',
'browser/api/declarative/declarative_rule_unittest.cc',
'browser/api/declarative/deduping_factory_unittest.cc',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment