Commit edc09064 authored by zhaobin's avatar zhaobin Committed by Commit Bot

[cast_channel] Make CastSocketService a global leaky singleton

- make CastSocketService a global leaky singleton and remove CastSocketServiceFactory
- make CastChannelAPI and CastMediaSinkService own CastSocket::Observer
- fix unit tests and browser tests

BUG=687377

Review-Url: https://codereview.chromium.org/2974523002
Cr-Commit-Position: refs/heads/master@{#486886}
parent 26a60c29
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "chrome/browser/browser_process.h" #include "chrome/browser/browser_process.h"
#include "chrome/common/media_router/discovery/media_sink_internal.h" #include "chrome/common/media_router/discovery/media_sink_internal.h"
#include "components/cast_channel/cast_socket_service.h" #include "components/cast_channel/cast_socket_service.h"
#include "components/cast_channel/cast_socket_service_factory.h"
#include "components/net_log/chrome_net_log.h" #include "components/net_log/chrome_net_log.h"
#include "content/public/common/content_client.h" #include "content/public/common/content_client.h"
#include "net/base/host_port_pair.h" #include "net/base/host_port_pair.h"
...@@ -16,8 +15,6 @@ ...@@ -16,8 +15,6 @@
namespace { namespace {
constexpr char kObserverId[] = "browser_observer_id";
enum ErrorType { enum ErrorType {
NONE, NONE,
NOT_CAST_DEVICE, NOT_CAST_DEVICE,
...@@ -90,10 +87,9 @@ const char CastMediaSinkService::kCastServiceType[] = "_googlecast._tcp.local"; ...@@ -90,10 +87,9 @@ const char CastMediaSinkService::kCastServiceType[] = "_googlecast._tcp.local";
CastMediaSinkService::CastMediaSinkService( CastMediaSinkService::CastMediaSinkService(
const OnSinksDiscoveredCallback& callback, const OnSinksDiscoveredCallback& callback,
content::BrowserContext* browser_context) content::BrowserContext* browser_context)
: MediaSinkServiceBase(callback) { : MediaSinkServiceBase(callback),
cast_socket_service_(cast_channel::CastSocketService::GetInstance()) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
cast_socket_service_ = cast_channel::CastSocketServiceFactory::GetInstance()
->GetForBrowserContext(browser_context);
DCHECK(cast_socket_service_); DCHECK(cast_socket_service_);
} }
...@@ -169,17 +165,14 @@ void CastMediaSinkService::OnDnsSdEvent( ...@@ -169,17 +165,14 @@ void CastMediaSinkService::OnDnsSdEvent(
void CastMediaSinkService::OpenChannelOnIOThread( void CastMediaSinkService::OpenChannelOnIOThread(
const DnsSdService& service, const DnsSdService& service,
const net::IPEndPoint& ip_endpoint) { const net::IPEndPoint& ip_endpoint) {
auto* observer = cast_socket_service_->GetObserver(kObserverId); if (!observer_)
if (!observer) { observer_.reset(new CastSocketObserver());
observer = new CastSocketObserver();
cast_socket_service_->AddObserver(kObserverId, base::WrapUnique(observer));
}
cast_socket_service_->OpenSocket( cast_socket_service_->OpenSocket(
ip_endpoint, g_browser_process->net_log(), ip_endpoint, g_browser_process->net_log(),
base::Bind(&CastMediaSinkService::OnChannelOpenedOnIOThread, this, base::Bind(&CastMediaSinkService::OnChannelOpenedOnIOThread, this,
service), service),
observer); observer_.get());
} }
void CastMediaSinkService::OnChannelOpenedOnIOThread( void CastMediaSinkService::OnChannelOpenedOnIOThread(
...@@ -228,7 +221,9 @@ void CastMediaSinkService::OnChannelOpenedOnUIThread( ...@@ -228,7 +221,9 @@ void CastMediaSinkService::OnChannelOpenedOnUIThread(
} }
CastMediaSinkService::CastSocketObserver::CastSocketObserver() {} CastMediaSinkService::CastSocketObserver::CastSocketObserver() {}
CastMediaSinkService::CastSocketObserver::~CastSocketObserver() {} CastMediaSinkService::CastSocketObserver::~CastSocketObserver() {
cast_channel::CastSocketService::GetInstance()->RemoveObserver(this);
}
void CastMediaSinkService::CastSocketObserver::OnError( void CastMediaSinkService::CastSocketObserver::OnError(
const cast_channel::CastSocket& socket, const cast_channel::CastSocket& socket,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "chrome/browser/media/router/discovery/media_sink_service_base.h" #include "chrome/browser/media/router/discovery/media_sink_service_base.h"
#include "components/cast_channel/cast_channel_enum.h" #include "components/cast_channel/cast_channel_enum.h"
#include "components/cast_channel/cast_socket.h" #include "components/cast_channel/cast_socket.h"
#include "content/public/browser/browser_thread.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
namespace cast_channel { namespace cast_channel {
...@@ -119,8 +120,13 @@ class CastMediaSinkService ...@@ -119,8 +120,13 @@ class CastMediaSinkService
// Service list from current round of discovery. // Service list from current round of discovery.
DnsSdRegistry::DnsSdServiceList current_services_; DnsSdRegistry::DnsSdServiceList current_services_;
// Service managing creating and removing cast channels. // Raw pointer of leaky singleton CastSocketService, which manages creating
scoped_refptr<cast_channel::CastSocketService> cast_socket_service_; // and removing Cast sockets.
cast_channel::CastSocketService* const cast_socket_service_;
std::unique_ptr<cast_channel::CastSocket::Observer,
content::BrowserThread::DeleteOnIOThread>
observer_;
THREAD_CHECKER(thread_checker_); THREAD_CHECKER(thread_checker_);
......
...@@ -68,23 +68,10 @@ void VerifyMediaSinkInternal(const media_router::MediaSinkInternal& cast_sink, ...@@ -68,23 +68,10 @@ void VerifyMediaSinkInternal(const media_router::MediaSinkInternal& cast_sink,
namespace media_router { namespace media_router {
class MockCastSocketService : public cast_channel::CastSocketService {
public:
MOCK_METHOD4(OpenSocket,
int(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
const cast_channel::CastSocket::OnOpenCallback& open_cb,
cast_channel::CastSocket::Observer* observer));
MOCK_CONST_METHOD1(GetSocket, cast_channel::CastSocket*(int channel_id));
private:
~MockCastSocketService() {}
};
class CastMediaSinkServiceTest : public ::testing::Test { class CastMediaSinkServiceTest : public ::testing::Test {
public: public:
CastMediaSinkServiceTest() CastMediaSinkServiceTest()
: mock_cast_socket_service_(new MockCastSocketService()), : mock_cast_socket_service_(new cast_channel::MockCastSocketService()),
media_sink_service_( media_sink_service_(
new CastMediaSinkService(mock_sink_discovered_cb_.Get(), new CastMediaSinkService(mock_sink_discovered_cb_.Get(),
mock_cast_socket_service_.get())), mock_cast_socket_service_.get())),
...@@ -103,7 +90,8 @@ class CastMediaSinkServiceTest : public ::testing::Test { ...@@ -103,7 +90,8 @@ class CastMediaSinkServiceTest : public ::testing::Test {
base::MockCallback<MediaSinkService::OnSinksDiscoveredCallback> base::MockCallback<MediaSinkService::OnSinksDiscoveredCallback>
mock_sink_discovered_cb_; mock_sink_discovered_cb_;
scoped_refptr<MockCastSocketService> mock_cast_socket_service_; std::unique_ptr<cast_channel::MockCastSocketService>
mock_cast_socket_service_;
scoped_refptr<CastMediaSinkService> media_sink_service_; scoped_refptr<CastMediaSinkService> media_sink_service_;
MockDnsSdRegistry test_dns_sd_registry_; MockDnsSdRegistry test_dns_sd_registry_;
base::MockTimer* mock_timer_; base::MockTimer* mock_timer_;
......
...@@ -18,7 +18,9 @@ chrome.cast.channel.open({ ...@@ -18,7 +18,9 @@ chrome.cast.channel.open({
chrome.test.assertEq(channel.keepAlive, true); chrome.test.assertEq(channel.keepAlive, true);
if (channel.readyState == 'closed' && if (channel.readyState == 'closed' &&
error.errorState == 'ping_timeout') { error.errorState == 'ping_timeout') {
chrome.test.sendMessage('timeout_ssl'); chrome.cast.channel.close(channel, () => {
chrome.test.sendMessage('timeout_ssl');
});
} }
}); });
chrome.test.notifyPass(); chrome.test.notifyPass();
......
...@@ -17,7 +17,9 @@ chrome.cast.channel.open({ ...@@ -17,7 +17,9 @@ chrome.cast.channel.open({
chrome.test.assertEq(channel.keepAlive, true); chrome.test.assertEq(channel.keepAlive, true);
if (channel.readyState == 'closed' && if (channel.readyState == 'closed' &&
error.errorState == 'ping_timeout') { error.errorState == 'ping_timeout') {
chrome.test.sendMessage('timeout_ssl_verified'); chrome.cast.channel.close(channel, () => {
chrome.test.sendMessage('timeout_ssl_verified');
});
} }
}); });
chrome.test.notifyPass(); chrome.test.notifyPass();
......
...@@ -16,8 +16,6 @@ static_library("cast_channel") { ...@@ -16,8 +16,6 @@ static_library("cast_channel") {
"cast_socket.h", "cast_socket.h",
"cast_socket_service.cc", "cast_socket_service.cc",
"cast_socket_service.h", "cast_socket_service.h",
"cast_socket_service_factory.cc",
"cast_socket_service_factory.h",
"cast_transport.cc", "cast_transport.cc",
"cast_transport.h", "cast_transport.h",
"keep_alive_delegate.cc", "keep_alive_delegate.cc",
......
...@@ -296,6 +296,11 @@ void CastSocketImpl::AddObserver(Observer* observer) { ...@@ -296,6 +296,11 @@ void CastSocketImpl::AddObserver(Observer* observer) {
observers_.AddObserver(observer); observers_.AddObserver(observer);
} }
void CastSocketImpl::RemoveObserver(Observer* observer) {
DCHECK(observer);
observers_.RemoveObserver(observer);
}
void CastSocketImpl::OnConnectTimeout() { void CastSocketImpl::OnConnectTimeout() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
// Stop all pending connection setup tasks and report back to the client. // Stop all pending connection setup tasks and report back to the client.
......
...@@ -129,6 +129,9 @@ class CastSocket { ...@@ -129,6 +129,9 @@ class CastSocket {
// Registers |observer| with the socket to receive messages and error events. // Registers |observer| with the socket to receive messages and error events.
virtual void AddObserver(Observer* observer) = 0; virtual void AddObserver(Observer* observer) = 0;
// Unregisters |observer|.
virtual void RemoveObserver(Observer* observer) = 0;
}; };
// This class implements a channel between Chrome and a Cast device using a TCP // This class implements a channel between Chrome and a Cast device using a TCP
...@@ -181,6 +184,7 @@ class CastSocketImpl : public CastSocket { ...@@ -181,6 +184,7 @@ class CastSocketImpl : public CastSocket {
bool keep_alive() const override; bool keep_alive() const override;
bool audio_only() const override; bool audio_only() const override;
void AddObserver(Observer* observer) override; void AddObserver(Observer* observer) override;
void RemoveObserver(Observer* observer) override;
protected: protected:
// CastTransport::Delegate methods for receiving handshake messages. // CastTransport::Delegate methods for receiving handshake messages.
......
...@@ -28,15 +28,17 @@ namespace cast_channel { ...@@ -28,15 +28,17 @@ namespace cast_channel {
int CastSocketService::last_channel_id_ = 0; int CastSocketService::last_channel_id_ = 0;
CastSocketService::CastSocketService() CastSocketService::CastSocketService() : logger_(new Logger()) {
: RefcountedKeyedService(
BrowserThread::GetTaskRunnerForThread(BrowserThread::IO)),
logger_(new Logger()) {
DETACH_FROM_THREAD(thread_checker_); DETACH_FROM_THREAD(thread_checker_);
} }
CastSocketService::~CastSocketService() { // This is a leaky singleton and the dtor won't be called.
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); CastSocketService::~CastSocketService() = default;
// static
CastSocketService* CastSocketService::GetInstance() {
return base::Singleton<CastSocketService,
base::LeakySingletonTraits<CastSocketService>>::get();
} }
scoped_refptr<Logger> CastSocketService::GetLogger() { scoped_refptr<Logger> CastSocketService::GetLogger() {
...@@ -128,17 +130,9 @@ int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, ...@@ -128,17 +130,9 @@ int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint,
observer); observer);
} }
CastSocket::Observer* CastSocketService::GetObserver(const std::string& id) { void CastSocketService::RemoveObserver(CastSocket::Observer* observer) {
auto it = socket_observer_map_.find(id); for (auto& socket_it : sockets_)
return it == socket_observer_map_.end() ? nullptr : it->second.get(); socket_it.second->RemoveObserver(observer);
}
CastSocket::Observer* CastSocketService::AddObserver(
const std::string& id,
std::unique_ptr<CastSocket::Observer> observer) {
CastSocket::Observer* observer_ptr = observer.get();
socket_observer_map_.insert(std::make_pair(id, std::move(observer)));
return observer_ptr;
} }
void CastSocketService::SetSocketForTest( void CastSocketService::SetSocketForTest(
...@@ -146,6 +140,4 @@ void CastSocketService::SetSocketForTest( ...@@ -146,6 +140,4 @@ void CastSocketService::SetSocketForTest(
socket_for_test_ = std::move(socket_for_test); socket_for_test_ = std::move(socket_for_test);
} }
void CastSocketService::ShutdownOnUIThread() {}
} // namespace cast_channel } // namespace cast_channel
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
#include <memory> #include <memory>
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/singleton.h"
#include "base/threading/thread_checker.h" #include "base/threading/thread_checker.h"
#include "components/cast_channel/cast_socket.h" #include "components/cast_channel/cast_socket.h"
#include "components/keyed_service/core/refcounted_keyed_service.h"
#include "content/public/browser/browser_thread.h" #include "content/public/browser/browser_thread.h"
namespace cast_channel { namespace cast_channel {
...@@ -20,9 +20,9 @@ namespace cast_channel { ...@@ -20,9 +20,9 @@ namespace cast_channel {
// to underlying storage. // to underlying storage.
// Instance of this class is created on the UI thread and destroyed on the IO // Instance of this class is created on the UI thread and destroyed on the IO
// thread. All public API must be called from the IO thread. // thread. All public API must be called from the IO thread.
class CastSocketService : public RefcountedKeyedService { class CastSocketService {
public: public:
CastSocketService(); static CastSocketService* GetInstance();
// Returns a pointer to the Logger member variable. // Returns a pointer to the Logger member variable.
scoped_refptr<cast_channel::Logger> GetLogger(); scoped_refptr<cast_channel::Logger> GetLogger();
...@@ -77,24 +77,20 @@ class CastSocketService : public RefcountedKeyedService { ...@@ -77,24 +77,20 @@ class CastSocketService : public RefcountedKeyedService {
const CastSocket::OnOpenCallback& open_cb, const CastSocket::OnOpenCallback& open_cb,
CastSocket::Observer* observer); CastSocket::Observer* observer);
// Returns an observer corresponding to |id|. // Remove |observer| from each socket in |sockets_|
CastSocket::Observer* GetObserver(const std::string& id); void RemoveObserver(CastSocket::Observer* observer);
// Adds |observer| to |socket_observer_map_| keyed by |id|. Return raw pointer
// of the newly added observer.
CastSocket::Observer* AddObserver(
const std::string& id,
std::unique_ptr<CastSocket::Observer> observer);
// Allow test to inject a mock cast socket. // Allow test to inject a mock cast socket.
void SetSocketForTest(std::unique_ptr<CastSocket> socket_for_test); void SetSocketForTest(std::unique_ptr<CastSocket> socket_for_test);
protected:
~CastSocketService() override;
private: private:
// RefcountedKeyedService implementation. friend class CastSocketServiceTest;
void ShutdownOnUIThread() override; friend class MockCastSocketService;
friend struct base::DefaultSingletonTraits<CastSocketService>;
friend struct std::default_delete<CastSocketService>;
CastSocketService();
virtual ~CastSocketService();
// Used to generate CastSocket id. // Used to generate CastSocket id.
static int last_channel_id_; static int last_channel_id_;
...@@ -102,13 +98,7 @@ class CastSocketService : public RefcountedKeyedService { ...@@ -102,13 +98,7 @@ class CastSocketService : public RefcountedKeyedService {
// The collection of CastSocket keyed by channel_id. // The collection of CastSocket keyed by channel_id.
std::map<int, std::unique_ptr<CastSocket>> sockets_; std::map<int, std::unique_ptr<CastSocket>> sockets_;
// Map of CastSocket::Observer keyed by observer id. For extension side scoped_refptr<Logger> logger_;
// observers, id is extension_id; For browser side observers, id is a hard
// coded string.
std::map<std::string, std::unique_ptr<CastSocket::Observer>>
socket_observer_map_;
scoped_refptr<cast_channel::Logger> logger_;
std::unique_ptr<CastSocket> socket_for_test_; std::unique_ptr<CastSocket> socket_for_test_;
......
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/cast_channel/cast_socket_service_factory.h"
#include "components/cast_channel/cast_socket_service.h"
#include "components/keyed_service/content/browser_context_dependency_manager.h"
namespace cast_channel {
using content::BrowserContext;
namespace {
base::LazyInstance<CastSocketServiceFactory>::DestructorAtExit service_factory =
LAZY_INSTANCE_INITIALIZER;
} // namespace
// static
scoped_refptr<CastSocketService> CastSocketServiceFactory::GetForBrowserContext(
BrowserContext* context) {
DCHECK(context);
// GetServiceForBrowserContext returns a KeyedService hence the static_cast<>
// to construct a temporary scoped_refptr on the stack for the return value.
return static_cast<CastSocketService*>(
service_factory.Get().GetServiceForBrowserContext(context, true).get());
}
// static
CastSocketServiceFactory* CastSocketServiceFactory::GetInstance() {
return &service_factory.Get();
}
CastSocketServiceFactory::CastSocketServiceFactory()
: RefcountedBrowserContextKeyedServiceFactory(
"CastSocketService",
BrowserContextDependencyManager::GetInstance()) {}
CastSocketServiceFactory::~CastSocketServiceFactory() {}
content::BrowserContext* CastSocketServiceFactory::GetBrowserContextToUse(
content::BrowserContext* context) const {
return context;
}
scoped_refptr<RefcountedKeyedService>
CastSocketServiceFactory::BuildServiceInstanceFor(
BrowserContext* context) const {
return make_scoped_refptr(new CastSocketService());
}
} // namespace cast_channel
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_CAST_CHANNEL_CAST_SOCKET_SERVICE_FACTORY_H_
#define COMPONENTS_CAST_CHANNEL_CAST_SOCKET_SERVICE_FACTORY_H_
#include "base/lazy_instance.h"
#include "base/macros.h"
#include "components/keyed_service/content/refcounted_browser_context_keyed_service_factory.h"
namespace cast_channel {
class CastSocketService;
// TODO(crbug.com/725717): CastSocket created by one profile (browser context)
// could be shared with other profiles.
class CastSocketServiceFactory
: public RefcountedBrowserContextKeyedServiceFactory {
public:
// Caller needs to make sure that it passes in the same |context| instance to
// this function for both normal profile and incognito profile.
static scoped_refptr<CastSocketService> GetForBrowserContext(
content::BrowserContext* context);
static CastSocketServiceFactory* GetInstance();
private:
friend struct base::LazyInstanceTraitsBase<CastSocketServiceFactory>;
CastSocketServiceFactory();
~CastSocketServiceFactory() override;
// BrowserContextKeyedServiceFactory interface.
content::BrowserContext* GetBrowserContextToUse(
content::BrowserContext* context) const override;
scoped_refptr<RefcountedKeyedService> BuildServiceInstanceFor(
content::BrowserContext* context) const override;
DISALLOW_COPY_AND_ASSIGN(CastSocketServiceFactory);
};
} // namespace cast_channel
#endif // COMPONENTS_CAST_CHANNEL_CAST_SOCKET_SERVICE_FACTORY_H_
...@@ -34,7 +34,7 @@ class CastSocketServiceTest : public testing::Test { ...@@ -34,7 +34,7 @@ class CastSocketServiceTest : public testing::Test {
protected: protected:
content::TestBrowserThreadBundle thread_bundle_; content::TestBrowserThreadBundle thread_bundle_;
scoped_refptr<CastSocketService> cast_socket_service_; std::unique_ptr<CastSocketService> cast_socket_service_;
base::MockCallback<CastSocket::OnOpenCallback> mock_on_open_callback_; base::MockCallback<CastSocket::OnOpenCallback> mock_on_open_callback_;
MockCastSocketObserver mock_observer_; MockCastSocketObserver mock_observer_;
}; };
......
...@@ -29,6 +29,9 @@ MockCastTransportDelegate::~MockCastTransportDelegate() {} ...@@ -29,6 +29,9 @@ MockCastTransportDelegate::~MockCastTransportDelegate() {}
MockCastSocketObserver::MockCastSocketObserver() {} MockCastSocketObserver::MockCastSocketObserver() {}
MockCastSocketObserver::~MockCastSocketObserver() {} MockCastSocketObserver::~MockCastSocketObserver() {}
MockCastSocketService::MockCastSocketService() {}
MockCastSocketService::~MockCastSocketService() {}
MockCastSocket::MockCastSocket() MockCastSocket::MockCastSocket()
: channel_id_(0), : channel_id_(0),
error_state_(ChannelError::NONE), error_state_(ChannelError::NONE),
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "base/macros.h" #include "base/macros.h"
#include "base/threading/thread_task_runner_handle.h" #include "base/threading/thread_task_runner_handle.h"
#include "components/cast_channel/cast_socket.h" #include "components/cast_channel/cast_socket.h"
#include "components/cast_channel/cast_socket_service.h"
#include "components/cast_channel/cast_transport.h" #include "components/cast_channel/cast_transport.h"
#include "components/cast_channel/proto/cast_channel.pb.h" #include "components/cast_channel/proto/cast_channel.pb.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
...@@ -64,6 +65,19 @@ class MockCastSocketObserver : public CastSocket::Observer { ...@@ -64,6 +65,19 @@ class MockCastSocketObserver : public CastSocket::Observer {
void(const CastSocket& socket, const CastMessage& message)); void(const CastSocket& socket, const CastMessage& message));
}; };
class MockCastSocketService : public CastSocketService {
public:
MockCastSocketService();
~MockCastSocketService() override;
MOCK_METHOD4(OpenSocket,
int(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
const CastSocket::OnOpenCallback& open_cb,
CastSocket::Observer* observer));
MOCK_CONST_METHOD1(GetSocket, CastSocket*(int channel_id));
};
class MockCastSocket : public CastSocket { class MockCastSocket : public CastSocket {
public: public:
MockCastSocket(); MockCastSocket();
...@@ -73,6 +87,7 @@ class MockCastSocket : public CastSocket { ...@@ -73,6 +87,7 @@ class MockCastSocket : public CastSocket {
MOCK_METHOD1(Close, void(const net::CompletionCallback& callback)); MOCK_METHOD1(Close, void(const net::CompletionCallback& callback));
MOCK_CONST_METHOD0(ready_state, ReadyState()); MOCK_CONST_METHOD0(ready_state, ReadyState());
MOCK_METHOD1(AddObserver, void(Observer* observer)); MOCK_METHOD1(AddObserver, void(Observer* observer));
MOCK_METHOD1(RemoveObserver, void(Observer* observer));
const net::IPEndPoint& ip_endpoint() const override { return ip_endpoint_; } const net::IPEndPoint& ip_endpoint() const override { return ip_endpoint_; }
void SetIPEndpoint(const net::IPEndPoint& ip_endpoint) { void SetIPEndpoint(const net::IPEndPoint& ip_endpoint) {
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "components/cast_channel/cast_message_util.h" #include "components/cast_channel/cast_message_util.h"
#include "components/cast_channel/cast_socket.h" #include "components/cast_channel/cast_socket.h"
#include "components/cast_channel/cast_socket_service.h" #include "components/cast_channel/cast_socket_service.h"
#include "components/cast_channel/cast_socket_service_factory.h"
#include "components/cast_channel/keep_alive_delegate.h" #include "components/cast_channel/keep_alive_delegate.h"
#include "components/cast_channel/logger.h" #include "components/cast_channel/logger.h"
#include "components/cast_channel/proto/cast_channel.pb.h" #include "components/cast_channel/proto/cast_channel.pb.h"
...@@ -137,6 +136,18 @@ void CastChannelAPI::SendEvent(const std::string& extension_id, ...@@ -137,6 +136,18 @@ void CastChannelAPI::SendEvent(const std::string& extension_id,
} }
} }
cast_channel::CastSocket::Observer* CastChannelAPI::GetObserver(
const std::string& extension_id,
scoped_refptr<cast_channel::Logger> logger) {
DCHECK_CURRENTLY_ON(BrowserThread::IO);
if (!observer_) {
observer_.reset(new CastMessageHandler(
base::Bind(&CastChannelAPI::SendEvent, this->AsWeakPtr(), extension_id),
logger));
}
return observer_.get();
}
static base::LazyInstance< static base::LazyInstance<
BrowserContextKeyedAPIFactory<CastChannelAPI>>::DestructorAtExit g_factory = BrowserContextKeyedAPIFactory<CastChannelAPI>>::DestructorAtExit g_factory =
LAZY_INSTANCE_INITIALIZER; LAZY_INSTANCE_INITIALIZER;
...@@ -168,9 +179,7 @@ CastChannelAsyncApiFunction::CastChannelAsyncApiFunction() ...@@ -168,9 +179,7 @@ CastChannelAsyncApiFunction::CastChannelAsyncApiFunction()
CastChannelAsyncApiFunction::~CastChannelAsyncApiFunction() { } CastChannelAsyncApiFunction::~CastChannelAsyncApiFunction() { }
bool CastChannelAsyncApiFunction::PrePrepare() { bool CastChannelAsyncApiFunction::PrePrepare() {
cast_socket_service_ = cast_socket_service_ = cast_channel::CastSocketService::GetInstance();
cast_channel::CastSocketServiceFactory::GetForBrowserContext(
browser_context());
DCHECK(cast_socket_service_); DCHECK(cast_socket_service_);
return true; return true;
} }
...@@ -277,14 +286,8 @@ void CastChannelOpenFunction::AsyncWorkStart() { ...@@ -277,14 +286,8 @@ void CastChannelOpenFunction::AsyncWorkStart() {
if (test_socket.get()) if (test_socket.get())
cast_socket_service_->SetSocketForTest(std::move(test_socket)); cast_socket_service_->SetSocketForTest(std::move(test_socket));
auto* observer = cast_socket_service_->GetObserver(extension_->id()); auto* observer =
if (!observer) { api_->GetObserver(extension_->id(), cast_socket_service_->GetLogger());
observer = cast_socket_service_->AddObserver(
extension_->id(), base::MakeUnique<CastMessageHandler>(
base::Bind(&CastChannelAPI::SendEvent,
api_->AsWeakPtr(), extension_->id()),
cast_socket_service_->GetLogger()));
}
cast_socket_service_->OpenSocket( cast_socket_service_->OpenSocket(
*ip_endpoint_, ExtensionsBrowserClient::Get()->GetNetLog(), *ip_endpoint_, ExtensionsBrowserClient::Get()->GetNetLog(),
...@@ -422,20 +425,23 @@ void CastChannelCloseFunction::OnClose(int result) { ...@@ -422,20 +425,23 @@ void CastChannelCloseFunction::OnClose(int result) {
AsyncWorkCompleted(); AsyncWorkCompleted();
} }
CastChannelOpenFunction::CastMessageHandler::CastMessageHandler( CastChannelAPI::CastMessageHandler::CastMessageHandler(
const EventDispatchCallback& ui_dispatch_cb, const EventDispatchCallback& ui_dispatch_cb,
scoped_refptr<Logger> logger) scoped_refptr<Logger> logger)
: ui_dispatch_cb_(ui_dispatch_cb), logger_(logger) { : ui_dispatch_cb_(ui_dispatch_cb), logger_(logger) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
DCHECK(logger_); DCHECK(logger_);
} }
CastChannelOpenFunction::CastMessageHandler::~CastMessageHandler() { CastChannelAPI::CastMessageHandler::~CastMessageHandler() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
cast_channel::CastSocketService::GetInstance()->RemoveObserver(this);
} }
void CastChannelOpenFunction::CastMessageHandler::OnError( void CastChannelAPI::CastMessageHandler::OnError(
const cast_channel::CastSocket& socket, const cast_channel::CastSocket& socket,
ChannelError error_state) { ChannelError error_state) {
DCHECK_CURRENTLY_ON(BrowserThread::IO); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
ChannelInfo channel_info; ChannelInfo channel_info;
FillChannelInfo(socket, &channel_info); FillChannelInfo(socket, &channel_info);
...@@ -453,10 +459,10 @@ void CastChannelOpenFunction::CastMessageHandler::OnError( ...@@ -453,10 +459,10 @@ void CastChannelOpenFunction::CastMessageHandler::OnError(
base::Bind(ui_dispatch_cb_, base::Passed(std::move(event)))); base::Bind(ui_dispatch_cb_, base::Passed(std::move(event))));
} }
void CastChannelOpenFunction::CastMessageHandler::OnMessage( void CastChannelAPI::CastMessageHandler::OnMessage(
const cast_channel::CastSocket& socket, const cast_channel::CastSocket& socket,
const CastMessage& message) { const CastMessage& message) {
DCHECK_CURRENTLY_ON(BrowserThread::IO); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
MessageInfo message_info; MessageInfo message_info;
CastMessageToMessageInfo(message, &message_info); CastMessageToMessageInfo(message, &message_info);
......
...@@ -60,11 +60,49 @@ class CastChannelAPI : public BrowserContextKeyedAPI, ...@@ -60,11 +60,49 @@ class CastChannelAPI : public BrowserContextKeyedAPI,
// Sends an event to the extension's EventRouter, if it exists. // Sends an event to the extension's EventRouter, if it exists.
void SendEvent(const std::string& extension_id, std::unique_ptr<Event> event); void SendEvent(const std::string& extension_id, std::unique_ptr<Event> event);
// Registers |extension_id| with |observer_| and returns |observer_|.
cast_channel::CastSocket::Observer* GetObserver(
const std::string& extension_id,
scoped_refptr<cast_channel::Logger> logger);
private: private:
friend class BrowserContextKeyedAPIFactory<CastChannelAPI>; friend class BrowserContextKeyedAPIFactory<CastChannelAPI>;
friend class ::CastChannelAPITest; friend class ::CastChannelAPITest;
friend class CastTransportDelegate; friend class CastTransportDelegate;
// Defines a callback used to send events to the extension's
// EventRouter.
// Parameter #0 is a unique pointer to the event payload.
using EventDispatchCallback = base::Callback<void(std::unique_ptr<Event>)>;
// Receives incoming messages and errors and provides additional API context.
class CastMessageHandler : public cast_channel::CastSocket::Observer {
public:
CastMessageHandler(const EventDispatchCallback& ui_dispatch_cb,
scoped_refptr<cast_channel::Logger> logger);
~CastMessageHandler() override;
// CastSocket::Observer implementation.
void OnError(const cast_channel::CastSocket& socket,
cast_channel::ChannelError error_state) override;
void OnMessage(const cast_channel::CastSocket& socket,
const cast_channel::CastMessage& message) override;
void RegisterExtensionId(const std::string& extension_id);
private:
// Callback for sending events to the extension.
// Should be bound to a weak pointer, to prevent any use-after-free
// conditions.
EventDispatchCallback const ui_dispatch_cb_;
// Logger object for reporting error details.
scoped_refptr<cast_channel::Logger> logger_;
THREAD_CHECKER(thread_checker_);
DISALLOW_COPY_AND_ASSIGN(CastMessageHandler);
};
~CastChannelAPI() override; ~CastChannelAPI() override;
// BrowserContextKeyedAPI implementation. // BrowserContextKeyedAPI implementation.
...@@ -72,6 +110,9 @@ class CastChannelAPI : public BrowserContextKeyedAPI, ...@@ -72,6 +110,9 @@ class CastChannelAPI : public BrowserContextKeyedAPI,
content::BrowserContext* const browser_context_; content::BrowserContext* const browser_context_;
std::unique_ptr<cast_channel::CastSocket> socket_for_test_; std::unique_ptr<cast_channel::CastSocket> socket_for_test_;
// Created and destroyed on the IO thread.
std::unique_ptr<CastMessageHandler, content::BrowserThread::DeleteOnIOThread>
observer_;
DISALLOW_COPY_AND_ASSIGN(CastChannelAPI); DISALLOW_COPY_AND_ASSIGN(CastChannelAPI);
}; };
...@@ -96,8 +137,9 @@ class CastChannelAsyncApiFunction : public AsyncApiFunction { ...@@ -96,8 +137,9 @@ class CastChannelAsyncApiFunction : public AsyncApiFunction {
void SetResultFromError(int channel_id, void SetResultFromError(int channel_id,
api::cast_channel::ChannelError error); api::cast_channel::ChannelError error);
// Manages creating and removing Cast sockets. // Raw pointer of leaky singleton CastSocketService, which manages creating
scoped_refptr<cast_channel::CastSocketService> cast_socket_service_; // and removing Cast sockets.
cast_channel::CastSocketService* cast_socket_service_;
private: private:
// Sets the function result from |channel_info|. // Sets the function result from |channel_info|.
...@@ -120,35 +162,6 @@ class CastChannelOpenFunction : public CastChannelAsyncApiFunction { ...@@ -120,35 +162,6 @@ class CastChannelOpenFunction : public CastChannelAsyncApiFunction {
private: private:
DECLARE_EXTENSION_FUNCTION("cast.channel.open", CAST_CHANNEL_OPEN) DECLARE_EXTENSION_FUNCTION("cast.channel.open", CAST_CHANNEL_OPEN)
// Defines a callback used to send events to the extension's
// EventRouter.
// Parameter #0 is a scoped pointer to the event payload.
using EventDispatchCallback = base::Callback<void(std::unique_ptr<Event>)>;
// Receives incoming messages and errors and provides additional API context.
class CastMessageHandler : public cast_channel::CastSocket::Observer {
public:
CastMessageHandler(const EventDispatchCallback& ui_dispatch_cb,
scoped_refptr<cast_channel::Logger> logger);
~CastMessageHandler() override;
// CastSocket::Observer implementation.
void OnError(const cast_channel::CastSocket& socket,
cast_channel::ChannelError error_state) override;
void OnMessage(const cast_channel::CastSocket& socket,
const cast_channel::CastMessage& message) override;
private:
// Callback for sending events to the extension.
// Should be bound to a weak pointer, to prevent any use-after-free
// conditions.
EventDispatchCallback const ui_dispatch_cb_;
// Logger object for reporting error details.
scoped_refptr<cast_channel::Logger> logger_;
DISALLOW_COPY_AND_ASSIGN(CastMessageHandler);
};
// Validates that |connect_info| represents a valid IP end point and returns a // Validates that |connect_info| represents a valid IP end point and returns a
// new IPEndPoint if so. Otherwise returns nullptr. // new IPEndPoint if so. Otherwise returns nullptr.
static net::IPEndPoint* ParseConnectInfo( static net::IPEndPoint* ParseConnectInfo(
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "chrome/browser/ui/browser.h" #include "chrome/browser/ui/browser.h"
#include "components/cast_channel/cast_socket.h" #include "components/cast_channel/cast_socket.h"
#include "components/cast_channel/cast_socket_service.h" #include "components/cast_channel/cast_socket_service.h"
#include "components/cast_channel/cast_socket_service_factory.h"
#include "components/cast_channel/cast_test_util.h" #include "components/cast_channel/cast_test_util.h"
#include "components/cast_channel/logger.h" #include "components/cast_channel/logger.h"
#include "components/cast_channel/proto/cast_channel.pb.h" #include "components/cast_channel/proto/cast_channel.pb.h"
...@@ -145,8 +144,11 @@ class CastChannelAPITest : public ExtensionApiTest { ...@@ -145,8 +144,11 @@ class CastChannelAPITest : public ExtensionApiTest {
.WillOnce(Return(ReadyState::OPEN)) .WillOnce(Return(ReadyState::OPEN))
.RetiresOnSaturation(); .RetiresOnSaturation();
EXPECT_CALL(*mock_cast_socket_, ready_state()) EXPECT_CALL(*mock_cast_socket_, ready_state())
.WillOnce(Return(ReadyState::CLOSED)); .Times(2)
.WillRepeatedly(Return(ReadyState::CLOSED));
} }
EXPECT_CALL(*mock_cast_socket_, Close(_))
.WillOnce(InvokeCompletionCallback<0>(net::OK));
} }
extensions::CastChannelAPI* GetApi() { extensions::CastChannelAPI* GetApi() {
...@@ -154,9 +156,7 @@ class CastChannelAPITest : public ExtensionApiTest { ...@@ -154,9 +156,7 @@ class CastChannelAPITest : public ExtensionApiTest {
} }
cast_channel::CastSocketService* GetCastSocketService() { cast_channel::CastSocketService* GetCastSocketService() {
return cast_channel::CastSocketServiceFactory::GetForBrowserContext( return cast_channel::CastSocketService::GetInstance();
profile())
.get();
} }
// Logs some bogus error details and calls the OnError handler. // Logs some bogus error details and calls the OnError handler.
...@@ -224,7 +224,7 @@ ACTION_P2(InvokeObserverOnError, api_test, cast_socket_service) { ...@@ -224,7 +224,7 @@ ACTION_P2(InvokeObserverOnError, api_test, cast_socket_service) {
content::BrowserThread::PostTask( content::BrowserThread::PostTask(
content::BrowserThread::IO, FROM_HERE, content::BrowserThread::IO, FROM_HERE,
base::Bind(&CastChannelAPITest::DoCallOnError, base::Unretained(api_test), base::Bind(&CastChannelAPITest::DoCallOnError, base::Unretained(api_test),
base::RetainedRef(cast_socket_service))); base::Unretained(cast_socket_service)));
} }
// TODO(kmarshall): Win Dbg has a workaround that makes RunExtensionSubtest // TODO(kmarshall): Win Dbg has a workaround that makes RunExtensionSubtest
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment