Commit f247c6a6 authored by Ken MacKay's avatar Ken MacKay Committed by Commit Bot

[Chromecast] Use PostTask for in-process mixer connections

Thereby making it more efficient.

Merge-With: eureka-internal/326819
Bug: internal b/127963522
Change-Id: Idc7db2c86d54c3b14293f407744c51ac6e9d40e1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1869128Reviewed-by: default avatarYuchen Liu <yucliu@chromium.org>
Commit-Queue: Kenneth MacKay <kmackay@chromium.org>
Cr-Commit-Position: refs/heads/master@{#708065}
parent 1d829a40
...@@ -44,7 +44,7 @@ ...@@ -44,7 +44,7 @@
namespace chromecast { namespace chromecast {
namespace media { namespace media {
class CaptureServiceReceiver::Socket : public SmallMessageSocket { class CaptureServiceReceiver::Socket : public SmallMessageSocket::Delegate {
public: public:
Socket(std::unique_ptr<net::StreamSocket> socket, int channels); Socket(std::unique_ptr<net::StreamSocket> socket, int channels);
~Socket() override; ~Socket() override;
...@@ -52,7 +52,7 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket { ...@@ -52,7 +52,7 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket {
void Start(::media::AudioInputStream::AudioInputCallback* input_callback); void Start(::media::AudioInputStream::AudioInputCallback* input_callback);
private: private:
// SmallMessageSocket implementation: // SmallMessageSocket::Delegate implementation:
void OnError(int error) override; void OnError(int error) override;
void OnEndOfStream() override; void OnEndOfStream() override;
bool OnMessage(char* data, int size) override; bool OnMessage(char* data, int size) override;
...@@ -61,6 +61,8 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket { ...@@ -61,6 +61,8 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket {
bool HandleAudio(std::unique_ptr<::media::AudioBus> audio, int64_t timestamp); bool HandleAudio(std::unique_ptr<::media::AudioBus> audio, int64_t timestamp);
void ReportErrorAndStop(); void ReportErrorAndStop();
SmallMessageSocket socket_;
// Number of audio capture channels that audio manager defines. // Number of audio capture channels that audio manager defines.
const int channels_; const int channels_;
...@@ -72,7 +74,7 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket { ...@@ -72,7 +74,7 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket {
CaptureServiceReceiver::Socket::Socket( CaptureServiceReceiver::Socket::Socket(
std::unique_ptr<net::StreamSocket> socket, std::unique_ptr<net::StreamSocket> socket,
int channels) int channels)
: SmallMessageSocket(std::move(socket)), : socket_(this, std::move(socket)),
channels_(channels), channels_(channels),
input_callback_(nullptr) { input_callback_(nullptr) {
DCHECK_GT(channels_, 0); DCHECK_GT(channels_, 0);
...@@ -84,7 +86,7 @@ CaptureServiceReceiver::Socket::~Socket() = default; ...@@ -84,7 +86,7 @@ CaptureServiceReceiver::Socket::~Socket() = default;
void CaptureServiceReceiver::Socket::Start( void CaptureServiceReceiver::Socket::Start(
::media::AudioInputStream::AudioInputCallback* input_callback) { ::media::AudioInputStream::AudioInputCallback* input_callback) {
input_callback_ = input_callback; input_callback_ = input_callback;
ReceiveMessages(); socket_.ReceiveMessages();
} }
void CaptureServiceReceiver::Socket::ReportErrorAndStop() { void CaptureServiceReceiver::Socket::ReportErrorAndStop() {
......
...@@ -24,6 +24,7 @@ cast_source_set("common") { ...@@ -24,6 +24,7 @@ cast_source_set("common") {
deps = [ deps = [
"//base", "//base",
"//chromecast/net:io_buffer_pool",
"//chromecast/public", "//chromecast/public",
"//chromecast/public/media", "//chromecast/public/media",
"//net", "//net",
......
...@@ -102,9 +102,9 @@ void ControlConnection::SetStreamCountCallback(StreamCountCallback callback) { ...@@ -102,9 +102,9 @@ void ControlConnection::SetStreamCountCallback(StreamCountCallback callback) {
} }
} }
void ControlConnection::OnConnected(std::unique_ptr<net::StreamSocket> socket) { void ControlConnection::OnConnected(std::unique_ptr<MixerSocket> socket) {
socket_ = std::make_unique<MixerSocket>(std::move(socket), this); socket_ = std::move(socket);
socket_->ReceiveMessages(); socket_->SetDelegate(this);
for (const auto& item : volume_limit_) { for (const auto& item : volume_limit_) {
Generic message; Generic message;
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
#include "chromecast/media/audio/mixer_service/mixer_socket.h" #include "chromecast/media/audio/mixer_service/mixer_socket.h"
#include "chromecast/public/volume_control.h" #include "chromecast/public/volume_control.h"
namespace net {
class StreamSocket;
} // namespace net
namespace chromecast { namespace chromecast {
namespace media { namespace media {
namespace mixer_service { namespace mixer_service {
...@@ -71,7 +67,7 @@ class ControlConnection : public MixerConnection, public MixerSocket::Delegate { ...@@ -71,7 +67,7 @@ class ControlConnection : public MixerConnection, public MixerSocket::Delegate {
private: private:
// MixerConnection implementation: // MixerConnection implementation:
void OnConnected(std::unique_ptr<net::StreamSocket> socket) override; void OnConnected(std::unique_ptr<MixerSocket> socket) override;
void OnConnectionError() override; void OnConnectionError() override;
// MixerSocket::Delegate implementation: // MixerSocket::Delegate implementation:
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "chromecast/base/chromecast_switches.h" #include "chromecast/base/chromecast_switches.h"
#include "chromecast/media/audio/audio_buildflags.h" #include "chromecast/media/audio/audio_buildflags.h"
#include "chromecast/media/audio/mixer_service/constants.h" #include "chromecast/media/audio/mixer_service/constants.h"
#include "chromecast/media/audio/mixer_service/mixer_socket.h"
#include "net/base/address_list.h" #include "net/base/address_list.h"
#include "net/base/ip_address.h" #include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
...@@ -40,12 +41,22 @@ constexpr base::TimeDelta kConnectTimeout = base::TimeDelta::FromSeconds(1); ...@@ -40,12 +41,22 @@ constexpr base::TimeDelta kConnectTimeout = base::TimeDelta::FromSeconds(1);
} // namespace } // namespace
std::unique_ptr<MixerSocket> CreateLocalMixerServiceConnection()
__attribute__((__weak__));
MixerConnection::MixerConnection() : weak_factory_(this) {} MixerConnection::MixerConnection() : weak_factory_(this) {}
MixerConnection::~MixerConnection() = default; MixerConnection::~MixerConnection() = default;
void MixerConnection::Connect() { void MixerConnection::Connect() {
DCHECK(!connecting_socket_); DCHECK(!connecting_socket_);
if (CreateLocalMixerServiceConnection) {
auto socket = CreateLocalMixerServiceConnection();
if (socket) {
OnConnected(std::move(socket));
return;
}
}
#if BUILDFLAG(USE_UNIX_SOCKETS) #if BUILDFLAG(USE_UNIX_SOCKETS)
const base::CommandLine* command_line = const base::CommandLine* command_line =
...@@ -87,7 +98,8 @@ void MixerConnection::ConnectCallback(int result) { ...@@ -87,7 +98,8 @@ void MixerConnection::ConnectCallback(int result) {
LOG_IF(INFO, !log_timeout_) << "Now connected to mixer service"; LOG_IF(INFO, !log_timeout_) << "Now connected to mixer service";
log_connection_failure_ = true; log_connection_failure_ = true;
log_timeout_ = true; log_timeout_ = true;
OnConnected(std::move(connecting_socket_)); auto socket = std::make_unique<MixerSocket>(std::move(connecting_socket_));
OnConnected(std::move(socket));
return; return;
} }
......
...@@ -19,6 +19,7 @@ class StreamSocket; ...@@ -19,6 +19,7 @@ class StreamSocket;
namespace chromecast { namespace chromecast {
namespace media { namespace media {
namespace mixer_service { namespace mixer_service {
class MixerSocket;
// Base class for connecting to the mixer service. // Base class for connecting to the mixer service.
class MixerConnection { class MixerConnection {
...@@ -32,7 +33,7 @@ class MixerConnection { ...@@ -32,7 +33,7 @@ class MixerConnection {
protected: protected:
// Called when a connection is established to the mixer service. // Called when a connection is established to the mixer service.
virtual void OnConnected(std::unique_ptr<net::StreamSocket> socket) = 0; virtual void OnConnected(std::unique_ptr<MixerSocket> socket) = 0;
private: private:
void ConnectCallback(int result); void ConnectCallback(int result);
......
...@@ -9,9 +9,13 @@ ...@@ -9,9 +9,13 @@
#include <utility> #include <utility>
#include "base/big_endian.h" #include "base/big_endian.h"
#include "base/bind.h"
#include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/sequenced_task_runner.h"
#include "chromecast/media/audio/mixer_service/constants.h" #include "chromecast/media/audio/mixer_service/constants.h"
#include "chromecast/media/audio/mixer_service/mixer_service.pb.h" #include "chromecast/media/audio/mixer_service/mixer_service.pb.h"
#include "chromecast/net/io_buffer_pool.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/socket/stream_socket.h" #include "net/socket/stream_socket.h"
...@@ -74,17 +78,47 @@ bool MixerSocket::Delegate::HandleAudioBuffer( ...@@ -74,17 +78,47 @@ bool MixerSocket::Delegate::HandleAudioBuffer(
constexpr size_t MixerSocket::kAudioHeaderSize; constexpr size_t MixerSocket::kAudioHeaderSize;
constexpr size_t MixerSocket::kAudioMessageHeaderSize; constexpr size_t MixerSocket::kAudioMessageHeaderSize;
MixerSocket::MixerSocket(std::unique_ptr<net::StreamSocket> socket, MixerSocket::MixerSocket(std::unique_ptr<net::StreamSocket> socket)
Delegate* delegate) : socket_(std::make_unique<SmallMessageSocket>(this, std::move(socket))) {}
: SmallMessageSocket(std::move(socket)), delegate_(delegate) {
DCHECK(delegate_);
}
MixerSocket::~MixerSocket() = default; MixerSocket::MixerSocket() = default;
MixerSocket::~MixerSocket() {
if (counterpart_task_runner_) {
counterpart_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&MixerSocket::OnEndOfStream, local_counterpart_));
}
}
void MixerSocket::SetDelegate(Delegate* delegate) { void MixerSocket::SetDelegate(Delegate* delegate) {
DCHECK(delegate); DCHECK(delegate);
bool had_delegate = (delegate_ != nullptr);
delegate_ = delegate; delegate_ = delegate;
if (socket_ && !had_delegate) {
socket_->ReceiveMessages();
}
}
void MixerSocket::SetLocalCounterpart(
base::WeakPtr<MixerSocket> local_counterpart,
scoped_refptr<base::SequencedTaskRunner> counterpart_task_runner) {
local_counterpart_ = std::move(local_counterpart);
counterpart_task_runner_ = std::move(counterpart_task_runner);
}
base::WeakPtr<MixerSocket> MixerSocket::GetWeakPtr() {
return weak_factory_.GetWeakPtr();
}
void MixerSocket::UseBufferPool(scoped_refptr<IOBufferPool> buffer_pool) {
DCHECK(buffer_pool);
DCHECK(buffer_pool->threadsafe());
buffer_pool_ = std::move(buffer_pool);
if (socket_) {
socket_->UseBufferPool(buffer_pool_);
}
} }
// static // static
...@@ -120,10 +154,7 @@ void MixerSocket::SendPreparedAudioBuffer( ...@@ -120,10 +154,7 @@ void MixerSocket::SendPreparedAudioBuffer(
uint16_t payload_size; uint16_t payload_size;
base::ReadBigEndian(audio_buffer->data(), &payload_size); base::ReadBigEndian(audio_buffer->data(), &payload_size);
DCHECK_GE(payload_size, kAudioHeaderSize); DCHECK_GE(payload_size, kAudioHeaderSize);
if (!SmallMessageSocket::SendBuffer(audio_buffer.get(), SendBuffer(std::move(audio_buffer), sizeof(uint16_t) + payload_size);
sizeof(uint16_t) + payload_size)) {
write_queue_.push(std::move(audio_buffer));
}
} }
void MixerSocket::SendProto(const google::protobuf::MessageLite& message) { void MixerSocket::SendProto(const google::protobuf::MessageLite& message) {
...@@ -133,16 +164,20 @@ void MixerSocket::SendProto(const google::protobuf::MessageLite& message) { ...@@ -133,16 +164,20 @@ void MixerSocket::SendProto(const google::protobuf::MessageLite& message) {
int total_size = int total_size =
sizeof(type) + sizeof(padding_bytes) + message_size + padding_bytes; sizeof(type) + sizeof(padding_bytes) + message_size + padding_bytes;
scoped_refptr<net::IOBufferWithSize> storage;
void* buffer = PrepareSend(total_size); scoped_refptr<net::IOBuffer> buffer;
char* ptr; char* ptr = (socket_ ? static_cast<char*>(socket_->PrepareSend(total_size))
if (buffer) { : nullptr);
ptr = reinterpret_cast<char*>(buffer); if (!ptr) {
} else { if (buffer_pool_ &&
storage = base::MakeRefCounted<net::IOBufferWithSize>(sizeof(uint16_t) + buffer_pool_->buffer_size() >= sizeof(uint16_t) + total_size) {
total_size); buffer = buffer_pool_->GetBuffer();
}
ptr = storage->data(); if (!buffer) {
buffer =
base::MakeRefCounted<net::IOBuffer>(sizeof(uint16_t) + total_size);
}
ptr = buffer->data();
base::WriteBigEndian(ptr, static_cast<uint16_t>(total_size)); base::WriteBigEndian(ptr, static_cast<uint16_t>(total_size));
ptr += sizeof(uint16_t); ptr += sizeof(uint16_t);
} }
...@@ -155,19 +190,34 @@ void MixerSocket::SendProto(const google::protobuf::MessageLite& message) { ...@@ -155,19 +190,34 @@ void MixerSocket::SendProto(const google::protobuf::MessageLite& message) {
ptr += message_size; ptr += message_size;
memset(ptr, 0, padding_bytes); memset(ptr, 0, padding_bytes);
if (buffer) { if (!buffer) {
Send(); socket_->Send();
return;
} }
if (storage) { SendBuffer(std::move(buffer), sizeof(uint16_t) + total_size);
write_queue_.push(std::move(storage)); }
void MixerSocket::SendBuffer(scoped_refptr<net::IOBuffer> buffer,
int buffer_size) {
if (counterpart_task_runner_) {
counterpart_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(base::IgnoreResult(&MixerSocket::OnMessageBuffer),
local_counterpart_, std::move(buffer), buffer_size));
return;
}
DCHECK(socket_);
if (!socket_->SendBuffer(buffer, buffer_size)) {
write_queue_.push(std::move(buffer));
} }
} }
void MixerSocket::OnSendUnblocked() { void MixerSocket::OnSendUnblocked() {
DCHECK(socket_);
while (!write_queue_.empty()) { while (!write_queue_.empty()) {
uint16_t message_size; uint16_t message_size;
base::ReadBigEndian(write_queue_.front()->data(), &message_size); base::ReadBigEndian(write_queue_.front()->data(), &message_size);
if (!SmallMessageSocket::SendBuffer(write_queue_.front().get(), if (!socket_->SendBuffer(write_queue_.front().get(),
sizeof(uint16_t) + message_size)) { sizeof(uint16_t) + message_size)) {
return; return;
} }
...@@ -175,12 +225,18 @@ void MixerSocket::OnSendUnblocked() { ...@@ -175,12 +225,18 @@ void MixerSocket::OnSendUnblocked() {
} }
} }
void MixerSocket::ReceiveMoreMessages() {
socket_->ReceiveMessagesSynchronously();
}
void MixerSocket::OnError(int error) { void MixerSocket::OnError(int error) {
LOG(ERROR) << "Socket error from " << this << ": " << error; LOG(ERROR) << "Socket error from " << this << ": " << error;
DCHECK(delegate_);
delegate_->OnConnectionError(); delegate_->OnConnectionError();
} }
void MixerSocket::OnEndOfStream() { void MixerSocket::OnEndOfStream() {
DCHECK(delegate_);
delegate_->OnConnectionError(); delegate_->OnConnectionError();
} }
......
...@@ -11,8 +11,13 @@ ...@@ -11,8 +11,13 @@
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/scoped_refptr.h" #include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "chromecast/net/small_message_socket.h" #include "chromecast/net/small_message_socket.h"
namespace base {
class SequencedTaskRunner;
} // namespace base
namespace google { namespace google {
namespace protobuf { namespace protobuf {
class MessageLite; class MessageLite;
...@@ -25,13 +30,15 @@ class StreamSocket; ...@@ -25,13 +30,15 @@ class StreamSocket;
} // namespace net } // namespace net
namespace chromecast { namespace chromecast {
class IOBufferPool;
namespace media { namespace media {
namespace mixer_service { namespace mixer_service {
class Generic; class Generic;
// Base class for sending and receiving messages to/from the mixer service. // Base class for sending and receiving messages to/from the mixer service.
// Not thread-safe; all usage of a given instance must be on the same sequence. // Not thread-safe; all usage of a given instance must be on the same sequence.
class MixerSocket : public SmallMessageSocket { class MixerSocket : public SmallMessageSocket::Delegate {
public: public:
class Delegate { class Delegate {
public: public:
...@@ -64,12 +71,18 @@ class MixerSocket : public SmallMessageSocket { ...@@ -64,12 +71,18 @@ class MixerSocket : public SmallMessageSocket {
virtual ~Delegate() = default; virtual ~Delegate() = default;
}; };
MixerSocket(std::unique_ptr<net::StreamSocket> socket, Delegate* delegate); explicit MixerSocket(std::unique_ptr<net::StreamSocket> socket);
~MixerSocket() override; ~MixerSocket() override;
// Changes the delegate. // Sets/changes the delegate. Must be called immediately after creation
// (ie, synchronously on the same sequence).
void SetDelegate(Delegate* delegate); void SetDelegate(Delegate* delegate);
// Adds a |buffer_pool| used to allocate buffers to receive messages into,
// and for sending protos. If the pool-allocated buffers are too small for a
// given message, a normal IOBuffer will be dynamically allocated instead.
void UseBufferPool(scoped_refptr<IOBufferPool> buffer_pool);
// 16-bit type and 64-bit timestamp. // 16-bit type and 64-bit timestamp.
static constexpr size_t kAudioHeaderSize = sizeof(int16_t) + sizeof(int64_t); static constexpr size_t kAudioHeaderSize = sizeof(int16_t) + sizeof(int64_t);
// Includes additional 16-bit size field for SmallMessageSocket. // Includes additional 16-bit size field for SmallMessageSocket.
...@@ -96,8 +109,24 @@ class MixerSocket : public SmallMessageSocket { ...@@ -96,8 +109,24 @@ class MixerSocket : public SmallMessageSocket {
// Sends an arbitrary protobuf across the connection. // Sends an arbitrary protobuf across the connection.
void SendProto(const google::protobuf::MessageLite& message); void SendProto(const google::protobuf::MessageLite& message);
// Resumes receiving messages. Delegate calls may be called synchronously
// from within this method.
void ReceiveMoreMessages();
private: private:
// SmallMessageSocket implementation: friend class Receiver;
// Used by Receiver to create in-process mixer connections.
MixerSocket();
void SetLocalCounterpart(
base::WeakPtr<MixerSocket> local_counterpart,
scoped_refptr<base::SequencedTaskRunner> counterpart_task_runner);
base::WeakPtr<MixerSocket> GetWeakPtr();
void SendBuffer(scoped_refptr<net::IOBuffer> buffer, int buffer_size);
// SmallMessageSocket::Delegate implementation:
void OnSendUnblocked() override; void OnSendUnblocked() override;
void OnError(int error) override; void OnError(int error) override;
void OnEndOfStream() override; void OnEndOfStream() override;
...@@ -110,10 +139,17 @@ class MixerSocket : public SmallMessageSocket { ...@@ -110,10 +139,17 @@ class MixerSocket : public SmallMessageSocket {
char* data, char* data,
int size); int size);
Delegate* delegate_; Delegate* delegate_ = nullptr;
const std::unique_ptr<SmallMessageSocket> socket_;
scoped_refptr<IOBufferPool> buffer_pool_;
std::queue<scoped_refptr<net::IOBuffer>> write_queue_; std::queue<scoped_refptr<net::IOBuffer>> write_queue_;
base::WeakPtr<MixerSocket> local_counterpart_;
scoped_refptr<base::SequencedTaskRunner> counterpart_task_runner_;
base::WeakPtrFactory<MixerSocket> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(MixerSocket); DISALLOW_COPY_AND_ASSIGN(MixerSocket);
}; };
......
...@@ -125,10 +125,9 @@ void OutputStreamConnection::Resume() { ...@@ -125,10 +125,9 @@ void OutputStreamConnection::Resume() {
} }
} }
void OutputStreamConnection::OnConnected( void OutputStreamConnection::OnConnected(std::unique_ptr<MixerSocket> socket) {
std::unique_ptr<net::StreamSocket> socket) { socket_ = std::move(socket);
socket_ = std::make_unique<MixerSocket>(std::move(socket), this); socket_->SetDelegate(this);
socket_->ReceiveMessages();
Generic message; Generic message;
*(message.mutable_output_stream_params()) = params_; *(message.mutable_output_stream_params()) = params_;
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
#include "chromecast/media/audio/mixer_service/mixer_socket.h" #include "chromecast/media/audio/mixer_service/mixer_socket.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
namespace net {
class StreamSocket;
} // namespace net
namespace chromecast { namespace chromecast {
namespace media { namespace media {
namespace mixer_service { namespace mixer_service {
...@@ -90,7 +86,7 @@ class OutputStreamConnection : public MixerConnection, ...@@ -90,7 +86,7 @@ class OutputStreamConnection : public MixerConnection,
private: private:
// MixerConnection implementation: // MixerConnection implementation:
void OnConnected(std::unique_ptr<net::StreamSocket> socket) override; void OnConnected(std::unique_ptr<MixerSocket> socket) override;
void OnConnectionError() override; void OnConnectionError() override;
// MixerSocket::Delegate implementation: // MixerSocket::Delegate implementation:
......
...@@ -7,8 +7,14 @@ ...@@ -7,8 +7,14 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "base/bind.h"
#include "base/command_line.h" #include "base/command_line.h"
#include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/no_destructor.h"
#include "base/sequenced_task_runner.h"
#include "base/synchronization/lock.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "chromecast/base/chromecast_switches.h" #include "chromecast/base/chromecast_switches.h"
#include "chromecast/media/audio/mixer_service/constants.h" #include "chromecast/media/audio/mixer_service/constants.h"
#include "chromecast/media/audio/mixer_service/mixer_socket.h" #include "chromecast/media/audio/mixer_service/mixer_socket.h"
...@@ -32,15 +38,54 @@ std::string GetEndpoint() { ...@@ -32,15 +38,54 @@ std::string GetEndpoint() {
return path; return path;
} }
class LocalReceiverInstance {
public:
LocalReceiverInstance() = default;
void SetInstance(Receiver* receiver) {
base::AutoLock lock(lock_);
receiver_ = receiver;
}
void RemoveInstance(Receiver* receiver) {
base::AutoLock lock(lock_);
if (receiver_ == receiver) {
receiver_ = nullptr;
}
}
std::unique_ptr<MixerSocket> CreateLocalSocket() {
base::AutoLock lock(lock_);
if (receiver_) {
return receiver_->LocalConnect();
}
return nullptr;
}
private:
DISALLOW_COPY_AND_ASSIGN(LocalReceiverInstance);
base::Lock lock_;
Receiver* receiver_ = nullptr;
};
LocalReceiverInstance* GetLocalReceiver() {
static base::NoDestructor<LocalReceiverInstance> instance;
return instance.get();
}
} // namespace } // namespace
std::unique_ptr<MixerSocket> CreateLocalMixerServiceConnection() {
return GetLocalReceiver()->CreateLocalSocket();
}
class Receiver::InitialSocket : public MixerSocket::Delegate { class Receiver::InitialSocket : public MixerSocket::Delegate {
public: public:
InitialSocket(Receiver* receiver, std::unique_ptr<net::StreamSocket> socket) InitialSocket(Receiver* receiver, std::unique_ptr<MixerSocket> socket)
: receiver_(receiver), : receiver_(receiver), socket_(std::move(socket)) {
socket_(std::make_unique<MixerSocket>(std::move(socket), this)) {
DCHECK(receiver_); DCHECK(receiver_);
socket_->ReceiveMessages(); socket_->SetDelegate(this);
} }
~InitialSocket() override = default; ~InitialSocket() override = default;
...@@ -79,20 +124,49 @@ class Receiver::InitialSocket : public MixerSocket::Delegate { ...@@ -79,20 +124,49 @@ class Receiver::InitialSocket : public MixerSocket::Delegate {
}; };
Receiver::Receiver() Receiver::Receiver()
: socket_service_( : task_runner_(base::SequencedTaskRunnerHandle::Get()),
socket_service_(
GetEndpoint(), GetEndpoint(),
GetSwitchValueNonNegativeInt(switches::kMixerServiceEndpoint, GetSwitchValueNonNegativeInt(switches::kMixerServiceEndpoint,
mixer_service::kDefaultTcpPort), mixer_service::kDefaultTcpPort),
kMaxAcceptLoop, kMaxAcceptLoop,
this) { this),
weak_factory_(this) {
socket_service_.Accept(); socket_service_.Accept();
GetLocalReceiver()->SetInstance(this);
} }
Receiver::~Receiver() = default; Receiver::~Receiver() {
GetLocalReceiver()->RemoveInstance(this);
}
std::unique_ptr<MixerSocket> Receiver::LocalConnect() {
std::unique_ptr<MixerSocket> receiver_socket(new MixerSocket);
std::unique_ptr<MixerSocket> caller_socket(new MixerSocket);
receiver_socket->SetLocalCounterpart(caller_socket->GetWeakPtr(),
base::SequencedTaskRunnerHandle::Get());
caller_socket->SetLocalCounterpart(receiver_socket->GetWeakPtr(),
task_runner_);
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&Receiver::HandleLocalConnection,
weak_factory_.GetWeakPtr(), std::move(receiver_socket)));
return caller_socket;
}
void Receiver::HandleAcceptedSocket(std::unique_ptr<net::StreamSocket> socket) { void Receiver::HandleAcceptedSocket(std::unique_ptr<net::StreamSocket> socket) {
auto initial_socket = AddInitialSocket(std::make_unique<InitialSocket>(
std::make_unique<InitialSocket>(this, std::move(socket)); this, std::make_unique<MixerSocket>(std::move(socket))));
}
void Receiver::HandleLocalConnection(std::unique_ptr<MixerSocket> socket) {
AddInitialSocket(std::make_unique<InitialSocket>(this, std::move(socket)));
}
void Receiver::AddInitialSocket(std::unique_ptr<InitialSocket> initial_socket) {
InitialSocket* ptr = initial_socket.get(); InitialSocket* ptr = initial_socket.get();
initial_sockets_[ptr] = std::move(initial_socket); initial_sockets_[ptr] = std::move(initial_socket);
} }
......
...@@ -9,9 +9,15 @@ ...@@ -9,9 +9,15 @@
#include "base/containers/flat_map.h" #include "base/containers/flat_map.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "chromecast/media/audio/mixer_service/audio_socket_service.h" #include "chromecast/media/audio/mixer_service/audio_socket_service.h"
#include "chromecast/media/audio/mixer_service/mixer_service.pb.h" #include "chromecast/media/audio/mixer_service/mixer_service.pb.h"
namespace base {
class SequencedTaskRunner;
} // namespace base
namespace chromecast { namespace chromecast {
namespace media { namespace media {
namespace mixer_service { namespace mixer_service {
...@@ -31,19 +37,31 @@ class Receiver : public AudioSocketService::Delegate { ...@@ -31,19 +37,31 @@ class Receiver : public AudioSocketService::Delegate {
virtual void CreateControlConnection(std::unique_ptr<MixerSocket> socket, virtual void CreateControlConnection(std::unique_ptr<MixerSocket> socket,
const Generic& message) = 0; const Generic& message) = 0;
// Creates a local (in-process) connection to this receiver. May be called on
// any thread; the returned MixerSocket can only be used on the calling
// thread. The returned socket must have its delegate set immediately.
std::unique_ptr<MixerSocket> LocalConnect();
private: private:
class InitialSocket; class InitialSocket;
void RemoveInitialSocket(InitialSocket* socket);
// AudioSocketService::Delegate implementation: // AudioSocketService::Delegate implementation:
void HandleAcceptedSocket(std::unique_ptr<net::StreamSocket> socket) override; void HandleAcceptedSocket(std::unique_ptr<net::StreamSocket> socket) override;
void HandleLocalConnection(std::unique_ptr<MixerSocket> socket);
void AddInitialSocket(std::unique_ptr<InitialSocket> initial_socket);
void RemoveInitialSocket(InitialSocket* socket);
const scoped_refptr<base::SequencedTaskRunner> task_runner_;
AudioSocketService socket_service_; AudioSocketService socket_service_;
base::flat_map<InitialSocket*, std::unique_ptr<InitialSocket>> base::flat_map<InitialSocket*, std::unique_ptr<InitialSocket>>
initial_sockets_; initial_sockets_;
base::WeakPtrFactory<Receiver> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(Receiver); DISALLOW_COPY_AND_ASSIGN(Receiver);
}; };
......
...@@ -119,7 +119,7 @@ class ReceiverCma::Stream : public MixerSocket::Delegate, ...@@ -119,7 +119,7 @@ class ReceiverCma::Stream : public MixerSocket::Delegate,
last_send_time_ = base::TimeTicks::Now(); last_send_time_ = base::TimeTicks::Now();
} }
socket_->ReceiveMessages(); socket_->ReceiveMoreMessages();
} }
void PlayedEos() override { void PlayedEos() override {
......
...@@ -226,6 +226,7 @@ IOBufferPool::IOBufferPool(size_t buffer_size, ...@@ -226,6 +226,7 @@ IOBufferPool::IOBufferPool(size_t buffer_size,
bool threadsafe) bool threadsafe)
: buffer_size_(buffer_size), : buffer_size_(buffer_size),
max_buffers_(max_buffers), max_buffers_(max_buffers),
threadsafe_(threadsafe),
internal_(new Internal(buffer_size, max_buffers, threadsafe)) {} internal_(new Internal(buffer_size, max_buffers, threadsafe)) {}
IOBufferPool::IOBufferPool(size_t buffer_size) IOBufferPool::IOBufferPool(size_t buffer_size)
......
...@@ -30,6 +30,7 @@ class IOBufferPool : public base::RefCountedThreadSafe<IOBufferPool> { ...@@ -30,6 +30,7 @@ class IOBufferPool : public base::RefCountedThreadSafe<IOBufferPool> {
size_t buffer_size() const { return buffer_size_; } size_t buffer_size() const { return buffer_size_; }
size_t max_buffers() const { return max_buffers_; } size_t max_buffers() const { return max_buffers_; }
bool threadsafe() const { return threadsafe_; }
// Ensures that at least |num_buffers| are allocated. If |num_buffers| is // Ensures that at least |num_buffers| are allocated. If |num_buffers| is
// greater than |max_buffers|, makes sure that |max_buffers| buffers have been // greater than |max_buffers|, makes sure that |max_buffers| buffers have been
...@@ -56,6 +57,7 @@ class IOBufferPool : public base::RefCountedThreadSafe<IOBufferPool> { ...@@ -56,6 +57,7 @@ class IOBufferPool : public base::RefCountedThreadSafe<IOBufferPool> {
const size_t buffer_size_; const size_t buffer_size_;
const size_t max_buffers_; const size_t max_buffers_;
const bool threadsafe_;
Internal* internal_; // Manages its own lifetime. Internal* internal_; // Manages its own lifetime.
DISALLOW_COPY_AND_ASSIGN(IOBufferPool); DISALLOW_COPY_AND_ASSIGN(IOBufferPool);
......
...@@ -75,14 +75,17 @@ class SmallMessageSocket::BufferWrapper : public ::net::IOBuffer { ...@@ -75,14 +75,17 @@ class SmallMessageSocket::BufferWrapper : public ::net::IOBuffer {
size_t used_ = 0; size_t used_ = 0;
}; };
SmallMessageSocket::SmallMessageSocket(std::unique_ptr<net::Socket> socket) SmallMessageSocket::SmallMessageSocket(Delegate* delegate,
: socket_(std::move(socket)), std::unique_ptr<net::Socket> socket)
: delegate_(delegate),
socket_(std::move(socket)),
task_runner_(base::SequencedTaskRunnerHandle::Get()), task_runner_(base::SequencedTaskRunnerHandle::Get()),
write_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()), write_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
write_buffer_(base::MakeRefCounted<BufferWrapper>()), write_buffer_(base::MakeRefCounted<BufferWrapper>()),
read_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()), read_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
read_buffer_(base::MakeRefCounted<BufferWrapper>()), read_buffer_(base::MakeRefCounted<BufferWrapper>()),
weak_factory_(this) { weak_factory_(this) {
DCHECK(delegate_);
write_storage_->SetCapacity(kDefaultBufferSize); write_storage_->SetCapacity(kDefaultBufferSize);
read_storage_->SetCapacity(kDefaultBufferSize); read_storage_->SetCapacity(kDefaultBufferSize);
} }
...@@ -115,6 +118,7 @@ void SmallMessageSocket::ActivateBufferPool(char* current_data, ...@@ -115,6 +118,7 @@ void SmallMessageSocket::ActivateBufferPool(char* current_data,
size_t new_buffer_size; size_t new_buffer_size;
if (current_size <= buffer_pool_->buffer_size()) { if (current_size <= buffer_pool_->buffer_size()) {
new_buffer = buffer_pool_->GetBuffer(); new_buffer = buffer_pool_->GetBuffer();
CHECK(new_buffer);
new_buffer_size = buffer_pool_->buffer_size(); new_buffer_size = buffer_pool_->buffer_size();
} else { } else {
new_buffer = base::MakeRefCounted<::net::IOBuffer>(current_size * 2); new_buffer = base::MakeRefCounted<::net::IOBuffer>(current_size * 2);
...@@ -199,7 +203,12 @@ bool SmallMessageSocket::HandleWriteResult(int result) { ...@@ -199,7 +203,12 @@ bool SmallMessageSocket::HandleWriteResult(int result) {
return false; return false;
} }
if (result <= 0) { if (result <= 0) {
PostError(result); // Post a task rather than just calling OnError(), to avoid calling
// OnError()
// synchronously.
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&SmallMessageSocket::OnError,
weak_factory_.GetWeakPtr(), result));
return false; return false;
} }
...@@ -211,17 +220,13 @@ bool SmallMessageSocket::HandleWriteResult(int result) { ...@@ -211,17 +220,13 @@ bool SmallMessageSocket::HandleWriteResult(int result) {
write_buffer_->ClearUnderlyingBuffer(); write_buffer_->ClearUnderlyingBuffer();
if (send_blocked_) { if (send_blocked_) {
send_blocked_ = false; send_blocked_ = false;
OnSendUnblocked(); delegate_->OnSendUnblocked();
} }
return false; return false;
} }
void SmallMessageSocket::PostError(int error) { void SmallMessageSocket::OnError(int error) {
// Post a task rather than just calling OnError(), to avoid calling OnError() delegate_->OnError(error);
// synchronously.
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&SmallMessageSocket::OnError,
weak_factory_.GetWeakPtr(), error));
} }
void SmallMessageSocket::ReceiveMessages() { void SmallMessageSocket::ReceiveMessages() {
...@@ -281,12 +286,12 @@ bool SmallMessageSocket::HandleReadResult(int result) { ...@@ -281,12 +286,12 @@ bool SmallMessageSocket::HandleReadResult(int result) {
} }
if (result == 0 || result == net::ERR_CONNECTION_CLOSED) { if (result == 0 || result == net::ERR_CONNECTION_CLOSED) {
OnEndOfStream(); delegate_->OnEndOfStream();
return false; return false;
} }
if (result < 0) { if (result < 0) {
OnError(result); delegate_->OnError(result);
return false; return false;
} }
...@@ -325,7 +330,8 @@ bool SmallMessageSocket::HandleCompletedMessages() { ...@@ -325,7 +330,8 @@ bool SmallMessageSocket::HandleCompletedMessages() {
// Take a weak pointer in case OnMessage() causes this to be deleted. // Take a weak pointer in case OnMessage() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr(); auto self = weak_factory_.GetWeakPtr();
in_message_ = true; in_message_ = true;
keep_reading = OnMessage(start_ptr + sizeof(uint16_t), message_size); keep_reading =
delegate_->OnMessage(start_ptr + sizeof(uint16_t), message_size);
if (!self) { if (!self) {
return false; return false;
} }
...@@ -372,6 +378,7 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() { ...@@ -372,6 +378,7 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() {
auto old_buffer = read_buffer_->TakeUnderlyingBuffer(); auto old_buffer = read_buffer_->TakeUnderlyingBuffer();
auto new_buffer = buffer_pool_->GetBuffer(); auto new_buffer = buffer_pool_->GetBuffer();
CHECK(new_buffer);
size_t new_buffer_size = buffer_pool_->buffer_size(); size_t new_buffer_size = buffer_pool_->buffer_size();
size_t extra_size = bytes_read - required_size; size_t extra_size = bytes_read - required_size;
if (extra_size > 0) { if (extra_size > 0) {
...@@ -388,7 +395,8 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() { ...@@ -388,7 +395,8 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() {
// Take a weak pointer in case OnMessageBuffer() causes this to be deleted. // Take a weak pointer in case OnMessageBuffer() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr(); auto self = weak_factory_.GetWeakPtr();
bool keep_reading = OnMessageBuffer(std::move(old_buffer), required_size); bool keep_reading =
delegate_->OnMessageBuffer(std::move(old_buffer), required_size);
if (!self || !keep_reading) { if (!self || !keep_reading) {
return false; return false;
} }
...@@ -401,7 +409,8 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() { ...@@ -401,7 +409,8 @@ bool SmallMessageSocket::HandleCompletedMessageBuffers() {
return true; return true;
} }
bool SmallMessageSocket::OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, bool SmallMessageSocket::Delegate::OnMessageBuffer(
scoped_refptr<net::IOBuffer> buffer,
int size) { int size) {
return OnMessage(buffer->data() + sizeof(uint16_t), size - sizeof(uint16_t)); return OnMessage(buffer->data() + sizeof(uint16_t), size - sizeof(uint16_t));
} }
......
...@@ -29,7 +29,36 @@ class IOBufferPool; ...@@ -29,7 +29,36 @@ class IOBufferPool;
// desired. // desired.
class SmallMessageSocket { class SmallMessageSocket {
public: public:
explicit SmallMessageSocket(std::unique_ptr<net::Socket> socket); class Delegate {
public:
// Called when sending becomes possible again, if a previous attempt to send
// was rejected.
virtual void OnSendUnblocked() {}
// Called when an unrecoverable error occurs while sending or receiving. Is
// only called asynchronously.
virtual void OnError(int error) {}
// Called when the end of stream has been read. No more data will be
// received.
virtual void OnEndOfStream() {}
// Called when a message has been received and there is no buffer pool. The
// |data| buffer contains |size| bytes of data. Return |true| to continue
// reading messages after OnMessage() returns.
virtual bool OnMessage(char* data, int size) = 0;
// Called when a message has been received. The |buffer| contains |size|
// bytes of data, which includes the first 2 bytes which are the size in
// network byte order. Note that these 2 bytes are not included in
// OnMessage()! Return |true| to continue receiving messages.
virtual bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size);
protected:
virtual ~Delegate() = default;
};
SmallMessageSocket(Delegate* delegate, std::unique_ptr<net::Socket> socket);
virtual ~SmallMessageSocket(); virtual ~SmallMessageSocket();
net::Socket* socket() const { return socket_.get(); } net::Socket* socket() const { return socket_.get(); }
...@@ -76,35 +105,12 @@ class SmallMessageSocket { ...@@ -76,35 +105,12 @@ class SmallMessageSocket {
// asynchronous reads. // asynchronous reads.
void ReceiveMessagesSynchronously(); void ReceiveMessagesSynchronously();
protected: private:
class BufferWrapper; class BufferWrapper;
// Called when sending becomes possible again, if a previous attempt to send
// was rejected.
virtual void OnSendUnblocked() {}
// Called when an unrecoverable error occurs while sending or receiving. Is
// only called asynchronously.
virtual void OnError(int error) {}
// Called when the end of stream has been read. No more data will be received.
virtual void OnEndOfStream() {}
// Called when a message has been received and there is no buffer pool. The
// |data| buffer contains |size| bytes of data. Return |true| to continue
// reading messages after OnMessage() returns.
virtual bool OnMessage(char* data, int size) = 0;
// Called when a message has been received. The |buffer| contains |size| bytes
// of data, which includes the first 2 bytes which are the size in network
// byte order. Note that these 2 bytes are not included in OnMessage()!
// Return |true| to continue receiving messages.
virtual bool OnMessageBuffer(scoped_refptr<net::IOBuffer> buffer, int size);
private:
void OnWriteComplete(int result); void OnWriteComplete(int result);
bool HandleWriteResult(int result); bool HandleWriteResult(int result);
void PostError(int error); void OnError(int error);
void Read(); void Read();
void OnReadComplete(int result); void OnReadComplete(int result);
...@@ -113,6 +119,7 @@ class SmallMessageSocket { ...@@ -113,6 +119,7 @@ class SmallMessageSocket {
bool HandleCompletedMessageBuffers(); bool HandleCompletedMessageBuffers();
void ActivateBufferPool(char* current_data, size_t current_size); void ActivateBufferPool(char* current_data, size_t current_size);
Delegate* const delegate_;
const std::unique_ptr<net::Socket> socket_; const std::unique_ptr<net::Socket> socket_;
const scoped_refptr<base::SequencedTaskRunner> task_runner_; const scoped_refptr<base::SequencedTaskRunner> task_runner_;
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "chromecast/net/small_message_socket.h" #include "chromecast/net/small_message_socket.h"
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "base/big_endian.h" #include "base/big_endian.h"
...@@ -43,19 +44,29 @@ void CheckData(char* buffer, int size) { ...@@ -43,19 +44,29 @@ void CheckData(char* buffer, int size) {
} }
} }
class TestSocket : public SmallMessageSocket { class TestSocket : public SmallMessageSocket::Delegate {
public: public:
explicit TestSocket(std::unique_ptr<net::Socket> socket) explicit TestSocket(std::unique_ptr<net::Socket> socket)
: SmallMessageSocket(std::move(socket)) {} : socket_(this, std::move(socket)) {}
~TestSocket() override = default; ~TestSocket() override = default;
void UseBufferPool() { void UseBufferPool() {
SmallMessageSocket::UseBufferPool(base::MakeRefCounted<IOBufferPool>( buffer_pool_ = base::MakeRefCounted<IOBufferPool>(kDefaultMessageSize +
kDefaultMessageSize + sizeof(uint16_t))); sizeof(uint16_t));
socket_.UseBufferPool(buffer_pool_);
} }
void SwapPoolUse(bool swap) { swap_pool_use_ = swap; } void SwapPoolUse(bool swap) { swap_pool_use_ = swap; }
void* PrepareSend(int message_size) {
return socket_.PrepareSend(message_size);
}
void Send() { socket_.Send(); }
bool SendBuffer(scoped_refptr<net::IOBuffer> data, int size) {
return socket_.SendBuffer(std::move(data), size);
}
void ReceiveMessages() { socket_.ReceiveMessages(); }
size_t last_message_size() const { size_t last_message_size() const {
DCHECK(!message_history_.empty()); DCHECK(!message_history_.empty());
return message_history_[message_history_.size() - 1]; return message_history_[message_history_.size() - 1];
...@@ -65,6 +76,8 @@ class TestSocket : public SmallMessageSocket { ...@@ -65,6 +76,8 @@ class TestSocket : public SmallMessageSocket {
return message_history_; return message_history_;
} }
IOBufferPool* buffer_pool() const { return buffer_pool_.get(); }
private: private:
void OnError(int error) override { NOTREACHED(); } void OnError(int error) override { NOTREACHED(); }
...@@ -84,12 +97,15 @@ class TestSocket : public SmallMessageSocket { ...@@ -84,12 +97,15 @@ class TestSocket : public SmallMessageSocket {
message_history_.push_back(message_size); message_history_.push_back(message_size);
CheckData(buffer->data() + sizeof(uint16_t), message_size); CheckData(buffer->data() + sizeof(uint16_t), message_size);
if (swap_pool_use_) { if (swap_pool_use_) {
RemoveBufferPool(); socket_.RemoveBufferPool();
buffer_pool_ = nullptr;
} }
return true; return true;
} }
SmallMessageSocket socket_;
std::vector<size_t> message_history_; std::vector<size_t> message_history_;
scoped_refptr<IOBufferPool> buffer_pool_;
bool swap_pool_use_ = false; bool swap_pool_use_ = false;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment