Commit 5e56b770 authored by Kehuang Li's avatar Kehuang Li Committed by Commit Bot

[Chromecast] Fix PacketHeader.sample_rate overflow

We used to assume 16bits is sufficient, but somehow it's not always the
case. In this cl, let handshake message and pcm audio message no longer
share the same packet header. I.e., let them have their own header
struct and read/make methods.

Besides, add more unittest to cover packet header poluate/parse.

Merge-With: eureka-internal/453374
Bug: internal: 168457620
Test: Unittest.
Change-Id: I6e10fe9d58974646029ba222229e2f6d4df8acd5
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2413110
Commit-Queue: Kehuang Li <kehuangli@chromium.org>
Reviewed-by: default avatarKenneth MacKay <kmackay@chromium.org>
Reviewed-by: default avatarYuchen Liu <yucliu@chromium.org>
Cr-Commit-Position: refs/heads/master@{#809715}
parent 2c47e581
...@@ -108,6 +108,7 @@ cast_source_set("audio") { ...@@ -108,6 +108,7 @@ cast_source_set("audio") {
"//chromecast/base", "//chromecast/base",
"//chromecast/common/mojom", "//chromecast/common/mojom",
"//chromecast/media/api", "//chromecast/media/api",
"//chromecast/media/audio/capture_service:common",
"//chromecast/media/audio/capture_service:receiver", "//chromecast/media/audio/capture_service:receiver",
"//chromecast/media/audio/capture_service:utils", "//chromecast/media/audio/capture_service:utils",
"//chromecast/media/audio/mixer_service:common", "//chromecast/media/audio/mixer_service:common",
......
...@@ -91,10 +91,8 @@ CaptureServiceReceiver::Socket::~Socket() = default; ...@@ -91,10 +91,8 @@ CaptureServiceReceiver::Socket::~Socket() = default;
bool CaptureServiceReceiver::Socket::SendRequest() { bool CaptureServiceReceiver::Socket::SendRequest() {
DCHECK_EQ(state_, State::kInit); DCHECK_EQ(state_, State::kInit);
auto request_buffer = capture_service::MakeMessage( auto request_buffer =
capture_service::PacketInfo{capture_service::MessageType::kHandshake, capture_service::MakeHandshakeMessage(request_stream_info_);
request_stream_info_, 0 /* timestamp_us */},
nullptr /* data */, 0 /* data_size */);
if (!request_buffer) { if (!request_buffer) {
return false; return false;
} }
...@@ -157,9 +155,9 @@ bool CaptureServiceReceiver::Socket::OnMessage(char* data, size_t size) { ...@@ -157,9 +155,9 @@ bool CaptureServiceReceiver::Socket::OnMessage(char* data, size_t size) {
bool CaptureServiceReceiver::Socket::HandleAck(char* data, size_t size) { bool CaptureServiceReceiver::Socket::HandleAck(char* data, size_t size) {
DCHECK_EQ(state_, State::kWaitForAck); DCHECK_EQ(state_, State::kWaitForAck);
capture_service::PacketInfo info; capture_service::StreamInfo info;
if (!capture_service::ReadHeader(data, size, &info) || if (!capture_service::ReadHandshakeMessage(data, size, &info) ||
!delegate_->OnInitialStreamInfo(info.stream_info)) { !delegate_->OnInitialStreamInfo(info)) {
ReportErrorAndStop(); ReportErrorAndStop();
return false; return false;
} }
......
...@@ -34,30 +34,18 @@ constexpr StreamInfo kStreamInfo = ...@@ -34,30 +34,18 @@ constexpr StreamInfo kStreamInfo =
SampleFormat::PLANAR_FLOAT, SampleFormat::PLANAR_FLOAT,
16000, 16000,
160}; 160};
constexpr PacketHeader kHandshakePacketHeader = constexpr HandshakePacket kHandshakePacket =
PacketHeader{0, HandshakePacket{0,
static_cast<uint8_t>(MessageType::kHandshake), static_cast<uint8_t>(MessageType::kHandshake),
static_cast<uint8_t>(kStreamInfo.stream_type), static_cast<uint8_t>(kStreamInfo.stream_type),
static_cast<uint8_t>(kStreamInfo.audio_codec), static_cast<uint8_t>(kStreamInfo.audio_codec),
kStreamInfo.num_channels,
kStreamInfo.sample_rate,
kStreamInfo.frames_per_buffer};
constexpr PacketHeader kPcmAudioPacketHeader =
PacketHeader{0,
static_cast<uint8_t>(MessageType::kPcmAudio),
static_cast<uint8_t>(kStreamInfo.stream_type),
static_cast<uint8_t>(kStreamInfo.sample_format), static_cast<uint8_t>(kStreamInfo.sample_format),
kStreamInfo.num_channels, kStreamInfo.num_channels,
kStreamInfo.sample_rate, kStreamInfo.frames_per_buffer,
0}; kStreamInfo.sample_rate};
constexpr PcmPacketHeader kPcmAudioPacketHeader =
void FillHeader(char* buf, uint16_t size, const PacketHeader& header) { PcmPacketHeader{0, static_cast<uint8_t>(MessageType::kPcmAudio),
base::WriteBigEndian(buf, size); static_cast<uint8_t>(kStreamInfo.stream_type), 0};
memcpy(buf + sizeof(size),
reinterpret_cast<const char*>(&header) +
offsetof(struct PacketHeader, message_type),
sizeof(header) - offsetof(struct PacketHeader, message_type));
}
class MockStreamSocket : public chromecast::MockStreamSocket { class MockStreamSocket : public chromecast::MockStreamSocket {
public: public:
...@@ -95,7 +83,7 @@ TEST_F(CaptureServiceReceiverTest, StartStop) { ...@@ -95,7 +83,7 @@ TEST_F(CaptureServiceReceiverTest, StartStop) {
auto socket1 = std::make_unique<MockStreamSocket>(); auto socket1 = std::make_unique<MockStreamSocket>();
auto socket2 = std::make_unique<MockStreamSocket>(); auto socket2 = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket1, Connect).WillOnce(Return(net::OK)); EXPECT_CALL(*socket1, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket1, Write).WillOnce(Return(16)); EXPECT_CALL(*socket1, Write).WillOnce(Return(sizeof(HandshakePacket)));
EXPECT_CALL(*socket1, Read).WillOnce(Return(net::ERR_IO_PENDING)); EXPECT_CALL(*socket1, Read).WillOnce(Return(net::ERR_IO_PENDING));
EXPECT_CALL(*socket2, Connect).WillOnce(Return(net::OK)); EXPECT_CALL(*socket2, Connect).WillOnce(Return(net::OK));
...@@ -135,17 +123,19 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) { ...@@ -135,17 +123,19 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) {
.WillOnce(Invoke([](net::IOBuffer* buf, int buf_len, .WillOnce(Invoke([](net::IOBuffer* buf, int buf_len,
net::CompletionOnceCallback, net::CompletionOnceCallback,
const net::NetworkTrafficAnnotationTag&) { const net::NetworkTrafficAnnotationTag&) {
EXPECT_EQ(buf_len, static_cast<int>(sizeof(PacketHeader))); EXPECT_EQ(buf_len, static_cast<int>(sizeof(HandshakePacket)));
const char* data = buf->data(); const char* data = buf->data();
uint16_t size; uint16_t size;
base::ReadBigEndian(data, &size); base::ReadBigEndian(data, &size);
EXPECT_EQ(size, sizeof(PacketHeader) - sizeof(size)); EXPECT_EQ(size, sizeof(HandshakePacket) - sizeof(size));
PacketHeader header; HandshakePacket packet;
std::memcpy(&header, data, sizeof(PacketHeader)); std::memcpy(&packet, data, sizeof(HandshakePacket));
EXPECT_EQ(header.message_type, kHandshakePacketHeader.message_type); EXPECT_EQ(packet.message_type, kHandshakePacket.message_type);
EXPECT_EQ(header.stream_type, kHandshakePacketHeader.stream_type); EXPECT_EQ(packet.stream_type, kHandshakePacket.stream_type);
EXPECT_EQ(header.codec_or_sample_format, EXPECT_EQ(packet.audio_codec, kHandshakePacket.audio_codec);
kHandshakePacketHeader.codec_or_sample_format); EXPECT_EQ(packet.num_channels, kHandshakePacket.num_channels);
EXPECT_EQ(packet.num_frames, kHandshakePacket.num_frames);
EXPECT_EQ(packet.sample_rate, kHandshakePacket.sample_rate);
return buf_len; return buf_len;
})); }));
EXPECT_CALL(*socket, Read).WillOnce(Return(net::ERR_IO_PENDING)); EXPECT_CALL(*socket, Read).WillOnce(Return(net::ERR_IO_PENDING));
...@@ -161,26 +151,24 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) { ...@@ -161,26 +151,24 @@ TEST_F(CaptureServiceReceiverTest, SendRequest) {
TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) { TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write).WillOnce(Return(16)); EXPECT_CALL(*socket, Write).WillOnce(Return(sizeof(HandshakePacket)));
EXPECT_CALL(*socket, Read) EXPECT_CALL(*socket, Read)
// Ack message. // Ack message.
.WillOnce(Invoke( .WillOnce(Invoke(
[](net::IOBuffer* buf, int buf_len, net::CompletionOnceCallback) { [](net::IOBuffer* buf, int buf_len, net::CompletionOnceCallback) {
int total_size = sizeof(PacketHeader); int total_size = sizeof(HandshakePacket);
EXPECT_GE(buf_len, total_size); EXPECT_GE(buf_len, total_size);
uint16_t size = total_size - sizeof(uint16_t); FillBuffer(buf->data(), total_size, &kHandshakePacket.message_type,
PacketHeader header = kHandshakePacketHeader; sizeof(HandshakePacket) - sizeof(uint16_t));
FillHeader(buf->data(), size, header);
return total_size; return total_size;
})) }))
// Audio message. // Audio message.
.WillOnce(Invoke([](net::IOBuffer* buf, int buf_len, .WillOnce(Invoke([](net::IOBuffer* buf, int buf_len,
net::CompletionOnceCallback) { net::CompletionOnceCallback) {
int total_size = sizeof(PacketHeader) + DataSizeInBytes(kStreamInfo); int total_size = sizeof(PcmPacketHeader) + DataSizeInBytes(kStreamInfo);
EXPECT_GE(buf_len, total_size); EXPECT_GE(buf_len, total_size);
uint16_t size = total_size - sizeof(uint16_t); FillBuffer(buf->data(), total_size, &kPcmAudioPacketHeader.message_type,
PacketHeader header = kPcmAudioPacketHeader; sizeof(PcmPacketHeader) - sizeof(uint16_t));
FillHeader(buf->data(), size, header);
return total_size; // No need to fill audio frames. return total_size; // No need to fill audio frames.
})) }))
.WillOnce(Return(net::ERR_IO_PENDING)); .WillOnce(Return(net::ERR_IO_PENDING));
...@@ -198,7 +186,7 @@ TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) { ...@@ -198,7 +186,7 @@ TEST_F(CaptureServiceReceiverTest, ReceivePcmAudioMessage) {
TEST_F(CaptureServiceReceiverTest, ReceiveMetadataMessage) { TEST_F(CaptureServiceReceiverTest, ReceiveMetadataMessage) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write).WillOnce(Return(16)); EXPECT_CALL(*socket, Write).WillOnce(Return(sizeof(HandshakePacket)));
EXPECT_CALL(*socket, Read) EXPECT_CALL(*socket, Read)
.WillOnce(Invoke( .WillOnce(Invoke(
[](net::IOBuffer* buf, int buf_len, net::CompletionOnceCallback) { [](net::IOBuffer* buf, int buf_len, net::CompletionOnceCallback) {
...@@ -223,7 +211,7 @@ TEST_F(CaptureServiceReceiverTest, ReceiveMetadataMessage) { ...@@ -223,7 +211,7 @@ TEST_F(CaptureServiceReceiverTest, ReceiveMetadataMessage) {
TEST_F(CaptureServiceReceiverTest, ReceiveError) { TEST_F(CaptureServiceReceiverTest, ReceiveError) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write).WillOnce(Return(16)); EXPECT_CALL(*socket, Write).WillOnce(Return(sizeof(HandshakePacket)));
EXPECT_CALL(*socket, Read).WillOnce(Return(net::ERR_CONNECTION_RESET)); EXPECT_CALL(*socket, Read).WillOnce(Return(net::ERR_CONNECTION_RESET));
EXPECT_CALL(delegate_, OnCaptureError); EXPECT_CALL(delegate_, OnCaptureError);
...@@ -234,7 +222,7 @@ TEST_F(CaptureServiceReceiverTest, ReceiveError) { ...@@ -234,7 +222,7 @@ TEST_F(CaptureServiceReceiverTest, ReceiveError) {
TEST_F(CaptureServiceReceiverTest, ReceiveEosMessage) { TEST_F(CaptureServiceReceiverTest, ReceiveEosMessage) {
auto socket = std::make_unique<MockStreamSocket>(); auto socket = std::make_unique<MockStreamSocket>();
EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK)); EXPECT_CALL(*socket, Connect).WillOnce(Return(net::OK));
EXPECT_CALL(*socket, Write).WillOnce(Return(16)); EXPECT_CALL(*socket, Write).WillOnce(Return(sizeof(HandshakePacket)));
EXPECT_CALL(*socket, Read).WillOnce(Return(0)); EXPECT_CALL(*socket, Read).WillOnce(Return(0));
EXPECT_CALL(delegate_, OnCaptureError); EXPECT_CALL(delegate_, OnCaptureError);
......
...@@ -78,20 +78,8 @@ struct StreamInfo { ...@@ -78,20 +78,8 @@ struct StreamInfo {
int frames_per_buffer = 0; int frames_per_buffer = 0;
}; };
// Info describes the message packet. PacketInfo is only for message types that // Size of a PCM audio message header.
// support packet header, i.e., kHandshake and kPcmAudio. |timestamp_us| is constexpr size_t kPcmAudioHeaderBytes = 10;
// about when the buffer is captured. If the audio source is from ALSA, i.e.,
// stream type is raw mic, it's the ALSA capture timestamp; otherwise, it may be
// shifted based on the samples and sample rate upon raw mic input.
struct PacketInfo {
MessageType message_type;
StreamInfo stream_info;
int64_t timestamp_us = 0;
};
// Size of a message header. The header can be parsed into PacketInfo with
// methods in message_parsing_utils.h
constexpr size_t kMessageHeaderBytes = 14;
} // namespace capture_service } // namespace capture_service
} // namespace media } // namespace media
......
...@@ -14,8 +14,8 @@ struct Environment { ...@@ -14,8 +14,8 @@ struct Environment {
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
static Environment env; static Environment env;
PacketInfo info; StreamInfo info;
chromecast::media::capture_service::ReadHeader( chromecast::media::capture_service::ReadHandshakeMessage(
reinterpret_cast<const char*>(data), size, &info); reinterpret_cast<const char*>(data), size, &info);
return 0; return 0;
} }
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <limits> #include <limits>
#include "base/big_endian.h" #include "base/big_endian.h"
#include "base/check.h"
#include "base/check_op.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/notreached.h" #include "base/notreached.h"
#include "chromecast/media/audio/capture_service/constants.h" #include "chromecast/media/audio/capture_service/constants.h"
...@@ -21,12 +23,18 @@ namespace media { ...@@ -21,12 +23,18 @@ namespace media {
namespace capture_service { namespace capture_service {
namespace { namespace {
// Size in bytes of the total message header. // Size in bytes of the header part of a handshake message.
constexpr size_t kTotalHeaderBytes = kMessageHeaderBytes + sizeof(uint16_t); constexpr size_t kHandshakeHeaderBytes =
sizeof(HandshakePacket) - sizeof(uint16_t);
static_assert(sizeof(PacketHeader) == kTotalHeaderBytes, static_assert(kPcmAudioHeaderBytes ==
"Invalid packet header size."); sizeof(PcmPacketHeader) - sizeof(uint16_t),
static_assert(offsetof(struct PacketHeader, message_type) == sizeof(uint16_t), "Invalid message header size.");
static_assert(offsetof(struct PcmPacketHeader, message_type) ==
sizeof(uint16_t),
"Invalid message header offset.");
static_assert(offsetof(struct HandshakePacket, message_type) ==
sizeof(uint16_t),
"Invalid message header offset."); "Invalid message header offset.");
// Check if audio data is properly aligned and has valid frame size. Return the // Check if audio data is properly aligned and has valid frame size. Return the
...@@ -141,110 +149,97 @@ bool ConvertData(int channels, ...@@ -141,110 +149,97 @@ bool ConvertData(int channels,
return false; return false;
} }
bool HasPacketHeader(MessageType type) {
// Packet header is only for the messages generated from packet info. For
// other message type such as kOpusAudio and kMetadata, the packet does not
// contain the packet header and only contains the message type and serialized
// data.
return type == MessageType::kHandshake || type == MessageType::kPcmAudio;
}
} // namespace } // namespace
char* PopulateHeader(char* data, size_t size, const PacketInfo& packet_info) { void FillBuffer(char* buf,
DCHECK(HasPacketHeader(packet_info.message_type)); size_t buf_size,
const StreamInfo& stream_info = packet_info.stream_info; const void* data,
PacketHeader header; size_t data_size) {
header.message_type = static_cast<uint8_t>(packet_info.message_type); DCHECK_LE(data_size, buf_size - sizeof(uint16_t));
header.stream_type = static_cast<uint8_t>(stream_info.stream_type);
header.num_channels = stream_info.num_channels;
header.sample_rate = stream_info.sample_rate;
// In request/ack message, the header contains a codec field and a
// frames_per_buffer field, while in PCM audio message, it instead contains a
// sample format field and a timestamp field.
if (packet_info.message_type == MessageType::kHandshake) {
header.codec_or_sample_format =
static_cast<uint8_t>(stream_info.audio_codec);
header.timestamp_or_frames = stream_info.frames_per_buffer;
} else if (packet_info.message_type == MessageType::kPcmAudio) {
header.codec_or_sample_format =
static_cast<uint8_t>(stream_info.sample_format);
header.timestamp_or_frames = packet_info.timestamp_us;
} else {
NOTREACHED();
}
base::WriteBigEndian( // Deduct the size of |size| itself. base::WriteBigEndian( // Deduct the size of |size| itself.
data, static_cast<uint16_t>(size - sizeof(uint16_t))); buf, static_cast<uint16_t>(buf_size - sizeof(uint16_t)));
DCHECK_EQ(sizeof(header), kTotalHeaderBytes); memcpy(buf + sizeof(uint16_t), data, data_size);
memcpy(data + sizeof(uint16_t), &header.message_type, kMessageHeaderBytes); }
return data + kTotalHeaderBytes;
char* PopulatePcmAudioHeader(char* data,
size_t size,
StreamType stream_type,
int64_t timestamp_us) {
PcmPacketHeader header;
header.message_type = static_cast<uint8_t>(MessageType::kPcmAudio);
header.stream_type = static_cast<uint8_t>(stream_type);
header.timestamp_us = timestamp_us;
FillBuffer(data, size, &header.message_type,
sizeof(header) - sizeof(uint16_t));
return data + sizeof(header);
}
void PopulateHandshakeMessage(char* data,
size_t size,
const StreamInfo& stream_info) {
HandshakePacket packet;
packet.message_type = static_cast<uint8_t>(MessageType::kHandshake);
packet.stream_type = static_cast<uint8_t>(stream_info.stream_type);
packet.audio_codec = static_cast<uint8_t>(stream_info.audio_codec);
packet.sample_format = static_cast<uint8_t>(stream_info.sample_format);
packet.num_channels = stream_info.num_channels;
packet.num_frames = stream_info.frames_per_buffer;
packet.sample_rate = stream_info.sample_rate;
FillBuffer(data, size, &packet.message_type,
sizeof(packet) - sizeof(uint16_t));
} }
bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info) { bool ReadPcmAudioHeader(const char* data,
DCHECK(packet_info); size_t size,
if (size < kMessageHeaderBytes) { const StreamInfo& stream_info,
LOG(ERROR) << "Message doesn't have a complete header."; int64_t* timestamp_us) {
DCHECK(timestamp_us);
if (size < kPcmAudioHeaderBytes) {
LOG(ERROR) << "Message doesn't have a complete header: " << size << " v.s. "
<< kPcmAudioHeaderBytes << ".";
return false; return false;
} }
PacketHeader header; PcmPacketHeader header;
memcpy(&header.message_type, data, kMessageHeaderBytes); memcpy(&header.message_type, data, kPcmAudioHeaderBytes);
MessageType message_type = static_cast<MessageType>(header.message_type); if (static_cast<MessageType>(header.message_type) != MessageType::kPcmAudio) {
uint8_t last_codec_or_sample_format = LOG(ERROR) << "Message type mismatch.";
(message_type == MessageType::kHandshake)
? static_cast<uint8_t>(AudioCodec::kLastCodec)
: static_cast<uint8_t>(SampleFormat::LAST_FORMAT);
if (!HasPacketHeader(message_type) ||
header.stream_type > static_cast<uint8_t>(StreamType::kLastType) ||
header.codec_or_sample_format > last_codec_or_sample_format) {
LOG(ERROR) << "Invalid message header.";
return false; return false;
} }
if (header.num_channels > ::media::limits::kMaxChannels) { if (static_cast<StreamType>(header.stream_type) != stream_info.stream_type) {
LOG(ERROR) << "Invalid number of channels: " << header.num_channels; LOG(ERROR) << "Stream type mistach.";
return false; return false;
} }
packet_info->message_type = message_type; *timestamp_us = header.timestamp_us;
packet_info->stream_info.stream_type =
static_cast<StreamType>(header.stream_type);
packet_info->stream_info.num_channels = header.num_channels;
packet_info->stream_info.sample_rate = header.sample_rate;
if (message_type == MessageType::kHandshake) {
packet_info->stream_info.audio_codec =
static_cast<AudioCodec>(header.codec_or_sample_format);
packet_info->stream_info.frames_per_buffer = header.timestamp_or_frames;
} else if (message_type == MessageType::kPcmAudio) {
packet_info->stream_info.sample_format =
static_cast<SampleFormat>(header.codec_or_sample_format);
packet_info->timestamp_us = header.timestamp_or_frames;
} else {
NOTREACHED();
}
return true; return true;
} }
scoped_refptr<net::IOBufferWithSize> MakeMessage(const PacketInfo& packet_info, scoped_refptr<net::IOBufferWithSize> MakePcmAudioMessage(StreamType stream_type,
int64_t timestamp_us,
const char* data, const char* data,
size_t data_size) { size_t data_size) {
if (!HasPacketHeader(packet_info.message_type)) { const size_t total_size = sizeof(PcmPacketHeader) + data_size;
LOG(ERROR)
<< "Only kHandshake and kPcmAudio message have packet header, use "
"MakeSerializedMessage otherwise.";
return nullptr;
}
const size_t total_size = kTotalHeaderBytes + data_size;
DCHECK_LE(total_size, std::numeric_limits<uint16_t>::max()); DCHECK_LE(total_size, std::numeric_limits<uint16_t>::max());
auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(total_size); auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(total_size);
char* ptr = PopulateHeader(io_buffer->data(), io_buffer->size(), packet_info); char* ptr = PopulatePcmAudioHeader(io_buffer->data(), io_buffer->size(),
stream_type, timestamp_us);
if (!ptr) { if (!ptr) {
return nullptr; return nullptr;
} }
if (packet_info.message_type == MessageType::kPcmAudio && data_size > 0) { if (data_size > 0) {
DCHECK(data); DCHECK(data);
std::copy(data, data + data_size, ptr); std::copy(data, data + data_size, ptr);
} }
return io_buffer; return io_buffer;
} }
scoped_refptr<net::IOBufferWithSize> MakeHandshakeMessage(
const StreamInfo& stream_info) {
auto io_buffer =
base::MakeRefCounted<net::IOBufferWithSize>(sizeof(HandshakePacket));
PopulateHandshakeMessage(io_buffer->data(), io_buffer->size(), stream_info);
return io_buffer;
}
scoped_refptr<net::IOBufferWithSize> MakeSerializedMessage( scoped_refptr<net::IOBufferWithSize> MakeSerializedMessage(
MessageType message_type, MessageType message_type,
const char* data, const char* data,
...@@ -278,22 +273,51 @@ bool ReadDataToAudioBus(const StreamInfo& stream_info, ...@@ -278,22 +273,51 @@ bool ReadDataToAudioBus(const StreamInfo& stream_info,
DCHECK(audio_bus); DCHECK(audio_bus);
DCHECK_EQ(stream_info.num_channels, audio_bus->channels()); DCHECK_EQ(stream_info.num_channels, audio_bus->channels());
return ConvertData(stream_info.num_channels, stream_info.sample_format, return ConvertData(stream_info.num_channels, stream_info.sample_format,
data + kMessageHeaderBytes, size - kMessageHeaderBytes, data + kPcmAudioHeaderBytes, size - kPcmAudioHeaderBytes,
audio_bus); audio_bus);
} }
bool ReadPcmAudioMessage(const char* data, bool ReadPcmAudioMessage(const char* data,
size_t size, size_t size,
PacketInfo* packet_info, const StreamInfo& stream_info,
int64_t* timestamp_us,
::media::AudioBus* audio_bus) { ::media::AudioBus* audio_bus) {
if (!ReadHeader(data, size, packet_info)) { if (!ReadPcmAudioHeader(data, size, stream_info, timestamp_us)) {
return false;
}
return ReadDataToAudioBus(stream_info, data, size, audio_bus);
}
bool ReadHandshakeMessage(const char* data,
size_t size,
StreamInfo* stream_info) {
DCHECK(stream_info);
if (size != kHandshakeHeaderBytes) {
LOG(ERROR) << "Message doesn't have a complete handshake packet: " << size
<< " v.s. " << kHandshakeHeaderBytes << ".";
return false; return false;
} }
if (packet_info->message_type != MessageType::kPcmAudio) { HandshakePacket packet;
LOG(WARNING) << "Received non-pcm-audio message."; memcpy(&packet.message_type, data, kHandshakeHeaderBytes);
MessageType message_type = static_cast<MessageType>(packet.message_type);
if (message_type != MessageType::kHandshake ||
packet.stream_type > static_cast<uint8_t>(StreamType::kLastType) ||
packet.audio_codec > static_cast<uint8_t>(AudioCodec::kLastCodec) ||
packet.sample_format > static_cast<uint8_t>(SampleFormat::LAST_FORMAT)) {
LOG(ERROR) << "Invalid message header.";
return false; return false;
} }
return ReadDataToAudioBus(packet_info->stream_info, data, size, audio_bus); if (packet.num_channels > ::media::limits::kMaxChannels) {
LOG(ERROR) << "Invalid number of channels: " << packet.num_channels;
return false;
}
stream_info->stream_type = static_cast<StreamType>(packet.stream_type);
stream_info->audio_codec = static_cast<AudioCodec>(packet.audio_codec);
stream_info->sample_format = static_cast<SampleFormat>(packet.sample_format);
stream_info->num_channels = packet.num_channels;
stream_info->frames_per_buffer = packet.num_frames;
stream_info->sample_rate = packet.sample_rate;
return true;
} }
size_t DataSizeInBytes(const StreamInfo& stream_info) { size_t DataSizeInBytes(const StreamInfo& stream_info) {
......
...@@ -16,32 +16,34 @@ namespace chromecast { ...@@ -16,32 +16,34 @@ namespace chromecast {
namespace media { namespace media {
namespace capture_service { namespace capture_service {
// Read message header to |packet_info|, and return whether success. // Read message header, check if it matches |stream_info|, retrieve timestamp,
// The header of the message consists of <uint8_t message_type> // and return whether success.
// <uint8_t stream_type> <uint8_t audio_codec|sample_format> <uint8_t channels> // Note |data| here has been parsed firstly by SmallMessageSocket, and thus
// <uint16_t sample_rate> <uint64_t frames_per_buffer|timestamp_us>. // doesn't have <uint16_t size> bits.
// If |message_type| is kHandshake, it is a handshake message that has bool ReadPcmAudioHeader(const char* data,
// |audio_codec| and |frames_per_buffer|, otherwise if |message_type| is size_t size,
// kPcmAudio, it's a PCM audio data message that has |sample_format| and const StreamInfo& stream_info,
// |timestamp_us|. Note it cannot be used to read kOpusAudio or kMetadata int64_t* timestamp_us);
// messages, which don't have header besides |message_type| bits. Note
// |packet_info| will be untouched if fails to read header. Note unsigned // Make a IO buffer for stream message. It will populate the header and copy
// |timestamp_us| will be converted to signed |timestamp| if valid. Note |data| // |data| into the message if packet has audio and |data| is not null. The
// here has been parsed firstly by SmallMessageSocket, and thus doesn't have // returned buffer will have a length of |data_size| + header size. Return
// <uint16_t size> bits. // nullptr if fails. Caller must guarantee the memory of |data| has at least
bool ReadHeader(const char* data, size_t size, PacketInfo* packet_info); // |data_size| when has audio.
// Make a IO buffer for stream message. It will populate the header with
// |packet_info|, and copy |data| into the message if packet has audio and
// |data| is not null. The returned buffer will have a length of |data_size| +
// header size. Return nullptr if fails. Caller must guarantee the memory of
// |data| has at least |data_size| when has audio.
// Note buffer will be sent with SmallMessageSocket, and thus contains a uint16 // Note buffer will be sent with SmallMessageSocket, and thus contains a uint16
// size field in the very first. // size field in the very first.
scoped_refptr<net::IOBufferWithSize> MakeMessage(const PacketInfo& packet_info, scoped_refptr<net::IOBufferWithSize> MakePcmAudioMessage(StreamType stream_type,
int64_t timestamp_us,
const char* data, const char* data,
size_t data_size); size_t data_size);
// Make a IO buffer for handshake message. It will populate the header with
// |stream_info|. Return nullptr if fails.
// Note buffer will be sent with SmallMessageSocket, and thus contains a uint16
// size field in the very first.
scoped_refptr<net::IOBufferWithSize> MakeHandshakeMessage(
const StreamInfo& stream_info);
// Make a IO buffer for serialized message. It will populate message size and // Make a IO buffer for serialized message. It will populate message size and
// type fields, and copy |data| into the message. The returned buffer will have // type fields, and copy |data| into the message. The returned buffer will have
// a length of |data_size| + sizeof(uint8_t message_type) + sizeof(uint16_t // a length of |data_size| + sizeof(uint8_t message_type) + sizeof(uint16_t
...@@ -60,21 +62,43 @@ bool ReadDataToAudioBus(const StreamInfo& stream_info, ...@@ -60,21 +62,43 @@ bool ReadDataToAudioBus(const StreamInfo& stream_info,
size_t size, size_t size,
::media::AudioBus* audio_bus); ::media::AudioBus* audio_bus);
// Read the header part of the PCM audio message to packet info and the audio // Read the PCM audio message and copy the audio data to audio bus, as well as
// data part to audio bus, and return whether success. This will run // the timestamp. Return whether success. This will run ReadPcmAudioHeader() and
// ReadHeader() and ReadDataToAudioBus() in the underlying implementation. // ReadDataToAudioBus() in the underlying implementation.
bool ReadPcmAudioMessage(const char* data, bool ReadPcmAudioMessage(const char* data,
size_t size, size_t size,
PacketInfo* packet_info, const StreamInfo& stream_info,
int64_t* timestamp_us,
::media::AudioBus* audio_bus); ::media::AudioBus* audio_bus);
// Populate header of the message, including the SmallMessageSocket size bits. // Read the handshake message to |stream_info|, and return true on success.
// Note this is used by unittest, user should use MakeMessage directly. bool ReadHandshakeMessage(const char* data,
char* PopulateHeader(char* data, size_t size, const PacketInfo& stream_info); size_t size,
StreamInfo* stream_info);
// Return the expected size of the data of a stream message with |stream_info|. // Return the expected size of the data of a stream message with |stream_info|.
size_t DataSizeInBytes(const StreamInfo& stream_info); size_t DataSizeInBytes(const StreamInfo& stream_info);
// Following methods are exposed for unittests:
// Write |buf_size|, in big-endian order, to |buf|, and fill |data| to |buf|
// afterward.
void FillBuffer(char* buf, size_t buf_size, const void* data, size_t data_size);
// Populate header of the PCM audio message, including the SmallMessageSocket
// size bits.
// Note this is used by unittest, user should use MakePcmAudioMessage directly.
char* PopulatePcmAudioHeader(char* data,
size_t size,
StreamType stream_type,
int64_t timestamp_us);
// Populate the handshake message, including the SmallMessageSocket size bits.
// Note this is used by unittest, user should use MakeHandshakeMessage directly.
void PopulateHandshakeMessage(char* data,
size_t size,
const StreamInfo& stream_info);
} // namespace capture_service } // namespace capture_service
} // namespace media } // namespace media
} // namespace chromecast } // namespace chromecast
......
...@@ -16,7 +16,6 @@ namespace media { ...@@ -16,7 +16,6 @@ namespace media {
namespace capture_service { namespace capture_service {
namespace { namespace {
constexpr size_t kTotalHeaderBytes = 16;
constexpr size_t kFrames = 10; constexpr size_t kFrames = 10;
constexpr size_t kChannels = 2; constexpr size_t kChannels = 2;
constexpr StreamInfo kStreamInfo = constexpr StreamInfo kStreamInfo =
...@@ -26,37 +25,72 @@ constexpr StreamInfo kStreamInfo = ...@@ -26,37 +25,72 @@ constexpr StreamInfo kStreamInfo =
SampleFormat::PLANAR_FLOAT, SampleFormat::PLANAR_FLOAT,
16000, 16000,
kFrames}; kFrames};
constexpr PacketInfo kHandshakePacketInfo = {MessageType::kHandshake,
kStreamInfo, 0}; class PacketHeaderTest
constexpr PacketInfo kPcmAudioPacketInfo = {MessageType::kPcmAudio, kStreamInfo, : public testing::TestWithParam<
0}; std::tuple<StreamType, AudioCodec, int, SampleFormat, int, int>> {
protected:
StreamInfo GetStreamInfo() {
StreamInfo info;
info.stream_type = std::get<0>(GetParam());
info.audio_codec = std::get<1>(GetParam());
info.num_channels = std::get<2>(GetParam());
info.sample_format = std::get<3>(GetParam());
info.sample_rate = std::get<4>(GetParam());
info.frames_per_buffer = std::get<5>(GetParam());
return info;
}
};
TEST_P(PacketHeaderTest, HandshakeMessage) {
std::vector<char> data(sizeof(HandshakePacket), 0);
StreamInfo stream_info = GetStreamInfo();
PopulateHandshakeMessage(data.data(), data.size(), stream_info);
StreamInfo info_out;
bool success =
ReadHandshakeMessage(data.data() + sizeof(uint16_t),
data.size() - sizeof(uint16_t), &info_out);
EXPECT_TRUE(success);
EXPECT_EQ(info_out.stream_type, stream_info.stream_type);
EXPECT_EQ(info_out.audio_codec, stream_info.audio_codec);
EXPECT_EQ(info_out.sample_format, stream_info.sample_format);
EXPECT_EQ(info_out.num_channels, stream_info.num_channels);
EXPECT_EQ(info_out.sample_rate, stream_info.sample_rate);
EXPECT_EQ(info_out.frames_per_buffer, stream_info.frames_per_buffer);
}
TEST(MessageParsingUtilsTest, PcmAudioMessage) {
size_t data_size = sizeof(PcmPacketHeader) / sizeof(float);
std::vector<float> data(data_size, 1.0f);
int64_t timestamp_us = 100;
PopulatePcmAudioHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), kStreamInfo.stream_type,
timestamp_us);
int64_t timestamp_out = 0;
bool success = ReadPcmAudioHeader(
reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), kStreamInfo,
&timestamp_out);
EXPECT_TRUE(success);
EXPECT_EQ(timestamp_out, timestamp_us);
}
TEST(MessageParsingUtilsTest, ValidPlanarFloat) { TEST(MessageParsingUtilsTest, ValidPlanarFloat) {
size_t data_size = kTotalHeaderBytes / sizeof(float) + kFrames * kChannels; size_t data_size =
sizeof(PcmPacketHeader) / sizeof(float) + kFrames * kChannels;
std::vector<float> data(data_size, .0f); std::vector<float> data(data_size, .0f);
PopulateHeader(reinterpret_cast<char*>(data.data()), PopulatePcmAudioHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), kPcmAudioPacketInfo); data.size() * sizeof(float), kStreamInfo.stream_type,
0);
// Fill the last k frames, i.e., the second channel, with 0.5f. // Fill the last k frames, i.e., the second channel, with 0.5f.
for (size_t i = data_size - kFrames; i < data_size; i++) { for (size_t i = data_size - kFrames; i < data_size; i++) {
data[i] = .5f; data[i] = .5f;
} }
// Audio header.
PacketInfo info;
bool success =
ReadHeader(reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), &info);
EXPECT_TRUE(success);
EXPECT_EQ(info.message_type, kPcmAudioPacketInfo.message_type);
EXPECT_EQ(info.stream_info.stream_type, kStreamInfo.stream_type);
EXPECT_EQ(info.stream_info.num_channels, kStreamInfo.num_channels);
EXPECT_EQ(info.stream_info.sample_format, kStreamInfo.sample_format);
EXPECT_EQ(info.stream_info.sample_rate, kStreamInfo.sample_rate);
EXPECT_EQ(info.timestamp_us, kPcmAudioPacketInfo.timestamp_us);
// Audio data.
auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames); auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames);
success = ReadDataToAudioBus( bool success = ReadDataToAudioBus(
kStreamInfo, reinterpret_cast<char*>(data.data()) + sizeof(uint16_t), kStreamInfo, reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), audio_bus.get()); data_size * sizeof(float) - sizeof(uint16_t), audio_bus.get());
EXPECT_TRUE(success); EXPECT_TRUE(success);
...@@ -67,22 +101,23 @@ TEST(MessageParsingUtilsTest, ValidPlanarFloat) { ...@@ -67,22 +101,23 @@ TEST(MessageParsingUtilsTest, ValidPlanarFloat) {
} }
TEST(MessageParsingUtilsTest, ValidInterleavedInt16) { TEST(MessageParsingUtilsTest, ValidInterleavedInt16) {
size_t data_size = kTotalHeaderBytes / sizeof(int16_t) + kFrames * kChannels; size_t data_size =
sizeof(PcmPacketHeader) / sizeof(int16_t) + kFrames * kChannels;
std::vector<int16_t> data(data_size, std::numeric_limits<int16_t>::max()); std::vector<int16_t> data(data_size, std::numeric_limits<int16_t>::max());
PacketInfo packet_info = kPcmAudioPacketInfo; PopulatePcmAudioHeader(reinterpret_cast<char*>(data.data()),
packet_info.stream_info.sample_format = SampleFormat::INTERLEAVED_INT16; data.size() * sizeof(int16_t), kStreamInfo.stream_type,
PopulateHeader(reinterpret_cast<char*>(data.data()), 0);
data.size() * sizeof(int16_t), packet_info);
// Fill the second channel with min(). // Fill the second channel with min().
for (size_t i = kTotalHeaderBytes / sizeof(int16_t) + 1; i < data_size; for (size_t i = sizeof(PcmPacketHeader) / sizeof(int16_t) + 1; i < data_size;
i += 2) { i += 2) {
data[i] = std::numeric_limits<int16_t>::min(); data[i] = std::numeric_limits<int16_t>::min();
} }
auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames); auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames);
StreamInfo stream_info = kStreamInfo;
stream_info.sample_format = SampleFormat::INTERLEAVED_INT16;
bool success = ReadDataToAudioBus( bool success = ReadDataToAudioBus(
packet_info.stream_info, stream_info, reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(int16_t) - sizeof(uint16_t), audio_bus.get()); data_size * sizeof(int16_t) - sizeof(uint16_t), audio_bus.get());
EXPECT_TRUE(success); EXPECT_TRUE(success);
for (size_t f = 0; f < kFrames; f++) { for (size_t f = 0; f < kFrames; f++) {
...@@ -92,22 +127,23 @@ TEST(MessageParsingUtilsTest, ValidInterleavedInt16) { ...@@ -92,22 +127,23 @@ TEST(MessageParsingUtilsTest, ValidInterleavedInt16) {
} }
TEST(MessageParsingUtilsTest, ValidInterleavedInt32) { TEST(MessageParsingUtilsTest, ValidInterleavedInt32) {
size_t data_size = kTotalHeaderBytes / sizeof(int32_t) + kFrames * kChannels; size_t data_size =
sizeof(PcmPacketHeader) / sizeof(int32_t) + kFrames * kChannels;
std::vector<int32_t> data(data_size, std::numeric_limits<int32_t>::min()); std::vector<int32_t> data(data_size, std::numeric_limits<int32_t>::min());
PacketInfo packet_info = kPcmAudioPacketInfo; PopulatePcmAudioHeader(reinterpret_cast<char*>(data.data()),
packet_info.stream_info.sample_format = SampleFormat::INTERLEAVED_INT32; data.size() * sizeof(int32_t), kStreamInfo.stream_type,
PopulateHeader(reinterpret_cast<char*>(data.data()), 0);
data.size() * sizeof(int32_t), packet_info);
// Fill the second channel with max(). // Fill the second channel with max().
for (size_t i = kTotalHeaderBytes / sizeof(int32_t) + 1; i < data_size; for (size_t i = sizeof(PcmPacketHeader) / sizeof(int32_t) + 1; i < data_size;
i += 2) { i += 2) {
data[i] = std::numeric_limits<int32_t>::max(); data[i] = std::numeric_limits<int32_t>::max();
} }
auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames); auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames);
StreamInfo stream_info = kStreamInfo;
stream_info.sample_format = SampleFormat::INTERLEAVED_INT32;
bool success = ReadDataToAudioBus( bool success = ReadDataToAudioBus(
packet_info.stream_info, stream_info, reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(int32_t) - sizeof(uint16_t), audio_bus.get()); data_size * sizeof(int32_t) - sizeof(uint16_t), audio_bus.get());
EXPECT_TRUE(success); EXPECT_TRUE(success);
for (size_t f = 0; f < kFrames; f++) { for (size_t f = 0; f < kFrames; f++) {
...@@ -116,92 +152,68 @@ TEST(MessageParsingUtilsTest, ValidInterleavedInt32) { ...@@ -116,92 +152,68 @@ TEST(MessageParsingUtilsTest, ValidInterleavedInt32) {
} }
} }
TEST(MessageParsingUtilsTest, InvalidType) { TEST(MessageParsingUtilsTest, InvalidTypeHandshake) {
size_t data_size = kTotalHeaderBytes / sizeof(float); std::vector<char> data(sizeof(HandshakePacket), 0);
std::vector<float> data(data_size, 1.0f); StreamInfo stream_info = kStreamInfo;
// Request packet PopulateHandshakeMessage(data.data(), data.size(), stream_info);
PacketInfo request_packet_info = kHandshakePacketInfo;
PopulateHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), request_packet_info);
*(reinterpret_cast<uint8_t*>(data.data()) + *(reinterpret_cast<uint8_t*>(data.data()) +
offsetof(struct PacketHeader, stream_type)) = offsetof(struct HandshakePacket, stream_type)) =
static_cast<uint8_t>(StreamType::kLastType) + 1; static_cast<uint8_t>(StreamType::kLastType) + 1;
bool success = ReadHeader( bool success =
reinterpret_cast<char*>(data.data()) + sizeof(uint16_t), ReadHandshakeMessage(data.data() + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), &request_packet_info); data.size() - sizeof(uint16_t), &stream_info);
EXPECT_FALSE(success); EXPECT_FALSE(success);
}
// PCM audio packet TEST(MessageParsingUtilsTest, InvalidTypePcmAudio) {
PacketInfo pcm_audio_packet_info = kPcmAudioPacketInfo; size_t data_size = sizeof(PcmPacketHeader) / sizeof(float);
PopulateHeader(reinterpret_cast<char*>(data.data()), std::vector<float> data(data_size, 1.0f);
data.size() * sizeof(float), pcm_audio_packet_info); PopulatePcmAudioHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), kStreamInfo.stream_type,
0);
*(reinterpret_cast<uint8_t*>(data.data()) + *(reinterpret_cast<uint8_t*>(data.data()) +
offsetof(struct PacketHeader, stream_type)) = offsetof(struct PcmPacketHeader, stream_type)) =
static_cast<uint8_t>(StreamType::kLastType) + 1; static_cast<uint8_t>(StreamType::kLastType) + 1;
success = ReadHeader(reinterpret_cast<char*>(data.data()) + sizeof(uint16_t), int64_t timestamp_us;
data_size * sizeof(float) - sizeof(uint16_t), bool success = ReadPcmAudioHeader(
&pcm_audio_packet_info); reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), kStreamInfo, &timestamp_us);
EXPECT_FALSE(success); EXPECT_FALSE(success);
} }
TEST(MessageParsingUtilsTest, InvalidCodec) { TEST(MessageParsingUtilsTest, InvalidCodec) {
size_t data_size = kTotalHeaderBytes / sizeof(float); std::vector<char> data(sizeof(HandshakePacket), 0);
std::vector<float> data(data_size, 1.0f); StreamInfo stream_info = kStreamInfo;
PacketInfo packet_info = kHandshakePacketInfo; PopulateHandshakeMessage(data.data(), data.size(), stream_info);
PopulateHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), packet_info);
*(reinterpret_cast<uint8_t*>(data.data()) + *(reinterpret_cast<uint8_t*>(data.data()) +
offsetof(struct PacketHeader, codec_or_sample_format)) = offsetof(struct HandshakePacket, audio_codec)) =
static_cast<uint8_t>(AudioCodec::kLastCodec) + 1; static_cast<uint8_t>(AudioCodec::kLastCodec) + 1;
bool success = bool success =
ReadHeader(reinterpret_cast<char*>(data.data()) + sizeof(uint16_t), ReadHandshakeMessage(data.data() + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), &packet_info); data.size() - sizeof(uint16_t), &stream_info);
EXPECT_FALSE(success); EXPECT_FALSE(success);
} }
TEST(MessageParsingUtilsTest, InvalidFormat) { TEST(MessageParsingUtilsTest, InvalidFormat) {
size_t data_size = kTotalHeaderBytes / sizeof(float); std::vector<char> data(sizeof(HandshakePacket), 0);
std::vector<float> data(data_size, 1.0f); StreamInfo stream_info = kStreamInfo;
PacketInfo packet_info = kPcmAudioPacketInfo; PopulateHandshakeMessage(data.data(), data.size(), stream_info);
PopulateHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), packet_info);
*(reinterpret_cast<uint8_t*>(data.data()) + *(reinterpret_cast<uint8_t*>(data.data()) +
offsetof(struct PacketHeader, codec_or_sample_format)) = offsetof(struct HandshakePacket, sample_format)) =
static_cast<uint8_t>(SampleFormat::LAST_FORMAT) + 1; static_cast<uint8_t>(SampleFormat::LAST_FORMAT) + 1;
bool success = bool success =
ReadHeader(reinterpret_cast<char*>(data.data()) + sizeof(uint16_t), ReadHandshakeMessage(data.data() + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), &packet_info); data.size() - sizeof(uint16_t), &stream_info);
EXPECT_FALSE(success); EXPECT_FALSE(success);
} }
TEST(MessageParsingUtilsTest, RequestMessage) {
size_t data_size = kTotalHeaderBytes / sizeof(float);
std::vector<float> data(data_size, 1.0f);
PacketInfo packet_info = kHandshakePacketInfo;
PopulateHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), packet_info);
PacketInfo info;
bool success =
ReadHeader(reinterpret_cast<char*>(data.data()) + sizeof(uint16_t),
data_size * sizeof(float) - sizeof(uint16_t), &info);
EXPECT_TRUE(success);
EXPECT_EQ(info.message_type, kHandshakePacketInfo.message_type);
EXPECT_EQ(info.stream_info.stream_type, kStreamInfo.stream_type);
EXPECT_EQ(info.stream_info.audio_codec, kStreamInfo.audio_codec);
EXPECT_EQ(info.stream_info.num_channels, kStreamInfo.num_channels);
EXPECT_EQ(info.stream_info.sample_rate, kStreamInfo.sample_rate);
EXPECT_EQ(info.stream_info.frames_per_buffer, kStreamInfo.frames_per_buffer);
}
TEST(MessageParsingUtilsTest, InvalidDataLength) { TEST(MessageParsingUtilsTest, InvalidDataLength) {
size_t data_size = size_t data_size =
kTotalHeaderBytes / sizeof(float) + kFrames * kChannels + 1; sizeof(PcmPacketHeader) / sizeof(float) + kFrames * kChannels + 1;
std::vector<float> data(data_size, 1.0f); std::vector<float> data(data_size, 1.0f);
PopulateHeader(reinterpret_cast<char*>(data.data()), PopulatePcmAudioHeader(reinterpret_cast<char*>(data.data()),
data.size() * sizeof(float), kPcmAudioPacketInfo); data.size() * sizeof(float), kStreamInfo.stream_type,
0);
auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames); auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames);
bool success = ReadDataToAudioBus( bool success = ReadDataToAudioBus(
...@@ -212,10 +224,11 @@ TEST(MessageParsingUtilsTest, InvalidDataLength) { ...@@ -212,10 +224,11 @@ TEST(MessageParsingUtilsTest, InvalidDataLength) {
TEST(MessageParsingUtilsTest, NotAlignedData) { TEST(MessageParsingUtilsTest, NotAlignedData) {
size_t data_size = size_t data_size =
kTotalHeaderBytes / sizeof(float) + kFrames * kChannels + 1; sizeof(PcmPacketHeader) / sizeof(float) + kFrames * kChannels + 1;
std::vector<float> data(data_size, 1.0f); std::vector<float> data(data_size, 1.0f);
PopulateHeader(reinterpret_cast<char*>(data.data()) + 1, PopulatePcmAudioHeader(reinterpret_cast<char*>(data.data()) + 1,
data.size() * sizeof(float) - 1, kPcmAudioPacketInfo); data.size() * sizeof(float) - 1,
kStreamInfo.stream_type, 0);
auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames); auto audio_bus = ::media::AudioBus::Create(kChannels, kFrames);
bool success = ReadDataToAudioBus( bool success = ReadDataToAudioBus(
...@@ -224,6 +237,18 @@ TEST(MessageParsingUtilsTest, NotAlignedData) { ...@@ -224,6 +237,18 @@ TEST(MessageParsingUtilsTest, NotAlignedData) {
EXPECT_FALSE(success); EXPECT_FALSE(success);
} }
INSTANTIATE_TEST_SUITE_P(
MessageParsingUtilsTest,
PacketHeaderTest,
testing::Combine(testing::Values(StreamType::kMicRaw,
StreamType::kHardwareEchoRescaled),
testing::Values(AudioCodec::kPcm, AudioCodec::kOpus),
testing::Values(1, 8),
testing::Values(SampleFormat::INTERLEAVED_INT16,
SampleFormat::PLANAR_FLOAT),
testing::Values(16000, 96000),
testing::Values(0, 32761)));
} // namespace } // namespace
} // namespace capture_service } // namespace capture_service
} // namespace media } // namespace media
......
...@@ -11,19 +11,29 @@ namespace chromecast { ...@@ -11,19 +11,29 @@ namespace chromecast {
namespace media { namespace media {
namespace capture_service { namespace capture_service {
// Memory block of a packet header. Changes to it need to make sure about the // Memory block of a PCM audio packet header. Changes to it need to ensure the
// memory alignment to avoid extra paddings being inserted. It reflects real // size is a multiple of 4 bytes. It reflects real packet header structure,
// packet header structure, however, the |size| bits are in big-endian order, // however, the |size| bits are in big-endian order, and thus is only for
// and thus is only for padding purpose in this struct, when all bytes after it // padding purpose in this struct, when all bytes after it represent a message
// represent a message header. // header.
struct __attribute__((__packed__)) PacketHeader { struct __attribute__((__packed__)) PcmPacketHeader {
uint16_t size; uint16_t size;
uint8_t message_type; uint8_t message_type;
uint8_t stream_type; uint8_t stream_type;
uint8_t codec_or_sample_format; int64_t timestamp_us;
};
// Memory block of a handshake packet. There is no size restriction for this
// structure.
struct __attribute__((__packed__)) HandshakePacket {
uint16_t size;
uint8_t message_type;
uint8_t stream_type;
uint8_t audio_codec;
uint8_t sample_format;
uint8_t num_channels; uint8_t num_channels;
uint16_t sample_rate; uint16_t num_frames;
int64_t timestamp_or_frames; uint32_t sample_rate;
}; };
} // namespace capture_service } // namespace capture_service
......
...@@ -48,14 +48,14 @@ bool CastAudioInputStream::Open() { ...@@ -48,14 +48,14 @@ bool CastAudioInputStream::Open() {
audio_bus_ = ::media::AudioBus::Create(audio_params_.channels(), audio_bus_ = ::media::AudioBus::Create(audio_params_.channels(),
audio_params_.frames_per_buffer()); audio_params_.frames_per_buffer());
capture_service_receiver_ = std::make_unique<CaptureServiceReceiver>( stream_info_ = capture_service::StreamInfo{
capture_service::StreamInfo{
capture_service::StreamType::kSoftwareEchoCancelled, capture_service::StreamType::kSoftwareEchoCancelled,
capture_service::AudioCodec::kPcm, audio_params_.channels(), capture_service::AudioCodec::kPcm, audio_params_.channels(),
// Format doesn't matter in the request. // Format doesn't matter in the request.
capture_service::SampleFormat::LAST_FORMAT, capture_service::SampleFormat::LAST_FORMAT, audio_params_.sample_rate(),
audio_params_.sample_rate(), audio_params_.frames_per_buffer()}, audio_params_.frames_per_buffer()};
this); capture_service_receiver_ =
std::make_unique<CaptureServiceReceiver>(stream_info_, this);
return true; return true;
} }
...@@ -117,33 +117,40 @@ void CastAudioInputStream::SetOutputDeviceForAec( ...@@ -117,33 +117,40 @@ void CastAudioInputStream::SetOutputDeviceForAec(
bool CastAudioInputStream::OnInitialStreamInfo( bool CastAudioInputStream::OnInitialStreamInfo(
const capture_service::StreamInfo& stream_info) { const capture_service::StreamInfo& stream_info) {
const bool is_params_match = const bool is_params_match =
stream_info.stream_type == stream_info.stream_type == stream_info_.stream_type &&
capture_service::StreamType::kSoftwareEchoCancelled && stream_info.audio_codec == stream_info_.audio_codec &&
stream_info.audio_codec == capture_service::AudioCodec::kPcm && stream_info.num_channels == stream_info_.num_channels &&
stream_info.num_channels == audio_params_.channels() && stream_info.sample_rate == stream_info_.sample_rate &&
stream_info.sample_rate == audio_params_.sample_rate() && stream_info.frames_per_buffer == stream_info_.frames_per_buffer;
stream_info.frames_per_buffer == audio_params_.frames_per_buffer();
LOG_IF(ERROR, !is_params_match) LOG_IF(ERROR, !is_params_match)
<< "Got different parameters from sender, sample_rate: " << "Got different parameters from sender, stream_type: "
<< audio_params_.sample_rate() << " Hz -> " << stream_info.sample_rate << static_cast<int>(stream_info_.stream_type) << " -> "
<< " Hz, num_channels: " << audio_params_.channels() << " -> " << static_cast<int>(stream_info.stream_type)
<< ", audio_codec: " << static_cast<int>(stream_info_.audio_codec)
<< " -> " << static_cast<int>(stream_info.audio_codec)
<< ", sample_rate: " << stream_info_.sample_rate << " Hz -> "
<< stream_info.sample_rate
<< " Hz, num_channels: " << stream_info_.num_channels << " -> "
<< stream_info.num_channels << stream_info.num_channels
<< ", frames_per_buffer: " << audio_params_.frames_per_buffer() << " -> " << ", frames_per_buffer: " << stream_info_.frames_per_buffer << " -> "
<< stream_info.frames_per_buffer << "."; << stream_info.frames_per_buffer << ".";
stream_info_.sample_format = stream_info.sample_format;
LOG(INFO) << "Set sample_format: "
<< static_cast<int>(stream_info.sample_format);
return is_params_match; return is_params_match;
} }
bool CastAudioInputStream::OnCaptureData(const char* data, size_t size) { bool CastAudioInputStream::OnCaptureData(const char* data, size_t size) {
capture_service::PacketInfo info; int64_t timestamp_us;
if (!capture_service::ReadPcmAudioMessage(data, size, &info, if (!capture_service::ReadPcmAudioMessage(data, size, stream_info_,
audio_bus_.get())) { &timestamp_us, audio_bus_.get())) {
return false; return false;
} }
DCHECK(input_callback_); DCHECK(input_callback_);
input_callback_->OnData( input_callback_->OnData(
audio_bus_.get(), audio_bus_.get(),
base::TimeTicks() + base::TimeDelta::FromMicroseconds(info.timestamp_us), base::TimeTicks() + base::TimeDelta::FromMicroseconds(timestamp_us),
/* volume */ 1.0); /* volume */ 1.0);
return true; return true;
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/threading/thread_checker.h" #include "base/threading/thread_checker.h"
#include "chromecast/media/audio/capture_service/capture_service_receiver.h" #include "chromecast/media/audio/capture_service/capture_service_receiver.h"
#include "chromecast/media/audio/capture_service/constants.h"
#include "media/audio/audio_io.h" #include "media/audio/audio_io.h"
#include "media/base/audio_bus.h" #include "media/base/audio_bus.h"
#include "media/base/audio_parameters.h" #include "media/base/audio_parameters.h"
...@@ -58,6 +59,7 @@ class CastAudioInputStream : public ::media::AudioInputStream, ...@@ -58,6 +59,7 @@ class CastAudioInputStream : public ::media::AudioInputStream,
// may be null, if |this| is not created by audio manager, e.g., in unit test. // may be null, if |this| is not created by audio manager, e.g., in unit test.
::media::AudioManagerBase* const audio_manager_; ::media::AudioManagerBase* const audio_manager_;
const ::media::AudioParameters audio_params_; const ::media::AudioParameters audio_params_;
capture_service::StreamInfo stream_info_;
std::unique_ptr<CaptureServiceReceiver> capture_service_receiver_; std::unique_ptr<CaptureServiceReceiver> capture_service_receiver_;
AudioInputCallback* input_callback_ = nullptr; AudioInputCallback* input_callback_ = nullptr;
std::unique_ptr<::media::AudioBus> audio_bus_; std::unique_ptr<::media::AudioBus> audio_bus_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment