Commit 08a23048 authored by Bin Zhao's avatar Bin Zhao Committed by Commit Bot

[cast_channel] Make CastSocket::OnOpenCallback take CastSocket* parameter

Currently we have

  CastSocket::OnOpenCallBack 
    = base::OnceCallback<void(int channel_id, ChannelError error_state)>. 

We need to call 

  CastSocket* socket = cast_socket_service_->GetSocket(channel_id); 

to get socket object in callback function, which seems unnecessary.

Make CastSocket::OnOpenCallback take CastSocket* parameter instead. Callback is invoked 
by CastSocket object with 'this' pointer. Since CastSocket only runs on the IO thread, so
do callback functions, no post task is involved, raw pointer seems safe.

Resolve code review comments for: https://chromium-review.googlesource.com/c/575247

Bug: 749762
Change-Id: Iaab109774fa2c67d99a7fa5afffdf2315b32fd59
Reviewed-on: https://chromium-review.googlesource.com/590588
Commit-Queue: Bin Zhao <zhaobin@chromium.org>
Reviewed-by: default avatarDerek Cheng <imcheng@chromium.org>
Reviewed-by: default avatarmark a. foltz <mfoltz@chromium.org>
Cr-Commit-Position: refs/heads/master@{#491423}
parent 998e3702
...@@ -176,25 +176,19 @@ void CastMediaSinkService::OpenChannelOnIOThread( ...@@ -176,25 +176,19 @@ void CastMediaSinkService::OpenChannelOnIOThread(
void CastMediaSinkService::OnChannelOpenedOnIOThread( void CastMediaSinkService::OnChannelOpenedOnIOThread(
const DnsSdService& service, const DnsSdService& service,
int channel_id, cast_channel::CastSocket* socket) {
cast_channel::ChannelError channel_error) { DCHECK(socket);
if (channel_error != cast_channel::ChannelError::NONE) { if (socket->error_state() != cast_channel::ChannelError::NONE) {
DVLOG(2) << "Fail to open channel " << service.ip_address << ": " DVLOG(2) << "Fail to open channel " << service.ip_address << ": "
<< service.service_host_port.ToString() << service.service_host_port.ToString() << " [ChannelError]: "
<< " [ChannelError]: " << (int)channel_error; << cast_channel::ChannelErrorToString(socket->error_state());
return;
}
auto* socket = cast_socket_service_->GetSocket(channel_id);
if (!socket) {
DVLOG(2) << "Fail to find socket with [channel_id]: " << channel_id;
return; return;
} }
content::BrowserThread::PostTask( content::BrowserThread::PostTask(
content::BrowserThread::UI, FROM_HERE, content::BrowserThread::UI, FROM_HERE,
base::Bind(&CastMediaSinkService::OnChannelOpenedOnUIThread, this, base::Bind(&CastMediaSinkService::OnChannelOpenedOnUIThread, this,
service, channel_id, socket->audio_only())); service, socket->id(), socket->audio_only()));
} }
void CastMediaSinkService::OnChannelOpenedOnUIThread( void CastMediaSinkService::OnChannelOpenedOnUIThread(
......
...@@ -99,11 +99,10 @@ class CastMediaSinkService ...@@ -99,11 +99,10 @@ class CastMediaSinkService
// Invoked when opening cast channel on IO thread completes. // Invoked when opening cast channel on IO thread completes.
// |service|: mDNS service description. // |service|: mDNS service description.
// |channel_id|: channel id of newly created cast channel. // |socket|: raw pointer of newly created cast channel. Does not take
// |channel_error|: error encounted when opending cast channel. // ownership of |socket|.
void OnChannelOpenedOnIOThread(const DnsSdService& service, void OnChannelOpenedOnIOThread(const DnsSdService& service,
int channel_id, cast_channel::CastSocket* socket);
cast_channel::ChannelError channel_error);
// Invoked by |OnChannelOpenedOnIOThread| to post task on UI thread. // Invoked by |OnChannelOpenedOnIOThread| to post task on UI thread.
// |service|: mDNS service description. // |service|: mDNS service description.
......
...@@ -137,12 +137,10 @@ TEST_F(CastMediaSinkServiceTest, TestMultipleStartAndStop) { ...@@ -137,12 +137,10 @@ TEST_F(CastMediaSinkServiceTest, TestMultipleStartAndStop) {
TEST_F(CastMediaSinkServiceTest, TestOnChannelOpenedOnIOThread) { TEST_F(CastMediaSinkServiceTest, TestOnChannelOpenedOnIOThread) {
DnsSdService service = CreateDnsService(1); DnsSdService service = CreateDnsService(1);
cast_channel::MockCastSocket socket; cast_channel::MockCastSocket socket;
EXPECT_CALL(*mock_cast_socket_service_, GetSocket(1)) socket.set_id(1);
.WillOnce(Return(&socket));
media_sink_service_->current_services_.push_back(service); media_sink_service_->current_services_.push_back(service);
media_sink_service_->OnChannelOpenedOnIOThread( media_sink_service_->OnChannelOpenedOnIOThread(service, &socket);
service, 1, cast_channel::ChannelError::NONE);
// Invoke CastMediaSinkService::OnChannelOpenedOnUIThread on the UI thread. // Invoke CastMediaSinkService::OnChannelOpenedOnUIThread on the UI thread.
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
...@@ -158,24 +156,16 @@ TEST_F(CastMediaSinkServiceTest, TestMultipleOnChannelOpenedOnIOThread) { ...@@ -158,24 +156,16 @@ TEST_F(CastMediaSinkServiceTest, TestMultipleOnChannelOpenedOnIOThread) {
DnsSdService service3 = CreateDnsService(3); DnsSdService service3 = CreateDnsService(3);
cast_channel::MockCastSocket socket2; cast_channel::MockCastSocket socket2;
socket2.set_id(2);
cast_channel::MockCastSocket socket3; cast_channel::MockCastSocket socket3;
// Fail to open channel 1. socket3.set_id(3);
EXPECT_CALL(*mock_cast_socket_service_, GetSocket(1))
.WillOnce(Return(nullptr));
EXPECT_CALL(*mock_cast_socket_service_, GetSocket(2))
.WillOnce(Return(&socket2));
EXPECT_CALL(*mock_cast_socket_service_, GetSocket(3))
.WillOnce(Return(&socket2));
// Current round of Dns discovery finds service1 and service 2. // Current round of Dns discovery finds service1 and service 2.
media_sink_service_->current_services_.push_back(service1); media_sink_service_->current_services_.push_back(service1);
media_sink_service_->current_services_.push_back(service2); media_sink_service_->current_services_.push_back(service2);
media_sink_service_->OnChannelOpenedOnIOThread( // Fail to open channel 1.
service1, 1, cast_channel::ChannelError::NONE); media_sink_service_->OnChannelOpenedOnIOThread(service2, &socket2);
media_sink_service_->OnChannelOpenedOnIOThread( media_sink_service_->OnChannelOpenedOnIOThread(service3, &socket3);
service2, 2, cast_channel::ChannelError::NONE);
media_sink_service_->OnChannelOpenedOnIOThread(
service3, 3, cast_channel::ChannelError::NONE);
// Invoke CastMediaSinkService::OnChannelOpenedOnUIThread on the UI thread. // Invoke CastMediaSinkService::OnChannelOpenedOnUIThread on the UI thread.
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
...@@ -209,15 +199,12 @@ TEST_F(CastMediaSinkServiceTest, TestOnDnsSdEvent) { ...@@ -209,15 +199,12 @@ TEST_F(CastMediaSinkServiceTest, TestOnDnsSdEvent) {
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
cast_channel::MockCastSocket socket1; cast_channel::MockCastSocket socket1;
socket1.set_id(1);
cast_channel::MockCastSocket socket2; cast_channel::MockCastSocket socket2;
socket2.set_id(2);
EXPECT_CALL(*mock_cast_socket_service_, GetSocket(1)) callback1.Run(&socket1);
.WillOnce(Return(&socket1)); callback2.Run(&socket2);
EXPECT_CALL(*mock_cast_socket_service_, GetSocket(2))
.WillOnce(Return(&socket2));
callback1.Run(1, cast_channel::ChannelError::NONE);
callback2.Run(2, cast_channel::ChannelError::NONE);
// Invoke CastMediaSinkService::OnChannelOpenedOnUIThread on the UI thread. // Invoke CastMediaSinkService::OnChannelOpenedOnUIThread on the UI thread.
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
......
...@@ -115,8 +115,9 @@ CastSocketImpl::~CastSocketImpl() { ...@@ -115,8 +115,9 @@ CastSocketImpl::~CastSocketImpl() {
// would result in re-entrancy. // would result in re-entrancy.
CloseInternal(); CloseInternal();
error_state_ = ChannelError::UNKNOWN;
for (auto& connect_callback : connect_callbacks_) for (auto& connect_callback : connect_callbacks_)
std::move(connect_callback).Run(channel_id_, ChannelError::UNKNOWN); std::move(connect_callback).Run(this);
connect_callbacks_.clear(); connect_callbacks_.clear();
} }
...@@ -230,10 +231,12 @@ void CastSocketImpl::Connect(OnOpenCallback callback) { ...@@ -230,10 +231,12 @@ void CastSocketImpl::Connect(OnOpenCallback callback) {
connect_callbacks_.push_back(std::move(callback)); connect_callbacks_.push_back(std::move(callback));
break; break;
case ReadyState::OPEN: case ReadyState::OPEN:
std::move(callback).Run(channel_id_, ChannelError::NONE); error_state_ = ChannelError::NONE;
std::move(callback).Run(this);
break; break;
case ReadyState::CLOSED: case ReadyState::CLOSED:
std::move(callback).Run(channel_id_, ChannelError::CONNECT_ERROR); error_state_ = ChannelError::CONNECT_ERROR;
std::move(callback).Run(this);
break; break;
default: default:
NOTREACHED() << "Unknown ReadyState: " NOTREACHED() << "Unknown ReadyState: "
...@@ -552,7 +555,7 @@ void CastSocketImpl::DoConnectCallback() { ...@@ -552,7 +555,7 @@ void CastSocketImpl::DoConnectCallback() {
} }
for (auto& connect_callback : connect_callbacks_) for (auto& connect_callback : connect_callbacks_)
std::move(connect_callback).Run(channel_id_, error_state_); std::move(connect_callback).Run(this);
connect_callbacks_.clear(); connect_callbacks_.clear();
} }
......
...@@ -56,8 +56,10 @@ enum CastDeviceCapability { ...@@ -56,8 +56,10 @@ enum CastDeviceCapability {
// Public interface of the CastSocket class. // Public interface of the CastSocket class.
class CastSocket { class CastSocket {
public: public:
using OnOpenCallback = // Invoked when CastSocket opens.
base::OnceCallback<void(int channel_id, ChannelError error_state)>; // |socket|: raw pointer of opened socket (this pointer). Guaranteed to be
// valid in callback function. Do not pass |socket| around.
using OnOpenCallback = base::OnceCallback<void(CastSocket* socket)>;
class Observer { class Observer {
public: public:
......
...@@ -83,9 +83,9 @@ TEST_F(CastSocketServiceTest, TestOpenChannel) { ...@@ -83,9 +83,9 @@ TEST_F(CastSocketServiceTest, TestOpenChannel) {
EXPECT_CALL(*mock_socket, ConnectInternal(_)) EXPECT_CALL(*mock_socket, ConnectInternal(_))
.WillOnce(WithArgs<0>( .WillOnce(WithArgs<0>(
Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) { Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) {
callback.Run(mock_socket->id(), ChannelError::NONE); callback.Run(mock_socket);
}))); })));
EXPECT_CALL(mock_on_open_callback_, Run(_, ChannelError::NONE)); EXPECT_CALL(mock_on_open_callback_, Run(mock_socket));
EXPECT_CALL(*mock_socket, AddObserver(_)); EXPECT_CALL(*mock_socket, AddObserver(_));
cast_socket_service_->OpenSocket(ip_endpoint, nullptr /* net_log */, cast_socket_service_->OpenSocket(ip_endpoint, nullptr /* net_log */,
......
...@@ -84,15 +84,14 @@ class MockCastSocketService : public CastSocketService { ...@@ -84,15 +84,14 @@ class MockCastSocketService : public CastSocketService {
MOCK_METHOD4(OpenSocketInternal, MOCK_METHOD4(OpenSocketInternal,
int(const net::IPEndPoint& ip_endpoint, int(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log, net::NetLog* net_log,
const base::Callback<void(int, ChannelError)>& open_cb, const base::Callback<void(CastSocket*)>& open_cb,
CastSocket::Observer* observer)); CastSocket::Observer* observer));
MOCK_CONST_METHOD1(GetSocket, CastSocket*(int channel_id)); MOCK_CONST_METHOD1(GetSocket, CastSocket*(int channel_id));
}; };
class MockCastSocket : public CastSocket { class MockCastSocket : public CastSocket {
public: public:
using MockOnOpenCallback = using MockOnOpenCallback = base::Callback<void(CastSocket* socket)>;
base::Callback<void(int channel_id, ChannelError error_state)>;
MockCastSocket(); MockCastSocket();
~MockCastSocket() override; ~MockCastSocket() override;
......
...@@ -300,19 +300,18 @@ void CastChannelOpenFunction::AsyncWorkStart() { ...@@ -300,19 +300,18 @@ void CastChannelOpenFunction::AsyncWorkStart() {
observer); observer);
} }
void CastChannelOpenFunction::OnOpen(int channel_id, ChannelError result) { void CastChannelOpenFunction::OnOpen(CastSocket* socket) {
DCHECK_CURRENTLY_ON(BrowserThread::IO); DCHECK_CURRENTLY_ON(BrowserThread::IO);
VLOG(1) << "Connect finished, OnOpen invoked."; VLOG(1) << "Connect finished, OnOpen invoked.";
DCHECK(socket);
// TODO: If we failed to open the CastSocket, we may want to clean up here, // TODO: If we failed to open the CastSocket, we may want to clean up here,
// rather than relying on the extension to call close(). This can be done by // rather than relying on the extension to call close(). This can be done by
// calling RemoveSocket() and api_->GetLogger()->ClearLastError(channel_id). // calling RemoveSocket() and api_->GetLogger()->ClearLastError(channel_id).
if (result != ChannelError::UNKNOWN) { if (socket->error_state() != ChannelError::UNKNOWN) {
CastSocket* socket = cast_socket_service_->GetSocket(channel_id);
CHECK(socket);
SetResultFromSocket(*socket); SetResultFromSocket(*socket);
} else { } else {
// The socket is being destroyed. // The socket is being destroyed.
SetResultFromError(channel_id, api::cast_channel::CHANNEL_ERROR_UNKNOWN); SetResultFromError(socket->id(), api::cast_channel::CHANNEL_ERROR_UNKNOWN);
} }
AsyncWorkCompleted(); AsyncWorkCompleted();
......
...@@ -167,7 +167,9 @@ class CastChannelOpenFunction : public CastChannelAsyncApiFunction { ...@@ -167,7 +167,9 @@ class CastChannelOpenFunction : public CastChannelAsyncApiFunction {
static net::IPEndPoint* ParseConnectInfo( static net::IPEndPoint* ParseConnectInfo(
const api::cast_channel::ConnectInfo& connect_info); const api::cast_channel::ConnectInfo& connect_info);
void OnOpen(int channel_id, cast_channel::ChannelError result); // |socket|: raw pointer of newly created cast channel. Does not take
// ownership of |socket|.
void OnOpen(cast_channel::CastSocket* socket);
std::unique_ptr<api::cast_channel::Open::Params> params_; std::unique_ptr<api::cast_channel::Open::Params> params_;
CastChannelAPI* api_; CastChannelAPI* api_;
......
...@@ -111,7 +111,7 @@ class CastChannelAPITest : public ExtensionApiTest { ...@@ -111,7 +111,7 @@ class CastChannelAPITest : public ExtensionApiTest {
EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_)) EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_))
.WillOnce(WithArgs<0>( .WillOnce(WithArgs<0>(
Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) { Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) {
callback.Run(mock_cast_socket_->id(), ChannelError::NONE); callback.Run(mock_cast_socket_);
}))); })));
EXPECT_CALL(*mock_cast_socket_, ready_state()) EXPECT_CALL(*mock_cast_socket_, ready_state())
.WillOnce(Return(ReadyState::OPEN)); .WillOnce(Return(ReadyState::OPEN));
...@@ -138,7 +138,7 @@ class CastChannelAPITest : public ExtensionApiTest { ...@@ -138,7 +138,7 @@ class CastChannelAPITest : public ExtensionApiTest {
EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_)) EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_))
.WillOnce(WithArgs<0>( .WillOnce(WithArgs<0>(
Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) { Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) {
callback.Run(mock_cast_socket_->id(), ChannelError::NONE); callback.Run(mock_cast_socket_);
}))); })));
EXPECT_CALL(*mock_cast_socket_, ready_state()) EXPECT_CALL(*mock_cast_socket_, ready_state())
.WillOnce(Return(ReadyState::OPEN)) .WillOnce(Return(ReadyState::OPEN))
...@@ -306,7 +306,7 @@ IN_PROC_BROWSER_TEST_F(CastChannelAPITest, MAYBE_TestOpenReceiveClose) { ...@@ -306,7 +306,7 @@ IN_PROC_BROWSER_TEST_F(CastChannelAPITest, MAYBE_TestOpenReceiveClose) {
EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_)) EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_))
.WillOnce(WithArgs<0>( .WillOnce(WithArgs<0>(
Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) { Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) {
callback.Run(mock_cast_socket_->id(), ChannelError::NONE); callback.Run(mock_cast_socket_);
}))); })));
EXPECT_CALL(*mock_cast_socket_, ready_state()) EXPECT_CALL(*mock_cast_socket_, ready_state())
.Times(3) .Times(3)
...@@ -343,7 +343,7 @@ IN_PROC_BROWSER_TEST_F(CastChannelAPITest, MAYBE_TestOpenError) { ...@@ -343,7 +343,7 @@ IN_PROC_BROWSER_TEST_F(CastChannelAPITest, MAYBE_TestOpenError) {
EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_)) EXPECT_CALL(*mock_cast_socket_, ConnectInternal(_))
.WillOnce(WithArgs<0>( .WillOnce(WithArgs<0>(
Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) { Invoke([&](const MockCastSocket::MockOnOpenCallback& callback) {
callback.Run(mock_cast_socket_->id(), ChannelError::CONNECT_ERROR); callback.Run(mock_cast_socket_);
}))); })));
mock_cast_socket_->SetErrorState(ChannelError::CONNECT_ERROR); mock_cast_socket_->SetErrorState(ChannelError::CONNECT_ERROR);
EXPECT_CALL(*mock_cast_socket_, ready_state()) EXPECT_CALL(*mock_cast_socket_, ready_state())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment