Commit 53df5c3a authored by Derek Cheng's avatar Derek Cheng Committed by Commit Bot

[MediaRouter] Remove DIAL sink if the device is Cast-enabled.

Certain Cast-enabled devices also advertise via SSDP for the purpose of
improving discovey reliability. However, they also (incorrectly) claim
to support DIAL apps, only for the DIAL app launch to fail.

We currently prevent such cases of Cast devices showing up as a
duplicated (incorrect) DIAL sink by relying on a hardcoded list of
model names to determine whether a DIAL sink is "discovery only",
i.e., DIAL sink queries should not be performed on them.

Unfortunately the list is not exhaustive and it would be challenging
if not impossible to obtain and maintain an acccurate list, as seen
from the recent regression. This patch takes a different approach, by
noting that if a Cast sink can be derived from a DIAL sink (or if it
already exists), then the DIAL sink can be removed from
DialMediaSinkServiceImpl as it has no further use by the DIAL MRP.
This ensures that no further DIAL queries will be issued to that sink,
and that the sink will not show up on a sink query.

Change-Id: Icdd3fc35baf6e3898537d38e1b059fc5c2714007
Bug: 869112
Reviewed-on: https://chromium-review.googlesource.com/1155258
Commit-Queue: Derek Cheng <imcheng@chromium.org>
Reviewed-by: default avatarmark a. foltz <mfoltz@chromium.org>
Cr-Commit-Position: refs/heads/master@{#579256}
parent 5a103d4a
...@@ -185,6 +185,16 @@ MediaSink::Id CastMediaSinkServiceImpl::GetCastSinkIdFromDial( ...@@ -185,6 +185,16 @@ MediaSink::Id CastMediaSinkServiceImpl::GetCastSinkIdFromDial(
return "cast:" + dial_sink_id.substr(5); return "cast:" + dial_sink_id.substr(5);
} }
// static
MediaSink::Id CastMediaSinkServiceImpl::GetDialSinkIdFromCast(
const MediaSink::Id& cast_sink_id) {
DCHECK_EQ("cast:", cast_sink_id.substr(0, 5))
<< "unexpected Cast sink id " << cast_sink_id;
// Replace the "cast:" prefix with "dial:".
return "dial:" + cast_sink_id.substr(5);
}
CastMediaSinkServiceImpl::CastMediaSinkServiceImpl( CastMediaSinkServiceImpl::CastMediaSinkServiceImpl(
const OnSinksDiscoveredCallback& callback, const OnSinksDiscoveredCallback& callback,
cast_channel::CastSocketService* cast_socket_service, cast_channel::CastSocketService* cast_socket_service,
...@@ -632,6 +642,13 @@ void CastMediaSinkServiceImpl::OnChannelOpenSucceeded( ...@@ -632,6 +642,13 @@ void CastMediaSinkServiceImpl::OnChannelOpenSucceeded(
if (old_sink_it != sinks.end()) if (old_sink_it != sinks.end())
RemoveSink(old_sink_it->second); RemoveSink(old_sink_it->second);
// Certain classes of Cast sinks support advertising via SSDP but do not
// properly implement the rest of the DIAL protocol. If we successfully open
// a Cast channel to a device that came from DIAL, remove it from
// |dial_media_sink_service_|. This ensures the device shows up as a Cast sink
// only.
dial_media_sink_service_->RemoveSinkById(GetDialSinkIdFromCast(sink_id));
} }
void CastMediaSinkServiceImpl::OnChannelOpenFailed( void CastMediaSinkServiceImpl::OnChannelOpenFailed(
...@@ -674,6 +691,9 @@ void CastMediaSinkServiceImpl::TryConnectDialDiscoveredSink( ...@@ -674,6 +691,9 @@ void CastMediaSinkServiceImpl::TryConnectDialDiscoveredSink(
DVLOG(2) << "Sink discovered by mDNS, skip adding [name]: " DVLOG(2) << "Sink discovered by mDNS, skip adding [name]: "
<< sink.sink().name(); << sink.sink().name();
metrics_.RecordCastSinkDiscoverySource(SinkSource::kMdnsDial); metrics_.RecordCastSinkDiscoverySource(SinkSource::kMdnsDial);
// Sink is a Cast device; remove from |dial_media_sink_service_| to prevent
// duplicates.
dial_media_sink_service_->RemoveSink(dial_sink);
return; return;
} }
......
...@@ -44,8 +44,9 @@ class CastMediaSinkServiceImpl : public MediaSinkServiceBase, ...@@ -44,8 +44,9 @@ class CastMediaSinkServiceImpl : public MediaSinkServiceBase,
// before we can say confidently that it is unlikely to be a Cast device. // before we can say confidently that it is unlikely to be a Cast device.
static constexpr int kMaxDialSinkFailureCount = 10; static constexpr int kMaxDialSinkFailureCount = 10;
// Returns a Cast MediaSink ID from a DIAL MediaSink ID |dial_sink_id|. // Returns a Cast MediaSink ID from a DIAL MediaSink ID, and vice versa.
static MediaSink::Id GetCastSinkIdFromDial(const MediaSink::Id& dial_sink_id); static MediaSink::Id GetCastSinkIdFromDial(const MediaSink::Id& dial_sink_id);
static MediaSink::Id GetDialSinkIdFromCast(const MediaSink::Id& cast_sink_id);
// |callback|: Callback passed to MediaSinkServiceBase. // |callback|: Callback passed to MediaSinkServiceBase.
// |observer|: Observer to invoke on sink updates. Can be nullptr. // |observer|: Observer to invoke on sink updates. Can be nullptr.
......
...@@ -587,11 +587,6 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnChannelErrorNoRetryForMissingSink) { ...@@ -587,11 +587,6 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnChannelErrorNoRetryForMissingSink) {
EXPECT_CALL(socket, ready_state()) EXPECT_CALL(socket, ready_state())
.WillOnce(Return(cast_channel::ReadyState::CLOSED)); .WillOnce(Return(cast_channel::ReadyState::CLOSED));
// There is no existing cast sink.
/* XXX
media_sink_service_impl_.pending_for_open_ip_endpoints_.clear();
media_sink_service_impl_.current_sinks_.clear();
*/
media_sink_service_impl_.OnError( media_sink_service_impl_.OnError(
socket, cast_channel::ChannelError::CHANNEL_NOT_OPEN); socket, cast_channel::ChannelError::CHANNEL_NOT_OPEN);
...@@ -602,6 +597,10 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnChannelErrorNoRetryForMissingSink) { ...@@ -602,6 +597,10 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnChannelErrorNoRetryForMissingSink) {
} }
TEST_F(CastMediaSinkServiceImplTest, TestOnSinkAddedOrUpdated) { TEST_F(CastMediaSinkServiceImplTest, TestOnSinkAddedOrUpdated) {
// Make sure |media_sink_service_impl_| adds itself as an observer to
// |dial_media_sink_service_|.
media_sink_service_impl_.Start();
MediaSinkInternal dial_sink1 = CreateDialSink(1); MediaSinkInternal dial_sink1 = CreateDialSink(1);
MediaSinkInternal dial_sink2 = CreateDialSink(2); MediaSinkInternal dial_sink2 = CreateDialSink(2);
net::IPEndPoint ip_endpoint1(dial_sink1.dial_data().ip_address, net::IPEndPoint ip_endpoint1(dial_sink1.dial_data().ip_address,
...@@ -618,22 +617,22 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnSinkAddedOrUpdated) { ...@@ -618,22 +617,22 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnSinkAddedOrUpdated) {
// Channel 1, 2 opened. // Channel 1, 2 opened.
EXPECT_CALL(*mock_cast_socket_service_, EXPECT_CALL(*mock_cast_socket_service_,
OpenSocketInternal(ip_endpoint1, _, _)) OpenSocketInternal(ip_endpoint1, _, _))
.WillOnce(WithArgs<2>(Invoke( .WillOnce(WithArgs<2>(
[&](const base::Callback<void(cast_channel::CastSocket * socket)>& [&socket1](
callback) { std::move(callback).Run(&socket1); }))); const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { callback.Run(&socket1); }));
EXPECT_CALL(*mock_cast_socket_service_, EXPECT_CALL(*mock_cast_socket_service_,
OpenSocketInternal(ip_endpoint2, _, _)) OpenSocketInternal(ip_endpoint2, _, _))
.WillOnce(WithArgs<2>(Invoke( .WillOnce(WithArgs<2>(
[&](const base::Callback<void(cast_channel::CastSocket * socket)>& [&socket2](
callback) { std::move(callback).Run(&socket2); }))); const base::Callback<void(cast_channel::CastSocket * socket)>&
callback) { callback.Run(&socket2); }));
// Invoke CastSocketService::OpenSocket on the IO thread.
media_sink_service_impl_.OnSinkAddedOrUpdated(dial_sink1);
base::RunLoop().RunUntilIdle();
// Invoke CastSocketService::OpenSocket on the IO thread. // Add DIAL sinks to |dial_media_sink_service_|, which in turn notifies
media_sink_service_impl_.OnSinkAddedOrUpdated(dial_sink2); // |media_sink_service_impl_| via the Observer interface.
base::RunLoop().RunUntilIdle(); dial_media_sink_service_.AddOrUpdateSink(dial_sink1);
dial_media_sink_service_.AddOrUpdateSink(dial_sink2);
EXPECT_TRUE(dial_media_sink_service_.timer()->IsRunning());
// Verify sink content. // Verify sink content.
const auto& sinks = media_sink_service_impl_.GetSinks(); const auto& sinks = media_sink_service_impl_.GetSinks();
...@@ -648,6 +647,9 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnSinkAddedOrUpdated) { ...@@ -648,6 +647,9 @@ TEST_F(CastMediaSinkServiceImplTest, TestOnSinkAddedOrUpdated) {
CastMediaSinkServiceImpl::GetCastSinkIdFromDial(dial_sink2.sink().id())); CastMediaSinkServiceImpl::GetCastSinkIdFromDial(dial_sink2.sink().id()));
ASSERT_TRUE(sink); ASSERT_TRUE(sink);
EXPECT_EQ(SinkIconType::CAST_AUDIO, sink->sink().icon_type()); EXPECT_EQ(SinkIconType::CAST_AUDIO, sink->sink().icon_type());
// The sinks are removed from |dial_media_sink_service_|.
EXPECT_TRUE(dial_media_sink_service_.GetSinks().empty());
} }
TEST_F(CastMediaSinkServiceImplTest, TEST_F(CastMediaSinkServiceImplTest,
......
...@@ -25,6 +25,10 @@ MediaSinkInternal::MediaSinkInternal(const MediaSinkInternal& other) { ...@@ -25,6 +25,10 @@ MediaSinkInternal::MediaSinkInternal(const MediaSinkInternal& other) {
InternalCopyConstructFrom(other); InternalCopyConstructFrom(other);
} }
MediaSinkInternal::MediaSinkInternal(MediaSinkInternal&& other) noexcept {
InternalMoveConstructFrom(std::move(other));
}
MediaSinkInternal::~MediaSinkInternal() { MediaSinkInternal::~MediaSinkInternal() {
InternalCleanup(); InternalCleanup();
} }
...@@ -38,6 +42,15 @@ MediaSinkInternal& MediaSinkInternal::operator=( ...@@ -38,6 +42,15 @@ MediaSinkInternal& MediaSinkInternal::operator=(
return *this; return *this;
} }
MediaSinkInternal& MediaSinkInternal::operator=(
MediaSinkInternal&& other) noexcept {
if (this != &other) {
InternalCleanup();
InternalMoveConstructFrom(std::move(other));
}
return *this;
}
bool MediaSinkInternal::operator==(const MediaSinkInternal& other) const { bool MediaSinkInternal::operator==(const MediaSinkInternal& other) const {
if (sink_type_ != other.sink_type_) if (sink_type_ != other.sink_type_)
return false; return false;
...@@ -142,6 +155,23 @@ void MediaSinkInternal::InternalCopyConstructFrom( ...@@ -142,6 +155,23 @@ void MediaSinkInternal::InternalCopyConstructFrom(
NOTREACHED(); NOTREACHED();
} }
void MediaSinkInternal::InternalMoveConstructFrom(MediaSinkInternal&& other) {
sink_ = std::move(other.sink_);
sink_type_ = other.sink_type_;
switch (sink_type_) {
case SinkType::DIAL:
new (&dial_data_) DialSinkExtraData(std::move(other.dial_data_));
return;
case SinkType::CAST:
new (&cast_data_) CastSinkExtraData(std::move(other.cast_data_));
return;
case SinkType::GENERIC:
return;
}
NOTREACHED();
}
void MediaSinkInternal::InternalCleanup() { void MediaSinkInternal::InternalCleanup() {
switch (sink_type_) { switch (sink_type_) {
case SinkType::DIAL: case SinkType::DIAL:
...@@ -158,6 +188,8 @@ void MediaSinkInternal::InternalCleanup() { ...@@ -158,6 +188,8 @@ void MediaSinkInternal::InternalCleanup() {
DialSinkExtraData::DialSinkExtraData() = default; DialSinkExtraData::DialSinkExtraData() = default;
DialSinkExtraData::DialSinkExtraData(const DialSinkExtraData& other) = default; DialSinkExtraData::DialSinkExtraData(const DialSinkExtraData& other) = default;
DialSinkExtraData::DialSinkExtraData(DialSinkExtraData&& other) noexcept =
default;
DialSinkExtraData::~DialSinkExtraData() = default; DialSinkExtraData::~DialSinkExtraData() = default;
bool DialSinkExtraData::operator==(const DialSinkExtraData& other) const { bool DialSinkExtraData::operator==(const DialSinkExtraData& other) const {
...@@ -167,6 +199,8 @@ bool DialSinkExtraData::operator==(const DialSinkExtraData& other) const { ...@@ -167,6 +199,8 @@ bool DialSinkExtraData::operator==(const DialSinkExtraData& other) const {
CastSinkExtraData::CastSinkExtraData() = default; CastSinkExtraData::CastSinkExtraData() = default;
CastSinkExtraData::CastSinkExtraData(const CastSinkExtraData& other) = default; CastSinkExtraData::CastSinkExtraData(const CastSinkExtraData& other) = default;
CastSinkExtraData::CastSinkExtraData(CastSinkExtraData&& other) noexcept =
default;
CastSinkExtraData::~CastSinkExtraData() = default; CastSinkExtraData::~CastSinkExtraData() = default;
bool CastSinkExtraData::operator==(const CastSinkExtraData& other) const { bool CastSinkExtraData::operator==(const CastSinkExtraData& other) const {
......
...@@ -26,6 +26,7 @@ struct DialSinkExtraData { ...@@ -26,6 +26,7 @@ struct DialSinkExtraData {
DialSinkExtraData(); DialSinkExtraData();
DialSinkExtraData(const DialSinkExtraData& other); DialSinkExtraData(const DialSinkExtraData& other);
DialSinkExtraData(DialSinkExtraData&& other) noexcept;
~DialSinkExtraData(); ~DialSinkExtraData();
bool operator==(const DialSinkExtraData& other) const; bool operator==(const DialSinkExtraData& other) const;
...@@ -54,6 +55,7 @@ struct CastSinkExtraData { ...@@ -54,6 +55,7 @@ struct CastSinkExtraData {
CastSinkExtraData(); CastSinkExtraData();
CastSinkExtraData(const CastSinkExtraData& other); CastSinkExtraData(const CastSinkExtraData& other);
CastSinkExtraData(CastSinkExtraData&& other) noexcept;
~CastSinkExtraData(); ~CastSinkExtraData();
bool operator==(const CastSinkExtraData& other) const; bool operator==(const CastSinkExtraData& other) const;
...@@ -73,10 +75,12 @@ class MediaSinkInternal { ...@@ -73,10 +75,12 @@ class MediaSinkInternal {
// Used to push instance of this class into vector. // Used to push instance of this class into vector.
MediaSinkInternal(const MediaSinkInternal& other); MediaSinkInternal(const MediaSinkInternal& other);
MediaSinkInternal(MediaSinkInternal&& other) noexcept;
~MediaSinkInternal(); ~MediaSinkInternal();
MediaSinkInternal& operator=(const MediaSinkInternal& other); MediaSinkInternal& operator=(const MediaSinkInternal& other);
MediaSinkInternal& operator=(MediaSinkInternal&& other) noexcept;
bool operator==(const MediaSinkInternal& other) const; bool operator==(const MediaSinkInternal& other) const;
bool operator!=(const MediaSinkInternal& other) const; bool operator!=(const MediaSinkInternal& other) const;
// Sorted by sink id. // Sorted by sink id.
...@@ -109,6 +113,7 @@ class MediaSinkInternal { ...@@ -109,6 +113,7 @@ class MediaSinkInternal {
private: private:
void InternalCopyConstructFrom(const MediaSinkInternal& other); void InternalCopyConstructFrom(const MediaSinkInternal& other);
void InternalMoveConstructFrom(MediaSinkInternal&& other);
void InternalCleanup(); void InternalCleanup();
enum class SinkType { GENERIC, DIAL, CAST }; enum class SinkType { GENERIC, DIAL, CAST };
......
...@@ -72,6 +72,21 @@ void MediaSinkServiceBase::RemoveSink(const MediaSinkInternal& sink) { ...@@ -72,6 +72,21 @@ void MediaSinkServiceBase::RemoveSink(const MediaSinkInternal& sink) {
StartTimer(); StartTimer();
} }
void MediaSinkServiceBase::RemoveSinkById(const MediaSink::Id& sink_id) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto it = sinks_.find(sink_id);
if (it == sinks_.end())
return;
// Make a copy of the sink to avoid potential use-after-free.
MediaSinkInternal sink = it->second;
sinks_.erase(it);
for (auto& observer : observers_)
observer.OnSinkRemoved(sink);
StartTimer();
}
void MediaSinkServiceBase::SetTimerForTest( void MediaSinkServiceBase::SetTimerForTest(
std::unique_ptr<base::OneShotTimer> timer) { std::unique_ptr<base::OneShotTimer> timer) {
discovery_timer_ = std::move(timer); discovery_timer_ = std::move(timer);
......
...@@ -63,6 +63,7 @@ class MediaSinkServiceBase { ...@@ -63,6 +63,7 @@ class MediaSinkServiceBase {
// Also invokes |StartTimer()|. // Also invokes |StartTimer()|.
void AddOrUpdateSink(const MediaSinkInternal& sink); void AddOrUpdateSink(const MediaSinkInternal& sink);
void RemoveSink(const MediaSinkInternal& sink); void RemoveSink(const MediaSinkInternal& sink);
void RemoveSinkById(const MediaSink::Id& sink_id);
const base::flat_map<MediaSink::Id, MediaSinkInternal>& GetSinks() const; const base::flat_map<MediaSink::Id, MediaSinkInternal>& GetSinks() const;
const MediaSinkInternal* GetSinkById(const MediaSink::Id& sink_id) const; const MediaSinkInternal* GetSinkById(const MediaSink::Id& sink_id) const;
......
...@@ -19,10 +19,12 @@ MediaSink::MediaSink(const MediaSink::Id& sink_id, ...@@ -19,10 +19,12 @@ MediaSink::MediaSink(const MediaSink::Id& sink_id,
provider_id_(provider_id) {} provider_id_(provider_id) {}
MediaSink::MediaSink(const MediaSink& other) = default; MediaSink::MediaSink(const MediaSink& other) = default;
MediaSink::MediaSink(MediaSink&& other) noexcept = default;
MediaSink::MediaSink() = default;
MediaSink::~MediaSink() = default;
MediaSink::MediaSink() {} MediaSink& MediaSink::operator=(const MediaSink& other) = default;
MediaSink& MediaSink::operator=(MediaSink&& other) noexcept = default;
MediaSink::~MediaSink() {}
bool MediaSink::Equals(const MediaSink& other) const { bool MediaSink::Equals(const MediaSink& other) const {
return sink_id_ == other.sink_id_; return sink_id_ == other.sink_id_;
......
...@@ -49,10 +49,13 @@ class MediaSink { ...@@ -49,10 +49,13 @@ class MediaSink {
SinkIconType icon_type, SinkIconType icon_type,
MediaRouteProviderId provider_id = MediaRouteProviderId::UNKNOWN); MediaRouteProviderId provider_id = MediaRouteProviderId::UNKNOWN);
MediaSink(const MediaSink& other); MediaSink(const MediaSink& other);
MediaSink(MediaSink&& other) noexcept;
MediaSink(); MediaSink();
~MediaSink(); ~MediaSink();
MediaSink& operator=(const MediaSink& other);
MediaSink& operator=(MediaSink&& other) noexcept;
void set_sink_id(const MediaSink::Id& sink_id) { sink_id_ = sink_id; } void set_sink_id(const MediaSink::Id& sink_id) { sink_id_ = sink_id; }
const MediaSink::Id& id() const { return sink_id_; } const MediaSink::Id& id() const { return sink_id_; }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment