Commit ddf6c858 authored by Kehuangli's avatar Kehuangli Committed by Commit Bot

[CaptureService] Let receiver wait for stream info

So recevier doesn't need to get stream info from side channel. And the
delegate of the receiver can directly get the stream info from receiver.

Merge-With: eureka-internal/408117
Bug: internal: 154953397
Test: Unittest.
Change-Id: I5915b297cb3f252b7493b36719670f842a5d8c95
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2352216Reviewed-by: default avatarKenneth MacKay <kmackay@chromium.org>
Commit-Queue: Kehuang Li <kehuangli@chromium.org>
Cr-Commit-Position: refs/heads/master@{#805513}
parent 7f28356d
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include "chromecast/media/audio/capture_service/capture_service_receiver.h" #include "chromecast/media/audio/capture_service/capture_service_receiver.h"
#include <cstdint>
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -42,6 +44,13 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket::Delegate { ...@@ -42,6 +44,13 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket::Delegate {
~Socket() override; ~Socket() override;
private: private:
enum class State {
kInit,
kWaitForAck,
kStreaming,
kShutdown,
};
// SmallMessageSocket::Delegate implementation: // SmallMessageSocket::Delegate implementation:
void OnSendUnblocked() override; void OnSendUnblocked() override;
void OnError(int error) override; void OnError(int error) override;
...@@ -50,13 +59,17 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket::Delegate { ...@@ -50,13 +59,17 @@ class CaptureServiceReceiver::Socket : public SmallMessageSocket::Delegate {
bool SendRequest(); bool SendRequest();
void OnInactivityTimeout(); void OnInactivityTimeout();
bool HandleAudio(int64_t timestamp); void OnInitialStreamInfo(const capture_service::StreamInfo& stream_info);
bool HandleAck(char* data, size_t size);
bool HandleAudio(char* data, size_t size);
void ReportErrorAndStop(); void ReportErrorAndStop();
SmallMessageSocket socket_; SmallMessageSocket socket_;
const capture_service::StreamInfo request_stream_info_; const capture_service::StreamInfo request_stream_info_;
CaptureServiceReceiver::Delegate* const delegate_; CaptureServiceReceiver::Delegate* const delegate_;
bool error_reported_ = false;
State state_ = State::kInit;
}; };
CaptureServiceReceiver::Socket::Socket( CaptureServiceReceiver::Socket::Socket(
...@@ -77,8 +90,9 @@ CaptureServiceReceiver::Socket::Socket( ...@@ -77,8 +90,9 @@ CaptureServiceReceiver::Socket::Socket(
CaptureServiceReceiver::Socket::~Socket() = default; CaptureServiceReceiver::Socket::~Socket() = default;
bool CaptureServiceReceiver::Socket::SendRequest() { bool CaptureServiceReceiver::Socket::SendRequest() {
DCHECK_EQ(state_, State::kInit);
auto request_buffer = capture_service::MakeMessage( auto request_buffer = capture_service::MakeMessage(
capture_service::PacketInfo{capture_service::MessageType::kRequest, capture_service::PacketInfo{capture_service::MessageType::kHandshake,
request_stream_info_, 0 /* timestamp_us */}, request_stream_info_, 0 /* timestamp_us */},
nullptr /* data */, 0 /* data_size */); nullptr /* data */, 0 /* data_size */);
if (!request_buffer) { if (!request_buffer) {
...@@ -89,6 +103,7 @@ bool CaptureServiceReceiver::Socket::SendRequest() { ...@@ -89,6 +103,7 @@ bool CaptureServiceReceiver::Socket::SendRequest() {
"first buffer sent."; "first buffer sent.";
return false; return false;
} }
state_ = State::kWaitForAck;
return true; return true;
} }
...@@ -99,9 +114,9 @@ void CaptureServiceReceiver::Socket::OnSendUnblocked() { ...@@ -99,9 +114,9 @@ void CaptureServiceReceiver::Socket::OnSendUnblocked() {
} }
void CaptureServiceReceiver::Socket::ReportErrorAndStop() { void CaptureServiceReceiver::Socket::ReportErrorAndStop() {
DCHECK(!error_reported_) << "Error should not be reported more than once."; DCHECK_NE(state_, State::kShutdown);
delegate_->OnCaptureError(); delegate_->OnCaptureError();
error_reported_ = true; state_ = State::kShutdown;
} }
void CaptureServiceReceiver::Socket::OnError(int error) { void CaptureServiceReceiver::Socket::OnError(int error) {
...@@ -115,6 +130,45 @@ void CaptureServiceReceiver::Socket::OnEndOfStream() { ...@@ -115,6 +130,45 @@ void CaptureServiceReceiver::Socket::OnEndOfStream() {
} }
bool CaptureServiceReceiver::Socket::OnMessage(char* data, size_t size) { bool CaptureServiceReceiver::Socket::OnMessage(char* data, size_t size) {
uint8_t type = 0;
if (size < sizeof(type)) {
LOG(ERROR) << "Invalid message size: " << size << ".";
return false;
}
memcpy(&type, data, sizeof(type));
capture_service::MessageType message_type =
static_cast<capture_service::MessageType>(type);
if (state_ == State::kWaitForAck &&
message_type == capture_service::MessageType::kHandshake) {
return HandleAck(data, size);
}
if (state_ == State::kStreaming &&
(message_type == capture_service::MessageType::kPcmAudio ||
message_type == capture_service::MessageType::kOpusAudio)) {
return HandleAudio(data, size);
}
LOG(WARNING) << "Receive message with type " << type << " at state "
<< static_cast<int>(state_) << ", ignored.";
return true;
}
bool CaptureServiceReceiver::Socket::HandleAck(char* data, size_t size) {
DCHECK_EQ(state_, State::kWaitForAck);
capture_service::PacketInfo info;
if (!capture_service::ReadHeader(data, size, &info) ||
!delegate_->OnInitialStreamInfo(info.stream_info)) {
ReportErrorAndStop();
return false;
}
state_ = State::kStreaming;
return true;
}
bool CaptureServiceReceiver::Socket::HandleAudio(char* data, size_t size) {
DCHECK_EQ(state_, State::kStreaming);
if (!delegate_->OnCaptureData(data, size)) { if (!delegate_->OnCaptureData(data, size)) {
ReportErrorAndStop(); ReportErrorAndStop();
return false; return false;
......
...@@ -32,6 +32,9 @@ class CaptureServiceReceiver { ...@@ -32,6 +32,9 @@ class CaptureServiceReceiver {
public: public:
virtual ~Delegate() = default; virtual ~Delegate() = default;
virtual bool OnInitialStreamInfo(
const capture_service::StreamInfo& stream_info) = 0;
// Called when more data are received from socket. Return |true| to continue // Called when more data are received from socket. Return |true| to continue
// reading messages after OnCaptureData() returns. // reading messages after OnCaptureData() returns.
virtual bool OnCaptureData(const char* data, size_t size) = 0; virtual bool OnCaptureData(const char* data, size_t size) = 0;
...@@ -75,6 +78,7 @@ class CaptureServiceReceiver { ...@@ -75,6 +78,7 @@ class CaptureServiceReceiver {
const capture_service::StreamInfo request_stream_info_; const capture_service::StreamInfo request_stream_info_;
Delegate* const delegate_; Delegate* const delegate_;
// Socket requires IO thread, and low latency input stream requires high // Socket requires IO thread, and low latency input stream requires high
// thread priority. Therefore, a private thread instead of the IO thread from // thread priority. Therefore, a private thread instead of the IO thread from
// browser thread pool is necessary so as to make sure input stream won't be // browser thread pool is necessary so as to make sure input stream won't be
......
...@@ -34,9 +34,9 @@ constexpr StreamInfo kStreamInfo = ...@@ -34,9 +34,9 @@ constexpr StreamInfo kStreamInfo =
SampleFormat::PLANAR_FLOAT, SampleFormat::PLANAR_FLOAT,
16000, 16000,
160}; 160};
constexpr PacketHeader kRequestPacketHeader = constexpr PacketHeader kHandshakePacketHeader =
PacketHeader{0, PacketHeader{0,
static_cast<uint8_t>(MessageType::kRequest), static_cast<uint8_t>(MessageType::kHandshake),
static_cast<uint8_t>(kStreamInfo.stream_type), static_cast<uint8_t>(kStreamInfo.stream_type),
static_cast<uint8_t>(kStreamInfo.audio_codec), static_cast<uint8_t>(kStreamInfo.audio_codec),
kStreamInfo.num_channels, kStreamInfo.num_channels,
...@@ -71,6 +71,7 @@ class MockCaptureServiceReceiverDelegate ...@@ -71,6 +71,7 @@ class MockCaptureServiceReceiverDelegate
MockCaptureServiceReceiverDelegate() = default; MockCaptureServiceReceiverDelegate() = default;
~MockCaptureServiceReceiverDelegate() override = default; ~MockCaptureServiceReceiverDelegate() override = default;
MOCK_METHOD(bool, OnInitialStreamInfo, (const StreamInfo&), (override));
MOCK_METHOD(bool, OnCaptureData, (const char*, size_t), (override)); MOCK_METHOD(bool, OnCaptureData, (const char*, size_t), (override));
MOCK_METHOD(void, OnCaptureError, (), (override)); MOCK_METHOD(void, OnCaptureError, (), (override));
}; };
...@@ -93,10 +94,10 @@ class CaptureServiceReceiverTest : public ::testing::Test { ...@@ -93,10 +94,10 @@ class CaptureServiceReceiverTest : public ::testing::Test {
TEST_F(CaptureServiceReceiverTest, StartStop) { TEST_F(CaptureServiceReceiverTest, StartStop) {
auto socket1 = std::make_unique<MockStreamSocket>(); auto socket1 = std::make_unique<MockStreamSocket>();
auto socket2 = std::make_unique<MockStreamSocket>(); auto socket2 = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket1, Connect(_)).WillOnce(Return(net::OK)); EXPECT_CALL(*socket1, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket1, Write(_, _, _, _)).WillOnce(Return(16)); EXPECT_CALL(*socket1, Write).WillOnce(Return(16));
EXPECT_CALL(*socket1, Read(_, _, _)).WillOnce(Return(net::ERR_IO_PENDING)); EXPECT_CALL(*socket1, Read).WillOnce(Return(net::ERR_IO_PENDING));
EXPECT_CALL(*socket2, Connect(_)).WillOnce(Return(net::OK)); EXPECT_CALL(*socket2, Connect).WillOnce(Return(net::OK));
// Sync. // Sync.
receiver_.StartWithSocket(std::move(socket1)); receiver_.StartWithSocket(std::move(socket1));
...@@ -111,8 +112,8 @@ TEST_F(CaptureServiceReceiverTest, StartStop) { ...@@ -111,8 +112,8 @@ TEST_F(CaptureServiceReceiverTest, StartStop) {
TEST_F(CaptureServiceReceiverTest, ConnectFailed) { TEST_F(CaptureServiceReceiverTest, ConnectFailed) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect(_)).WillOnce(Return(net::ERR_FAILED)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::ERR_FAILED));
EXPECT_CALL(delegate_, OnCaptureError()); EXPECT_CALL(delegate_, OnCaptureError);
receiver_.StartWithSocket(std::move(socket)); receiver_.StartWithSocket(std::move(socket));
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
...@@ -120,8 +121,8 @@ TEST_F(CaptureServiceReceiverTest, ConnectFailed) { ...@@ -120,8 +121,8 @@ TEST_F(CaptureServiceReceiverTest, ConnectFailed) {
TEST_F(CaptureServiceReceiverTest, ConnectTimeout) { TEST_F(CaptureServiceReceiverTest, ConnectTimeout) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect(_)).WillOnce(Return(net::ERR_IO_PENDING)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::ERR_IO_PENDING));
EXPECT_CALL(delegate_, OnCaptureError()); EXPECT_CALL(delegate_, OnCaptureError);
receiver_.StartWithSocket(std::move(socket)); receiver_.StartWithSocket(std::move(socket));
task_environment_.FastForwardBy(CaptureServiceReceiver::kConnectTimeout); task_environment_.FastForwardBy(CaptureServiceReceiver::kConnectTimeout);
...@@ -129,8 +130,8 @@ TEST_F(CaptureServiceReceiverTest, ConnectTimeout) { ...@@ -129,8 +130,8 @@ TEST_F(CaptureServiceReceiverTest, ConnectTimeout) {
TEST_F(CaptureServiceReceiverTest, SendRequest) { TEST_F(CaptureServiceReceiverTest, SendRequest) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect(_)).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write(_, _, _, _)) EXPECT_CALL(*socket, Write)
.WillOnce(Invoke([](net::IOBuffer* buf, int buf_len, .WillOnce(Invoke([](net::IOBuffer* buf, int buf_len,
net::CompletionOnceCallback, net::CompletionOnceCallback,
const net::NetworkTrafficAnnotationTag&) { const net::NetworkTrafficAnnotationTag&) {
...@@ -140,18 +141,14 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) { ...@@ -140,18 +141,14 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) {
base::ReadBigEndian(data, &size); base::ReadBigEndian(data, &size);
EXPECT_EQ(size, sizeof(PacketHeader) - sizeof(size)); EXPECT_EQ(size, sizeof(PacketHeader) - sizeof(size));
PacketHeader header; PacketHeader header;
memcpy(&header, data, sizeof(PacketHeader)); std::memcpy(&header, data, sizeof(PacketHeader));
EXPECT_EQ(header.message_type, kRequestPacketHeader.message_type); EXPECT_EQ(header.message_type, kHandshakePacketHeader.message_type);
EXPECT_EQ(header.stream_type, kRequestPacketHeader.stream_type); EXPECT_EQ(header.stream_type, kHandshakePacketHeader.stream_type);
EXPECT_EQ(header.codec_or_sample_format, EXPECT_EQ(header.codec_or_sample_format,
kRequestPacketHeader.codec_or_sample_format); kHandshakePacketHeader.codec_or_sample_format);
EXPECT_EQ(header.num_channels, kRequestPacketHeader.num_channels);
EXPECT_EQ(header.sample_rate, kRequestPacketHeader.sample_rate);
EXPECT_EQ(header.timestamp_or_frames,
kRequestPacketHeader.timestamp_or_frames);
return buf_len; return buf_len;
})); }));
EXPECT_CALL(*socket, Read(_, _, _)).WillOnce(Return(net::ERR_IO_PENDING)); EXPECT_CALL(*socket, Read).WillOnce(Return(net::ERR_IO_PENDING));
receiver_.StartWithSocket(std::move(socket)); receiver_.StartWithSocket(std::move(socket));
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
...@@ -163,9 +160,20 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) { ...@@ -163,9 +160,20 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) {
TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) { TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect(_)).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write(_, _, _, _)).WillOnce(Return(16)); EXPECT_CALL(*socket, Write).WillOnce(Return(16));
EXPECT_CALL(*socket, Read(_, _, _)) EXPECT_CALL(*socket, Read)
// Ack message.
.WillOnce(Invoke(
[](net::IOBuffer* buf, int buf_len, net::CompletionOnceCallback) {
int total_size = sizeof(PacketHeader);
EXPECT_GE(buf_len, total_size);
uint16_t size = total_size - sizeof(uint16_t);
PacketHeader header = kHandshakePacketHeader;
FillHeader(buf->data(), size, header);
return total_size;
}))
// Audio message.
.WillOnce(Invoke([](net::IOBuffer* buf, int buf_len, .WillOnce(Invoke([](net::IOBuffer* buf, int buf_len,
net::CompletionOnceCallback) { net::CompletionOnceCallback) {
int total_size = sizeof(PacketHeader) + DataSizeInBytes(kStreamInfo); int total_size = sizeof(PacketHeader) + DataSizeInBytes(kStreamInfo);
...@@ -176,7 +184,8 @@ TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) { ...@@ -176,7 +184,8 @@ TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) {
return total_size; // No need to fill audio frames. return total_size; // No need to fill audio frames.
})) }))
.WillOnce(Return(net::ERR_IO_PENDING)); .WillOnce(Return(net::ERR_IO_PENDING));
EXPECT_CALL(delegate_, OnCaptureData(_, _)).WillOnce(Return(true)); EXPECT_CALL(delegate_, OnInitialStreamInfo).WillOnce(Return(true));
EXPECT_CALL(delegate_, OnCaptureData).WillOnce(Return(true));
receiver_.StartWithSocket(std::move(socket)); receiver_.StartWithSocket(std::move(socket));
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
...@@ -186,13 +195,37 @@ TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) { ...@@ -186,13 +195,37 @@ TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) {
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
} }
TEST_F(CaptureServiceReceiverTest, ReceiveMetadataMessage) {
auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write).WillOnce(Return(16));
EXPECT_CALL(*socket, Read)
.WillOnce(Invoke(
[](net::IOBuffer* buf, int buf_len, net::CompletionOnceCallback) {
uint16_t size = sizeof(uint8_t) + 1; // MessageType and 1 byte.
int total_size = size + sizeof(size);
EXPECT_GE(buf_len, total_size);
base::WriteBigEndian(buf->data(), size);
uint8_t message_type = static_cast<uint8_t>(MessageType::kMetadata);
std::memcpy(buf->data() + sizeof(size), &message_type,
sizeof(message_type));
return total_size; // No need to fill metadata.
}))
.WillOnce(Return(net::ERR_IO_PENDING));
// Neither OnCaptureError nor OnCaptureData will be called.
EXPECT_CALL(delegate_, OnCaptureError).Times(0);
EXPECT_CALL(delegate_, OnCaptureData).Times(0);
receiver_.StartWithSocket(std::move(socket));
task_environment_.RunUntilIdle();
}
TEST_F(CaptureServiceReceiverTest, ReceiveError) { TEST_F(CaptureServiceReceiverTest, ReceiveError) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect(_)).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write(_, _, _, _)).WillOnce(Return(16)); EXPECT_CALL(*socket, Write).WillOnce(Return(16));
EXPECT_CALL(*socket, Read(_, _, _)) EXPECT_CALL(*socket, Read).WillOnce(Return(net::ERR_CONNECTION_RESET));
.WillOnce(Return(net::ERR_CONNECTION_RESET)); EXPECT_CALL(delegate_, OnCaptureError);
EXPECT_CALL(delegate_, OnCaptureError());
receiver_.StartWithSocket(std::move(socket)); receiver_.StartWithSocket(std::move(socket));
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
...@@ -200,10 +233,10 @@ TEST_F(CaptureServiceReceiverTest, ReceiveError) { ...@@ -200,10 +233,10 @@ TEST_F(CaptureServiceReceiverTest, ReceiveError) {
TEST_F(CaptureServiceReceiverTest, ReceiveEosMessage) { TEST_F(CaptureServiceReceiverTest, ReceiveEosMessage) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect(_)).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write(_, _, _, _)).WillOnce(Return(16)); EXPECT_CALL(*socket, Write).WillOnce(Return(16));
EXPECT_CALL(*socket, Read(_, _, _)).WillOnce(Return(0)); EXPECT_CALL(*socket, Read).WillOnce(Return(0));
EXPECT_CALL(delegate_, OnCaptureError()); EXPECT_CALL(delegate_, OnCaptureError);
receiver_.StartWithSocket(std::move(socket)); receiver_.StartWithSocket(std::move(socket));
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
......
...@@ -54,10 +54,10 @@ enum class AudioCodec : uint8_t { ...@@ -54,10 +54,10 @@ enum class AudioCodec : uint8_t {
}; };
enum class MessageType : uint8_t { enum class MessageType : uint8_t {
// Request message that has stream header but empty body. It is used by // Handshake message that has stream header but empty body. It is used by
// receiver notifying the stream it is observing, and sender can confirm the // receiver notifying the stream it is observing, and sender can confirm the
// parameters are all correct. // types/codec are supported and send back more detailed parameters.
kRequest = 0, kHandshake = 0,
// PCM audio message that has stream header and audio data in the message // PCM audio message that has stream header and audio data in the message
// body. The audio data will match the parameters in the header. // body. The audio data will match the parameters in the header.
kPcmAudio, kPcmAudio,
...@@ -79,9 +79,9 @@ struct StreamInfo { ...@@ -79,9 +79,9 @@ struct StreamInfo {
}; };
// Info describes the message packet. PacketInfo is only for message types that // Info describes the message packet. PacketInfo is only for message types that
// support packet header, i.e., kRequest and kPcmAudio. |timestamp_us| is about // support packet header, i.e., kHandshake and kPcmAudio. |timestamp_us| is
// when the buffer is captured. If the audio source is from ALSA, i.e., stream // about when the buffer is captured. If the audio source is from ALSA, i.e.,
// type is raw mic, it's the ALSA capture timestamp; otherwise, it may be // stream type is raw mic, it's the ALSA capture timestamp; otherwise, it may be
// shifted based on the samples and sample rate upon raw mic input. // shifted based on the samples and sample rate upon raw mic input.
struct PacketInfo { struct PacketInfo {
MessageType message_type; MessageType message_type;
...@@ -89,6 +89,10 @@ struct PacketInfo { ...@@ -89,6 +89,10 @@ struct PacketInfo {
int64_t timestamp_us = 0; int64_t timestamp_us = 0;
}; };
// Size of a message header. The header can be parsed into PacketInfo with
// methods in message_parsing_utils.h
constexpr size_t kMessageHeaderBytes = 14;
} // namespace capture_service } // namespace capture_service
} // namespace media } // namespace media
} // namespace chromecast } // namespace chromecast
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "base/big_endian.h" #include "base/big_endian.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/notreached.h" #include "base/notreached.h"
#include "chromecast/media/audio/capture_service/constants.h"
#include "chromecast/media/audio/capture_service/packet_header.h" #include "chromecast/media/audio/capture_service/packet_header.h"
#include "media/base/limits.h" #include "media/base/limits.h"
...@@ -20,9 +21,8 @@ namespace media { ...@@ -20,9 +21,8 @@ namespace media {
namespace capture_service { namespace capture_service {
namespace { namespace {
// Size in bytes of the total/message header. // Size in bytes of the total message header.
constexpr size_t kTotalHeaderBytes = 16; constexpr size_t kTotalHeaderBytes = kMessageHeaderBytes + sizeof(uint16_t);
constexpr size_t kMessageHeaderBytes = kTotalHeaderBytes - sizeof(uint16_t);
static_assert(sizeof(PacketHeader) == kTotalHeaderBytes, static_assert(sizeof(PacketHeader) == kTotalHeaderBytes,
"Invalid packet header size."); "Invalid packet header size.");
...@@ -146,7 +146,7 @@ bool HasPacketHeader(MessageType type) { ...@@ -146,7 +146,7 @@ bool HasPacketHeader(MessageType type) {
// other message type such as kOpusAudio and kMetadata, the packet does not // other message type such as kOpusAudio and kMetadata, the packet does not
// contain the packet header and only contains the message type and serialized // contain the packet header and only contains the message type and serialized
// data. // data.
return type == MessageType::kRequest || type == MessageType::kPcmAudio; return type == MessageType::kHandshake || type == MessageType::kPcmAudio;
} }
} // namespace } // namespace
...@@ -159,10 +159,10 @@ char* PopulateHeader(char* data, size_t size, const PacketInfo& packet_info) { ...@@ -159,10 +159,10 @@ char* PopulateHeader(char* data, size_t size, const PacketInfo& packet_info) {
header.stream_type = static_cast<uint8_t>(stream_info.stream_type); header.stream_type = static_cast<uint8_t>(stream_info.stream_type);
header.num_channels = stream_info.num_channels; header.num_channels = stream_info.num_channels;
header.sample_rate = stream_info.sample_rate; header.sample_rate = stream_info.sample_rate;
// In request message, the header contains a codec field and a // In request/ack message, the header contains a codec field and a
// frames_per_buffer field, while in PCM audio message, it instead contains a // frames_per_buffer field, while in PCM audio message, it instead contains a
// sample format field and a timestamp field. // sample format field and a timestamp field.
if (packet_info.message_type == MessageType::kRequest) { if (packet_info.message_type == MessageType::kHandshake) {
header.codec_or_sample_format = header.codec_or_sample_format =
static_cast<uint8_t>(stream_info.audio_codec); static_cast<uint8_t>(stream_info.audio_codec);
header.timestamp_or_frames = stream_info.frames_per_buffer; header.timestamp_or_frames = stream_info.frames_per_buffer;
...@@ -176,10 +176,7 @@ char* PopulateHeader(char* data, size_t size, const PacketInfo& packet_info) { ...@@ -176,10 +176,7 @@ char* PopulateHeader(char* data, size_t size, const PacketInfo& packet_info) {
base::WriteBigEndian( // Deduct the size of |size| itself. base::WriteBigEndian( // Deduct the size of |size| itself.
data, static_cast<uint16_t>(size - sizeof(uint16_t))); data, static_cast<uint16_t>(size - sizeof(uint16_t)));
DCHECK_EQ(sizeof(header), kTotalHeaderBytes); DCHECK_EQ(sizeof(header), kTotalHeaderBytes);
memcpy(data + sizeof(uint16_t), memcpy(data + sizeof(uint16_t), &header.message_type, kMessageHeaderBytes);
reinterpret_cast<const char*>(&header) +
offsetof(struct PacketHeader, message_type),
kMessageHeaderBytes);
return data + kTotalHeaderBytes; return data + kTotalHeaderBytes;
} }
...@@ -190,12 +187,10 @@ bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info) { ...@@ -190,12 +187,10 @@ bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info) {
return false; return false;
} }
PacketHeader header; PacketHeader header;
memcpy(reinterpret_cast<char*>(&header) + memcpy(&header.message_type, data, kMessageHeaderBytes);
offsetof(struct PacketHeader, message_type),
data, kMessageHeaderBytes);
MessageType message_type = static_cast<MessageType>(header.message_type); MessageType message_type = static_cast<MessageType>(header.message_type);
uint8_t last_codec_or_sample_format = uint8_t last_codec_or_sample_format =
(message_type == MessageType::kRequest) (message_type == MessageType::kHandshake)
? static_cast<uint8_t>(AudioCodec::kLastCodec) ? static_cast<uint8_t>(AudioCodec::kLastCodec)
: static_cast<uint8_t>(SampleFormat::LAST_FORMAT); : static_cast<uint8_t>(SampleFormat::LAST_FORMAT);
if (!HasPacketHeader(message_type) || if (!HasPacketHeader(message_type) ||
...@@ -213,7 +208,7 @@ bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info) { ...@@ -213,7 +208,7 @@ bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info) {
static_cast<StreamType>(header.stream_type); static_cast<StreamType>(header.stream_type);
packet_info->stream_info.num_channels = header.num_channels; packet_info->stream_info.num_channels = header.num_channels;
packet_info->stream_info.sample_rate = header.sample_rate; packet_info->stream_info.sample_rate = header.sample_rate;
if (message_type == MessageType::kRequest) { if (message_type == MessageType::kHandshake) {
packet_info->stream_info.audio_codec = packet_info->stream_info.audio_codec =
static_cast<AudioCodec>(header.codec_or_sample_format); static_cast<AudioCodec>(header.codec_or_sample_format);
packet_info->stream_info.frames_per_buffer = header.timestamp_or_frames; packet_info->stream_info.frames_per_buffer = header.timestamp_or_frames;
...@@ -231,8 +226,9 @@ scoped_refptr<net::IOBufferWithSize> MakeMessage(const PacketInfo& packet_info, ...@@ -231,8 +226,9 @@ scoped_refptr<net::IOBufferWithSize> MakeMessage(const PacketInfo& packet_info,
const char* data, const char* data,
size_t data_size) { size_t data_size) {
if (!HasPacketHeader(packet_info.message_type)) { if (!HasPacketHeader(packet_info.message_type)) {
LOG(ERROR) << "Only kRequest and kPcmAudio message have packet header, use " LOG(ERROR)
"MakeSerializedMessage otherwise."; << "Only kHandshake and kPcmAudio message have packet header, use "
"MakeSerializedMessage otherwise.";
return nullptr; return nullptr;
} }
const size_t total_size = kTotalHeaderBytes + data_size; const size_t total_size = kTotalHeaderBytes + data_size;
......
...@@ -20,16 +20,15 @@ namespace capture_service { ...@@ -20,16 +20,15 @@ namespace capture_service {
// The header of the message consists of <uint8_t message_type> // The header of the message consists of <uint8_t message_type>
// <uint8_t stream_type> <uint8_t audio_codec|sample_format> <uint8_t channels> // <uint8_t stream_type> <uint8_t audio_codec|sample_format> <uint8_t channels>
// <uint16_t sample_rate> <uint64_t frames_per_buffer|timestamp_us>. // <uint16_t sample_rate> <uint64_t frames_per_buffer|timestamp_us>.
// If |message_type| is kRequest, it is a request message that has |audio_codec| // If |message_type| is kHandshake, it is a handshake message that has
// and |frames_per_buffer|, otherwise if |message_type| is kPcmAudio, it's a PCM // |audio_codec| and |frames_per_buffer|, otherwise if |message_type| is
// audio data message that has |sample_format| and |timestamp_us|. // kPcmAudio, it's a PCM audio data message that has |sample_format| and
// Note it cannot be used to read kOpusAudio or kMetadata messages, which don't // |timestamp_us|. Note it cannot be used to read kOpusAudio or kMetadata
// have header besides |message_type| bits. // messages, which don't have header besides |message_type| bits. Note
// Note |packet_info| will be untouched if fails to read header. // |packet_info| will be untouched if fails to read header. Note unsigned
// Note unsigned |timestamp_us| will be converted to signed |timestamp| if // |timestamp_us| will be converted to signed |timestamp| if valid. Note |data|
// valid. // here has been parsed firstly by SmallMessageSocket, and thus doesn't have
// Note |data| here has been parsed firstly by SmallMessageSocket, and // <uint16_t size> bits.
// thus doesn't have <uint16_t size> bits.
bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info); bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info);
// Make a IO buffer for stream message. It will populate the header with // Make a IO buffer for stream message. It will populate the header with
......
...@@ -26,8 +26,8 @@ constexpr StreamInfo kStreamInfo = ...@@ -26,8 +26,8 @@ constexpr StreamInfo kStreamInfo =
SampleFormat::PLANAR_FLOAT, SampleFormat::PLANAR_FLOAT,
16000, 16000,
kFrames}; kFrames};
constexpr PacketInfo kRequestPacketInfo = {MessageType::kRequest, kStreamInfo, constexpr PacketInfo kHandshakePacketInfo = {MessageType::kHandshake,
0}; kStreamInfo, 0};
constexpr PacketInfo kPcmAudioPacketInfo = {MessageType::kPcmAudio, kStreamInfo, constexpr PacketInfo kPcmAudioPacketInfo = {MessageType::kPcmAudio, kStreamInfo,
0}; 0};
...@@ -120,7 +120,7 @@ TEST(MessageParsingUtilsTest, InvalidType) { ...@@ -120,7 +120,7 @@ TEST(MessageParsingUtilsTest, InvalidType) {
size_t data_size = kTotalHeaderBytes / sizeof(float); size_t data_size = kTotalHeaderBytes / sizeof(float);
std::vector<float> data(data_size, 1.0f); std::vector<float> data(data_size, 1.0f);
// Request packet // Request packet
PacketInfo request_packet_info = kRequestPacketInfo; PacketInfo request_packet_info = kHandshakePacketInfo;
PopulateHeader(reinterpret_cast<char*>(data.data()), PopulateHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), request_packet_info); data.size() * sizeof(float), request_packet_info);
*(reinterpret_cast<uint8_t*>(data.data()) + *(reinterpret_cast<uint8_t*>(data.data()) +
...@@ -147,7 +147,7 @@ TEST(MessageParsingUtilsTest, InvalidType) { ...@@ -147,7 +147,7 @@ TEST(MessageParsingUtilsTest, InvalidType) {
TEST(MessageParsingUtilsTest, InvalidCodec) { TEST(MessageParsingUtilsTest, InvalidCodec) {
size_t data_size = kTotalHeaderBytes / sizeof(float); size_t data_size = kTotalHeaderBytes / sizeof(float);
std::vector<float> data(data_size, 1.0f); std::vector<float> data(data_size, 1.0f);
PacketInfo packet_info = kRequestPacketInfo; PacketInfo packet_info = kHandshakePacketInfo;
PopulateHeader(reinterpret_cast<char*>(data.data()), PopulateHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), packet_info); data.size() * sizeof(float), packet_info);
*(reinterpret_cast<uint8_t*>(data.data()) + *(reinterpret_cast<uint8_t*>(data.data()) +
...@@ -179,7 +179,7 @@ TEST(MessageParsingUtilsTest, InvalidFormat) { ...@@ -179,7 +179,7 @@ TEST(MessageParsingUtilsTest, InvalidFormat) {
TEST(MessageParsingUtilsTest, RequestMessage) { TEST(MessageParsingUtilsTest, RequestMessage) {
size_t data_size = kTotalHeaderBytes / sizeof(float); size_t data_size = kTotalHeaderBytes / sizeof(float);
std::vector<float> data(data_size, 1.0f); std::vector<float> data(data_size, 1.0f);
PacketInfo packet_info = kRequestPacketInfo; PacketInfo packet_info = kHandshakePacketInfo;
PopulateHeader(reinterpret_cast<char*>(data.data()), PopulateHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), packet_info); data.size() * sizeof(float), packet_info);
...@@ -188,7 +188,7 @@ TEST(MessageParsingUtilsTest, RequestMessage) { ...@@ -188,7 +188,7 @@ TEST(MessageParsingUtilsTest, RequestMessage) {
ReadHeader(reinterpret_cast<char*>(data.data()) + sizeof(uint16_t), ReadHeader(reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), &info); data_size * sizeof(float) - sizeof(uint16_t), &info);
EXPECT_TRUE(success); EXPECT_TRUE(success);
EXPECT_EQ(info.message_type, kRequestPacketInfo.message_type); EXPECT_EQ(info.message_type, kHandshakePacketInfo.message_type);
EXPECT_EQ(info.stream_info.stream_type, kStreamInfo.stream_type); EXPECT_EQ(info.stream_info.stream_type, kStreamInfo.stream_type);
EXPECT_EQ(info.stream_info.audio_codec, kStreamInfo.audio_codec); EXPECT_EQ(info.stream_info.audio_codec, kStreamInfo.audio_codec);
EXPECT_EQ(info.stream_info.num_channels, kStreamInfo.num_channels); EXPECT_EQ(info.stream_info.num_channels, kStreamInfo.num_channels);
......
...@@ -16,7 +16,7 @@ namespace capture_service { ...@@ -16,7 +16,7 @@ namespace capture_service {
// packet header structure, however, the |size| bits are in big-endian order, // packet header structure, however, the |size| bits are in big-endian order,
// and thus is only for padding purpose in this struct, when all bytes after it // and thus is only for padding purpose in this struct, when all bytes after it
// represent a message header. // represent a message header.
struct PacketHeader { struct __attribute__((__packed__)) PacketHeader {
uint16_t size; uint16_t size;
uint8_t message_type; uint8_t message_type;
uint8_t stream_type; uint8_t stream_type;
......
...@@ -114,6 +114,25 @@ void CastAudioInputStream::SetOutputDeviceForAec( ...@@ -114,6 +114,25 @@ void CastAudioInputStream::SetOutputDeviceForAec(
// Not supported. Do nothing. // Not supported. Do nothing.
} }
bool CastAudioInputStream::OnInitialStreamInfo(
const capture_service::StreamInfo& stream_info) {
const bool is_params_match =
stream_info.stream_type ==
capture_service::StreamType::kSoftwareEchoCancelled &&
stream_info.audio_codec == capture_service::AudioCodec::kPcm &&
stream_info.num_channels == audio_params_.channels() &&
stream_info.sample_rate == audio_params_.sample_rate() &&
stream_info.frames_per_buffer == audio_params_.frames_per_buffer();
LOG_IF(ERROR, !is_params_match)
<< "Got different parameters from sender, sample_rate: "
<< audio_params_.sample_rate() << " Hz -> " << stream_info.sample_rate
<< " Hz, num_channels: " << audio_params_.channels() << " -> "
<< stream_info.num_channels
<< ", frames_per_buffer: " << audio_params_.frames_per_buffer() << " -> "
<< stream_info.frames_per_buffer << ".";
return is_params_match;
}
bool CastAudioInputStream::OnCaptureData(const char* data, size_t size) { bool CastAudioInputStream::OnCaptureData(const char* data, size_t size) {
capture_service::PacketInfo info; capture_service::PacketInfo info;
if (!capture_service::ReadPcmAudioMessage(data, size, &info, if (!capture_service::ReadPcmAudioMessage(data, size, &info,
......
...@@ -48,6 +48,8 @@ class CastAudioInputStream : public ::media::AudioInputStream, ...@@ -48,6 +48,8 @@ class CastAudioInputStream : public ::media::AudioInputStream,
void SetOutputDeviceForAec(const std::string& output_device_id) override; void SetOutputDeviceForAec(const std::string& output_device_id) override;
// CaptureServiceReceiver::Delegate implementation: // CaptureServiceReceiver::Delegate implementation:
bool OnInitialStreamInfo(
const capture_service::StreamInfo& stream_info) override;
bool OnCaptureData(const char* data, size_t size) override; bool OnCaptureData(const char* data, size_t size) override;
void OnCaptureError() override; void OnCaptureError() override;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment