Commit 698d3431 authored by John Williams's avatar John Williams Committed by Commit Bot

[Cast MRP] Refactored to streamline unit tests.

This patch as a number of changes to streamline unit tests; the main
on is that it names ActivityRecord::SetOrUpdateSession() a nonvirtual
method and delegates subclass-specific functionality to the new
OnSessionSet() and OnSessionUpdated() methods.

Change-Id: I3ce0c0ced968d3e3945ecd215dfe9a25959a49a1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2278138
Commit-Queue: John Williams <jrw@chromium.org>
Reviewed-by: default avatarTakumi Fujimoto <takumif@chromium.org>
Cr-Commit-Position: refs/heads/master@{#786504}
parent 8dfe34d4
...@@ -73,10 +73,10 @@ void ActivityRecord::SetOrUpdateSession(const CastSession& session, ...@@ -73,10 +73,10 @@ void ActivityRecord::SetOrUpdateSession(const CastSession& session,
sink_ = sink; sink_ = sink;
if (session_id_) { if (session_id_) {
DCHECK_EQ(*session_id_, session.session_id()); DCHECK_EQ(*session_id_, session.session_id());
OnSessionUpdated(session, hash_token);
} else { } else {
session_id_ = session.session_id(); session_id_ = session.session_id();
if (on_session_set_) OnSessionSet(session);
std::move(on_session_set_).Run();
} }
} }
...@@ -175,6 +175,9 @@ void ActivityRecord::HandleLeaveSession(const std::string& client_id) { ...@@ -175,6 +175,9 @@ void ActivityRecord::HandleLeaveSession(const std::string& client_id) {
} }
} }
void ActivityRecord::OnSessionUpdated(const CastSession& session,
const std::string& hash_token) {}
CastSessionClientFactoryForTest* ActivityRecord::client_factory_for_test_ = CastSessionClientFactoryForTest* ActivityRecord::client_factory_for_test_ =
nullptr; nullptr;
......
...@@ -74,9 +74,9 @@ class ActivityRecord { ...@@ -74,9 +74,9 @@ class ActivityRecord {
// //
// The |hash_token| parameter is used for hashing receiver IDs in messages // The |hash_token| parameter is used for hashing receiver IDs in messages
// sent to the Cast SDK, and |sink| is the sink associated with |session|. // sent to the Cast SDK, and |sink| is the sink associated with |session|.
virtual void SetOrUpdateSession(const CastSession& session, void SetOrUpdateSession(const CastSession& session,
const MediaSinkInternal& sink, const MediaSinkInternal& sink,
const std::string& hash_token); const std::string& hash_token);
virtual void SendStopSessionMessageToClients(const std::string& hash_token); virtual void SendStopSessionMessageToClients(const std::string& hash_token);
...@@ -144,20 +144,25 @@ class ActivityRecord { ...@@ -144,20 +144,25 @@ class ActivityRecord {
client_factory_for_test_ = factory; client_factory_for_test_ = factory;
} }
void SetSessionAndSinkForTest(const CastSession& session, void SetSessionIdForTest(const std::string& session_id) {
const MediaSinkInternal& sink, session_id_ = session_id;
const std::string& hash_code) {
session_id_ = session.session_id();
sink_ = sink;
} }
protected: protected:
using ClientMap = using ClientMap =
base::flat_map<std::string, std::unique_ptr<CastSessionClient>>; base::flat_map<std::string, std::unique_ptr<CastSessionClient>>;
// Gets the session associated with this activity. May return nullptr. // Gets the session based on its ID. May return null.
CastSession* GetSession() const; CastSession* GetSession() const;
// Called after the session has been set by SetOrUpdateSession. The |session|
// parameters are somewhat redundant because the same information is available
// using the GetSession() method, but passing the parameter avoids some
// unnecessary lookups and eliminates the need to a null check.
virtual void OnSessionSet(const CastSession& session) = 0;
virtual void OnSessionUpdated(const CastSession& session,
const std::string& hash_token);
CastSessionClient* GetClient(const std::string& client_id) { CastSessionClient* GetClient(const std::string& client_id) {
auto it = connected_clients_.find(client_id); auto it = connected_clients_.find(client_id);
return it == connected_clients_.end() ? nullptr : it->second.get(); return it == connected_clients_.end() ? nullptr : it->second.get();
...@@ -169,9 +174,6 @@ class ActivityRecord { ...@@ -169,9 +174,6 @@ class ActivityRecord {
std::string app_id_; std::string app_id_;
base::Optional<int> mirroring_tab_id_; base::Optional<int> mirroring_tab_id_;
// Called when a session is initially set from SetOrUpdateSession().
base::OnceCallback<void()> on_session_set_;
// TODO(https://crbug.com/809249): Consider wrapping CastMessageHandler with // TODO(https://crbug.com/809249): Consider wrapping CastMessageHandler with
// known parameters (sink, client ID, session transport ID) and passing them // known parameters (sink, client ID, session transport ID) and passing them
// to objects that need to send messages to the receiver. // to objects that need to send messages to the receiver.
......
...@@ -161,8 +161,8 @@ void CastActivityManager::LaunchSessionParsed( ...@@ -161,8 +161,8 @@ void CastActivityManager::LaunchSessionParsed(
activity_it->second->route().media_route_id(); activity_it->second->route().media_route_id();
// We cannot launch the new session in the TerminateSession() callback // We cannot launch the new session in the TerminateSession() callback
// because if we create a session there, then it may get deleted when // because if we create a session there, then it may get deleted when
// OnSessionRemoved() is called to notify that the previous session was // OnSessionRemoved() is called to notify that the previous session
// removed on the receiver. // was removed on the receiver.
TerminateSession(existing_route_id, base::DoNothing()); TerminateSession(existing_route_id, base::DoNothing());
// The new session will be launched when OnSessionRemoved() is called for // The new session will be launched when OnSessionRemoved() is called for
// the old session. // the old session.
...@@ -487,8 +487,7 @@ ActivityRecord* CastActivityManager::AddMirroringActivityRecord( ...@@ -487,8 +487,7 @@ ActivityRecord* CastActivityManager::AddMirroringActivityRecord(
: std::make_unique<MirroringActivityRecord>( : std::make_unique<MirroringActivityRecord>(
route, app_id, message_handler_, session_tracker_, tab_id, route, app_id, message_handler_, session_tracker_, tab_id,
cast_data, std::move(on_stop)); cast_data, std::move(on_stop));
if (route.is_local()) activity->CreateMojoBindings(media_router_);
activity->CreateMojoBindings(media_router_);
auto* const activity_ptr = activity.get(); auto* const activity_ptr = activity.get();
activities_.emplace(route.media_route_id(), std::move(activity)); activities_.emplace(route.media_route_id(), std::move(activity));
return activity_ptr; return activity_ptr;
......
...@@ -29,15 +29,16 @@ CastActivityRecord::CastActivityRecord( ...@@ -29,15 +29,16 @@ CastActivityRecord::CastActivityRecord(
CastActivityRecord::~CastActivityRecord() = default; CastActivityRecord::~CastActivityRecord() = default;
void CastActivityRecord::SetOrUpdateSession(const CastSession& session, void CastActivityRecord::OnSessionSet(const CastSession& session) {
const MediaSinkInternal& sink, if (media_controller_)
const std::string& hash_token) { media_controller_->SetSession(session);
bool had_session_id = session_id_.has_value(); }
ActivityRecord::SetOrUpdateSession(session, sink, hash_token);
if (had_session_id) { void CastActivityRecord::OnSessionUpdated(const CastSession& session,
for (auto& client : connected_clients_) const std::string& hash_token) {
client.second->SendMessageToClient( for (auto& client : connected_clients_) {
CreateUpdateSessionMessage(session, client.first, sink, hash_token)); client.second->SendMessageToClient(
CreateUpdateSessionMessage(session, client.first, sink_, hash_token));
} }
if (media_controller_) if (media_controller_)
media_controller_->SetSession(session); media_controller_->SetSession(session);
......
...@@ -40,10 +40,6 @@ class CastActivityRecord : public ActivityRecord { ...@@ -40,10 +40,6 @@ class CastActivityRecord : public ActivityRecord {
CastSessionTracker* session_tracker); CastSessionTracker* session_tracker);
~CastActivityRecord() override; ~CastActivityRecord() override;
// ActivityRecord implementation
void SetOrUpdateSession(const CastSession& session,
const MediaSinkInternal& sink,
const std::string& hash_token) override;
void SendMediaStatusToClients(const base::Value& media_status, void SendMediaStatusToClients(const base::Value& media_status,
base::Optional<int> request_id) override; base::Optional<int> request_id) override;
void OnAppMessage(const cast::channel::CastMessage& message) override; void OnAppMessage(const cast::channel::CastMessage& message) override;
...@@ -64,6 +60,9 @@ class CastActivityRecord : public ActivityRecord { ...@@ -64,6 +60,9 @@ class CastActivityRecord : public ActivityRecord {
bool HasJoinableClient(AutoJoinPolicy policy, bool HasJoinableClient(AutoJoinPolicy policy,
const url::Origin& origin, const url::Origin& origin,
int tab_id) const; int tab_id) const;
void OnSessionSet(const CastSession& session) override;
void OnSessionUpdated(const CastSession& session,
const std::string& hash_token) override;
private: private:
friend class CastSessionClientImpl; friend class CastSessionClientImpl;
......
...@@ -71,24 +71,21 @@ class MockPresentationConnection : public blink::mojom::PresentationConnection { ...@@ -71,24 +71,21 @@ class MockPresentationConnection : public blink::mojom::PresentationConnection {
} // namespace } // namespace
#define EXPECT_ERROR_LOG(matcher) \
if (DLOG_IS_ON(ERROR)) { \
EXPECT_CALL(log_, Log(logging::LOG_ERROR, _, _, _, matcher)) \
.WillOnce(Return(true)); /* suppress logging */ \
}
class CastSessionClientImplTest : public testing::Test { class CastSessionClientImplTest : public testing::Test {
public: public:
CastSessionClientImplTest() { activity_.set_session_id("theSessionId"); } CastSessionClientImplTest() { activity_.SetSessionIdForTest("theSessionId"); }
~CastSessionClientImplTest() override { RunUntilIdle(); } ~CastSessionClientImplTest() override { RunUntilIdle(); }
protected: protected:
void RunUntilIdle() { task_environment_.RunUntilIdle(); } void RunUntilIdle() { task_environment_.RunUntilIdle(); }
template <typename T>
void ExpectErrorLog(const T& matcher) {
if (DLOG_IS_ON(ERROR)) {
EXPECT_CALL(log_, Log(logging::LOG_ERROR, _, _, _,
matcher))
.WillOnce(Return(true)); // suppress logging
}
}
content::BrowserTaskEnvironment task_environment_; content::BrowserTaskEnvironment task_environment_;
data_decoder::test::InProcessDataDecoder in_process_data_decoder_; data_decoder::test::InProcessDataDecoder in_process_data_decoder_;
cast_channel::MockCastSocketService socket_service_{ cast_channel::MockCastSocketService socket_service_{
...@@ -111,7 +108,7 @@ class CastSessionClientImplTest : public testing::Test { ...@@ -111,7 +108,7 @@ class CastSessionClientImplTest : public testing::Test {
TEST_F(CastSessionClientImplTest, OnInvalidJson) { TEST_F(CastSessionClientImplTest, OnInvalidJson) {
// TODO(crbug.com/905002): Check UMA calls instead of logging (here and // TODO(crbug.com/905002): Check UMA calls instead of logging (here and
// below). // below).
ExpectErrorLog(HasSubstr("Failed to parse Cast client message")); EXPECT_ERROR_LOG(HasSubstr("Failed to parse Cast client message"));
log_.StartCapturingLogs(); log_.StartCapturingLogs();
client_->OnMessage( client_->OnMessage(
...@@ -119,8 +116,8 @@ TEST_F(CastSessionClientImplTest, OnInvalidJson) { ...@@ -119,8 +116,8 @@ TEST_F(CastSessionClientImplTest, OnInvalidJson) {
} }
TEST_F(CastSessionClientImplTest, OnInvalidMessage) { TEST_F(CastSessionClientImplTest, OnInvalidMessage) {
ExpectErrorLog(AllOf(HasSubstr("Failed to parse Cast client message"), EXPECT_ERROR_LOG(AllOf(HasSubstr("Failed to parse Cast client message"),
HasSubstr("Not a Cast message"))); HasSubstr("Not a Cast message")));
log_.StartCapturingLogs(); log_.StartCapturingLogs();
client_->OnMessage( client_->OnMessage(
...@@ -128,9 +125,9 @@ TEST_F(CastSessionClientImplTest, OnInvalidMessage) { ...@@ -128,9 +125,9 @@ TEST_F(CastSessionClientImplTest, OnInvalidMessage) {
} }
TEST_F(CastSessionClientImplTest, OnMessageWrongClientId) { TEST_F(CastSessionClientImplTest, OnMessageWrongClientId) {
ExpectErrorLog(AllOf(HasSubstr("Client ID mismatch"), EXPECT_ERROR_LOG(AllOf(HasSubstr("Client ID mismatch"),
HasSubstr("theClientId"), HasSubstr("theClientId"),
HasSubstr("theWrongClientId"))); HasSubstr("theWrongClientId")));
log_.StartCapturingLogs(); log_.StartCapturingLogs();
client_->OnMessage( client_->OnMessage(
...@@ -145,9 +142,9 @@ TEST_F(CastSessionClientImplTest, OnMessageWrongClientId) { ...@@ -145,9 +142,9 @@ TEST_F(CastSessionClientImplTest, OnMessageWrongClientId) {
} }
TEST_F(CastSessionClientImplTest, OnMessageWrongSessionId) { TEST_F(CastSessionClientImplTest, OnMessageWrongSessionId) {
ExpectErrorLog(AllOf(HasSubstr("Session ID mismatch"), EXPECT_ERROR_LOG(AllOf(HasSubstr("Session ID mismatch"),
HasSubstr("theSessionId"), HasSubstr("theSessionId"),
HasSubstr("theWrongSessionId"))); HasSubstr("theWrongSessionId")));
log_.StartCapturingLogs(); log_.StartCapturingLogs();
client_->OnMessage( client_->OnMessage(
......
...@@ -142,7 +142,8 @@ MirroringActivityRecord::~MirroringActivityRecord() { ...@@ -142,7 +142,8 @@ MirroringActivityRecord::~MirroringActivityRecord() {
void MirroringActivityRecord::CreateMojoBindings( void MirroringActivityRecord::CreateMojoBindings(
mojom::MediaRouter* media_router) { mojom::MediaRouter* media_router) {
DCHECK(mirroring_type_); if (!mirroring_type_)
return;
// Get a reference to the mirroring service host. // Get a reference to the mirroring service host.
switch (*mirroring_type_) { switch (*mirroring_type_) {
...@@ -165,28 +166,9 @@ void MirroringActivityRecord::CreateMojoBindings( ...@@ -165,28 +166,9 @@ void MirroringActivityRecord::CreateMojoBindings(
break; break;
} }
auto cast_source = CastMediaSource::FromMediaSource(route_.media_source()); DCHECK(!channel_to_service_receiver_);
DCHECK(cast_source); channel_to_service_receiver_ =
channel_to_service_.BindNewPipeAndPassReceiver();
// Derive session type from capabilities and media source.
const bool has_audio = (cast_data_.capabilities &
static_cast<uint8_t>(cast_channel::AUDIO_OUT)) != 0 &&
cast_source->allow_audio_capture();
const bool has_video = (cast_data_.capabilities &
static_cast<uint8_t>(cast_channel::VIDEO_OUT)) != 0;
DCHECK(has_audio || has_video);
const SessionType session_type =
has_audio && has_video
? SessionType::AUDIO_AND_VIDEO
: has_audio ? SessionType::AUDIO_ONLY : SessionType::VIDEO_ONLY;
// Arrange to start mirroring once the session is set.
on_session_set_ = base::BindOnce(
&MirroringActivityRecord::StartMirroring, base::Unretained(this),
SessionParameters::New(session_type, cast_data_.ip_endpoint.address(),
cast_data_.model_name,
cast_source->target_playout_delay()),
channel_to_service_.BindNewPipeAndPassReceiver());
} }
void MirroringActivityRecord::OnError(SessionError error) { void MirroringActivityRecord::OnError(SessionError error) {
...@@ -293,9 +275,25 @@ void MirroringActivityRecord::HandleParseJsonResult( ...@@ -293,9 +275,25 @@ void MirroringActivityRecord::HandleParseJsonResult(
message_handler_->SendCastMessage(cast_data_.cast_channel_id, cast_message); message_handler_->SendCastMessage(cast_data_.cast_channel_id, cast_message);
} }
void MirroringActivityRecord::StartMirroring( void MirroringActivityRecord::OnSessionSet(const CastSession& session) {
mirroring::mojom::SessionParametersPtr session_params, if (!mirroring_type_)
mojo::PendingReceiver<CastMessageChannel> channel_to_service) { return;
auto cast_source = CastMediaSource::FromMediaSource(route_.media_source());
DCHECK(cast_source);
// Derive session type from capabilities and media source.
const bool has_audio = (cast_data_.capabilities &
static_cast<uint8_t>(cast_channel::AUDIO_OUT)) != 0 &&
cast_source->allow_audio_capture();
const bool has_video = (cast_data_.capabilities &
static_cast<uint8_t>(cast_channel::VIDEO_OUT)) != 0;
DCHECK(has_audio || has_video);
const SessionType session_type =
has_audio && has_video
? SessionType::AUDIO_AND_VIDEO
: has_audio ? SessionType::AUDIO_ONLY : SessionType::VIDEO_ONLY;
will_start_mirroring_timestamp_ = base::Time::Now(); will_start_mirroring_timestamp_ = base::Time::Now();
// Bind Mojo receivers for the interfaces this object implements. // Bind Mojo receivers for the interfaces this object implements.
...@@ -304,8 +302,15 @@ void MirroringActivityRecord::StartMirroring( ...@@ -304,8 +302,15 @@ void MirroringActivityRecord::StartMirroring(
mojo::PendingRemote<mirroring::mojom::CastMessageChannel> channel_remote; mojo::PendingRemote<mirroring::mojom::CastMessageChannel> channel_remote;
channel_receiver_.Bind(channel_remote.InitWithNewPipeAndPassReceiver()); channel_receiver_.Bind(channel_remote.InitWithNewPipeAndPassReceiver());
host_->Start(std::move(session_params), std::move(observer_remote), // If this fails, it's probably because CreateMojoBindings() hasn't been
std::move(channel_remote), std::move(channel_to_service)); // called.
DCHECK(channel_to_service_receiver_);
host_->Start(SessionParameters::New(
session_type, cast_data_.ip_endpoint.address(),
cast_data_.model_name, cast_source->target_playout_delay()),
std::move(observer_remote), std::move(channel_remote),
std::move(channel_to_service_receiver_));
} }
void MirroringActivityRecord::StopMirroring() { void MirroringActivityRecord::StopMirroring() {
......
...@@ -66,6 +66,7 @@ class MirroringActivityRecord : public ActivityRecord, ...@@ -66,6 +66,7 @@ class MirroringActivityRecord : public ActivityRecord,
void OnInternalMessage(const cast_channel::InternalMessage& message) override; void OnInternalMessage(const cast_channel::InternalMessage& message) override;
protected: protected:
void OnSessionSet(const CastSession& session) override;
void CreateMediaController( void CreateMediaController(
mojo::PendingReceiver<mojom::MediaController> media_controller, mojo::PendingReceiver<mojom::MediaController> media_controller,
mojo::PendingRemote<mojom::MediaStatusObserver> observer) override; mojo::PendingRemote<mojom::MediaStatusObserver> observer) override;
...@@ -74,9 +75,6 @@ class MirroringActivityRecord : public ActivityRecord, ...@@ -74,9 +75,6 @@ class MirroringActivityRecord : public ActivityRecord,
void HandleParseJsonResult(const std::string& route_id, void HandleParseJsonResult(const std::string& route_id,
data_decoder::DataDecoder::ValueOrError result); data_decoder::DataDecoder::ValueOrError result);
void StartMirroring(
mirroring::mojom::SessionParametersPtr session_params,
mojo::PendingReceiver<CastMessageChannel> channel_to_service);
void StopMirroring(); void StopMirroring();
mojo::Remote<mirroring::mojom::MirroringServiceHost> host_; mojo::Remote<mirroring::mojom::MirroringServiceHost> host_;
...@@ -84,6 +82,11 @@ class MirroringActivityRecord : public ActivityRecord, ...@@ -84,6 +82,11 @@ class MirroringActivityRecord : public ActivityRecord,
// Sends Cast messages from the mirroring receiver to the mirroring service. // Sends Cast messages from the mirroring receiver to the mirroring service.
mojo::Remote<mirroring::mojom::CastMessageChannel> channel_to_service_; mojo::Remote<mirroring::mojom::CastMessageChannel> channel_to_service_;
// Only used to store pending CastMessageChannel receiver while waiting for
// OnSessionSet() to be called.
mojo::PendingReceiver<mirroring::mojom::CastMessageChannel>
channel_to_service_receiver_;
mojo::Receiver<mirroring::mojom::SessionObserver> observer_receiver_{this}; mojo::Receiver<mirroring::mojom::SessionObserver> observer_receiver_{this};
// To handle Cast messages from the mirroring service to the mirroring // To handle Cast messages from the mirroring service to the mirroring
......
...@@ -21,12 +21,6 @@ class MockCastActivityRecord : public CastActivityRecord { ...@@ -21,12 +21,6 @@ class MockCastActivityRecord : public CastActivityRecord {
MockCastActivityRecord(const MediaRoute& route, const std::string& app_id); MockCastActivityRecord(const MediaRoute& route, const std::string& app_id);
~MockCastActivityRecord() override; ~MockCastActivityRecord() override;
void set_session_id(const std::string& new_id) {
if (!session_id_)
session_id_ = new_id;
ASSERT_EQ(session_id_, new_id);
}
MOCK_METHOD1(SendAppMessageToReceiver, MOCK_METHOD1(SendAppMessageToReceiver,
cast_channel::Result(const CastInternalMessage& cast_message)); cast_channel::Result(const CastInternalMessage& cast_message));
MOCK_METHOD1(SendMediaRequestToReceiver, MOCK_METHOD1(SendMediaRequestToReceiver,
...@@ -47,10 +41,9 @@ class MockCastActivityRecord : public CastActivityRecord { ...@@ -47,10 +41,9 @@ class MockCastActivityRecord : public CastActivityRecord {
const url::Origin& origin, const url::Origin& origin,
int tab_id)); int tab_id));
MOCK_METHOD1(RemoveClient, void(const std::string& client_id)); MOCK_METHOD1(RemoveClient, void(const std::string& client_id));
MOCK_METHOD3(SetOrUpdateSession, MOCK_METHOD1(OnSessionSet, void(const CastSession& session));
void(const CastSession& session, MOCK_METHOD2(OnSessionUpdated,
const MediaSinkInternal& sink, void(const CastSession& session, const std::string& hash_token));
const std::string& hash_token));
MOCK_METHOD2(SendMessageToClient, MOCK_METHOD2(SendMessageToClient,
void(const std::string& client_id, void(const std::string& client_id,
blink::mojom::PresentationConnectionMessagePtr message)); blink::mojom::PresentationConnectionMessagePtr message));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment