Commit 56a76561 authored by Ken MacKay's avatar Ken MacKay Committed by Commit Bot

[Chromecast] Allow SmallMessageSocket to receive into pool-allocated buffers

By receiving into a threadsafe IOBufferPool, we can avoid extra memory
copies for things like audio streaming.

Bug: internal b/127963522
Test: cast_net_unittests
Change-Id: I86da6e7a5e07a7a3273d53485393468b4481ec94
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1822220
Commit-Queue: Kenneth MacKay <kmackay@chromium.org>
Reviewed-by: default avatarYuchen Liu <yucliu@chromium.org>
Cr-Commit-Position: refs/heads/master@{#700113}
parent b04ec388
...@@ -71,6 +71,7 @@ cast_source_set("small_message_socket") { ...@@ -71,6 +71,7 @@ cast_source_set("small_message_socket") {
] ]
deps = [ deps = [
":io_buffer_pool",
"//base", "//base",
] ]
} }
...@@ -99,10 +100,12 @@ test("cast_net_unittests") { ...@@ -99,10 +100,12 @@ test("cast_net_unittests") {
sources = [ sources = [
"fake_stream_socket_unittest.cc", "fake_stream_socket_unittest.cc",
"io_buffer_pool_unittest.cc", "io_buffer_pool_unittest.cc",
"small_message_socket_unittest.cc",
] ]
deps = [ deps = [
":io_buffer_pool", ":io_buffer_pool",
":small_message_socket",
":test_support", ":test_support",
"//base", "//base",
"//base/test:run_all_unittests", "//base/test:run_all_unittests",
......
...@@ -89,6 +89,10 @@ void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) { ...@@ -89,6 +89,10 @@ void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) {
peer_ = peer; peer_ = peer;
} }
void FakeStreamSocket::SetBadSenderMode(bool bad_sender) {
bad_sender_mode_ = bad_sender;
}
int FakeStreamSocket::Read(net::IOBuffer* buf, int FakeStreamSocket::Read(net::IOBuffer* buf,
int buf_len, int buf_len,
net::CompletionOnceCallback callback) { net::CompletionOnceCallback callback) {
...@@ -105,8 +109,12 @@ int FakeStreamSocket::Write( ...@@ -105,8 +109,12 @@ int FakeStreamSocket::Write(
if (!peer_) { if (!peer_) {
return net::ERR_SOCKET_NOT_CONNECTED; return net::ERR_SOCKET_NOT_CONNECTED;
} }
peer_->buffer_->Write(buf->data(), buf_len); int amount_to_send = buf_len;
return buf_len; if (bad_sender_mode_) {
amount_to_send = std::min(buf_len, buf_len / 2 + 1);
}
peer_->buffer_->Write(buf->data(), amount_to_send);
return amount_to_send;
} }
int FakeStreamSocket::SetReceiveBufferSize(int32_t /* size */) { int FakeStreamSocket::SetReceiveBufferSize(int32_t /* size */) {
......
...@@ -27,6 +27,10 @@ class FakeStreamSocket : public net::StreamSocket { ...@@ -27,6 +27,10 @@ class FakeStreamSocket : public net::StreamSocket {
// Sets the peer for this socket. // Sets the peer for this socket.
void SetPeer(FakeStreamSocket* peer); void SetPeer(FakeStreamSocket* peer);
// Enables/disables "bad sender mode", where Write() will always try to send
// less than the full buffer. Disabled by default.
void SetBadSenderMode(bool bad_sender);
// net::StreamSocket implementation: // net::StreamSocket implementation:
int Read(net::IOBuffer* buf, int Read(net::IOBuffer* buf,
int buf_len, int buf_len,
...@@ -60,6 +64,7 @@ class FakeStreamSocket : public net::StreamSocket { ...@@ -60,6 +64,7 @@ class FakeStreamSocket : public net::StreamSocket {
const std::unique_ptr<SocketBuffer> buffer_; const std::unique_ptr<SocketBuffer> buffer_;
FakeStreamSocket* peer_; FakeStreamSocket* peer_;
net::NetLogWithSource net_log_; net::NetLogWithSource net_log_;
bool bad_sender_mode_ = false;
DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket); DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket);
}; };
......
This diff is collapsed.
...@@ -22,16 +22,29 @@ class Socket; ...@@ -22,16 +22,29 @@ class Socket;
} // namespace net } // namespace net
namespace chromecast { namespace chromecast {
class IOBufferPool;
// Sends and receives small messages (< 64 KB) over a Socket. All methods must // Sends small messages (< 64 KB) over a Socket. All methods must be called on
// be called on the same sequence. Any of the virtual methods can destroy this // the same sequence. Any of the virtual methods can destroy this object if
// object if desired. // desired.
class SmallMessageSocket { class SmallMessageSocket {
public: public:
explicit SmallMessageSocket(std::unique_ptr<net::Socket> socket); explicit SmallMessageSocket(std::unique_ptr<net::Socket> socket);
virtual ~SmallMessageSocket(); virtual ~SmallMessageSocket();
const net::Socket* socket() const { return socket_.get(); } net::Socket* socket() const { return socket_.get(); }
base::SequencedTaskRunner* task_runner() const { return task_runner_.get(); }
IOBufferPool* buffer_pool() const { return buffer_pool_.get(); }
// Adds a |buffer_pool| used to allocate buffers to receive messages into;
// received messages are passed to OnMessageBuffer(). If a message would be
// too big to fit in a pool-provided buffer, a dynamically allocated IOBuffer
// will be used instead for that message.
void UseBufferPool(scoped_refptr<IOBufferPool> buffer_pool);
// Removes the buffer pool; subsequent received messages will be passed to
// OnMessage().
void RemoveBufferPool();
// Prepares a buffer to send a message of the given |message_size|. Returns // Prepares a buffer to send a message of the given |message_size|. Returns
// nullptr if sending is not allowed right now (ie, another send is currently // nullptr if sending is not allowed right now (ie, another send is currently
...@@ -54,10 +67,18 @@ class SmallMessageSocket { ...@@ -54,10 +67,18 @@ class SmallMessageSocket {
// Enables receiving messages from the stream. Messages will be received and // Enables receiving messages from the stream. Messages will be received and
// passed to OnMessage() until either an error occurs, the end of stream is // passed to OnMessage() until either an error occurs, the end of stream is
// reached, or OnMessage() returns false. If OnMessage() returns false, you // reached, or OnMessage() returns false. If OnMessage() returns false, you
// may call ReceiveMessages() to start receiving again. // may call ReceiveMessages() to start receiving again. OnMessage() will not
// be called synchronously from within this method (it always posts a task).
void ReceiveMessages(); void ReceiveMessages();
// Same as ReceiveMessages(), but OnMessage() may be called synchronously.
// This is more efficient because it doesn't post a task to ensure
// asynchronous reads.
void ReceiveMessagesSynchronously();
protected: protected:
class BufferWrapper;
// Called when sending becomes possible again, if a previous attempt to send // Called when sending becomes possible again, if a previous attempt to send
// was rejected. // was rejected.
virtual void OnSendUnblocked() {} virtual void OnSendUnblocked() {}
...@@ -69,31 +90,42 @@ class SmallMessageSocket { ...@@ -69,31 +90,42 @@ class SmallMessageSocket {
// Called when the end of stream has been read. No more data will be received. // Called when the end of stream has been read. No more data will be received.
virtual void OnEndOfStream() {} virtual void OnEndOfStream() {}
// Called when a message has been received. The |data| buffer contains |size| // Called when a message has been received and there is no buffer pool. The
// bytes of data. // |data| buffer contains |size| bytes of data. Return |true| to continue
// reading messages after OnMessage() returns.
virtual bool OnMessage(char* data, int size) = 0; virtual bool OnMessage(char* data, int size) = 0;
private: // Called when a message has been received. The |buffer| contains |size| bytes
class WriteBuffer; // of data, which includes the first 2 bytes which are the size in network
// byte order. Note that these 2 bytes are not included in OnMessage()!
// Return |true| to continue receiving messages.
virtual bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size);
private:
void OnWriteComplete(int result); void OnWriteComplete(int result);
bool HandleWriteResult(int result); bool HandleWriteResult(int result);
void PostError(int error); void PostError(int error);
void StartReading();
void Read(); void Read();
void OnReadComplete(int result); void OnReadComplete(int result);
bool HandleReadResult(int result); bool HandleReadResult(int result);
bool HandleCompletedMessages(); bool HandleCompletedMessages();
bool HandleCompletedMessageBuffers();
void ActivateBufferPool(char* current_data, size_t current_size);
std::unique_ptr<net::Socket> socket_; const std::unique_ptr<net::Socket> socket_;
const scoped_refptr<base::SequencedTaskRunner> task_runner_; const scoped_refptr<base::SequencedTaskRunner> task_runner_;
scoped_refptr<net::GrowableIOBuffer> write_storage_; const scoped_refptr<net::GrowableIOBuffer> write_storage_;
scoped_refptr<WriteBuffer> write_buffer_; const scoped_refptr<BufferWrapper> write_buffer_;
bool send_blocked_ = false; bool send_blocked_ = false;
scoped_refptr<net::GrowableIOBuffer> read_buffer_; const scoped_refptr<net::GrowableIOBuffer> read_storage_;
scoped_refptr<IOBufferPool> buffer_pool_;
const scoped_refptr<BufferWrapper> read_buffer_;
bool in_message_ = false;
base::WeakPtrFactory<SmallMessageSocket> weak_factory_; base::WeakPtrFactory<SmallMessageSocket> weak_factory_;
......
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromecast/net/small_message_socket.h"
#include <memory>
#include <vector>
#include "base/big_endian.h"
#include "base/test/task_environment.h"
#include "chromecast/net/fake_stream_socket.h"
#include "chromecast/net/io_buffer_pool.h"
#include "net/base/io_buffer.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace chromecast {
namespace {
const size_t kDefaultMessageSize = 256;
const char kIpAddress1[] = "192.168.0.1";
const uint16_t kPort1 = 10001;
const char kIpAddress2[] = "192.168.0.2";
const uint16_t kPort2 = 10002;
net::IPAddress IpLiteralToIpAddress(const std::string& ip_literal) {
net::IPAddress ip_address;
CHECK(ip_address.AssignFromIPLiteral(ip_literal));
return ip_address;
}
void SetData(char* buffer, int size) {
for (int i = 0; i < size; ++i) {
buffer[i] = static_cast<char>(i);
}
}
void CheckData(char* buffer, int size) {
for (int i = 0; i < size; ++i) {
EXPECT_EQ(buffer[i], static_cast<char>(i));
}
}
class TestSocket : public SmallMessageSocket {
public:
explicit TestSocket(std::unique_ptr<net::Socket> socket)
: SmallMessageSocket(std::move(socket)) {}
~TestSocket() override = default;
void UseBufferPool() {
SmallMessageSocket::UseBufferPool(base::MakeRefCounted<IOBufferPool>(
kDefaultMessageSize + sizeof(uint16_t)));
}
void SwapPoolUse(bool swap) { swap_pool_use_ = swap; }
size_t last_message_size() const {
DCHECK(!message_history_.empty());
return message_history_[message_history_.size() - 1];
}
const std::vector<size_t>& message_history() const {
return message_history_;
}
private:
void OnError(int error) override { NOTREACHED(); }
bool OnMessage(char* data, int size) override {
message_history_.push_back(size);
CheckData(data, size);
if (swap_pool_use_) {
UseBufferPool();
}
return true;
}
bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size) override {
uint16_t message_size;
base::ReadBigEndian(buffer->data(), &message_size);
DCHECK_EQ(message_size, size - sizeof(uint16_t));
message_history_.push_back(message_size);
CheckData(buffer->data() + sizeof(uint16_t), message_size);
if (swap_pool_use_) {
RemoveBufferPool();
}
return true;
}
std::vector<size_t> message_history_;
bool swap_pool_use_ = false;
};
} // namespace
class SmallMessageSocketTest : public ::testing::Test {
public:
SmallMessageSocketTest() {
auto fake1 = std::make_unique<FakeStreamSocket>(
net::IPEndPoint(IpLiteralToIpAddress(kIpAddress1), kPort1));
auto fake2 = std::make_unique<FakeStreamSocket>(
net::IPEndPoint(IpLiteralToIpAddress(kIpAddress2), kPort2));
fake1->SetPeer(fake2.get());
fake2->SetPeer(fake1.get());
fake1->SetBadSenderMode(true);
fake2->SetBadSenderMode(true);
socket_1_ = std::make_unique<TestSocket>(std::move(fake1));
socket_2_ = std::make_unique<TestSocket>(std::move(fake2));
}
~SmallMessageSocketTest() override = default;
protected:
base::test::TaskEnvironment task_environment_;
std::unique_ptr<TestSocket> socket_1_;
std::unique_ptr<TestSocket> socket_2_;
};
TEST_F(SmallMessageSocketTest, SendAndReceive) {
auto buffer = base::MakeRefCounted<net::IOBuffer>(kDefaultMessageSize +
sizeof(uint16_t));
base::WriteBigEndian(buffer->data(),
static_cast<uint16_t>(kDefaultMessageSize));
SetData(buffer->data() + sizeof(uint16_t), kDefaultMessageSize);
socket_2_->ReceiveMessages();
socket_1_->SendBuffer(std::move(buffer),
kDefaultMessageSize + sizeof(uint16_t));
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize);
}
TEST_F(SmallMessageSocketTest, PrepareSendAndReceive) {
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_2_->ReceiveMessages();
socket_1_->Send();
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize);
}
TEST_F(SmallMessageSocketTest, MultipleMessages) {
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2 + 1));
SetData(buffer, kDefaultMessageSize * 2 + 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 5));
SetData(buffer, kDefaultMessageSize - 5);
socket_1_->Send();
task_environment_.RunUntilIdle();
socket_2_->ReceiveMessages();
task_environment_.RunUntilIdle();
ASSERT_EQ(socket_2_->message_history().size(), 3u);
EXPECT_EQ(socket_2_->message_history()[0], kDefaultMessageSize);
EXPECT_EQ(socket_2_->message_history()[1], kDefaultMessageSize * 2 + 1);
EXPECT_EQ(socket_2_->message_history()[2], kDefaultMessageSize - 5);
}
TEST_F(SmallMessageSocketTest, BufferSendAndReceive) {
socket_1_->UseBufferPool();
socket_2_->UseBufferPool();
auto buffer = base::MakeRefCounted<net::IOBuffer>(kDefaultMessageSize +
sizeof(uint16_t));
base::WriteBigEndian(buffer->data(),
static_cast<uint16_t>(kDefaultMessageSize));
SetData(buffer->data() + sizeof(uint16_t), kDefaultMessageSize);
socket_2_->ReceiveMessages();
socket_1_->SendBuffer(std::move(buffer),
kDefaultMessageSize + sizeof(uint16_t));
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize);
EXPECT_GT(socket_2_->buffer_pool()->NumAllocatedForTesting(), 0u);
EXPECT_GT(socket_2_->buffer_pool()->NumFreeForTesting(), 0u);
}
TEST_F(SmallMessageSocketTest, SendLargerThanPoolBufferSize) {
socket_1_->UseBufferPool();
socket_2_->UseBufferPool();
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2));
SetData(buffer, kDefaultMessageSize * 2);
socket_2_->ReceiveMessages();
socket_1_->Send();
task_environment_.RunUntilIdle();
EXPECT_EQ(socket_2_->last_message_size(), kDefaultMessageSize * 2);
}
TEST_F(SmallMessageSocketTest, BufferMultipleMessages) {
socket_1_->UseBufferPool();
socket_2_->UseBufferPool();
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 1));
SetData(buffer, kDefaultMessageSize - 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2 + 1));
SetData(buffer, kDefaultMessageSize * 2 + 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 5));
SetData(buffer, kDefaultMessageSize - 5);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_1_->Send();
task_environment_.RunUntilIdle();
socket_2_->ReceiveMessages();
task_environment_.RunUntilIdle();
ASSERT_EQ(socket_2_->message_history().size(), 4u);
EXPECT_EQ(socket_2_->message_history()[0], kDefaultMessageSize - 1);
EXPECT_EQ(socket_2_->message_history()[1], kDefaultMessageSize * 2 + 1);
EXPECT_EQ(socket_2_->message_history()[2], kDefaultMessageSize - 5);
EXPECT_EQ(socket_2_->message_history()[3], kDefaultMessageSize);
}
TEST_F(SmallMessageSocketTest, SwapPoolUse) {
socket_2_->SwapPoolUse(true);
char* buffer =
static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize * 2 + 1));
SetData(buffer, kDefaultMessageSize * 2 + 1);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize - 5));
SetData(buffer, kDefaultMessageSize - 5);
socket_1_->Send();
task_environment_.RunUntilIdle();
buffer = static_cast<char*>(socket_1_->PrepareSend(kDefaultMessageSize));
SetData(buffer, kDefaultMessageSize);
socket_1_->Send();
task_environment_.RunUntilIdle();
socket_2_->ReceiveMessages();
task_environment_.RunUntilIdle();
ASSERT_EQ(socket_2_->message_history().size(), 3u);
EXPECT_EQ(socket_2_->message_history()[0], kDefaultMessageSize * 2 + 1);
EXPECT_EQ(socket_2_->message_history()[1], kDefaultMessageSize - 5);
EXPECT_EQ(socket_2_->message_history()[2], kDefaultMessageSize);
}
} // namespace chromecast
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment