Commit b1c22e2d authored by Bin Zhao's avatar Bin Zhao Committed by Commit Bot

[cast_channel] clean up CastSocketImpl::OpenSocket() parameters

Create a CastSocketConfig struct to hold cast socket related settings and clean up CastSocketImpl::OpenSocket() parameters

Bug: 734855
Change-Id: I71a9847079a57f5f5c248c3b41354ac5f37bff66
Reviewed-on: https://chromium-review.googlesource.com/575793Reviewed-by: default avatarDerek Cheng <imcheng@chromium.org>
Reviewed-by: default avatarmark a. foltz <mfoltz@chromium.org>
Commit-Queue: Bin Zhao <zhaobin@chromium.org>
Cr-Commit-Position: refs/heads/master@{#488553}
parent 6726378f
...@@ -51,8 +51,8 @@ ...@@ -51,8 +51,8 @@
// Helper for logging data with remote host IP and authentication state. // Helper for logging data with remote host IP and authentication state.
// Assumes |ip_endpoint_| of type net::IPEndPoint and |channel_auth_| of enum // Assumes |ip_endpoint_| of type net::IPEndPoint and |channel_auth_| of enum
// type ChannelAuthType are available in the current scope. // type ChannelAuthType are available in the current scope.
#define CONNECTION_INFO() \ #define CONNECTION_INFO() \
"[" << ip_endpoint_.ToString() << ", auth=SSL_VERIFIED" \ "[" << open_params_.ip_endpoint.ToString() << ", auth=SSL_VERIFIED" \
<< "] " << "] "
#define VLOG_WITH_CONNECTION(level) VLOG(level) << CONNECTION_INFO() #define VLOG_WITH_CONNECTION(level) VLOG(level) << CONNECTION_INFO()
#define LOG_WITH_CONNECTION(level) LOG(level) << CONNECTION_INFO() #define LOG_WITH_CONNECTION(level) LOG(level) << CONNECTION_INFO()
...@@ -86,49 +86,28 @@ class FakeCertVerifier : public net::CertVerifier { ...@@ -86,49 +86,28 @@ class FakeCertVerifier : public net::CertVerifier {
} // namespace } // namespace
CastSocketImpl::CastSocketImpl(const net::IPEndPoint& ip_endpoint, CastSocketImpl::CastSocketImpl(const CastSocketOpenParams& open_params,
net::NetLog* net_log, const scoped_refptr<Logger>& logger)
base::TimeDelta timeout, : CastSocketImpl(open_params, logger, AuthContext::Create()) {}
base::TimeDelta liveness_timeout,
base::TimeDelta ping_interval, CastSocketImpl::CastSocketImpl(const CastSocketOpenParams& open_params,
const scoped_refptr<Logger>& logger,
uint64_t device_capabilities)
: CastSocketImpl(ip_endpoint,
net_log,
timeout,
liveness_timeout,
ping_interval,
logger,
device_capabilities,
AuthContext::Create()) {}
CastSocketImpl::CastSocketImpl(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
base::TimeDelta timeout,
base::TimeDelta liveness_timeout,
base::TimeDelta ping_interval,
const scoped_refptr<Logger>& logger, const scoped_refptr<Logger>& logger,
uint64_t device_capabilities,
const AuthContext& auth_context) const AuthContext& auth_context)
: channel_id_(0), : channel_id_(0),
ip_endpoint_(ip_endpoint), open_params_(open_params),
net_log_(net_log),
liveness_timeout_(liveness_timeout),
ping_interval_(ping_interval),
logger_(logger), logger_(logger),
auth_context_(auth_context), auth_context_(auth_context),
connect_timeout_(timeout),
connect_timeout_timer_(new base::OneShotTimer), connect_timeout_timer_(new base::OneShotTimer),
is_canceled_(false), is_canceled_(false),
device_capabilities_(device_capabilities),
audio_only_(false), audio_only_(false),
connect_state_(ConnectionState::START_CONNECT), connect_state_(ConnectionState::START_CONNECT),
error_state_(ChannelError::NONE), error_state_(ChannelError::NONE),
ready_state_(ReadyState::NONE), ready_state_(ReadyState::NONE),
auth_delegate_(nullptr) { auth_delegate_(nullptr) {
DCHECK(net_log_); DCHECK(open_params.ip_endpoint.address().IsValid());
DCHECK(open_params_.net_log);
net_log_source_.type = net::NetLogSourceType::SOCKET; net_log_source_.type = net::NetLogSourceType::SOCKET;
net_log_source_.id = net_log_->NextID(); net_log_source_.id = open_params_.net_log->NextID();
} }
CastSocketImpl::~CastSocketImpl() { CastSocketImpl::~CastSocketImpl() {
...@@ -150,7 +129,7 @@ ChannelError CastSocketImpl::error_state() const { ...@@ -150,7 +129,7 @@ ChannelError CastSocketImpl::error_state() const {
} }
const net::IPEndPoint& CastSocketImpl::ip_endpoint() const { const net::IPEndPoint& CastSocketImpl::ip_endpoint() const {
return ip_endpoint_; return open_params_.ip_endpoint;
} }
int CastSocketImpl::id() const { int CastSocketImpl::id() const {
...@@ -162,7 +141,7 @@ void CastSocketImpl::set_id(int id) { ...@@ -162,7 +141,7 @@ void CastSocketImpl::set_id(int id) {
} }
bool CastSocketImpl::keep_alive() const { bool CastSocketImpl::keep_alive() const {
return liveness_timeout_ > base::TimeDelta(); return open_params_.liveness_timeout > base::TimeDelta();
} }
bool CastSocketImpl::audio_only() const { bool CastSocketImpl::audio_only() const {
...@@ -170,9 +149,9 @@ bool CastSocketImpl::audio_only() const { ...@@ -170,9 +149,9 @@ bool CastSocketImpl::audio_only() const {
} }
std::unique_ptr<net::TCPClientSocket> CastSocketImpl::CreateTcpSocket() { std::unique_ptr<net::TCPClientSocket> CastSocketImpl::CreateTcpSocket() {
net::AddressList addresses(ip_endpoint_); net::AddressList addresses(open_params_.ip_endpoint);
return std::unique_ptr<net::TCPClientSocket>( return std::unique_ptr<net::TCPClientSocket>(new net::TCPClientSocket(
new net::TCPClientSocket(addresses, nullptr, net_log_, net_log_source_)); addresses, nullptr, open_params_.net_log, net_log_source_));
// Options cannot be set on the TCPClientSocket yet, because the // Options cannot be set on the TCPClientSocket yet, because the
// underlying platform socket will not be created until Bind() // underlying platform socket will not be created until Bind()
// or Connect() is called. // or Connect() is called.
...@@ -197,7 +176,7 @@ std::unique_ptr<net::SSLClientSocket> CastSocketImpl::CreateSslSocket( ...@@ -197,7 +176,7 @@ std::unique_ptr<net::SSLClientSocket> CastSocketImpl::CreateSslSocket(
new net::ClientSocketHandle); new net::ClientSocketHandle);
connection->SetSocket(std::move(socket)); connection->SetSocket(std::move(socket));
net::HostPortPair host_and_port = net::HostPortPair host_and_port =
net::HostPortPair::FromIPEndPoint(ip_endpoint_); net::HostPortPair::FromIPEndPoint(open_params_.ip_endpoint);
return net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( return net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
std::move(connection), host_and_port, ssl_config, context); std::move(connection), host_and_port, ssl_config, context);
...@@ -213,8 +192,8 @@ scoped_refptr<net::X509Certificate> CastSocketImpl::ExtractPeerCert() { ...@@ -213,8 +192,8 @@ scoped_refptr<net::X509Certificate> CastSocketImpl::ExtractPeerCert() {
bool CastSocketImpl::VerifyChannelPolicy(const AuthResult& result) { bool CastSocketImpl::VerifyChannelPolicy(const AuthResult& result) {
audio_only_ = (result.channel_policies & AuthResult::POLICY_AUDIO_ONLY) != 0; audio_only_ = (result.channel_policies & AuthResult::POLICY_AUDIO_ONLY) != 0;
if (audio_only_ && if (audio_only_ && (open_params_.device_capabilities &
(device_capabilities_ & CastDeviceCapability::VIDEO_OUT) != 0) { CastDeviceCapability::VIDEO_OUT) != 0) {
LOG_WITH_CONNECTION(ERROR) LOG_WITH_CONNECTION(ERROR)
<< "Audio only channel policy enforced for video out capable device"; << "Audio only channel policy enforced for video out capable device";
return false; return false;
...@@ -275,11 +254,11 @@ void CastSocketImpl::Connect() { ...@@ -275,11 +254,11 @@ void CastSocketImpl::Connect() {
SetConnectState(ConnectionState::TCP_CONNECT); SetConnectState(ConnectionState::TCP_CONNECT);
// Set up connection timeout. // Set up connection timeout.
if (connect_timeout_.InMicroseconds() > 0) { if (open_params_.connect_timeout.InMicroseconds() > 0) {
DCHECK(connect_timeout_callback_.IsCancelled()); DCHECK(connect_timeout_callback_.IsCancelled());
connect_timeout_callback_.Reset( connect_timeout_callback_.Reset(
base::Bind(&CastSocketImpl::OnConnectTimeout, base::Unretained(this))); base::Bind(&CastSocketImpl::OnConnectTimeout, base::Unretained(this)));
GetTimer()->Start(FROM_HERE, connect_timeout_, GetTimer()->Start(FROM_HERE, open_params_.connect_timeout,
connect_timeout_callback_.callback()); connect_timeout_callback_.callback());
} }
...@@ -447,8 +426,8 @@ int CastSocketImpl::DoSslConnectComplete(int result) { ...@@ -447,8 +426,8 @@ int CastSocketImpl::DoSslConnectComplete(int result) {
if (!transport_.get()) { if (!transport_.get()) {
// Create a channel transport if one wasn't already set (e.g. by test // Create a channel transport if one wasn't already set (e.g. by test
// code). // code).
transport_.reset(new CastTransportImpl(this->socket_.get(), channel_id_, transport_.reset(new CastTransportImpl(
ip_endpoint_, logger_)); this->socket_.get(), channel_id_, open_params_.ip_endpoint, logger_));
} }
auth_delegate_ = new AuthTransportDelegate(this); auth_delegate_ = new AuthTransportDelegate(this);
transport_->SetReadDelegate(base::WrapUnique(auth_delegate_)); transport_->SetReadDelegate(base::WrapUnique(auth_delegate_));
...@@ -562,9 +541,9 @@ void CastSocketImpl::DoConnectCallback() { ...@@ -562,9 +541,9 @@ void CastSocketImpl::DoConnectCallback() {
if (error_state_ == ChannelError::NONE) { if (error_state_ == ChannelError::NONE) {
SetReadyState(ReadyState::OPEN); SetReadyState(ReadyState::OPEN);
if (keep_alive()) { if (keep_alive()) {
auto* keep_alive_delegate = auto* keep_alive_delegate = new KeepAliveDelegate(
new KeepAliveDelegate(this, logger_, std::move(delegate_), this, logger_, std::move(delegate_), open_params_.ping_interval,
ping_interval_, liveness_timeout_); open_params_.liveness_timeout);
delegate_.reset(keep_alive_delegate); delegate_.reset(keep_alive_delegate);
} }
transport_->SetReadDelegate(std::move(delegate_)); transport_->SetReadDelegate(std::move(delegate_));
...@@ -655,5 +634,26 @@ void CastSocketImpl::CastSocketMessageDelegate::OnMessage( ...@@ -655,5 +634,26 @@ void CastSocketImpl::CastSocketMessageDelegate::OnMessage(
void CastSocketImpl::CastSocketMessageDelegate::Start() {} void CastSocketImpl::CastSocketMessageDelegate::Start() {}
CastSocketOpenParams::CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
base::TimeDelta connect_timeout)
: ip_endpoint(ip_endpoint),
net_log(net_log),
connect_timeout(connect_timeout),
device_capabilities(cast_channel::CastDeviceCapability::NONE) {}
CastSocketOpenParams::CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
base::TimeDelta connect_timeout,
base::TimeDelta liveness_timeout,
base::TimeDelta ping_interval,
uint64_t device_capabilities)
: ip_endpoint(ip_endpoint),
net_log(net_log),
connect_timeout(connect_timeout),
liveness_timeout(liveness_timeout),
ping_interval(ping_interval),
device_capabilities(device_capabilities) {}
} // namespace cast_channel } // namespace cast_channel
#undef VLOG_WITH_CONNECTION #undef VLOG_WITH_CONNECTION
...@@ -134,6 +134,47 @@ class CastSocket { ...@@ -134,6 +134,47 @@ class CastSocket {
virtual void RemoveObserver(Observer* observer) = 0; virtual void RemoveObserver(Observer* observer) = 0;
}; };
// Holds parameters necessary to open a Cast channel (CastSocket) to a Cast
// device.
struct CastSocketOpenParams {
// IP endpoint of the Cast device.
net::IPEndPoint ip_endpoint;
// Log of socket events.
net::NetLog* net_log;
// Connection timeout interval. If this value is not set, Cast socket will not
// report CONNECT_TIMEOUT error and may hang when connecting to a Cast device.
base::TimeDelta connect_timeout;
// Amount of idle time to wait before disconnecting. Cast socket will ping
// Cast device periodically at |ping_interval| to check liveness. If it does
// not receive response in |liveness_timeout|, it reports PING_TIMEOUT error.
// |liveness_timeout| should always be larger than or equal to
// |ping_interval|.
// If this value is not set, there is not periodic ping and Cast socket is
// always assumed alive.
base::TimeDelta liveness_timeout;
// Amount of idle time to wait before pinging the Cast device. See comments
// for |liveness_timeout|.
base::TimeDelta ping_interval;
// A bit vector representing the capabilities of the sink. The values are
// defined in components/cast_channel/cast_socket.h.
uint64_t device_capabilities;
CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
base::TimeDelta connect_timeout);
CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
base::TimeDelta connect_timeout,
base::TimeDelta liveness_timeout,
base::TimeDelta ping_interval,
uint64_t device_capabilities);
};
// This class implements a channel between Chrome and a Cast device using a TCP // This class implements a channel between Chrome and a Cast device using a TCP
// socket with SSL. The channel may authenticate that the receiver is a genuine // socket with SSL. The channel may authenticate that the receiver is a genuine
// Cast device. All CastSocketImpl objects must be used only on the IO thread. // Cast device. All CastSocketImpl objects must be used only on the IO thread.
...@@ -142,31 +183,11 @@ class CastSocket { ...@@ -142,31 +183,11 @@ class CastSocket {
// code. // code.
class CastSocketImpl : public CastSocket { class CastSocketImpl : public CastSocket {
public: public:
// Creates a new CastSocket that connects to |ip_endpoint|. CastSocketImpl(const CastSocketOpenParams& open_params,
// Parameters: const scoped_refptr<Logger>& logger);
// |ip_endpoint|: IP address of the remote host.
// |net_log|: Log of socket events. CastSocketImpl(const CastSocketOpenParams& open_params,
// |connect_timeout|: Connection timeout interval.
// |liveness_timeout|: Amount of idle time to wait before disconnecting.
// |ping_interval|: Amount of idle time to wait before pinging the receiver.
// |logger|: Log of cast channel events.
CastSocketImpl(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
base::TimeDelta connect_timeout,
base::TimeDelta liveness_timeout,
base::TimeDelta ping_interval,
const scoped_refptr<Logger>& logger,
uint64_t device_capabilities);
// For test-only.
// This constructor allows for setting a custom AuthContext.
CastSocketImpl(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log,
base::TimeDelta connect_timeout,
base::TimeDelta liveness_timeout,
base::TimeDelta ping_interval,
const scoped_refptr<Logger>& logger, const scoped_refptr<Logger>& logger,
uint64_t device_capabilities,
const AuthContext& auth_context); const AuthContext& auth_context);
// Ensures that the socket is closed. // Ensures that the socket is closed.
...@@ -313,20 +334,12 @@ class CastSocketImpl : public CastSocket { ...@@ -313,20 +334,12 @@ class CastSocketImpl : public CastSocket {
// The id of the channel. // The id of the channel.
int channel_id_; int channel_id_;
// The IP endpoint that the the channel is connected to.
net::IPEndPoint ip_endpoint_;
// The NetLog for this service.
net::NetLog* net_log_;
// The NetLog source for this service. // The NetLog source for this service.
net::NetLogSource net_log_source_; net::NetLogSource net_log_source_;
// Amount of idle time to wait before disconnecting. If |liveness_timeout_| is // Cast socket related settings.
// set, wraps |delegate_| with a KeepAliveDelegate. CastSocketOpenParams open_params_;
base::TimeDelta liveness_timeout_;
// Amount of idle time to wait before pinging the receiver, used to create
// KeepAliveDelegate.
base::TimeDelta ping_interval_;
// Shared logging object, used to log CastSocket events for diagnostics. // Shared logging object, used to log CastSocket events for diagnostics.
scoped_refptr<Logger> logger_; scoped_refptr<Logger> logger_;
...@@ -361,9 +374,6 @@ class CastSocketImpl : public CastSocket { ...@@ -361,9 +374,6 @@ class CastSocketImpl : public CastSocket {
// Callback invoked by |connect_timeout_timer_| to cancel the connection. // Callback invoked by |connect_timeout_timer_| to cancel the connection.
base::CancelableClosure connect_timeout_callback_; base::CancelableClosure connect_timeout_callback_;
// Duration to wait before timing out.
base::TimeDelta connect_timeout_;
// Timer invoked when the connection has timed out. // Timer invoked when the connection has timed out.
std::unique_ptr<base::Timer> connect_timeout_timer_; std::unique_ptr<base::Timer> connect_timeout_timer_;
...@@ -371,9 +381,6 @@ class CastSocketImpl : public CastSocket { ...@@ -371,9 +381,6 @@ class CastSocketImpl : public CastSocket {
// canceled. // canceled.
bool is_canceled_; bool is_canceled_;
// Capabilities declared by the cast device.
uint64_t device_capabilities_;
// Whether the channel is audio only as identified by the device // Whether the channel is audio only as identified by the device
// certificate during channel authentication. // certificate during channel authentication.
bool audio_only_; bool audio_only_;
......
...@@ -88,26 +88,19 @@ CastSocket* CastSocketService::GetSocket( ...@@ -88,26 +88,19 @@ CastSocket* CastSocketService::GetSocket(
return it == sockets_.end() ? nullptr : it->second.get(); return it == sockets_.end() ? nullptr : it->second.get();
} }
int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, int CastSocketService::OpenSocket(const CastSocketOpenParams& open_params,
net::NetLog* net_log,
const base::TimeDelta& connect_timeout,
const base::TimeDelta& liveness_timeout,
const base::TimeDelta& ping_interval,
uint64_t device_capabilities,
CastSocket::OnOpenCallback open_cb, CastSocket::OnOpenCallback open_cb,
CastSocket::Observer* observer) { CastSocket::Observer* observer) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
DCHECK(observer); DCHECK(observer);
auto* socket = GetSocket(ip_endpoint); auto* socket = GetSocket(open_params.ip_endpoint);
if (!socket) { if (!socket) {
// If cast socket does not exist. // If cast socket does not exist.
if (socket_for_test_) { if (socket_for_test_) {
socket = AddSocket(std::move(socket_for_test_)); socket = AddSocket(std::move(socket_for_test_));
} else { } else {
socket = new CastSocketImpl(ip_endpoint, net_log, connect_timeout, socket = new CastSocketImpl(open_params, logger_);
liveness_timeout, ping_interval, logger_,
device_capabilities);
AddSocket(base::WrapUnique(socket)); AddSocket(base::WrapUnique(socket));
} }
} }
...@@ -125,9 +118,11 @@ int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, ...@@ -125,9 +118,11 @@ int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint,
auto ping_interval = base::TimeDelta::FromSeconds(kPingIntervalInSecs); auto ping_interval = base::TimeDelta::FromSeconds(kPingIntervalInSecs);
auto liveness_timeout = auto liveness_timeout =
base::TimeDelta::FromSeconds(kConnectLivenessTimeoutSecs); base::TimeDelta::FromSeconds(kConnectLivenessTimeoutSecs);
return OpenSocket(ip_endpoint, net_log, connect_timeout, liveness_timeout, CastSocketOpenParams open_params(ip_endpoint, net_log, connect_timeout,
ping_interval, CastDeviceCapability::NONE, liveness_timeout, ping_interval,
std::move(open_cb), observer); CastDeviceCapability::NONE);
return OpenSocket(open_params, std::move(open_cb), observer);
} }
void CastSocketService::RemoveObserver(CastSocket::Observer* observer) { void CastSocketService::RemoveObserver(CastSocket::Observer* observer) {
......
...@@ -45,22 +45,11 @@ class CastSocketService { ...@@ -45,22 +45,11 @@ class CastSocketService {
// operation finishes. If cast socket with |ip_endpoint| already exists, // operation finishes. If cast socket with |ip_endpoint| already exists,
// invoke |open_cb| directly with existing socket's channel ID. // invoke |open_cb| directly with existing socket's channel ID.
// Parameters: // Parameters:
// |ip_endpoint|: IP address and port of the remote host. // |open_params|: Parameters necessary to open a Cast channel.
// |net_log|: Log of socket events.
// |connect_timeout|: Connection timeout interval.
// |liveness_timeout|: Liveness timeout for connect calls.
// |ping_interval|: Ping interval.
// |logger|: Log of cast channel events.
// |device_capabilities|: Device capabilities.
// |open_cb|: OnOpenCallback invoked when cast socket is opened. // |open_cb|: OnOpenCallback invoked when cast socket is opened.
// |observer|: Observer handles messages and errors on newly opened socket. // |observer|: Observer handles messages and errors on newly opened socket.
// Does not take ownership of |observer|. // Does not take ownership of |observer|.
int OpenSocket(const net::IPEndPoint& ip_endpoint, int OpenSocket(const CastSocketOpenParams& open_params,
net::NetLog* net_log,
const base::TimeDelta& connect_timeout,
const base::TimeDelta& liveness_timeout,
const base::TimeDelta& ping_interval,
uint64_t device_capabilities,
CastSocket::OnOpenCallback open_cb, CastSocket::OnOpenCallback open_cb,
CastSocket::Observer* observer); CastSocket::Observer* observer);
......
...@@ -177,31 +177,9 @@ class CompleteHandler { ...@@ -177,31 +177,9 @@ class CompleteHandler {
class TestCastSocketBase : public CastSocketImpl { class TestCastSocketBase : public CastSocketImpl {
public: public:
TestCastSocketBase(const net::IPEndPoint& ip_endpoint, TestCastSocketBase(const CastSocketOpenParams& open_params, Logger* logger)
int64_t timeout_ms, : CastSocketImpl(open_params, logger, AuthContext::Create()),
Logger* logger, ip_(open_params.ip_endpoint),
uint64_t device_capabilities)
: TestCastSocketBase(ip_endpoint,
timeout_ms,
logger,
new net::TestNetLog(),
device_capabilities) {}
TestCastSocketBase(const net::IPEndPoint& ip_endpoint,
int64_t timeout_ms,
Logger* logger,
net::TestNetLog* capturing_net_log,
uint64_t device_capabilities)
: CastSocketImpl(ip_endpoint,
capturing_net_log,
base::TimeDelta::FromMilliseconds(timeout_ms),
base::TimeDelta(),
base::TimeDelta(),
logger,
device_capabilities,
AuthContext::Create()),
capturing_net_log_(capturing_net_log),
ip_(ip_endpoint),
extract_cert_result_(true), extract_cert_result_(true),
verify_challenge_result_(true), verify_challenge_result_(true),
verify_challenge_disallow_(false), verify_challenge_disallow_(false),
...@@ -239,7 +217,6 @@ class TestCastSocketBase : public CastSocketImpl { ...@@ -239,7 +217,6 @@ class TestCastSocketBase : public CastSocketImpl {
base::Timer* GetTimer() override { return mock_timer_.get(); } base::Timer* GetTimer() override { return mock_timer_.get(); }
std::unique_ptr<net::TestNetLog> capturing_net_log_;
net::IPEndPoint ip_; net::IPEndPoint ip_;
// Simulated result of peer cert extraction. // Simulated result of peer cert extraction.
bool extract_cert_result_; bool extract_cert_result_;
...@@ -255,15 +232,18 @@ class TestCastSocketBase : public CastSocketImpl { ...@@ -255,15 +232,18 @@ class TestCastSocketBase : public CastSocketImpl {
class MockTestCastSocket : public TestCastSocketBase { class MockTestCastSocket : public TestCastSocketBase {
public: public:
static std::unique_ptr<MockTestCastSocket> CreateSecure( static std::unique_ptr<MockTestCastSocket> CreateSecure(
Logger* logger, const CastSocketOpenParams& open_params,
uint64_t device_capabilities = cast_channel::CastDeviceCapability::NONE) { Logger* logger) {
return std::unique_ptr<MockTestCastSocket>( return std::unique_ptr<MockTestCastSocket>(
new MockTestCastSocket(CreateIPEndPointForTest(), kDistantTimeoutMillis, new MockTestCastSocket(open_params, logger));
logger, device_capabilities));
} }
using TestCastSocketBase::TestCastSocketBase; using TestCastSocketBase::TestCastSocketBase;
MockTestCastSocket(const CastSocketOpenParams& open_params, Logger* logger)
: TestCastSocketBase(open_params, logger),
mock_net_log_(open_params.net_log) {}
~MockTestCastSocket() override {} ~MockTestCastSocket() override {}
void SetupMockTransport() { void SetupMockTransport() {
...@@ -324,9 +304,10 @@ class MockTestCastSocket : public TestCastSocketBase { ...@@ -324,9 +304,10 @@ class MockTestCastSocket : public TestCastSocketBase {
ssl_data_.reset(new net::StaticSocketDataProvider( ssl_data_.reset(new net::StaticSocketDataProvider(
reads_.data(), reads_.size(), writes_.data(), writes_.size())); reads_.data(), reads_.size(), writes_.data(), writes_.size()));
ssl_data_->set_connect_data(*ssl_connect_data_); ssl_data_->set_connect_data(*ssl_connect_data_);
// NOTE: net::MockTCPClientSocket inherits from net::SSLClientSocket !! // NOTE: net::MockTCPClientSocket inherits from net::SSLClientSocket !!
return std::unique_ptr<net::SSLClientSocket>(new net::MockTCPClientSocket( return std::unique_ptr<net::SSLClientSocket>(new net::MockTCPClientSocket(
net::AddressList(), capturing_net_log_.get(), ssl_data_.get())); net::AddressList(), mock_net_log_, ssl_data_.get()));
} }
// Simulated connect data // Simulated connect data
...@@ -339,6 +320,7 @@ class MockTestCastSocket : public TestCastSocketBase { ...@@ -339,6 +320,7 @@ class MockTestCastSocket : public TestCastSocketBase {
// If true, makes TCP connection process stall. For timeout testing. // If true, makes TCP connection process stall. For timeout testing.
bool tcp_unresponsive_ = false; bool tcp_unresponsive_ = false;
MockCastTransport* mock_transport_ = nullptr; MockCastTransport* mock_transport_ = nullptr;
net::NetLog* mock_net_log_ = nullptr;
DISALLOW_COPY_AND_ASSIGN(MockTestCastSocket); DISALLOW_COPY_AND_ASSIGN(MockTestCastSocket);
}; };
...@@ -346,11 +328,10 @@ class MockTestCastSocket : public TestCastSocketBase { ...@@ -346,11 +328,10 @@ class MockTestCastSocket : public TestCastSocketBase {
class SslTestCastSocket : public TestCastSocketBase { class SslTestCastSocket : public TestCastSocketBase {
public: public:
static std::unique_ptr<SslTestCastSocket> CreateSecure( static std::unique_ptr<SslTestCastSocket> CreateSecure(
Logger* logger, const CastSocketOpenParams& open_params,
uint64_t device_capabilities = cast_channel::CastDeviceCapability::NONE) { Logger* logger) {
return std::unique_ptr<SslTestCastSocket>( return std::unique_ptr<SslTestCastSocket>(
new SslTestCastSocket(CreateIPEndPointForTest(), kDistantTimeoutMillis, new SslTestCastSocket(open_params, logger));
logger, device_capabilities));
} }
using TestCastSocketBase::TestCastSocketBase; using TestCastSocketBase::TestCastSocketBase;
...@@ -372,7 +353,12 @@ class CastSocketTestBase : public testing::Test { ...@@ -372,7 +353,12 @@ class CastSocketTestBase : public testing::Test {
CastSocketTestBase() CastSocketTestBase()
: thread_bundle_(content::TestBrowserThreadBundle::IO_MAINLOOP), : thread_bundle_(content::TestBrowserThreadBundle::IO_MAINLOOP),
logger_(new Logger()), logger_(new Logger()),
observer_(new MockCastSocketObserver()) {} observer_(new MockCastSocketObserver()),
capturing_net_log_(new net::TestNetLog()),
socket_open_params_(
CreateIPEndPointForTest(),
capturing_net_log_.get(),
base::TimeDelta::FromMilliseconds(kDistantTimeoutMillis)) {}
~CastSocketTestBase() override {} ~CastSocketTestBase() override {}
void SetUp() override { EXPECT_CALL(*observer_, OnMessage(_, _)).Times(0); } void SetUp() override { EXPECT_CALL(*observer_, OnMessage(_, _)).Times(0); }
...@@ -387,6 +373,8 @@ class CastSocketTestBase : public testing::Test { ...@@ -387,6 +373,8 @@ class CastSocketTestBase : public testing::Test {
Logger* logger_; Logger* logger_;
CompleteHandler handler_; CompleteHandler handler_;
std::unique_ptr<MockCastSocketObserver> observer_; std::unique_ptr<MockCastSocketObserver> observer_;
std::unique_ptr<net::TestNetLog> capturing_net_log_;
CastSocketOpenParams socket_open_params_;
private: private:
DISALLOW_COPY_AND_ASSIGN(CastSocketTestBase); DISALLOW_COPY_AND_ASSIGN(CastSocketTestBase);
...@@ -405,7 +393,7 @@ class MockCastSocketTest : public CastSocketTestBase { ...@@ -405,7 +393,7 @@ class MockCastSocketTest : public CastSocketTestBase {
} }
void CreateCastSocketSecure() { void CreateCastSocketSecure() {
socket_ = MockTestCastSocket::CreateSecure(logger_); socket_ = MockTestCastSocket::CreateSecure(socket_open_params_, logger_);
} }
void HandleAuthHandshake() { void HandleAuthHandshake() {
...@@ -444,7 +432,7 @@ class SslCastSocketTest : public CastSocketTestBase { ...@@ -444,7 +432,7 @@ class SslCastSocketTest : public CastSocketTestBase {
} }
void CreateSockets() { void CreateSockets() {
socket_ = SslTestCastSocket::CreateSecure(logger_); socket_ = SslTestCastSocket::CreateSecure(socket_open_params_, logger_);
server_cert_ = server_cert_ =
net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem"); net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem");
......
...@@ -286,15 +286,18 @@ void CastChannelOpenFunction::AsyncWorkStart() { ...@@ -286,15 +286,18 @@ void CastChannelOpenFunction::AsyncWorkStart() {
auto* observer = auto* observer =
api_->GetObserver(extension_->id(), cast_socket_service_->GetLogger()); api_->GetObserver(extension_->id(), cast_socket_service_->GetLogger());
cast_socket_service_->OpenSocket( cast_channel::CastSocketOpenParams open_params(
*ip_endpoint_, ExtensionsBrowserClient::Get()->GetNetLog(), *ip_endpoint_, ExtensionsBrowserClient::Get()->GetNetLog(),
base::TimeDelta::FromMilliseconds(connect_info.timeout.get() base::TimeDelta::FromMilliseconds(connect_info.timeout.get()
? *connect_info.timeout ? *connect_info.timeout
: kDefaultConnectTimeoutMillis), : kDefaultConnectTimeoutMillis),
liveness_timeout_, ping_interval_, liveness_timeout_, ping_interval_,
connect_info.capabilities.get() ? *connect_info.capabilities connect_info.capabilities.get() ? *connect_info.capabilities
: CastDeviceCapability::NONE, : CastDeviceCapability::NONE);
base::Bind(&CastChannelOpenFunction::OnOpen, this), observer);
cast_socket_service_->OpenSocket(
open_params, base::Bind(&CastChannelOpenFunction::OnOpen, this),
observer);
} }
void CastChannelOpenFunction::OnOpen(int channel_id, ChannelError result) { void CastChannelOpenFunction::OnOpen(int channel_id, ChannelError result) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment