Commit 8364900b authored by mark a. foltz's avatar mark a. foltz Committed by Commit Bot

[Media Router] Close Cast virtual connections.

This patch closes a Cast virtual connection between a client and a
session in two cases:

1. The corresponding PresentationConnection is closed by the renderer.
2. The client sends a "leave session" message.

Case #2 could be implemented by #1, but for now, they are both
supported for backwards compatibility.

Bug: 1059053
Change-Id: I4bfaf5fd4a49adc16980f9ef12b2502c308c1770
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2099606Reviewed-by: default avatarTakumi Fujimoto <takumif@chromium.org>
Commit-Queue: mark a. foltz <mfoltz@chromium.org>
Cr-Commit-Position: refs/heads/master@{#751157}
parent 9019a8dd
...@@ -144,6 +144,14 @@ void ActivityRecord::StopSessionOnReceiver( ...@@ -144,6 +144,14 @@ void ActivityRecord::StopSessionOnReceiver(
std::move(callback)); std::move(callback));
} }
void ActivityRecord::CloseConnectionOnReceiver(const std::string& client_id) {
CastSession* session = GetSession();
if (!session)
return;
message_handler_->CloseConnection(cast_channel_id(), client_id,
session->transport_id());
}
void ActivityRecord::HandleLeaveSession(const std::string& client_id) { void ActivityRecord::HandleLeaveSession(const std::string& client_id) {
auto client_it = connected_clients_.find(client_id); auto client_it = connected_clients_.find(client_id);
CHECK(client_it != connected_clients_.end()); CHECK(client_it != connected_clients_.end());
......
...@@ -116,6 +116,10 @@ class ActivityRecord { ...@@ -116,6 +116,10 @@ class ActivityRecord {
virtual void StopSessionOnReceiver(const std::string& client_id, virtual void StopSessionOnReceiver(const std::string& client_id,
cast_channel::ResultCallback callback); cast_channel::ResultCallback callback);
// Closes any virtual connection between |client_id| and this session on the
// receiver.
virtual void CloseConnectionOnReceiver(const std::string& client_id);
// Called when the client given by |client_id| requests to leave the session. // Called when the client given by |client_id| requests to leave the session.
// This will also cause all clients within the session with matching origin // This will also cause all clients within the session with matching origin
// and/or tab ID to leave (i.e., their presentation connections will be // and/or tab ID to leave (i.e., their presentation connections will be
......
...@@ -455,4 +455,13 @@ TEST_F(CastActivityRecordTest, OnAppMessageAllClients) { ...@@ -455,4 +455,13 @@ TEST_F(CastActivityRecordTest, OnAppMessageAllClients) {
record_->OnAppMessage(message); record_->OnAppMessage(message);
} }
TEST_F(CastActivityRecordTest, CloseConnectionOnReceiver) {
SetUpSession();
AddMockClient("theClientId1");
EXPECT_CALL(message_handler_, CloseConnection(kChannelId, "theClientId1",
session_->transport_id()));
record_->CloseConnectionOnReceiver("theClientId1");
}
} // namespace media_router } // namespace media_router
...@@ -109,9 +109,7 @@ void CastSessionClientImpl::OnMessage( ...@@ -109,9 +109,7 @@ void CastSessionClientImpl::OnMessage(
} }
void CastSessionClientImpl::DidClose(PresentationConnectionCloseReason reason) { void CastSessionClientImpl::DidClose(PresentationConnectionCloseReason reason) {
// TODO(https://crbug.com/809249): Implement close connection with this activity_->CloseConnectionOnReceiver(client_id());
// method once we make sure Blink calls this on navigation and on
// PresentationConnection::close().
} }
void CastSessionClientImpl::SendErrorCodeToClient( void CastSessionClientImpl::SendErrorCodeToClient(
...@@ -250,8 +248,8 @@ void CastSessionClientImpl::CloseConnection( ...@@ -250,8 +248,8 @@ void CastSessionClientImpl::CloseConnection(
PresentationConnectionCloseReason close_reason) { PresentationConnectionCloseReason close_reason) {
if (connection_remote_) if (connection_remote_)
connection_remote_->DidClose(close_reason); connection_remote_->DidClose(close_reason);
TearDownPresentationConnection(); TearDownPresentationConnection();
activity_->CloseConnectionOnReceiver(client_id());
} }
void CastSessionClientImpl::TerminateConnection() { void CastSessionClientImpl::TerminateConnection() {
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
using base::test::IsJson; using base::test::IsJson;
using base::test::ParseJson; using base::test::ParseJson;
using blink::mojom::PresentationConnectionCloseReason;
using testing::_; using testing::_;
using testing::AllOf; using testing::AllOf;
using testing::AnyNumber; using testing::AnyNumber;
...@@ -269,4 +270,14 @@ TEST_F(CastSessionClientImplTest, SendStopSessionCommandToReceiver) { ...@@ -269,4 +270,14 @@ TEST_F(CastSessionClientImplTest, SendStopSessionCommandToReceiver) {
})")); })"));
} }
TEST_F(CastSessionClientImplTest, CloseConnection) {
EXPECT_CALL(activity_, CloseConnectionOnReceiver("theClientId"));
client_->CloseConnection(PresentationConnectionCloseReason::CLOSED);
}
TEST_F(CastSessionClientImplTest, DidCloseConnection) {
EXPECT_CALL(activity_, CloseConnectionOnReceiver("theClientId"));
client_->DidClose(PresentationConnectionCloseReason::WENT_AWAY);
}
} // namespace media_router } // namespace media_router
...@@ -37,6 +37,7 @@ class MockCastActivityRecord : public CastActivityRecord { ...@@ -37,6 +37,7 @@ class MockCastActivityRecord : public CastActivityRecord {
MOCK_METHOD2(StopSessionOnReceiver, MOCK_METHOD2(StopSessionOnReceiver,
void(const std::string& client_id, void(const std::string& client_id,
cast_channel::ResultCallback callback)); cast_channel::ResultCallback callback));
MOCK_METHOD1(CloseConnectionOnReceiver, void(const std::string& client_id));
MOCK_METHOD1(SendStopSessionMessageToClients, MOCK_METHOD1(SendStopSessionMessageToClients,
void(const std::string& hash_token)); void(const std::string& hash_token));
MOCK_METHOD1(HandleLeaveSession, void(const std::string& client_id)); MOCK_METHOD1(HandleLeaveSession, void(const std::string& client_id));
......
...@@ -94,6 +94,33 @@ void CastMessageHandler::EnsureConnection(int channel_id, ...@@ -94,6 +94,33 @@ void CastMessageHandler::EnsureConnection(int channel_id,
DoEnsureConnection(socket, source_id, destination_id); DoEnsureConnection(socket, source_id, destination_id);
} }
void CastMessageHandler::CloseConnection(int channel_id,
const std::string& source_id,
const std::string& destination_id) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CastSocket* socket = socket_service_->GetSocket(channel_id);
if (!socket) {
return;
}
VirtualConnection connection(socket->id(), source_id, destination_id);
if (virtual_connections_.find(connection) == virtual_connections_.end())
return;
VLOG(1) << "Closing VC for channel: " << connection.channel_id
<< ", source: " << connection.source_id
<< ", dest: " << connection.destination_id;
socket->transport()->SendMessage(
CreateVirtualConnectionClose(connection.source_id,
connection.destination_id),
base::BindOnce(&CastMessageHandler::OnMessageSent,
weak_ptr_factory_.GetWeakPtr()));
// Assume the virtual connection close will succeed. Eventually the receiver
// will remove the connection even if it doesn't.
virtual_connections_.erase(connection);
}
CastMessageHandler::PendingRequests* CastMessageHandler::PendingRequests*
CastMessageHandler::GetOrCreatePendingRequests(int channel_id) { CastMessageHandler::GetOrCreatePendingRequests(int channel_id) {
CastMessageHandler::PendingRequests* requests = nullptr; CastMessageHandler::PendingRequests* requests = nullptr;
...@@ -300,9 +327,6 @@ void CastMessageHandler::OnError(const CastSocket& socket, ...@@ -300,9 +327,6 @@ void CastMessageHandler::OnError(const CastSocket& socket,
void CastMessageHandler::OnMessage(const CastSocket& socket, void CastMessageHandler::OnMessage(const CastSocket& socket,
const CastMessage& message) { const CastMessage& message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DVLOG(2) << __func__ << ", channel_id: " << socket.id()
<< ", message: " << message;
// TODO(jrw): Splitting internal messages into a separate code path with a // TODO(jrw): Splitting internal messages into a separate code path with a
// separate data type is pretty questionable, because it causes duplicated // separate data type is pretty questionable, because it causes duplicated
// code paths in the downstream logic (manifested as separate OnAppMessage and // code paths in the downstream logic (manifested as separate OnAppMessage and
...@@ -310,6 +334,8 @@ void CastMessageHandler::OnMessage(const CastSocket& socket, ...@@ -310,6 +334,8 @@ void CastMessageHandler::OnMessage(const CastSocket& socket,
if (IsCastInternalNamespace(message.namespace_())) { if (IsCastInternalNamespace(message.namespace_())) {
if (message.payload_type() == if (message.payload_type() ==
cast::channel::CastMessage_PayloadType_STRING) { cast::channel::CastMessage_PayloadType_STRING) {
VLOG(1) << __func__ << ": channel_id: " << socket.id()
<< ", message: " << message;
parse_json_.Run( parse_json_.Run(
message.payload_utf8(), message.payload_utf8(),
base::BindOnce(&CastMessageHandler::HandleCastInternalMessage, base::BindOnce(&CastMessageHandler::HandleCastInternalMessage,
...@@ -390,6 +416,8 @@ void CastMessageHandler::SendCastMessageToSocket(CastSocket* socket, ...@@ -390,6 +416,8 @@ void CastMessageHandler::SendCastMessageToSocket(CastSocket* socket,
// A virtual connection must be opened to the receiver before other messages // A virtual connection must be opened to the receiver before other messages
// can be sent. // can be sent.
DoEnsureConnection(socket, message.source_id(), message.destination_id()); DoEnsureConnection(socket, message.source_id(), message.destination_id());
VLOG(1) << __func__ << ": channel_id: " << socket->id()
<< ", message: " << message;
socket->transport()->SendMessage( socket->transport()->SendMessage(
message, base::BindOnce(&CastMessageHandler::OnMessageSent, message, base::BindOnce(&CastMessageHandler::OnMessageSent,
weak_ptr_factory_.GetWeakPtr())); weak_ptr_factory_.GetWeakPtr()));
...@@ -404,9 +432,9 @@ void CastMessageHandler::DoEnsureConnection(CastSocket* socket, ...@@ -404,9 +432,9 @@ void CastMessageHandler::DoEnsureConnection(CastSocket* socket,
if (virtual_connections_.find(connection) != virtual_connections_.end()) if (virtual_connections_.find(connection) != virtual_connections_.end())
return; return;
DVLOG(1) << "Creating VC for channel: " << connection.channel_id VLOG(1) << "Creating VC for channel: " << connection.channel_id
<< ", source: " << connection.source_id << ", source: " << connection.source_id
<< ", dest: " << connection.destination_id; << ", dest: " << connection.destination_id;
CastMessage virtual_connection_request = CreateVirtualConnectionRequest( CastMessage virtual_connection_request = CreateVirtualConnectionRequest(
connection.source_id, connection.destination_id, connection.source_id, connection.destination_id,
connection.destination_id == kPlatformReceiverId connection.destination_id == kPlatformReceiverId
......
...@@ -153,6 +153,13 @@ class CastMessageHandler : public CastSocket::Observer { ...@@ -153,6 +153,13 @@ class CastMessageHandler : public CastSocket::Observer {
const std::string& source_id, const std::string& source_id,
const std::string& destination_id); const std::string& destination_id);
// Closes any virtual connection on (|source_id|, |destination_id|) on the
// device given by |channel_id|, sending a virtual connection close request to
// the device if necessary.
virtual void CloseConnection(int channel_id,
const std::string& source_id,
const std::string& destination_id);
// Sends an app availability for |app_id| to the device given by |socket|. // Sends an app availability for |app_id| to the device given by |socket|.
// |callback| is always invoked asynchronously, and will be invoked when a // |callback| is always invoked asynchronously, and will be invoked when a
// response is received, or if the request timed out. No-ops if there is // response is received, or if the request timed out. No-ops if there is
......
...@@ -277,6 +277,20 @@ TEST_F(CastMessageHandlerTest, EnsureConnection) { ...@@ -277,6 +277,20 @@ TEST_F(CastMessageHandlerTest, EnsureConnection) {
handler_.EnsureConnection(channel_id_, kSourceId, kDestinationId); handler_.EnsureConnection(channel_id_, kSourceId, kDestinationId);
} }
TEST_F(CastMessageHandlerTest, CloseConnection) {
ExpectEnsureConnection();
handler_.EnsureConnection(channel_id_, kSourceId, kDestinationId);
EXPECT_CALL(
*transport_,
SendMessage(HasMessageType(CastMessageType::kCloseConnection), _));
handler_.CloseConnection(channel_id_, kSourceId, kDestinationId);
// Re-open virtual connection should cause CONNECT message to be sent.
ExpectEnsureConnection();
handler_.EnsureConnection(channel_id_, kSourceId, kDestinationId);
}
TEST_F(CastMessageHandlerTest, CloseConnectionFromReceiver) { TEST_F(CastMessageHandlerTest, CloseConnectionFromReceiver) {
ExpectEnsureConnection(); ExpectEnsureConnection();
handler_.EnsureConnection(channel_id_, kSourceId, kDestinationId); handler_.EnsureConnection(channel_id_, kSourceId, kDestinationId);
......
...@@ -107,6 +107,10 @@ constexpr int kVirtualConnectSdkType = 2; ...@@ -107,6 +107,10 @@ constexpr int kVirtualConnectSdkType = 2;
// stands for CONNECTION_TYPE_LOCAL, which is the only type used in Chrome. // stands for CONNECTION_TYPE_LOCAL, which is the only type used in Chrome.
constexpr int kVirtualConnectTypeLocal = 1; constexpr int kVirtualConnectTypeLocal = 1;
// The reason code passed to the virtual connection CLOSE message indicating
// that the connection has been gracefully closed by the sender.
constexpr int kVirtualConnectionClosedByPeer = 5;
void FillCommonCastMessageFields(CastMessage* message, void FillCommonCastMessageFields(CastMessage* message,
const std::string& source_id, const std::string& source_id,
const std::string& destination_id, const std::string& destination_id,
...@@ -349,6 +353,18 @@ CastMessage CreateVirtualConnectionRequest( ...@@ -349,6 +353,18 @@ CastMessage CreateVirtualConnectionRequest(
destination_id); destination_id);
} }
CastMessage CreateVirtualConnectionClose(const std::string& source_id,
const std::string& destination_id) {
Value dict(Value::Type::DICTIONARY);
dict.SetKey(
"type",
Value(
EnumToString<CastMessageType, CastMessageType::kCloseConnection>()));
dict.SetKey("reasonCode", Value(kVirtualConnectionClosedByPeer));
return CreateCastMessage(kConnectionNamespace, dict, source_id,
destination_id);
}
CastMessage CreateGetAppAvailabilityRequest(const std::string& source_id, CastMessage CreateGetAppAvailabilityRequest(const std::string& source_id,
int request_id, int request_id,
const std::string& app_id) { const std::string& app_id) {
......
...@@ -202,6 +202,9 @@ CastMessage CreateVirtualConnectionRequest( ...@@ -202,6 +202,9 @@ CastMessage CreateVirtualConnectionRequest(
const std::string& user_agent, const std::string& user_agent,
const std::string& browser_version); const std::string& browser_version);
CastMessage CreateVirtualConnectionClose(const std::string& source_id,
const std::string& destination_id);
// Creates an app availability request for |app_id| from |source_id| with // Creates an app availability request for |app_id| from |source_id| with
// ID |request_id|. // ID |request_id|.
// TODO(imcheng): May not need |source_id|, just use sender-0? // TODO(imcheng): May not need |source_id|, just use sender-0?
......
...@@ -88,6 +88,23 @@ TEST(CastMessageUtilTest, CreateStopRequest) { ...@@ -88,6 +88,23 @@ TEST(CastMessageUtilTest, CreateStopRequest) {
EXPECT_THAT(message.payload_utf8(), IsJson(expected_message)); EXPECT_THAT(message.payload_utf8(), IsJson(expected_message));
} }
TEST(CastMessageUtilTest, CreateVirtualConnectionClose) {
std::string expected_message = R"(
{
"type": "CLOSE",
"reasonCode": 5
}
)";
CastMessage message =
CreateVirtualConnectionClose("sourceId", "destinationId");
ASSERT_TRUE(IsCastMessageValid(message));
EXPECT_EQ(message.source_id(), "sourceId");
EXPECT_EQ(message.destination_id(), "destinationId");
EXPECT_EQ(message.namespace_(), kConnectionNamespace);
EXPECT_THAT(message.payload_utf8(), IsJson(expected_message));
}
TEST(CastMessageUtilTest, CreateReceiverStatusRequest) { TEST(CastMessageUtilTest, CreateReceiverStatusRequest) {
std::string expected_message = R"( std::string expected_message = R"(
{ {
......
...@@ -167,6 +167,8 @@ class MockCastMessageHandler : public CastMessageHandler { ...@@ -167,6 +167,8 @@ class MockCastMessageHandler : public CastMessageHandler {
MOCK_METHOD3(EnsureConnection, MOCK_METHOD3(EnsureConnection,
void(int, const std::string&, const std::string&)); void(int, const std::string&, const std::string&));
MOCK_METHOD3(CloseConnection,
void(int, const std::string&, const std::string&));
MOCK_METHOD3(RequestAppAvailability, MOCK_METHOD3(RequestAppAvailability,
void(CastSocket* socket, void(CastSocket* socket,
const std::string& app_id, const std::string& app_id,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment