Commit af79a49b authored by mfoltz@chromium.org's avatar mfoltz@chromium.org

Remove weak pointers from CastSocket by explicitly tracking and resetting...

Remove weak pointers from CastSocket by explicitly tracking and resetting callbacks created inside the class.

Ensure all sockets are closed and callbacks reset in all relevant code paths: Close(), CloseWithError(), and the dtor.

Review URL: https://codereview.chromium.org/417403002

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@287477 0039d316-1c4b-4281-b951-d872f2087c98
parent 7f701013
......@@ -65,7 +65,7 @@ void FillChannelInfo(const CastSocket& socket, ChannelInfo* channel_info) {
bool IsValidConnectInfoPort(const ConnectInfo& connect_info) {
return connect_info.port > 0 && connect_info.port <
std::numeric_limits<unsigned short>::max();
std::numeric_limits<uint16_t>::max();
}
bool IsValidConnectInfoAuth(const ConnectInfo& connect_info) {
......@@ -162,7 +162,8 @@ CastSocket* CastChannelAsyncApiFunction::GetSocketOrCompleteWithError(
int channel_id) {
CastSocket* socket = GetSocket(channel_id);
if (!socket) {
SetResultFromError(cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID);
SetResultFromError(channel_id,
cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID);
AsyncWorkCompleted();
}
return socket;
......@@ -183,21 +184,24 @@ void CastChannelAsyncApiFunction::RemoveSocket(int channel_id) {
manager_->Remove(extension_->id(), channel_id);
}
void CastChannelAsyncApiFunction::SetResultFromSocket(int channel_id) {
CastSocket* socket = GetSocket(channel_id);
DCHECK(socket);
void CastChannelAsyncApiFunction::SetResultFromSocket(
const CastSocket& socket) {
ChannelInfo channel_info;
FillChannelInfo(*socket, &channel_info);
error_ = socket->error_state();
FillChannelInfo(socket, &channel_info);
error_ = socket.error_state();
SetResultFromChannelInfo(channel_info);
}
void CastChannelAsyncApiFunction::SetResultFromError(ChannelError error) {
void CastChannelAsyncApiFunction::SetResultFromError(int channel_id,
ChannelError error) {
ChannelInfo channel_info;
channel_info.channel_id = -1;
channel_info.channel_id = channel_id;
channel_info.url = "";
channel_info.ready_state = cast_channel::READY_STATE_CLOSED;
channel_info.error_state = error;
channel_info.connect_info.ip_address = "";
channel_info.connect_info.port = 0;
channel_info.connect_info.auth = cast_channel::CHANNEL_AUTH_TYPE_SSL;
SetResultFromChannelInfo(channel_info);
error_ = error;
}
......@@ -338,7 +342,13 @@ void CastChannelOpenFunction::AsyncWorkStart() {
void CastChannelOpenFunction::OnOpen(int result) {
DCHECK_CURRENTLY_ON(BrowserThread::IO);
VLOG(1) << "Connect finished, OnOpen invoked.";
SetResultFromSocket(new_channel_id_);
CastSocket* socket = GetSocket(new_channel_id_);
if (!socket) {
SetResultFromError(new_channel_id_,
cast_channel::CHANNEL_ERROR_CONNECT_ERROR);
} else {
SetResultFromSocket(*socket);
}
AsyncWorkCompleted();
}
......@@ -382,10 +392,13 @@ void CastChannelSendFunction::AsyncWorkStart() {
void CastChannelSendFunction::OnSend(int result) {
DCHECK_CURRENTLY_ON(BrowserThread::IO);
if (result < 0) {
SetResultFromError(cast_channel::CHANNEL_ERROR_SOCKET_ERROR);
int channel_id = params_->channel.channel_id;
CastSocket* socket = GetSocket(channel_id);
if (result < 0 || !socket) {
SetResultFromError(channel_id,
cast_channel::CHANNEL_ERROR_SOCKET_ERROR);
} else {
SetResultFromSocket(params_->channel.channel_id);
SetResultFromSocket(*socket);
}
AsyncWorkCompleted();
}
......@@ -410,12 +423,16 @@ void CastChannelCloseFunction::AsyncWorkStart() {
void CastChannelCloseFunction::OnClose(int result) {
DCHECK_CURRENTLY_ON(BrowserThread::IO);
VLOG(1) << "CastChannelCloseFunction::OnClose result = " << result;
if (result < 0) {
SetResultFromError(cast_channel::CHANNEL_ERROR_SOCKET_ERROR);
int channel_id = params_->channel.channel_id;
CastSocket* socket = GetSocket(channel_id);
if (result < 0 || !socket) {
SetResultFromError(channel_id,
cast_channel::CHANNEL_ERROR_SOCKET_ERROR);
} else {
int channel_id = params_->channel.channel_id;
SetResultFromSocket(channel_id);
SetResultFromSocket(*socket);
// This will delete |socket|.
RemoveSocket(channel_id);
socket = NULL;
}
AsyncWorkCompleted();
}
......
......@@ -98,12 +98,13 @@ class CastChannelAsyncApiFunction : public AsyncApiFunction {
// manager.
void RemoveSocket(int channel_id);
// Sets the function result to a ChannelInfo obtained from the state of the
// CastSocket corresponding to |channel_id|.
void SetResultFromSocket(int channel_id);
// Sets the function result to a ChannelInfo obtained from the state of
// |socket|.
void SetResultFromSocket(const cast_channel::CastSocket& socket);
// Sets the function result to a ChannelInfo with |error|.
void SetResultFromError(cast_channel::ChannelError error);
// Sets the function result to a ChannelInfo populated with |channel_id| and
// |error|.
void SetResultFromError(int channel_id, cast_channel::ChannelError error);
// Returns the socket corresponding to |channel_id| if one exists, or null
// otherwise.
......
......@@ -20,6 +20,9 @@
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gmock_mutant.h"
// TODO(mfoltz): Mock out the ApiResourceManager to resolve threading issues
// (crbug.com/398242) and simulate unloading of the extension.
namespace cast_channel = extensions::core_api::cast_channel;
using cast_channel::CastSocket;
using cast_channel::ChannelError;
......@@ -69,11 +72,6 @@ class MockCastSocket : public CastSocket {
base::TimeDelta::FromMilliseconds(kTimeoutMs)) {}
virtual ~MockCastSocket() {}
virtual bool CalledOnValidThread() const OVERRIDE {
// Always return true in testing.
return true;
}
MOCK_METHOD1(Connect, void(const net::CompletionCallback& callback));
MOCK_METHOD2(SendMessage, void(const MessageInfo& message,
const net::CompletionCallback& callback));
......
......@@ -99,7 +99,11 @@ CastSocket::CastSocket(const std::string& owner_extension_id,
current_read_buffer_ = header_read_buffer_;
}
CastSocket::~CastSocket() { }
CastSocket::~CastSocket() {
// Ensure that resources are freed but do not run pending callbacks to avoid
// any re-entrancy.
CloseInternal();
}
ReadyState CastSocket::ready_state() const {
return ready_state_;
......@@ -176,19 +180,24 @@ void CastSocket::Connect(const net::CompletionCallback& callback) {
connect_callback_ = callback;
connect_state_ = CONN_STATE_TCP_CONNECT;
if (connect_timeout_.InMicroseconds() > 0) {
GetTimer()->Start(
FROM_HERE,
connect_timeout_,
base::Bind(&CastSocket::CancelConnect, AsWeakPtr()));
DCHECK(connect_timeout_callback_.IsCancelled());
connect_timeout_callback_.Reset(base::Bind(&CastSocket::CancelConnect,
base::Unretained(this)));
GetTimer()->Start(FROM_HERE,
connect_timeout_,
connect_timeout_callback_.callback());
}
DoConnectLoop(net::OK);
}
void CastSocket::PostTaskToStartConnectLoop(int result) {
DCHECK(CalledOnValidThread());
base::MessageLoop::current()->PostTask(
FROM_HERE,
base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr(), result));
DCHECK(connect_loop_callback_.IsCancelled());
connect_loop_callback_.Reset(base::Bind(&CastSocket::DoConnectLoop,
base::Unretained(this),
result));
base::MessageLoop::current()->PostTask(FROM_HERE,
connect_loop_callback_.callback());
}
void CastSocket::CancelConnect() {
......@@ -204,6 +213,7 @@ void CastSocket::CancelConnect() {
// 1. Connect method: this starts the flow
// 2. Callback from network operations that finish asynchronously
void CastSocket::DoConnectLoop(int result) {
connect_loop_callback_.Cancel();
if (is_canceled_) {
LOG(ERROR) << "CANCELLED - Aborting DoConnectLoop.";
return;
......@@ -258,11 +268,12 @@ void CastSocket::DoConnectLoop(int result) {
}
int CastSocket::DoTcpConnect() {
DCHECK(connect_loop_callback_.IsCancelled());
VLOG_WITH_CONNECTION(1) << "DoTcpConnect";
connect_state_ = CONN_STATE_TCP_CONNECT_COMPLETE;
tcp_socket_ = CreateTcpSocket();
return tcp_socket_->Connect(
base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr()));
base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this)));
}
int CastSocket::DoTcpConnectComplete(int result) {
......@@ -277,11 +288,12 @@ int CastSocket::DoTcpConnectComplete(int result) {
}
int CastSocket::DoSslConnect() {
DCHECK(connect_loop_callback_.IsCancelled());
VLOG_WITH_CONNECTION(1) << "DoSslConnect";
connect_state_ = CONN_STATE_SSL_CONNECT_COMPLETE;
socket_ = CreateSslSocket(tcp_socket_.PassAs<net::StreamSocket>());
return socket_->Connect(
base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr()));
base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this)));
}
int CastSocket::DoSslConnectComplete(int result) {
......@@ -306,16 +318,28 @@ int CastSocket::DoAuthChallengeSend() {
// Post a task to send auth challenge so that DoWriteLoop is not nested inside
// DoConnectLoop. This is not strictly necessary but keeps the write loop
// code decoupled from connect loop code.
base::MessageLoop::current()->PostTask(
FROM_HERE,
DCHECK(send_auth_challenge_callback_.IsCancelled());
send_auth_challenge_callback_.Reset(
base::Bind(&CastSocket::SendCastMessageInternal,
AsWeakPtr(),
base::Unretained(this),
challenge_message,
base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr())));
base::Bind(&CastSocket::DoAuthChallengeSendWriteComplete,
base::Unretained(this))));
base::MessageLoop::current()->PostTask(
FROM_HERE,
send_auth_challenge_callback_.callback());
// Always return IO_PENDING since the result is always asynchronous.
return net::ERR_IO_PENDING;
}
void CastSocket::DoAuthChallengeSendWriteComplete(int result) {
send_auth_challenge_callback_.Cancel();
VLOG_WITH_CONNECTION(2) << "DoAuthChallengeSendWriteComplete: " << result;
DCHECK_GT(result, 0);
DCHECK_EQ(write_queue_.size(), 1UL);
PostTaskToStartConnectLoop(result);
}
int CastSocket::DoAuthChallengeSendComplete(int result) {
VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result;
if (result < 0)
......@@ -354,15 +378,46 @@ void CastSocket::DoConnectCallback(int result) {
}
void CastSocket::Close(const net::CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
CloseInternal();
RunPendingCallbacksOnClose();
// Run this callback last. It may delete the socket.
callback.Run(net::OK);
}
void CastSocket::CloseInternal() {
// TODO(mfoltz): Enforce this when CastChannelAPITest is rewritten to create
// and free sockets on the same thread. crbug.com/398242
// DCHECK(CalledOnValidThread());
if (ready_state_ == READY_STATE_CLOSED) {
return;
}
VLOG_WITH_CONNECTION(1) << "Close ReadyState = " << ready_state_;
tcp_socket_.reset();
socket_.reset();
cert_verifier_.reset();
transport_security_state_.reset();
GetTimer()->Stop();
// Cancel callbacks that we queued ourselves to re-enter the connect or read
// loops.
connect_loop_callback_.Cancel();
send_auth_challenge_callback_.Cancel();
read_loop_callback_.Cancel();
connect_timeout_callback_.Cancel();
ready_state_ = READY_STATE_CLOSED;
callback.Run(net::OK);
// |callback| can delete |this|
}
void CastSocket::RunPendingCallbacksOnClose() {
DCHECK_EQ(ready_state_, READY_STATE_CLOSED);
if (!connect_callback_.is_null()) {
connect_callback_.Run(net::ERR_CONNECTION_FAILED);
connect_callback_.Reset();
}
for (; !write_queue_.empty(); write_queue_.pop()) {
net::CompletionCallback& callback = write_queue_.front().callback;
callback.Run(net::ERR_FAILED);
callback.Reset();
}
}
void CastSocket::SendMessage(const MessageInfo& message,
......@@ -377,7 +432,6 @@ void CastSocket::SendMessage(const MessageInfo& message,
callback.Run(net::ERR_FAILED);
return;
}
SendCastMessageInternal(message_proto, callback);
}
......@@ -454,11 +508,10 @@ int CastSocket::DoWrite() {
<< request.io_buffer->BytesConsumed();
write_state_ = WRITE_STATE_WRITE_COMPLETE;
return socket_->Write(
request.io_buffer.get(),
request.io_buffer->BytesRemaining(),
base::Bind(&CastSocket::DoWriteLoop, AsWeakPtr()));
base::Bind(&CastSocket::DoWriteLoop, base::Unretained(this)));
}
int CastSocket::DoWriteComplete(int result) {
......@@ -483,21 +536,11 @@ int CastSocket::DoWriteComplete(int result) {
int CastSocket::DoWriteCallback() {
DCHECK(!write_queue_.empty());
write_state_ = WRITE_STATE_WRITE;
WriteRequest& request = write_queue_.front();
int bytes_consumed = request.io_buffer->BytesConsumed();
// If inside connection flow, then there should be exaclty one item in
// the write queue.
if (ready_state_ == READY_STATE_CONNECTING) {
write_queue_.pop();
DCHECK(write_queue_.empty());
PostTaskToStartConnectLoop(bytes_consumed);
} else {
WriteRequest& request = write_queue_.front();
request.callback.Run(bytes_consumed);
write_queue_.pop();
}
write_state_ = WRITE_STATE_WRITE;
request.callback.Run(bytes_consumed);
write_queue_.pop();
return net::OK;
}
......@@ -526,12 +569,15 @@ int CastSocket::DoWriteError(int result) {
void CastSocket::PostTaskToStartReadLoop() {
DCHECK(CalledOnValidThread());
base::MessageLoop::current()->PostTask(
FROM_HERE,
base::Bind(&CastSocket::StartReadLoop, AsWeakPtr()));
DCHECK(read_loop_callback_.IsCancelled());
read_loop_callback_.Reset(base::Bind(&CastSocket::StartReadLoop,
base::Unretained(this)));
base::MessageLoop::current()->PostTask(FROM_HERE,
read_loop_callback_.callback());
}
void CastSocket::StartReadLoop() {
read_loop_callback_.Cancel();
// Read loop would have already been started if read state is not NONE
if (read_state_ == READ_STATE_NONE) {
read_state_ = READ_STATE_READ;
......@@ -603,7 +649,7 @@ int CastSocket::DoRead() {
return socket_->Read(
current_read_buffer_.get(),
num_bytes_to_read,
base::Bind(&CastSocket::DoReadLoop, AsWeakPtr()));
base::Bind(&CastSocket::DoReadLoop, base::Unretained(this)));
}
int CastSocket::DoReadComplete(int result) {
......@@ -723,9 +769,9 @@ bool CastSocket::Serialize(const CastMessage& message_proto,
void CastSocket::CloseWithError(ChannelError error) {
DCHECK(CalledOnValidThread());
socket_.reset(NULL);
ready_state_ = READY_STATE_CLOSED;
CloseInternal();
error_state_ = error;
RunPendingCallbacksOnClose();
if (delegate_)
delegate_->OnError(this, error);
}
......@@ -756,7 +802,7 @@ void CastSocket::MessageHeader::SetMessageSize(size_t size) {
void CastSocket::MessageHeader::PrependToString(std::string* str) {
MessageHeader output = *this;
output.message_size = base::HostToNet32(message_size);
size_t header_size = base::checked_cast<size_t,uint32>(
size_t header_size = base::checked_cast<size_t, uint32>(
MessageHeader::header_size());
scoped_ptr<char, base::FreeDeleter> char_array(
static_cast<char*>(malloc(header_size)));
......@@ -769,7 +815,7 @@ void CastSocket::MessageHeader::PrependToString(std::string* str) {
void CastSocket::MessageHeader::ReadFromIOBuffer(
net::GrowableIOBuffer* buffer, MessageHeader* header) {
uint32 message_size;
size_t header_size = base::checked_cast<size_t,uint32>(
size_t header_size = base::checked_cast<size_t, uint32>(
MessageHeader::header_size());
memcpy(&message_size, buffer->StartOfBuffer(), header_size);
header->message_size = base::NetToHost32(message_size);
......
......@@ -9,11 +9,9 @@
#include <string>
#include "base/basictypes.h"
#include "base/callback.h"
#include "base/cancelable_callback.h"
#include "base/gtest_prod_util.h"
#include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h"
#include "base/threading/thread_checker.h"
#include "base/timer/timer.h"
#include "extensions/browser/api/api_resource.h"
......@@ -45,17 +43,16 @@ class CastMessage;
//
// NOTE: Not called "CastChannel" to reduce confusion with the generated API
// code.
class CastSocket : public ApiResource,
public base::SupportsWeakPtr<CastSocket> {
class CastSocket : public ApiResource {
public:
// Object to be informed of incoming messages and errors.
// Object to be informed of incoming messages and errors. The CastSocket that
// owns the delegate must not be deleted by it, only by the ApiResourceManager
// or in the callback to Close().
class Delegate {
public:
// An error occurred on the channel.
// It is fine to delete the socket in this callback.
virtual void OnError(const CastSocket* socket, ChannelError error) = 0;
// A message was received on the channel.
// Do NOT delete the socket in this callback.
virtual void OnMessage(const CastSocket* socket,
const MessageInfo& message) = 0;
......@@ -72,6 +69,8 @@ class CastSocket : public ApiResource,
CastSocket::Delegate* delegate,
net::NetLog* net_log,
const base::TimeDelta& connect_timeout);
// Ensures that the socket is closed.
virtual ~CastSocket();
// The IP endpoint for the destination of the channel.
......@@ -98,8 +97,8 @@ class CastSocket : public ApiResource,
virtual ChannelError error_state() const;
// Connects the channel to the peer. If successful, the channel will be in
// READY_STATE_OPEN.
// It is fine to delete the CastSocket object in |callback|.
// READY_STATE_OPEN. DO NOT delete the CastSocket object in |callback|.
// Instead use Close().
virtual void Connect(const net::CompletionCallback& callback);
// Sends a message over a connected channel. The channel must be in
......@@ -108,15 +107,15 @@ class CastSocket : public ApiResource,
// Note that if an error occurs the following happens:
// 1. Completion callbacks for all pending writes are invoked with error.
// 2. Delegate::OnError is called once.
// 3. Castsocket is closed.
// 3. CastSocket is closed.
//
// DO NOT delete the CastSocket object in write completion callback.
// But it is fine to delete the socket in Delegate::OnError
// DO NOT delete the CastSocket object in |callback|. Instead use Close().
virtual void SendMessage(const MessageInfo& message,
const net::CompletionCallback& callback);
// Closes the channel. On completion, the channel will be in
// READY_STATE_CLOSED.
// Closes the channel if not already closed. On completion, the channel will
// be in READY_STATE_CLOSED.
//
// It is fine to delete the CastSocket object in |callback|.
virtual void Close(const net::CompletionCallback& callback);
......@@ -221,6 +220,7 @@ class CastSocket : public ApiResource,
int DoSslConnectComplete(int result);
int DoAuthChallengeSend();
int DoAuthChallengeSendComplete(int result);
void DoAuthChallengeSendWriteComplete(int result);
int DoAuthChallengeReplyComplete(int result);
/////////////////////////////////////////////////////////////////////////////
......@@ -266,9 +266,17 @@ class CastSocket : public ApiResource,
// Parses the contents of body_read_buffer_ and sets current_message_ to
// the message received.
bool ProcessBody();
// Closes socket, updating the error state and signaling the delegate that
// |error| has occurred.
// Closes the socket, sets |error_state_| to |error| and signals the
// delegate that |error| has occurred.
void CloseWithError(ChannelError error);
// Frees resources and cancels pending callbacks. |ready_state_| will be set
// READY_STATE_CLOSED on completion. A no-op if |ready_state_| is already
// READY_STATE_CLOSED.
void CloseInternal();
// Runs pending callbacks that are passed into us to notify API clients that
// pending operations will fail because the socket has been closed.
void RunPendingCallbacksOnClose();
// Serializes the content of message_proto (with a header) to |message_data|.
static bool Serialize(const CastMessage& message_proto,
std::string* message_data);
......@@ -324,6 +332,8 @@ class CastSocket : public ApiResource,
// Callback invoked when the socket is connected or fails to connect.
net::CompletionCallback connect_callback_;
// Callback invoked by |connect_timeout_timer_| to cancel the connection.
base::CancelableClosure connect_timeout_callback_;
// Duration to wait before timing out.
base::TimeDelta connect_timeout_;
// Timer invoked when the connection has timed out.
......@@ -343,6 +353,16 @@ class CastSocket : public ApiResource,
// The current status of the channel.
ReadyState ready_state_;
// Task invoked to (re)start the connect loop. Canceled on entry to the
// connect loop.
base::CancelableClosure connect_loop_callback_;
// Task invoked to send the auth challenge. Canceled when the auth challenge
// has been sent.
base::CancelableClosure send_auth_challenge_callback_;
// Callback invoked to (re)start the read loop. Canceled on entry to the read
// loop.
base::CancelableClosure read_loop_callback_;
// Holds a message to be written to the socket. |callback| is invoked when the
// message is fully written or an error occurrs.
struct WriteRequest {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment