Commit f247c6a6 authored by Ken MacKay's avatar Ken MacKay Committed by Commit Bot

[Chromecast] Use PostTask for in-process mixer connections

Thereby making it more efficient.

Merge-With: eureka-internal/326819
Bug: internal b/127963522
Change-Id: Idc7db2c86d54c3b14293f407744c51ac6e9d40e1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1869128Reviewed-by: default avatarYuchen Liu <yucliu@chromium.org>
Commit-Queue: Kenneth MacKay <kmackay@chromium.org>
Cr-Commit-Position: refs/heads/master@{#708065}
parent 1d829a40
......@@ -44,7 +44,7 @@
namespace chromecast {
namespace media {
class CaptureServiceReceiver::Socket : public SmallMessageSocket {
class CaptureServiceReceiver::Socket : public SmallMessageSocket::Delegate {
public:
Socket(std::unique_ptr<net::StreamSocket> socket, int channels);
~Socket() override;
......@@ -52,7 +52,7 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket {
void Start(::media::AudioInputStream::AudioInputCallback* input_callback);
private:
// SmallMessageSocket implementation:
// SmallMessageSocket::Delegate implementation:
void OnError(int error) override;
void OnEndOfStream() override;
bool OnMessage(char* data, int size) override;
......@@ -61,6 +61,8 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket {
bool HandleAudio(std::unique_ptr<::media::AudioBus> audio, int64_t timestamp);
void ReportErrorAndStop();
SmallMessageSocket socket_;
// Number of audio capture channels that audio manager defines.
const int channels_;
......@@ -72,7 +74,7 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket {
CaptureServiceReceiver::Socket::Socket(
std::unique_ptr<net::StreamSocket> socket,
int channels)
: SmallMessageSocket(std::move(socket)),
: socket_(this, std::move(socket)),
channels_(channels),
input_callback_(nullptr) {
DCHECK_GT(channels_, 0);
......@@ -84,7 +86,7 @@ CaptureServiceReceiver::Socket::~Socket() = default;
void CaptureServiceReceiver::Socket::Start(
::media::AudioInputStream::AudioInputCallback* input_callback) {
input_callback_ = input_callback;
ReceiveMessages();
socket_.ReceiveMessages();
}
void CaptureServiceReceiver::Socket::ReportErrorAndStop() {
......
......@@ -24,6 +24,7 @@ cast_source_set("common") {
deps = [
"//base",
"//chromecast/net:io_buffer_pool",
"//chromecast/public",
"//chromecast/public/media",
"//net",
......
......@@ -102,9 +102,9 @@ void ControlConnection::SetStreamCountCallback(StreamCountCallback callback) {
}
}
void ControlConnection::OnConnected(std::unique_ptr<net::StreamSocket> socket) {
socket_ = std::make_unique<MixerSocket>(std::move(socket), this);
socket_->ReceiveMessages();
void ControlConnection::OnConnected(std::unique_ptr<MixerSocket> socket) {
socket_ = std::move(socket);
socket_->SetDelegate(this);
for (const auto& item : volume_limit_) {
Generic message;
......
......@@ -15,10 +15,6 @@
#include "chromecast/media/audio/mixer_service/mixer_socket.h"
#include "chromecast/public/volume_control.h"
namespace net {
class StreamSocket;
} // namespace net
namespace chromecast {
namespace media {
namespace mixer_service {
......@@ -71,7 +67,7 @@ class ControlConnection : public MixerConnection, public MixerSocket::Delegate {
private:
// MixerConnection implementation:
void OnConnected(std::unique_ptr<net::StreamSocket> socket) override;
void OnConnected(std::unique_ptr<MixerSocket> socket) override;
void OnConnectionError() override;
// MixerSocket::Delegate implementation:
......
......@@ -17,6 +17,7 @@
#include "chromecast/base/chromecast_switches.h"
#include "chromecast/media/audio/audio_buildflags.h"
#include "chromecast/media/audio/mixer_service/constants.h"
#include "chromecast/media/audio/mixer_service/mixer_socket.h"
#include "net/base/address_list.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
......@@ -40,12 +41,22 @@ constexpr base::TimeDelta kConnectTimeout = base::TimeDelta::FromSeconds(1);
} // namespace
std::unique_ptr<MixerSocket> CreateLocalMixerServiceConnection()
__attribute__((__weak__));
MixerConnection::MixerConnection() : weak_factory_(this) {}
MixerConnection::~MixerConnection() = default;
void MixerConnection::Connect() {
DCHECK(!connecting_socket_);
if (CreateLocalMixerServiceConnection) {
auto socket = CreateLocalMixerServiceConnection();
if (socket) {
OnConnected(std::move(socket));
return;
}
}
#if BUILDFLAG(USE_UNIX_SOCKETS)
const base::CommandLine* command_line =
......@@ -87,7 +98,8 @@ void MixerConnection::ConnectCallback(int result) {
LOG_IF(INFO, !log_timeout_) << "Now connected to mixer service";
log_connection_failure_ = true;
log_timeout_ = true;
OnConnected(std::move(connecting_socket_));
auto socket = std::make_unique<MixerSocket>(std::move(connecting_socket_));
OnConnected(std::move(socket));
return;
}
......
......@@ -19,6 +19,7 @@ class StreamSocket;
namespace chromecast {
namespace media {
namespace mixer_service {
class MixerSocket;
// Base class for connecting to the mixer service.
class MixerConnection {
......@@ -32,7 +33,7 @@ class MixerConnection {
protected:
// Called when a connection is established to the mixer service.
virtual void OnConnected(std::unique_ptr<net::StreamSocket> socket) = 0;
virtual void OnConnected(std::unique_ptr<MixerSocket> socket) = 0;
private:
void ConnectCallback(int result);
......
......@@ -9,9 +9,13 @@
#include <utility>
#include "base/big_endian.h"
#include "base/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/sequenced_task_runner.h"
#include "chromecast/media/audio/mixer_service/constants.h"
#include "chromecast/media/audio/mixer_service/mixer_service.pb.h"
#include "chromecast/net/io_buffer_pool.h"
#include "net/base/io_buffer.h"
#include "net/socket/stream_socket.h"
......@@ -74,17 +78,47 @@ bool MixerSocket::Delegate::HandleAudioBuffer(
constexpr size_t MixerSocket::kAudioHeaderSize;
constexpr size_t MixerSocket::kAudioMessageHeaderSize;
MixerSocket::MixerSocket(std::unique_ptr<net::StreamSocket> socket,
Delegate* delegate)
: SmallMessageSocket(std::move(socket)), delegate_(delegate) {
DCHECK(delegate_);
}
MixerSocket::MixerSocket(std::unique_ptr<net::StreamSocket> socket)
: socket_(std::make_unique<SmallMessageSocket>(this, std::move(socket))) {}
MixerSocket::~MixerSocket() = default;
MixerSocket::MixerSocket() = default;
MixerSocket::~MixerSocket() {
if (counterpart_task_runner_) {
counterpart_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&MixerSocket::OnEndOfStream, local_counterpart_));
}
}
void MixerSocket::SetDelegate(Delegate* delegate) {
DCHECK(delegate);
bool had_delegate = (delegate_ != nullptr);
delegate_ = delegate;
if (socket_ && !had_delegate) {
socket_->ReceiveMessages();
}
}
void MixerSocket::SetLocalCounterpart(
base::WeakPtr<MixerSocket> local_counterpart,
scoped_refptr<base::SequencedTaskRunner> counterpart_task_runner) {
local_counterpart_ = std::move(local_counterpart);
counterpart_task_runner_ = std::move(counterpart_task_runner);
}
base::WeakPtr<MixerSocket> MixerSocket::GetWeakPtr() {
return weak_factory_.GetWeakPtr();
}
void MixerSocket::UseBufferPool(scoped_refptr<IOBufferPool> buffer_pool) {
DCHECK(buffer_pool);
DCHECK(buffer_pool->threadsafe());
buffer_pool_ = std::move(buffer_pool);
if (socket_) {
socket_->UseBufferPool(buffer_pool_);
}
}
// static
......@@ -120,10 +154,7 @@ void MixerSocket::SendPreparedAudioBuffer(
uint16_t payload_size;
base::ReadBigEndian(audio_buffer->data(), &payload_size);
DCHECK_GE(payload_size, kAudioHeaderSize);
if (!SmallMessageSocket::SendBuffer(audio_buffer.get(),
sizeof(uint16_t) + payload_size)) {
write_queue_.push(std::move(audio_buffer));
}
SendBuffer(std::move(audio_buffer), sizeof(uint16_t) + payload_size);
}
void MixerSocket::SendProto(const google::protobuf::MessageLite& message) {
......@@ -133,16 +164,20 @@ void MixerSocket::SendProto(const google::protobuf::MessageLite& message) {
int total_size =
sizeof(type) + sizeof(padding_bytes) + message_size + padding_bytes;
scoped_refptr<net::IOBufferWithSize> storage;
void* buffer = PrepareSend(total_size);
char* ptr;
if (buffer) {
ptr = reinterpret_cast<char*>(buffer);
} else {
storage = base::MakeRefCounted<net::IOBufferWithSize>(sizeof(uint16_t) +
total_size);
ptr = storage->data();
scoped_refptr<net::IOBuffer> buffer;
char* ptr = (socket_ ? static_cast<char*>(socket_->PrepareSend(total_size))
: nullptr);
if (!ptr) {
if (buffer_pool_ &&
buffer_pool_->buffer_size() >= sizeof(uint16_t) + total_size) {
buffer = buffer_pool_->GetBuffer();
}
if (!buffer) {
buffer =
base::MakeRefCounted<net::IOBuffer>(sizeof(uint16_t) + total_size);
}
ptr = buffer->data();
base::WriteBigEndian(ptr, static_cast<uint16_t>(total_size));
ptr += sizeof(uint16_t);
}
......@@ -155,32 +190,53 @@ void MixerSocket::SendProto(const google::protobuf::MessageLite& message) {
ptr += message_size;
memset(ptr, 0, padding_bytes);
if (buffer) {
Send();
if (!buffer) {
socket_->Send();
return;
}
SendBuffer(std::move(buffer), sizeof(uint16_t) + total_size);
}
void MixerSocket::SendBuffer(scoped_refptr<net::IOBuffer> buffer,
int buffer_size) {
if (counterpart_task_runner_) {
counterpart_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(base::IgnoreResult(&MixerSocket::OnMessageBuffer),
local_counterpart_, std::move(buffer), buffer_size));
return;
}
if (storage) {
write_queue_.push(std::move(storage));
DCHECK(socket_);
if (!socket_->SendBuffer(buffer, buffer_size)) {
write_queue_.push(std::move(buffer));
}
}
void MixerSocket::OnSendUnblocked() {
DCHECK(socket_);
while (!write_queue_.empty()) {
uint16_t message_size;
base::ReadBigEndian(write_queue_.front()->data(), &message_size);
if (!SmallMessageSocket::SendBuffer(write_queue_.front().get(),
sizeof(uint16_t) + message_size)) {
if (!socket_->SendBuffer(write_queue_.front().get(),
sizeof(uint16_t) + message_size)) {
return;
}
write_queue_.pop();
}
}
void MixerSocket::ReceiveMoreMessages() {
socket_->ReceiveMessagesSynchronously();
}
void MixerSocket::OnError(int error) {
LOG(ERROR) << "Socket error from " << this << ": " << error;
DCHECK(delegate_);
delegate_->OnConnectionError();
}
void MixerSocket::OnEndOfStream() {
DCHECK(delegate_);
delegate_->OnConnectionError();
}
......
......@@ -11,8 +11,13 @@
#include "base/macros.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "chromecast/net/small_message_socket.h"
namespace base {
class SequencedTaskRunner;
} // namespace base
namespace google {
namespace protobuf {
class MessageLite;
......@@ -25,13 +30,15 @@ class StreamSocket;
} // namespace net
namespace chromecast {
class IOBufferPool;
namespace media {
namespace mixer_service {
class Generic;
// Base class for sending and receiving messages to/from the mixer service.
// Not thread-safe; all usage of a given instance must be on the same sequence.
class MixerSocket : public SmallMessageSocket {
class MixerSocket : public SmallMessageSocket::Delegate {
public:
class Delegate {
public:
......@@ -64,12 +71,18 @@ class MixerSocket : public SmallMessageSocket {
virtual ~Delegate() = default;
};
MixerSocket(std::unique_ptr<net::StreamSocket> socket, Delegate* delegate);
explicit MixerSocket(std::unique_ptr<net::StreamSocket> socket);
~MixerSocket() override;
// Changes the delegate.
// Sets/changes the delegate. Must be called immediately after creation
// (ie, synchronously on the same sequence).
void SetDelegate(Delegate* delegate);
// Adds a |buffer_pool| used to allocate buffers to receive messages into,
// and for sending protos. If the pool-allocated buffers are too small for a
// given message, a normal IOBuffer will be dynamically allocated instead.
void UseBufferPool(scoped_refptr<IOBufferPool> buffer_pool);
// 16-bit type and 64-bit timestamp.
static constexpr size_t kAudioHeaderSize = sizeof(int16_t) + sizeof(int64_t);
// Includes additional 16-bit size field for SmallMessageSocket.
......@@ -96,8 +109,24 @@ class MixerSocket : public SmallMessageSocket {
// Sends an arbitrary protobuf across the connection.
void SendProto(const google::protobuf::MessageLite& message);
// Resumes receiving messages. Delegate calls may be called synchronously
// from within this method.
void ReceiveMoreMessages();
private:
// SmallMessageSocket implementation:
friend class Receiver;
// Used by Receiver to create in-process mixer connections.
MixerSocket();
void SetLocalCounterpart(
base::WeakPtr<MixerSocket> local_counterpart,
scoped_refptr<base::SequencedTaskRunner> counterpart_task_runner);
base::WeakPtr<MixerSocket> GetWeakPtr();
void SendBuffer(scoped_refptr<net::IOBuffer> buffer, int buffer_size);
// SmallMessageSocket::Delegate implementation:
void OnSendUnblocked() override;
void OnError(int error) override;
void OnEndOfStream() override;
......@@ -110,10 +139,17 @@ class MixerSocket : public SmallMessageSocket {
char* data,
int size);
Delegate* delegate_;
Delegate* delegate_ = nullptr;
const std::unique_ptr<SmallMessageSocket> socket_;
scoped_refptr<IOBufferPool> buffer_pool_;
std::queue<scoped_refptr<net::IOBuffer>> write_queue_;
base::WeakPtr<MixerSocket> local_counterpart_;
scoped_refptr<base::SequencedTaskRunner> counterpart_task_runner_;
base::WeakPtrFactory<MixerSocket> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(MixerSocket);
};
......
......@@ -125,10 +125,9 @@ void OutputStreamConnection::Resume() {
}
}
void OutputStreamConnection::OnConnected(
std::unique_ptr<net::StreamSocket> socket) {
socket_ = std::make_unique<MixerSocket>(std::move(socket), this);
socket_->ReceiveMessages();
void OutputStreamConnection::OnConnected(std::unique_ptr<MixerSocket> socket) {
socket_ = std::move(socket);
socket_->SetDelegate(this);
Generic message;
*(message.mutable_output_stream_params()) = params_;
......
......@@ -15,10 +15,6 @@
#include "chromecast/media/audio/mixer_service/mixer_socket.h"
#include "net/base/io_buffer.h"
namespace net {
class StreamSocket;
} // namespace net
namespace chromecast {
namespace media {
namespace mixer_service {
......@@ -90,7 +86,7 @@ class OutputStreamConnection : public MixerConnection,
private:
// MixerConnection implementation:
void OnConnected(std::unique_ptr<net::StreamSocket> socket) override;
void OnConnected(std::unique_ptr<MixerSocket> socket) override;
void OnConnectionError() override;
// MixerSocket::Delegate implementation:
......
......@@ -7,8 +7,14 @@
#include <string>
#include <utility>
#include "base/bind.h"
#include "base/command_line.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/no_destructor.h"
#include "base/sequenced_task_runner.h"
#include "base/synchronization/lock.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "chromecast/base/chromecast_switches.h"
#include "chromecast/media/audio/mixer_service/constants.h"
#include "chromecast/media/audio/mixer_service/mixer_socket.h"
......@@ -32,15 +38,54 @@ std::string GetEndpoint() {
return path;
}
class LocalReceiverInstance {
public:
LocalReceiverInstance() = default;
void SetInstance(Receiver* receiver) {
base::AutoLock lock(lock_);
receiver_ = receiver;
}
void RemoveInstance(Receiver* receiver) {
base::AutoLock lock(lock_);
if (receiver_ == receiver) {
receiver_ = nullptr;
}
}
std::unique_ptr<MixerSocket> CreateLocalSocket() {
base::AutoLock lock(lock_);
if (receiver_) {
return receiver_->LocalConnect();
}
return nullptr;
}
private:
DISALLOW_COPY_AND_ASSIGN(LocalReceiverInstance);
base::Lock lock_;
Receiver* receiver_ = nullptr;
};
LocalReceiverInstance* GetLocalReceiver() {
static base::NoDestructor<LocalReceiverInstance> instance;
return instance.get();
}
} // namespace
std::unique_ptr<MixerSocket> CreateLocalMixerServiceConnection() {
return GetLocalReceiver()->CreateLocalSocket();
}
class Receiver::InitialSocket : public MixerSocket::Delegate {
public:
InitialSocket(Receiver* receiver, std::unique_ptr<net::StreamSocket> socket)
: receiver_(receiver),
socket_(std::make_unique<MixerSocket>(std::move(socket), this)) {
InitialSocket(Receiver* receiver, std::unique_ptr<MixerSocket> socket)
: receiver_(receiver), socket_(std::move(socket)) {
DCHECK(receiver_);
socket_->ReceiveMessages();
socket_->SetDelegate(this);
}
~InitialSocket() override = default;
......@@ -79,20 +124,49 @@ class Receiver::InitialSocket : public MixerSocket::Delegate {
};
Receiver::Receiver()
: socket_service_(
: task_runner_(base::SequencedTaskRunnerHandle::Get()),
socket_service_(
GetEndpoint(),
GetSwitchValueNonNegativeInt(switches::kMixerServiceEndpoint,
mixer_service::kDefaultTcpPort),
kMaxAcceptLoop,
this) {
this),
weak_factory_(this) {
socket_service_.Accept();
GetLocalReceiver()->SetInstance(this);
}
Receiver::~Receiver() = default;
Receiver::~Receiver() {
GetLocalReceiver()->RemoveInstance(this);
}
std::unique_ptr<MixerSocket> Receiver::LocalConnect() {
std::unique_ptr<MixerSocket> receiver_socket(new MixerSocket);
std::unique_ptr<MixerSocket> caller_socket(new MixerSocket);
receiver_socket->SetLocalCounterpart(caller_socket->GetWeakPtr(),
base::SequencedTaskRunnerHandle::Get());
caller_socket->SetLocalCounterpart(receiver_socket->GetWeakPtr(),
task_runner_);
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&Receiver::HandleLocalConnection,
weak_factory_.GetWeakPtr(), std::move(receiver_socket)));
return caller_socket;
}
void Receiver::HandleAcceptedSocket(std::unique_ptr<net::StreamSocket> socket) {
auto initial_socket =
std::make_unique<InitialSocket>(this, std::move(socket));
AddInitialSocket(std::make_unique<InitialSocket>(
this, std::make_unique<MixerSocket>(std::move(socket))));
}
void Receiver::HandleLocalConnection(std::unique_ptr<MixerSocket> socket) {
AddInitialSocket(std::make_unique<InitialSocket>(this, std::move(socket)));
}
void Receiver::AddInitialSocket(std::unique_ptr<InitialSocket> initial_socket) {
InitialSocket* ptr = initial_socket.get();
initial_sockets_[ptr] = std::move(initial_socket);
}
......
......@@ -9,9 +9,15 @@
#include "base/containers/flat_map.h"
#include "base/macros.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "chromecast/media/audio/mixer_service/audio_socket_service.h"
#include "chromecast/media/audio/mixer_service/mixer_service.pb.h"
namespace base {
class SequencedTaskRunner;
} // namespace base
namespace chromecast {
namespace media {
namespace mixer_service {
......@@ -31,19 +37,31 @@ class Receiver : public AudioSocketService::Delegate {
virtual void CreateControlConnection(std::unique_ptr<MixerSocket> socket,
const Generic& message) = 0;
// Creates a local (in-process) connection to this receiver. May be called on
// any thread; the returned MixerSocket can only be used on the calling
// thread. The returned socket must have its delegate set immediately.
std::unique_ptr<MixerSocket> LocalConnect();
private:
class InitialSocket;
void RemoveInitialSocket(InitialSocket* socket);
// AudioSocketService::Delegate implementation:
void HandleAcceptedSocket(std::unique_ptr<net::StreamSocket> socket) override;
void HandleLocalConnection(std::unique_ptr<MixerSocket> socket);
void AddInitialSocket(std::unique_ptr<InitialSocket> initial_socket);
void RemoveInitialSocket(InitialSocket* socket);
const scoped_refptr<base::SequencedTaskRunner> task_runner_;
AudioSocketService socket_service_;
base::flat_map<InitialSocket*, std::unique_ptr<InitialSocket>>
initial_sockets_;
base::WeakPtrFactory<Receiver> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(Receiver);
};
......
......@@ -119,7 +119,7 @@ class ReceiverCma::Stream : public MixerSocket::Delegate,
last_send_time_ = base::TimeTicks::Now();
}
socket_->ReceiveMessages();
socket_->ReceiveMoreMessages();
}
void PlayedEos() override {
......
......@@ -226,6 +226,7 @@ IOBufferPool::IOBufferPool(size_t buffer_size,
bool threadsafe)
: buffer_size_(buffer_size),
max_buffers_(max_buffers),
threadsafe_(threadsafe),
internal_(new Internal(buffer_size, max_buffers, threadsafe)) {}
IOBufferPool::IOBufferPool(size_t buffer_size)
......
......@@ -30,6 +30,7 @@ class IOBufferPool : public base::RefCountedThreadSafe<IOBufferPool> {
size_t buffer_size() const { return buffer_size_; }
size_t max_buffers() const { return max_buffers_; }
bool threadsafe() const { return threadsafe_; }
// Ensures that at least |num_buffers| are allocated. If |num_buffers| is
// greater than |max_buffers|, makes sure that |max_buffers| buffers have been
......@@ -56,6 +57,7 @@ class IOBufferPool : public base::RefCountedThreadSafe<IOBufferPool> {
const size_t buffer_size_;
const size_t max_buffers_;
const bool threadsafe_;
Internal* internal_; // Manages its own lifetime.
DISALLOW_COPY_AND_ASSIGN(IOBufferPool);
......
......@@ -75,14 +75,17 @@ class SmallMessageSocket::BufferWrapper : public ::net::IOBuffer {
size_t used_ = 0;
};
SmallMessageSocket::SmallMessageSocket(std::unique_ptr<net::Socket> socket)
: socket_(std::move(socket)),
SmallMessageSocket::SmallMessageSocket(Delegate* delegate,
std::unique_ptr<net::Socket> socket)
: delegate_(delegate),
socket_(std::move(socket)),
task_runner_(base::SequencedTaskRunnerHandle::Get()),
write_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
write_buffer_(base::MakeRefCounted<BufferWrapper>()),
read_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
read_buffer_(base::MakeRefCounted<BufferWrapper>()),
weak_factory_(this) {
DCHECK(delegate_);
write_storage_->SetCapacity(kDefaultBufferSize);
read_storage_->SetCapacity(kDefaultBufferSize);
}
......@@ -115,6 +118,7 @@ void SmallMessageSocket::ActivateBufferPool(char* current_data,
size_t new_buffer_size;
if (current_size <= buffer_pool_->buffer_size()) {
new_buffer = buffer_pool_->GetBuffer();
CHECK(new_buffer);
new_buffer_size = buffer_pool_->buffer_size();
} else {
new_buffer = base::MakeRefCounted<::net::IOBuffer>(current_size * 2);
......@@ -199,7 +203,12 @@ bool SmallMessageSocket::HandleWriteResult(int result) {
return false;
}
if (result <= 0) {
PostError(result);
// Post a task rather than just calling OnError(), to avoid calling
// OnError()
// synchronously.
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&SmallMessageSocket::OnError,
weak_factory_.GetWeakPtr(), result));
return false;
}
......@@ -211,17 +220,13 @@ bool SmallMessageSocket::HandleWriteResult(int result) {
write_buffer_->ClearUnderlyingBuffer();
if (send_blocked_) {
send_blocked_ = false;
OnSendUnblocked();
delegate_->OnSendUnblocked();
}
return false;
}
void SmallMessageSocket::PostError(int error) {
// Post a task rather than just calling OnError(), to avoid calling OnError()
// synchronously.
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&SmallMessageSocket::OnError,
weak_factory_.GetWeakPtr(), error));
void SmallMessageSocket::OnError(int error) {
delegate_->OnError(error);
}
void SmallMessageSocket::ReceiveMessages() {
......@@ -281,12 +286,12 @@ bool SmallMessageSocket::HandleReadResult(int result) {
}
if (result == 0 || result == net::ERR_CONNECTION_CLOSED) {
OnEndOfStream();
delegate_->OnEndOfStream();
return false;
}
if (result < 0) {
OnError(result);
delegate_->OnError(result);
return false;
}
......@@ -325,7 +330,8 @@ bool SmallMessageSocket::HandleCompletedMessages() {
// Take a weak pointer in case OnMessage() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr();
in_message_ = true;
keep_reading = OnMessage(start_ptr + sizeof(uint16_t), message_size);
keep_reading =
delegate_->OnMessage(start_ptr + sizeof(uint16_t), message_size);
if (!self) {
return false;
}
......@@ -372,6 +378,7 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() {
auto old_buffer = read_buffer_->TakeUnderlyingBuffer();
auto new_buffer = buffer_pool_->GetBuffer();
CHECK(new_buffer);
size_t new_buffer_size = buffer_pool_->buffer_size();
size_t extra_size = bytes_read - required_size;
if (extra_size > 0) {
......@@ -388,7 +395,8 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() {
// Take a weak pointer in case OnMessageBuffer() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr();
bool keep_reading = OnMessageBuffer(std::move(old_buffer), required_size);
bool keep_reading =
delegate_->OnMessageBuffer(std::move(old_buffer), required_size);
if (!self || !keep_reading) {
return false;
}
......@@ -401,8 +409,9 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() {
return true;
}
bool SmallMessageSocket::OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer,
int size) {
bool SmallMessageSocket::Delegate::OnMessageBuffer(
scoped_refptr<net::IOBuffer> buffer,
int size) {
return OnMessage(buffer->data() + sizeof(uint16_t), size - sizeof(uint16_t));
}
......
......@@ -29,7 +29,36 @@ class IOBufferPool;
// desired.
class SmallMessageSocket {
public:
explicit SmallMessageSocket(std::unique_ptr<net::Socket> socket);
class Delegate {
public:
// Called when sending becomes possible again, if a previous attempt to send
// was rejected.
virtual void OnSendUnblocked() {}
// Called when an unrecoverable error occurs while sending or receiving. Is
// only called asynchronously.
virtual void OnError(int error) {}
// Called when the end of stream has been read. No more data will be
// received.
virtual void OnEndOfStream() {}
// Called when a message has been received and there is no buffer pool. The
// |data| buffer contains |size| bytes of data. Return |true| to continue
// reading messages after OnMessage() returns.
virtual bool OnMessage(char* data, int size) = 0;
// Called when a message has been received. The |buffer| contains |size|
// bytes of data, which includes the first 2 bytes which are the size in
// network byte order. Note that these 2 bytes are not included in
// OnMessage()! Return |true| to continue receiving messages.
virtual bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size);
protected:
virtual ~Delegate() = default;
};
SmallMessageSocket(Delegate* delegate, std::unique_ptr<net::Socket> socket);
virtual ~SmallMessageSocket();
net::Socket* socket() const { return socket_.get(); }
......@@ -76,35 +105,12 @@ class SmallMessageSocket {
// asynchronous reads.
void ReceiveMessagesSynchronously();
protected:
private:
class BufferWrapper;
// Called when sending becomes possible again, if a previous attempt to send
// was rejected.
virtual void OnSendUnblocked() {}
// Called when an unrecoverable error occurs while sending or receiving. Is
// only called asynchronously.
virtual void OnError(int error) {}
// Called when the end of stream has been read. No more data will be received.
virtual void OnEndOfStream() {}
// Called when a message has been received and there is no buffer pool. The
// |data| buffer contains |size| bytes of data. Return |true| to continue
// reading messages after OnMessage() returns.
virtual bool OnMessage(char* data, int size) = 0;
// Called when a message has been received. The |buffer| contains |size| bytes
// of data, which includes the first 2 bytes which are the size in network
// byte order. Note that these 2 bytes are not included in OnMessage()!
// Return |true| to continue receiving messages.
virtual bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size);
private:
void OnWriteComplete(int result);
bool HandleWriteResult(int result);
void PostError(int error);
void OnError(int error);
void Read();
void OnReadComplete(int result);
......@@ -113,6 +119,7 @@ class SmallMessageSocket {
bool HandleCompletedMessageBuffers();
void ActivateBufferPool(char* current_data, size_t current_size);
Delegate* const delegate_;
const std::unique_ptr<net::Socket> socket_;
const scoped_refptr<base::SequencedTaskRunner> task_runner_;
......
......@@ -5,6 +5,7 @@
#include "chromecast/net/small_message_socket.h"
#include <memory>
#include <string>
#include <vector>
#include "base/big_endian.h"
......@@ -43,19 +44,29 @@ void CheckData(char* buffer, int size) {
}
}
class TestSocket : public SmallMessageSocket {
class TestSocket : public SmallMessageSocket::Delegate {
public:
explicit TestSocket(std::unique_ptr<net::Socket> socket)
: SmallMessageSocket(std::move(socket)) {}
: socket_(this, std::move(socket)) {}
~TestSocket() override = default;
void UseBufferPool() {
SmallMessageSocket::UseBufferPool(base::MakeRefCounted<IOBufferPool>(
kDefaultMessageSize + sizeof(uint16_t)));
buffer_pool_ = base::MakeRefCounted<IOBufferPool>(kDefaultMessageSize +
sizeof(uint16_t));
socket_.UseBufferPool(buffer_pool_);
}
void SwapPoolUse(bool swap) { swap_pool_use_ = swap; }
void* PrepareSend(int message_size) {
return socket_.PrepareSend(message_size);
}
void Send() { socket_.Send(); }
bool SendBuffer(scoped_refptr<net::IOBuffer> data, int size) {
return socket_.SendBuffer(std::move(data), size);
}
void ReceiveMessages() { socket_.ReceiveMessages(); }
size_t last_message_size() const {
DCHECK(!message_history_.empty());
return message_history_[message_history_.size() - 1];
......@@ -65,6 +76,8 @@ class TestSocket : public SmallMessageSocket {
return message_history_;
}
IOBufferPool* buffer_pool() const { return buffer_pool_.get(); }
private:
void OnError(int error) override { NOTREACHED(); }
......@@ -84,12 +97,15 @@ class TestSocket : public SmallMessageSocket {
message_history_.push_back(message_size);
CheckData(buffer->data() + sizeof(uint16_t), message_size);
if (swap_pool_use_) {
RemoveBufferPool();
socket_.RemoveBufferPool();
buffer_pool_ = nullptr;
}
return true;
}
SmallMessageSocket socket_;
std::vector<size_t> message_history_;
scoped_refptr<IOBufferPool> buffer_pool_;
bool swap_pool_use_ = false;
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment