Commit 56d85f95 authored by Derek Cheng's avatar Derek Cheng Committed by Commit Bot

[Cast channel] Validate IP address from mDNS / cast channel requests.

A valid Cast device address must be private. This is checked using
the IPAddress::IsReserved() method, similar to DIAL's device description
service.

The check is performed in several entry points (some are redundant as
extra safety net):
- DnsSdRegistry, when it receives an device advertisement from mDNS
- CastSocketService, before it opens socket
- CastSocketServiceImpl::OpenChannel
- CastChannelOpenFunction (entry point for chrome.cast.channel.open)

Bug: 786109
Change-Id: Iaad91834cd4149fd345b2ada4e2704a0e158ba49
Reviewed-on: https://chromium-review.googlesource.com/792650
Commit-Queue: Derek Cheng <imcheng@chromium.org>
Reviewed-by: default avatarmark a. foltz <mfoltz@chromium.org>
Cr-Commit-Position: refs/heads/master@{#519857}
parent d6b68e0c
......@@ -15,6 +15,7 @@
#include "chrome/common/media_router/discovery/media_sink_service.h"
#include "chrome/common/media_router/media_sink.h"
#include "components/cast_channel/cast_channel_enum.h"
#include "components/cast_channel/cast_channel_util.h"
#include "components/cast_channel/cast_socket_service.h"
#include "components/cast_channel/logger.h"
#include "components/net_log/chrome_net_log.h"
......@@ -413,6 +414,9 @@ void CastMediaSinkServiceImpl::OpenChannel(
SinkSource sink_source) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!cast_channel::IsValidCastIPAddress(ip_endpoint.address()))
return;
// Erase the entry from |dial_sink_failure_count_| since the device is now
// known to be a Cast device.
if (sink_source != SinkSource::kDial)
......
......@@ -121,7 +121,6 @@ class CastMediaSinkServiceImplTest : public ::testing::Test {
.WillOnce(Invoke(
[socket](const auto& ip_endpoint, auto* net_log, auto open_cb) {
std::move(open_cb).Run(socket);
return socket->id();
}));
}
......@@ -346,7 +345,6 @@ TEST_F(CastMediaSinkServiceImplTest, TestOpenChannelFails) {
.WillRepeatedly(
Invoke([&](const auto& ip_endpoint1, auto* net_log, auto open_cb) {
std::move(open_cb).Run(&socket);
return socket.id();
}));
media_sink_service_impl_.OpenChannel(
ip_endpoint, cast_sink, nullptr,
......@@ -484,7 +482,6 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnChannelErrorMayRetryForCastSink) {
.WillRepeatedly(
Invoke([&](const auto& ip_endpoint, auto* net_log, auto open_cb) {
std::move(open_cb).Run(&socket);
return socket.id();
}));
media_sink_service_impl_.OnError(
......@@ -534,18 +531,14 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnDialSinkAdded) {
// Channel 1, 2 opened.
EXPECT_CALL(*mock_cast_socket_service_,
OpenSocketInternal(ip_endpoint1, _, _))
.WillOnce(DoAll(
WithArgs<2>(Invoke(
[&](const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { std::move(callback).Run(&socket1); })),
Return(1)));
.WillOnce(WithArgs<2>(Invoke(
[&](const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { std::move(callback).Run(&socket1); })));
EXPECT_CALL(*mock_cast_socket_service_,
OpenSocketInternal(ip_endpoint2, _, _))
.WillOnce(DoAll(
WithArgs<2>(Invoke(
[&](const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { std::move(callback).Run(&socket2); })),
Return(2)));
.WillOnce(WithArgs<2>(Invoke(
[&](const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { std::move(callback).Run(&socket2); })));
// Invoke CastSocketService::OpenSocket on the IO thread.
media_sink_service_impl_.OnDialSinkAdded(dial_sink1);
......@@ -570,12 +563,10 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnDialSinkAddedSkipsIfNonCastDevice) {
EXPECT_CALL(*mock_cast_socket_service_,
OpenSocketInternal(ip_endpoint1, _, _))
.Times(1)
.WillOnce(DoAll(
WithArgs<2>(Invoke(
[&socket1](
const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { std::move(callback).Run(&socket1); })),
Return(1)));
.WillOnce(WithArgs<2>(Invoke(
[&socket1](
const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { std::move(callback).Run(&socket1); })));
media_sink_service_impl_.OnDialSinkAdded(dial_sink1);
// We don't trigger retries, thus each iteration will only increment the
......@@ -585,13 +576,10 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnDialSinkAddedSkipsIfNonCastDevice) {
EXPECT_CALL(*mock_cast_socket_service_,
OpenSocketInternal(ip_endpoint1, _, _))
.Times(1)
.WillOnce(DoAll(
WithArgs<2>(Invoke(
[&socket1](const base::Callback<void(cast_channel::CastSocket *
socket)>& callback) {
std::move(callback).Run(&socket1);
})),
Return(1)));
.WillOnce(WithArgs<2>(Invoke(
[&socket1](
const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { std::move(callback).Run(&socket1); })));
media_sink_service_impl_.OnDialSinkAdded(dial_sink1);
}
......
......@@ -11,6 +11,7 @@
#include "chrome/browser/local_discovery/service_discovery_shared_client.h"
#include "chrome/browser/media/router/discovery/mdns/dns_sd_device_lister.h"
#include "chrome/common/features.h"
#include "components/cast_channel/cast_channel_util.h"
using local_discovery::ServiceDiscoveryClient;
using local_discovery::ServiceDiscoverySharedClient;
......@@ -194,7 +195,12 @@ void DnsSdRegistry::ServiceChanged(const std::string& service_type,
VLOG(1) << "ServiceChanged: service_type: " << service_type
<< ", known: " << IsRegistered(service_type)
<< ", service: " << service.service_name << ", added: " << added;
if (!IsRegistered(service_type)) {
if (!IsRegistered(service_type))
return;
net::IPAddress ip_address;
if (!cast_channel::IsValidCastIPAddressString(service.ip_address)) {
VLOG(1) << "Invalid IP address: " << service.ip_address;
return;
}
......@@ -202,9 +208,8 @@ void DnsSdRegistry::ServiceChanged(const std::string& service_type,
service_data_map_[service_type]->UpdateService(added, service);
VLOG(1) << "ServiceChanged: is_updated: " << is_updated;
if (is_updated) {
if (is_updated)
DispatchApiEvent(service_type);
}
}
void DnsSdRegistry::ServiceRemoved(const std::string& service_type,
......@@ -213,9 +218,8 @@ void DnsSdRegistry::ServiceRemoved(const std::string& service_type,
VLOG(1) << "ServiceRemoved: service_type: " << service_type
<< ", known: " << IsRegistered(service_type)
<< ", service: " << service_name;
if (!IsRegistered(service_type)) {
if (!IsRegistered(service_type))
return;
}
bool is_removed =
service_data_map_[service_type]->RemoveService(service_name);
......@@ -229,9 +233,8 @@ void DnsSdRegistry::ServicesFlushed(const std::string& service_type) {
DCHECK(thread_checker_.CalledOnValidThread());
VLOG(1) << "ServicesFlushed: service_type: " << service_type
<< ", known: " << IsRegistered(service_type);
if (!IsRegistered(service_type)) {
if (!IsRegistered(service_type))
return;
}
bool is_cleared = service_data_map_[service_type]->ClearServices();
VLOG(1) << "ServicesFlushed: is_cleared: " << is_cleared;
......
......@@ -8,6 +8,8 @@
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
namespace media_router {
class MockDnsSdDeviceLister : public DnsSdDeviceLister {
......@@ -152,6 +154,29 @@ TEST_F(DnsSdRegistryTest, AddAndUpdate) {
registry_->GetDelegate()->ServiceChanged(service_type, false, service);
}
TEST_F(DnsSdRegistryTest, AddServiceWithInvalidIPAddress) {
const std::string service_type = "_testing._tcp.local";
const std::string ip_address1 = "invalid";
// |ip_address2| is not a private IP address and is therefore invalid.
const std::string ip_address2 = "111.111.111.111";
DnsSdService service;
service.service_name = "_myDevice." + service_type;
DnsSdRegistry::DnsSdServiceList service_list;
EXPECT_CALL(observer_, OnDnsSdEvent(_, _)).Times(1);
registry_->RegisterDnsSdListener(service_type);
EXPECT_TRUE(testing::Mock::VerifyAndClearExpectations(&observer_));
EXPECT_CALL(observer_, OnDnsSdEvent(_, _)).Times(0);
service.ip_address = ip_address1;
registry_->GetDelegate()->ServiceChanged(service_type, true, service);
service.ip_address = ip_address2;
registry_->GetDelegate()->ServiceChanged(service_type, false, service);
}
// Tests registering a listener and receiving an added and removed event.
TEST_F(DnsSdRegistryTest, AddAndRemove) {
const std::string service_type = "_testing._tcp.local";
......@@ -184,8 +209,8 @@ TEST_F(DnsSdRegistryTest, AddMultipleServices) {
service.ip_address = "192.168.0.100";
DnsSdService service2;
service.service_name = "_myDevice2." + service_type;
service.ip_address = "192.168.0.101";
service2.service_name = "_myDevice2." + service_type;
service2.ip_address = "192.168.0.101";
DnsSdRegistry::DnsSdServiceList service_list;
EXPECT_CALL(observer_, OnDnsSdEvent(service_type, service_list));
......@@ -209,8 +234,8 @@ TEST_F(DnsSdRegistryTest, FlushCache) {
service.ip_address = "192.168.0.100";
DnsSdService service2;
service.service_name = "_myDevice2." + service_type;
service.ip_address = "192.168.0.101";
service2.service_name = "_myDevice2." + service_type;
service2.ip_address = "192.168.0.101";
DnsSdRegistry::DnsSdServiceList service_list;
EXPECT_CALL(observer_, OnDnsSdEvent(service_type, service_list));
......
......@@ -8,6 +8,8 @@ static_library("cast_channel") {
"cast_auth_util.h",
"cast_channel_enum.cc",
"cast_channel_enum.h",
"cast_channel_util.cc",
"cast_channel_util.h",
"cast_framer.cc",
"cast_framer.h",
"cast_message_util.cc",
......
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/cast_channel/cast_channel_util.h"
namespace cast_channel {
bool IsValidCastIPAddress(const net::IPAddress& ip_address) {
// A valid Cast IP address must be private.
return ip_address.IsReserved();
}
bool IsValidCastIPAddressString(const std::string& ip_address_string) {
net::IPAddress ip_address;
return ip_address.AssignFromIPLiteral(ip_address_string) &&
IsValidCastIPAddress(ip_address);
}
} // namespace cast_channel
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_CAST_CHANNEL_CAST_CHANNEL_UTIL_H_
#define COMPONENTS_CAST_CHANNEL_CAST_CHANNEL_UTIL_H_
#include "net/base/ip_address.h"
namespace cast_channel {
// Returns true if |ip_address| represents a valid IP address of a Cast device.
bool IsValidCastIPAddress(const net::IPAddress& ip_address);
// Similar to above, but takes a std::string as input.
bool IsValidCastIPAddressString(const std::string& ip_address_string);
} // namespace cast_channel
#endif // COMPONENTS_CAST_CHANNEL_CAST_CHANNEL_UTIL_H_
......@@ -5,6 +5,7 @@
#include "components/cast_channel/cast_socket_service.h"
#include "base/memory/ptr_util.h"
#include "components/cast_channel/cast_channel_util.h"
#include "components/cast_channel/cast_socket.h"
#include "components/cast_channel/logger.h"
#include "content/public/browser/browser_thread.h"
......@@ -75,11 +76,14 @@ CastSocket* CastSocketService::GetSocket(
return it == sockets_.end() ? nullptr : it->second.get();
}
int CastSocketService::OpenSocket(const CastSocketOpenParams& open_params,
CastSocket::OnOpenCallback open_cb) {
void CastSocketService::OpenSocket(const CastSocketOpenParams& open_params,
CastSocket::OnOpenCallback open_cb) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
auto* socket = GetSocket(open_params.ip_endpoint);
const net::IPEndPoint& ip_endpoint = open_params.ip_endpoint;
CHECK(IsValidCastIPAddress(ip_endpoint.address()));
auto* socket = GetSocket(ip_endpoint);
if (!socket) {
// If cast socket does not exist.
if (socket_for_test_) {
......@@ -94,8 +98,6 @@ int CastSocketService::OpenSocket(const CastSocketOpenParams& open_params,
socket->AddObserver(&observer);
socket->Connect(std::move(open_cb));
return socket->id();
}
void CastSocketService::AddObserver(CastSocket::Observer* observer) {
......
......@@ -28,10 +28,6 @@ class CastSocketService {
// Returns a pointer to the Logger member variable.
scoped_refptr<cast_channel::Logger> GetLogger();
// Adds |socket| to |sockets_| and returns raw pointer of |socket|. Takes
// ownership of |socket|.
CastSocket* AddSocket(std::unique_ptr<CastSocket> socket);
// Removes the CastSocket corresponding to |channel_id| from the
// CastSocketRegistry. Returns nullptr if no such CastSocket exists.
std::unique_ptr<CastSocket> RemoveSocket(int channel_id);
......@@ -42,14 +38,15 @@ class CastSocketService {
CastSocket* GetSocket(const net::IPEndPoint& ip_endpoint) const;
// Opens cast socket with |ip_endpoint| and invokes |open_cb| when opening
// Opens cast socket with |open_params| and invokes |open_cb| when opening
// operation finishes. If cast socket with |ip_endpoint| already exists,
// invoke |open_cb| directly with existing socket's channel ID.
// Parameters:
// invoke |open_cb| directly with the existing socket.
// It is the caller's responsibility to ensure |open_params.ip_address| is
// a valid private IP address as determined by |IsValidCastIPAddress()|.
// |open_params|: Parameters necessary to open a Cast channel.
// |open_cb|: OnOpenCallback invoked when cast socket is opened.
virtual int OpenSocket(const CastSocketOpenParams& open_params,
CastSocket::OnOpenCallback open_cb);
virtual void OpenSocket(const CastSocketOpenParams& open_params,
CastSocket::OnOpenCallback open_cb);
// Adds |observer| to socket service. When socket service opens cast socket,
// it passes |observer| to opened socket.
......@@ -71,6 +68,10 @@ class CastSocketService {
CastSocketService();
virtual ~CastSocketService();
// Adds |socket| to |sockets_| and returns raw pointer of |socket|. Takes
// ownership of |socket|.
CastSocket* AddSocket(std::unique_ptr<CastSocket> socket);
// Used to generate CastSocket id.
static int last_channel_id_;
......
......@@ -22,9 +22,7 @@ namespace cast_channel {
class CastSocketServiceTest : public testing::Test {
public:
CastSocketServiceTest()
: thread_bundle_(content::TestBrowserThreadBundle::IO_MAINLOOP),
cast_socket_service_(new CastSocketService()) {}
CastSocketServiceTest() : cast_socket_service_(new CastSocketService()) {}
CastSocket* AddSocket(std::unique_ptr<CastSocket> socket) {
return cast_socket_service_->AddSocket(std::move(socket));
......@@ -33,7 +31,6 @@ class CastSocketServiceTest : public testing::Test {
void TearDown() override { cast_socket_service_ = nullptr; }
protected:
content::TestBrowserThreadBundle thread_bundle_;
std::unique_ptr<CastSocketService> cast_socket_service_;
base::MockCallback<CastSocket::OnOpenCallback> mock_on_open_callback_;
MockCastSocketObserver mock_observer_;
......
......@@ -70,19 +70,18 @@ class MockCastSocketService : public CastSocketService {
MockCastSocketService();
~MockCastSocketService() override;
int OpenSocket(const CastSocketOpenParams& open_params,
CastSocket::OnOpenCallback open_cb) override {
void OpenSocket(const CastSocketOpenParams& open_params,
CastSocket::OnOpenCallback open_cb) override {
// Unit test should not call |open_cb| more than once. Just use
// base::AdaptCallbackForRepeating to pass |open_cb| to a mock method.
return OpenSocketInternal(
open_params.ip_endpoint, open_params.net_log,
base::AdaptCallbackForRepeating(std::move(open_cb)));
OpenSocketInternal(open_params.ip_endpoint, open_params.net_log,
base::AdaptCallbackForRepeating(std::move(open_cb)));
}
MOCK_METHOD3(OpenSocketInternal,
int(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
const base::Callback<void(CastSocket*)>& open_cb));
void(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
const base::Callback<void(CastSocket*)>& open_cb));
MOCK_CONST_METHOD1(GetSocket, CastSocket*(int channel_id));
};
......
......@@ -19,6 +19,7 @@
#include "base/strings/string_number_conversions.h"
#include "base/values.h"
#include "components/cast_channel/cast_channel_enum.h"
#include "components/cast_channel/cast_channel_util.h"
#include "components/cast_channel/cast_message_util.h"
#include "components/cast_channel/cast_socket.h"
#include "components/cast_channel/cast_socket_service.h"
......@@ -110,8 +111,7 @@ bool IsValidConnectInfoPort(const ConnectInfo& connect_info) {
}
bool IsValidConnectInfoIpAddress(const ConnectInfo& connect_info) {
net::IPAddress ip_address;
return ip_address.AssignFromIPLiteral(connect_info.ip_address);
return cast_channel::IsValidCastIPAddressString(connect_info.ip_address);
}
} // namespace
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment