Commit 56a76561 authored by Ken MacKay's avatar Ken MacKay Committed by Commit Bot

[Chromecast] Allow SmallMessageSocket to receive into pool-allocated buffers

By receiving into a threadsafe IOBufferPool, we can avoid extra memory
copies for things like audio streaming.

Bug: internal b/127963522
Test: cast_net_unittests
Change-Id: I86da6e7a5e07a7a3273d53485393468b4481ec94
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1822220
Commit-Queue: Kenneth MacKay <kmackay@chromium.org>
Reviewed-by: default avatarYuchen Liu <yucliu@chromium.org>
Cr-Commit-Position: refs/heads/master@{#700113}
parent b04ec388
...@@ -71,6 +71,7 @@ cast_source_set("small_message_socket") { ...@@ -71,6 +71,7 @@ cast_source_set("small_message_socket") {
] ]
deps = [ deps = [
":io_buffer_pool",
"//base", "//base",
] ]
} }
...@@ -99,10 +100,12 @@ test("cast_net_unittests") { ...@@ -99,10 +100,12 @@ test("cast_net_unittests") {
sources = [ sources = [
"fake_stream_socket_unittest.cc", "fake_stream_socket_unittest.cc",
"io_buffer_pool_unittest.cc", "io_buffer_pool_unittest.cc",
"small_message_socket_unittest.cc",
] ]
deps = [ deps = [
":io_buffer_pool", ":io_buffer_pool",
":small_message_socket",
":test_support", ":test_support",
"//base", "//base",
"//base/test:run_all_unittests", "//base/test:run_all_unittests",
......
...@@ -89,6 +89,10 @@ void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) { ...@@ -89,6 +89,10 @@ void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) {
peer_ = peer; peer_ = peer;
} }
void FakeStreamSocket::SetBadSenderMode(bool bad_sender) {
bad_sender_mode_ = bad_sender;
}
int FakeStreamSocket::Read(net::IOBuffer* buf, int FakeStreamSocket::Read(net::IOBuffer* buf,
int buf_len, int buf_len,
net::CompletionOnceCallback callback) { net::CompletionOnceCallback callback) {
...@@ -105,8 +109,12 @@ int FakeStreamSocket::Write( ...@@ -105,8 +109,12 @@ int FakeStreamSocket::Write(
if (!peer_) { if (!peer_) {
return net::ERR_SOCKET_NOT_CONNECTED; return net::ERR_SOCKET_NOT_CONNECTED;
} }
peer_->buffer_->Write(buf->data(), buf_len); int amount_to_send = buf_len;
return buf_len; if (bad_sender_mode_) {
amount_to_send = std::min(buf_len, buf_len / 2 + 1);
}
peer_->buffer_->Write(buf->data(), amount_to_send);
return amount_to_send;
} }
int FakeStreamSocket::SetReceiveBufferSize(int32_t /* size */) { int FakeStreamSocket::SetReceiveBufferSize(int32_t /* size */) {
......
...@@ -27,6 +27,10 @@ class FakeStreamSocket : public net::StreamSocket { ...@@ -27,6 +27,10 @@ class FakeStreamSocket : public net::StreamSocket {
// Sets the peer for this socket. // Sets the peer for this socket.
void SetPeer(FakeStreamSocket* peer); void SetPeer(FakeStreamSocket* peer);
// Enables/disables "bad sender mode", where Write() will always try to send
// less than the full buffer. Disabled by default.
void SetBadSenderMode(bool bad_sender);
// net::StreamSocket implementation: // net::StreamSocket implementation:
int Read(net::IOBuffer* buf, int Read(net::IOBuffer* buf,
int buf_len, int buf_len,
...@@ -60,6 +64,7 @@ class FakeStreamSocket : public net::StreamSocket { ...@@ -60,6 +64,7 @@ class FakeStreamSocket : public net::StreamSocket {
const std::unique_ptr<SocketBuffer> buffer_; const std::unique_ptr<SocketBuffer> buffer_;
FakeStreamSocket* peer_; FakeStreamSocket* peer_;
net::NetLogWithSource net_log_; net::NetLogWithSource net_log_;
bool bad_sender_mode_ = false;
DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket); DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket);
}; };
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "base/logging.h" #include "base/logging.h"
#include "base/sequenced_task_runner.h" #include "base/sequenced_task_runner.h"
#include "base/threading/sequenced_task_runner_handle.h" #include "base/threading/sequenced_task_runner_handle.h"
#include "chromecast/net/io_buffer_pool.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/socket/socket.h" #include "net/socket/socket.h"
...@@ -33,7 +34,7 @@ const int kDefaultBufferSize = 2048; ...@@ -33,7 +34,7 @@ const int kDefaultBufferSize = 2048;
} // namespace } // namespace
class SmallMessageSocket::WriteBuffer : public ::net::IOBuffer { class SmallMessageSocket::BufferWrapper : public ::net::IOBuffer {
public: public:
void SetUnderlyingBuffer(scoped_refptr<IOBuffer> base, size_t size) { void SetUnderlyingBuffer(scoped_refptr<IOBuffer> base, size_t size) {
base_ = std::move(base); base_ = std::move(base);
...@@ -42,40 +43,111 @@ class SmallMessageSocket::WriteBuffer : public ::net::IOBuffer { ...@@ -42,40 +43,111 @@ class SmallMessageSocket::WriteBuffer : public ::net::IOBuffer {
data_ = base_->data(); data_ = base_->data();
} }
scoped_refptr<IOBuffer> TakeUnderlyingBuffer() { return std::move(base_); }
void ClearUnderlyingBuffer() {
data_ = nullptr;
base_.reset();
}
void DidConsume(size_t bytes) { void DidConsume(size_t bytes) {
used_ += bytes; used_ += bytes;
data_ = base_->data() + used_; data_ = base_->data() + used_;
} }
size_t BytesRemaining() { return size_ - used_; } char* StartOfBuffer() const {
DCHECK(base_);
return base_->data();
}
size_t size() const { return size_; }
size_t used() const { return used_; }
size_t remaining() const {
DCHECK_GE(size_, used_);
return size_ - used_;
}
private: private:
~WriteBuffer() override { data_ = nullptr; } ~BufferWrapper() override { data_ = nullptr; }
scoped_refptr<IOBuffer> base_; scoped_refptr<IOBuffer> base_;
size_t size_; size_t size_ = 0;
size_t used_; size_t used_ = 0;
}; };
SmallMessageSocket::SmallMessageSocket(std::unique_ptr<net::Socket> socket) SmallMessageSocket::SmallMessageSocket(std::unique_ptr<net::Socket> socket)
: socket_(std::move(socket)), : socket_(std::move(socket)),
task_runner_(base::SequencedTaskRunnerHandle::Get()), task_runner_(base::SequencedTaskRunnerHandle::Get()),
write_buffer_(base::MakeRefCounted<WriteBuffer>()), write_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
weak_factory_(this) {} write_buffer_(base::MakeRefCounted<BufferWrapper>()),
read_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
read_buffer_(base::MakeRefCounted<BufferWrapper>()),
weak_factory_(this) {
write_storage_->SetCapacity(kDefaultBufferSize);
read_storage_->SetCapacity(kDefaultBufferSize);
}
SmallMessageSocket::~SmallMessageSocket() = default; SmallMessageSocket::~SmallMessageSocket() = default;
void SmallMessageSocket::UseBufferPool(
scoped_refptr<IOBufferPool> buffer_pool) {
DCHECK(buffer_pool);
if (buffer_pool_) {
// Replace existing buffer pool. No need to copy data out of existing buffer
// since it will remain valid until we are done using it.
buffer_pool_ = std::move(buffer_pool);
return;
}
buffer_pool_ = std::move(buffer_pool);
if (!in_message_) {
ActivateBufferPool(read_storage_->StartOfBuffer(), read_storage_->offset());
}
}
void SmallMessageSocket::ActivateBufferPool(char* current_data,
size_t current_size) {
// Copy any already-read data into a new buffer for pool-based operation.
DCHECK(buffer_pool_);
DCHECK(!in_message_);
scoped_refptr<::net::IOBuffer> new_buffer;
size_t new_buffer_size;
if (current_size <= buffer_pool_->buffer_size()) {
new_buffer = buffer_pool_->GetBuffer();
new_buffer_size = buffer_pool_->buffer_size();
} else {
new_buffer = base::MakeRefCounted<::net::IOBuffer>(current_size * 2);
new_buffer_size = current_size * 2;
}
memcpy(new_buffer->data(), current_data, current_size);
read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), new_buffer_size);
read_buffer_->DidConsume(current_size);
}
void SmallMessageSocket::RemoveBufferPool() {
if (!buffer_pool_) {
return;
}
if (static_cast<size_t>(read_storage_->capacity()) < read_buffer_->used()) {
read_storage_->SetCapacity(read_buffer_->used());
}
memcpy(read_storage_->StartOfBuffer(), read_buffer_->StartOfBuffer(),
read_buffer_->used());
read_storage_->set_offset(read_buffer_->used());
buffer_pool_.reset();
}
void* SmallMessageSocket::PrepareSend(int message_size) { void* SmallMessageSocket::PrepareSend(int message_size) {
DCHECK_LE(message_size, std::numeric_limits<uint16_t>::max()); DCHECK_LE(message_size, std::numeric_limits<uint16_t>::max());
if (write_buffer_->BytesRemaining()) { if (write_buffer_->remaining()) {
send_blocked_ = true; send_blocked_ = true;
return nullptr; return nullptr;
} }
if (!write_storage_) {
write_storage_ = base::MakeRefCounted<net::GrowableIOBuffer>();
}
write_storage_->set_offset(0); write_storage_->set_offset(0);
const int total_size = sizeof(uint16_t) + message_size; const int total_size = sizeof(uint16_t) + message_size;
if (write_storage_->capacity() < total_size) { if (write_storage_->capacity() < total_size) {
...@@ -90,7 +162,7 @@ void* SmallMessageSocket::PrepareSend(int message_size) { ...@@ -90,7 +162,7 @@ void* SmallMessageSocket::PrepareSend(int message_size) {
bool SmallMessageSocket::SendBuffer(scoped_refptr<net::IOBuffer> data, bool SmallMessageSocket::SendBuffer(scoped_refptr<net::IOBuffer> data,
int size) { int size) {
if (write_buffer_->BytesRemaining()) { if (write_buffer_->remaining()) {
send_blocked_ = true; send_blocked_ = true;
return false; return false;
} }
...@@ -103,7 +175,7 @@ bool SmallMessageSocket::SendBuffer(scoped_refptr<net::IOBuffer> data, ...@@ -103,7 +175,7 @@ bool SmallMessageSocket::SendBuffer(scoped_refptr<net::IOBuffer> data,
void SmallMessageSocket::Send() { void SmallMessageSocket::Send() {
for (int i = 0; i < kMaxIOLoop; ++i) { for (int i = 0; i < kMaxIOLoop; ++i) {
int result = int result =
socket_->Write(write_buffer_.get(), write_buffer_->BytesRemaining(), socket_->Write(write_buffer_.get(), write_buffer_->remaining(),
base::BindOnce(&SmallMessageSocket::OnWriteComplete, base::BindOnce(&SmallMessageSocket::OnWriteComplete,
base::Unretained(this)), base::Unretained(this)),
MISSING_TRAFFIC_ANNOTATION); MISSING_TRAFFIC_ANNOTATION);
...@@ -132,10 +204,11 @@ bool SmallMessageSocket::HandleWriteResult(int result) { ...@@ -132,10 +204,11 @@ bool SmallMessageSocket::HandleWriteResult(int result) {
} }
write_buffer_->DidConsume(result); write_buffer_->DidConsume(result);
if (write_buffer_->BytesRemaining()) { if (write_buffer_->remaining()) {
return true; return true;
} }
write_buffer_->ClearUnderlyingBuffer();
if (send_blocked_) { if (send_blocked_) {
send_blocked_ = false; send_blocked_ = false;
OnSendUnblocked(); OnSendUnblocked();
...@@ -152,19 +225,17 @@ void SmallMessageSocket::PostError(int error) { ...@@ -152,19 +225,17 @@ void SmallMessageSocket::PostError(int error) {
} }
void SmallMessageSocket::ReceiveMessages() { void SmallMessageSocket::ReceiveMessages() {
if (!read_buffer_) {
read_buffer_ = base::MakeRefCounted<net::GrowableIOBuffer>();
read_buffer_->SetCapacity(kDefaultBufferSize);
}
// Post a task rather than just calling Read(), to avoid calling delegate // Post a task rather than just calling Read(), to avoid calling delegate
// methods from within this method. // methods from within this method.
task_runner_->PostTask(FROM_HERE, task_runner()->PostTask(
base::BindOnce(&SmallMessageSocket::StartReading, FROM_HERE,
base::BindOnce(&SmallMessageSocket::ReceiveMessagesSynchronously,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
} }
void SmallMessageSocket::StartReading() { void SmallMessageSocket::ReceiveMessagesSynchronously() {
if (HandleCompletedMessages()) { if ((buffer_pool_ && HandleCompletedMessageBuffers()) ||
(!buffer_pool_ && HandleCompletedMessages())) {
Read(); Read();
} }
} }
...@@ -174,8 +245,17 @@ void SmallMessageSocket::Read() { ...@@ -174,8 +245,17 @@ void SmallMessageSocket::Read() {
// This improves average packet receive delay as compared to always posting a // This improves average packet receive delay as compared to always posting a
// new task for each call to Read(). // new task for each call to Read().
for (int i = 0; i < kMaxIOLoop; ++i) { for (int i = 0; i < kMaxIOLoop; ++i) {
net::IOBuffer* buffer;
int size;
if (buffer_pool_) {
buffer = read_buffer_.get();
size = read_buffer_->remaining();
} else {
buffer = read_storage_.get();
size = read_storage_->RemainingCapacity();
}
int read_result = int read_result =
socket_->Read(read_buffer_.get(), read_buffer_->RemainingCapacity(), socket()->Read(buffer, size,
base::BindOnce(&SmallMessageSocket::OnReadComplete, base::BindOnce(&SmallMessageSocket::OnReadComplete,
base::Unretained(this))); base::Unretained(this)));
...@@ -184,8 +264,9 @@ void SmallMessageSocket::Read() { ...@@ -184,8 +264,9 @@ void SmallMessageSocket::Read() {
} }
} }
task_runner_->PostTask(FROM_HERE, base::BindOnce(&SmallMessageSocket::Read, task_runner()->PostTask(
weak_factory_.GetWeakPtr())); FROM_HERE,
base::BindOnce(&SmallMessageSocket::Read, weak_factory_.GetWeakPtr()));
} }
void SmallMessageSocket::OnReadComplete(int result) { void SmallMessageSocket::OnReadComplete(int result) {
...@@ -209,51 +290,120 @@ bool SmallMessageSocket::HandleReadResult(int result) { ...@@ -209,51 +290,120 @@ bool SmallMessageSocket::HandleReadResult(int result) {
return false; return false;
} }
read_buffer_->set_offset(read_buffer_->offset() + result); if (buffer_pool_) {
read_buffer_->DidConsume(result);
return HandleCompletedMessageBuffers();
} else {
read_storage_->set_offset(read_storage_->offset() + result);
return HandleCompletedMessages(); return HandleCompletedMessages();
}
} }
bool SmallMessageSocket::HandleCompletedMessages() { bool SmallMessageSocket::HandleCompletedMessages() {
size_t total_size = read_buffer_->offset(); DCHECK(!buffer_pool_);
char* start_ptr = read_buffer_->StartOfBuffer();
bool keep_reading = true; bool keep_reading = true;
size_t bytes_read = read_storage_->offset();
while (total_size >= sizeof(uint16_t)) { char* start_ptr = read_storage_->StartOfBuffer();
while (bytes_read >= sizeof(uint16_t) && keep_reading) {
uint16_t message_size; uint16_t message_size;
base::ReadBigEndian(start_ptr, &message_size); base::ReadBigEndian(start_ptr, &message_size);
if (static_cast<size_t>(read_buffer_->capacity()) < size_t required_size = sizeof(uint16_t) + message_size;
sizeof(uint16_t) + message_size) { if (static_cast<size_t>(read_storage_->capacity()) < required_size) {
int position = start_ptr - read_buffer_->StartOfBuffer(); if (start_ptr != read_storage_->StartOfBuffer()) {
read_buffer_->SetCapacity(sizeof(uint16_t) + message_size); memmove(read_storage_->StartOfBuffer(), start_ptr, bytes_read);
start_ptr = read_buffer_->StartOfBuffer() + position; read_storage_->set_offset(bytes_read);
}
read_storage_->SetCapacity(required_size);
return true;
} }
if (total_size < sizeof(uint16_t) + message_size) { if (bytes_read < required_size) {
break; // Haven't received the full message yet. break; // Haven't received the full message yet.
} }
// Take a weak pointer in case OnMessage() causes this to be deleted. // Take a weak pointer in case OnMessage() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr(); auto self = weak_factory_.GetWeakPtr();
in_message_ = true;
keep_reading = OnMessage(start_ptr + sizeof(uint16_t), message_size); keep_reading = OnMessage(start_ptr + sizeof(uint16_t), message_size);
if (!self) { if (!self) {
return false; return false;
} }
in_message_ = false;
total_size -= sizeof(uint16_t) + message_size; start_ptr += required_size;
start_ptr += sizeof(uint16_t) + message_size; bytes_read -= required_size;
if (!keep_reading) { if (buffer_pool_) {
break; // A buffer pool was added within OnMessage().
ActivateBufferPool(start_ptr, bytes_read);
return (keep_reading ? HandleCompletedMessageBuffers() : false);
} }
} }
if (start_ptr != read_buffer_->StartOfBuffer()) { if (start_ptr != read_storage_->StartOfBuffer()) {
memmove(read_buffer_->StartOfBuffer(), start_ptr, total_size); memmove(read_storage_->StartOfBuffer(), start_ptr, bytes_read);
read_buffer_->set_offset(total_size); read_storage_->set_offset(bytes_read);
} }
return keep_reading; return keep_reading;
} }
bool SmallMessageSocket::HandleCompletedMessageBuffers() {
DCHECK(buffer_pool_);
size_t bytes_read;
while ((bytes_read = read_buffer_->used()) >= sizeof(uint16_t)) {
uint16_t message_size;
base::ReadBigEndian(read_buffer_->StartOfBuffer(), &message_size);
size_t required_size = sizeof(uint16_t) + message_size;
if (read_buffer_->size() < required_size) {
// Current buffer is not big enough.
auto new_buffer = base::MakeRefCounted<::net::IOBuffer>(required_size);
memcpy(new_buffer->data(), read_buffer_->StartOfBuffer(), bytes_read);
read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), required_size);
read_buffer_->DidConsume(bytes_read);
return true;
}
if (bytes_read < required_size) {
break; // Haven't received the full message yet.
}
auto old_buffer = read_buffer_->TakeUnderlyingBuffer();
auto new_buffer = buffer_pool_->GetBuffer();
size_t new_buffer_size = buffer_pool_->buffer_size();
size_t extra_size = bytes_read - required_size;
if (extra_size > 0) {
// Copy extra data to new buffer.
if (extra_size > buffer_pool_->buffer_size()) {
new_buffer = base::MakeRefCounted<::net::IOBuffer>(extra_size);
new_buffer_size = extra_size;
}
memcpy(new_buffer->data(), old_buffer->data() + required_size,
extra_size);
}
read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), new_buffer_size);
read_buffer_->DidConsume(extra_size);
// Take a weak pointer in case OnMessageBuffer() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr();
bool keep_reading = OnMessageBuffer(std::move(old_buffer), required_size);
if (!self || !keep_reading) {
return false;
}
if (!buffer_pool_) {
// The buffer pool was removed within OnMessageBuffer().
return HandleCompletedMessages();
}
}
return true;
}
bool SmallMessageSocket::OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer,
int size) {
return OnMessage(buffer->data() + sizeof(uint16_t), size - sizeof(uint16_t));
}
} // namespace chromecast } // namespace chromecast
...@@ -22,16 +22,29 @@ class Socket; ...@@ -22,16 +22,29 @@ class Socket;
} // namespace net } // namespace net
namespace chromecast { namespace chromecast {
class IOBufferPool;
// Sends and receives small messages (< 64 KB) over a Socket. All methods must // Sends small messages (< 64 KB) over a Socket. All methods must be called on
// be called on the same sequence. Any of the virtual methods can destroy this // the same sequence. Any of the virtual methods can destroy this object if
// object if desired. // desired.
class SmallMessageSocket { class SmallMessageSocket {
public: public:
explicit SmallMessageSocket(std::unique_ptr<net::Socket> socket); explicit SmallMessageSocket(std::unique_ptr<net::Socket> socket);
virtual ~SmallMessageSocket(); virtual ~SmallMessageSocket();
const net::Socket* socket() const { return socket_.get(); } net::Socket* socket() const { return socket_.get(); }
base::SequencedTaskRunner* task_runner() const { return task_runner_.get(); }
IOBufferPool* buffer_pool() const { return buffer_pool_.get(); }
// Adds a |buffer_pool| used to allocate buffers to receive messages into;
// received messages are passed to OnMessageBuffer(). If a message would be
// too big to fit in a pool-provided buffer, a dynamically allocated IOBuffer
// will be used instead for that message.
void UseBufferPool(scoped_refptr<IOBufferPool> buffer_pool);
// Removes the buffer pool; subsequent received messages will be passed to
// OnMessage().
void RemoveBufferPool();
// Prepares a buffer to send a message of the given |message_size|. Returns // Prepares a buffer to send a message of the given |message_size|. Returns
// nullptr if sending is not allowed right now (ie, another send is currently // nullptr if sending is not allowed right now (ie, another send is currently
...@@ -54,10 +67,18 @@ class SmallMessageSocket { ...@@ -54,10 +67,18 @@ class SmallMessageSocket {
// Enables receiving messages from the stream. Messages will be received and // Enables receiving messages from the stream. Messages will be received and
// passed to OnMessage() until either an error occurs, the end of stream is // passed to OnMessage() until either an error occurs, the end of stream is
// reached, or OnMessage() returns false. If OnMessage() returns false, you // reached, or OnMessage() returns false. If OnMessage() returns false, you
// may call ReceiveMessages() to start receiving again. // may call ReceiveMessages() to start receiving again. OnMessage() will not
// be called synchronously from within this method (it always posts a task).
void ReceiveMessages(); void ReceiveMessages();
// Same as ReceiveMessages(), but OnMessage() may be called synchronously.
// This is more efficient because it doesn't post a task to ensure
// asynchronous reads.
void ReceiveMessagesSynchronously();
protected: protected:
class BufferWrapper;
// Called when sending becomes possible again, if a previous attempt to send // Called when sending becomes possible again, if a previous attempt to send
// was rejected. // was rejected.
virtual void OnSendUnblocked() {} virtual void OnSendUnblocked() {}
...@@ -69,31 +90,42 @@ class SmallMessageSocket { ...@@ -69,31 +90,42 @@ class SmallMessageSocket {
// Called when the end of stream has been read. No more data will be received. // Called when the end of stream has been read. No more data will be received.
virtual void OnEndOfStream() {} virtual void OnEndOfStream() {}
// Called when a message has been received. The |data| buffer contains |size| // Called when a message has been received and there is no buffer pool. The
// bytes of data. // |data| buffer contains |size| bytes of data. Return |true| to continue
// reading messages after OnMessage() returns.
virtual bool OnMessage(char* data, int size) = 0; virtual bool OnMessage(char* data, int size) = 0;
private: // Called when a message has been received. The |buffer| contains |size| bytes
class WriteBuffer; // of data, which includes the first 2 bytes which are the size in network
// byte order. Note that these 2 bytes are not included in OnMessage()!
// Return |true| to continue receiving messages.
virtual bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size);
private:
void OnWriteComplete(int result); void OnWriteComplete(int result);
bool HandleWriteResult(int result); bool HandleWriteResult(int result);
void PostError(int error); void PostError(int error);
void StartReading();
void Read(); void Read();
void OnReadComplete(int result); void OnReadComplete(int result);
bool HandleReadResult(int result); bool HandleReadResult(int result);
bool HandleCompletedMessages(); bool HandleCompletedMessages();
bool HandleCompletedMessageBuffers();
void ActivateBufferPool(char* current_data, size_t current_size);
std::unique_ptr<net::Socket> socket_; const std::unique_ptr<net::Socket> socket_;
const scoped_refptr<base::SequencedTaskRunner> task_runner_; const scoped_refptr<base::SequencedTaskRunner> task_runner_;
scoped_refptr<net::GrowableIOBuffer> write_storage_; const scoped_refptr<net::GrowableIOBuffer> write_storage_;
scoped_refptr<WriteBuffer> write_buffer_; const scoped_refptr<BufferWrapper> write_buffer_;
bool send_blocked_ = false; bool send_blocked_ = false;
scoped_refptr<net::GrowableIOBuffer> read_buffer_; const scoped_refptr<net::GrowableIOBuffer> read_storage_;
scoped_refptr<IOBufferPool> buffer_pool_;
const scoped_refptr<BufferWrapper> read_buffer_;
bool in_message_ = false;
base::WeakPtrFactory<SmallMessageSocket> weak_factory_; base::WeakPtrFactory<SmallMessageSocket> weak_factory_;
......
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromecast/net/small_message_socket.h"
#include <memory>
#include <vector>
#include "base/big_endian.h"
#include "base/test/task_environment.h"
#include "chromecast/net/fake_stream_socket.h"
#include "chromecast/net/io_buffer_pool.h"
#include "net/base/io_buffer.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace chromecast {
namespace {
const size_t kDefaultMessageSize = 256;
const char kIpAddress1[] = "192.168.0.1";
const uint16_t kPort1 = 10001;
const char kIpAddress2[] = "192.168.0.2";
const uint16_t kPort2 = 10002;
net::IPAddress IpLiteralToIpAddress(const std::string& ip_literal) {
net::IPAddress ip_address;
CHECK(ip_address.AssignFromIPLiteral(ip_literal));
return ip_address;
}
void SetData(char* buffer, int size) {
for (int i = 0; i < size; ++i) {
buffer[i] = static_cast<char>(i);
}
}
void CheckData(char* buffer, int size) {
for (int i = 0; i < size; ++i) {
EXPECT_EQ(buffer[i], static_cast<char>(i));
}
}
class TestSocket : public SmallMessageSocket {
public:
explicit TestSocket(std::unique_ptr<net::Socket> socket)
: SmallMessageSocket(std::move(socket)) {}
~TestSocket() override = default;
void UseBufferPool() {
SmallMessageSocket::UseBufferPool(base::MakeRefCounted<IOBufferPool>(
kDefaultMessageSize + sizeof(uint16_t)));
}
void SwapPoolUse(bool swap) { swap_pool_use_ = swap; }
size_t last_message_size() const {
DCHECK(!message_history_.empty());
return message_history_[message_history_.size() - 1];
}
const std::vector<size_t>& message_history() const {
return message_history_;
}
private:
void OnError(int error) override { NOTREACHED(); }
bool OnMessage(char* data, int size) override {
message_history_.push_back(size);
CheckData(data, size);
if (swap_pool_use_) {
UseBufferPool();
}
return true;
}
bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size) override {
uint16_t message_size;
base::ReadBigEndian(buffer->data(), &message_size);
DCHECK_EQ(message_size, size - sizeof(uint16_t));
message_history_.push_back(message_size);
CheckData(buffer->data() + sizeof(uint16_t), message_size);
if (swap_pool_use_) {
RemoveBufferPool();
}
return true;
}
std::vector<size_t> message_history_;
bool swap_pool_use_ = false;
};
} // namespace
class SmallMessageSocketTest : public ::testing::Test {
public:
SmallMessageSocketTest() {
auto fake1 = std::make_unique<FakeStreamSocket>(
net::IPEndPoint(IpLiteralToIpAddress(kIpAddress1), kPort1));
auto fake2 = std::make_unique<FakeStreamSocket>(
net::IPEndPoint(IpLiteralToIpAddress(kIpAddress2), kPort2));
fake1->SetPeer(fake2.get());
fake2->SetPeer(fake1.get());
fake1->SetBadSenderMode(true);
fake2->SetBadSenderMode(true);
socket_1_ = std::make_unique<TestSocket>(std::move(fake1));
socket_2_ = std::make_unique<TestSocket>(std::move(fake2));
}
~SmallMessageSocketTest() override = default;
protected:
base::test::TaskEnvironment task_environment_;
std::unique_ptr<TestSocket> socket_1_;
std::unique_ptr<TestSocket> socket_2_;
};
TEST_F(SmallMessageSocketTest, SendAndReceive) {
auto buffer = base::MakeRefCounted<net::IOBuffer>(kDefaultMessageSize +
sizeof(uint16_t));
base::WriteBigEndian(buffer->data(),
static_cast<uint16_t>(kDefaultMessageSize));
SetData(buffer->data() + sizeof(uint16_t), kDefaultMessageSize);
socket_2_->ReceiveMessages();
socket_1_->SendBuffer(std::move(buffer),
kDefaultMessageSize + sizeof(uint16_t));
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize);
}
TEST_F(SmallMessageSocketTest, PrepareSendAndReceive) {
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_2_->ReceiveMessages();
socket_1_->Send();
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize);
}
TEST_F(SmallMessageSocketTest, MultipleMessages) {
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2 + 1));
SetData(buffer, kDefaultMessageSize * 2 + 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 5));
SetData(buffer, kDefaultMessageSize - 5);
socket_1_->Send();
task_environment_.RunUntilIdle();
socket_2_->ReceiveMessages();
task_environment_.RunUntilIdle();
ASSERT_EQ(socket_2_->message_history().size(), 3u);
EXPECT_EQ(socket_2_->message_history()[0], kDefaultMessageSize);
EXPECT_EQ(socket_2_->message_history()[1], kDefaultMessageSize * 2 + 1);
EXPECT_EQ(socket_2_->message_history()[2], kDefaultMessageSize - 5);
}
TEST_F(SmallMessageSocketTest, BufferSendAndReceive) {
socket_1_->UseBufferPool();
socket_2_->UseBufferPool();
auto buffer = base::MakeRefCounted<net::IOBuffer>(kDefaultMessageSize +
sizeof(uint16_t));
base::WriteBigEndian(buffer->data(),
static_cast<uint16_t>(kDefaultMessageSize));
SetData(buffer->data() + sizeof(uint16_t), kDefaultMessageSize);
socket_2_->ReceiveMessages();
socket_1_->SendBuffer(std::move(buffer),
kDefaultMessageSize + sizeof(uint16_t));
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize);
EXPECT_GT(socket_2_->buffer_pool()->NumAllocatedForTesting(), 0u);
EXPECT_GT(socket_2_->buffer_pool()->NumFreeForTesting(), 0u);
}
TEST_F(SmallMessageSocketTest, SendLargerThanPoolBufferSize) {
socket_1_->UseBufferPool();
socket_2_->UseBufferPool();
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2));
SetData(buffer, kDefaultMessageSize * 2);
socket_2_->ReceiveMessages();
socket_1_->Send();
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize * 2);
}
TEST_F(SmallMessageSocketTest, BufferMultipleMessages) {
socket_1_->UseBufferPool();
socket_2_->UseBufferPool();
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 1));
SetData(buffer, kDefaultMessageSize - 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2 + 1));
SetData(buffer, kDefaultMessageSize * 2 + 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 5));
SetData(buffer, kDefaultMessageSize - 5);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_1_->Send();
task_environment_.RunUntilIdle();
socket_2_->ReceiveMessages();
task_environment_.RunUntilIdle();
ASSERT_EQ(socket_2_->message_history().size(), 4u);
EXPECT_EQ(socket_2_->message_history()[0], kDefaultMessageSize - 1);
EXPECT_EQ(socket_2_->message_history()[1], kDefaultMessageSize * 2 + 1);
EXPECT_EQ(socket_2_->message_history()[2], kDefaultMessageSize - 5);
EXPECT_EQ(socket_2_->message_history()[3], kDefaultMessageSize);
}
TEST_F(SmallMessageSocketTest, SwapPoolUse) {
socket_2_->SwapPoolUse(true);
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2 + 1));
SetData(buffer, kDefaultMessageSize * 2 + 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 5));
SetData(buffer, kDefaultMessageSize - 5);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_1_->Send();
task_environment_.RunUntilIdle();
socket_2_->ReceiveMessages();
task_environment_.RunUntilIdle();
ASSERT_EQ(socket_2_->message_history().size(), 3u);
EXPECT_EQ(socket_2_->message_history()[0], kDefaultMessageSize * 2 + 1);
EXPECT_EQ(socket_2_->message_history()[1], kDefaultMessageSize - 5);
EXPECT_EQ(socket_2_->message_history()[2], kDefaultMessageSize);
}
} // namespace chromecast
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment