Commit 5063b5f2 authored by Helen Li's avatar Helen Li Committed by Commit Bot

Add ReadIfReady to SOCKSClientSocket

This is to support mojo sockets. With mojo sockets, we always try to read from
mojo data pipes, and we need a way to cancel pending reads without having
buffered data in socket subclasses.

Bug: 875855
Change-Id: I1b34c15ff9b848517185caa222cce9109497b78e
Reviewed-on: https://chromium-review.googlesource.com/1183996
Commit-Queue: Helen Li <xunjieli@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Cr-Commit-Position: refs/heads/master@{#585184}
parent 27d8383e
...@@ -1174,7 +1174,13 @@ int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf, ...@@ -1174,7 +1174,13 @@ int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
if (read_data_.mode == ASYNC) { if (read_data_.mode == ASYNC) {
DCHECK(!callback.is_null()); DCHECK(!callback.is_null());
read_data_.mode = SYNCHRONOUS; read_data_.mode = SYNCHRONOUS;
RunCallbackAsync(std::move(callback), result); pending_read_if_ready_callback_ = std::move(callback);
// base::Unretained() is safe here because RunCallbackAsync will wrap it
// with a callback associated with a weak ptr.
RunCallbackAsync(
base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
base::Unretained(this)),
result);
return ERR_IO_PENDING; return ERR_IO_PENDING;
} }
...@@ -1195,6 +1201,13 @@ int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf, ...@@ -1195,6 +1201,13 @@ int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
return result; return result;
} }
void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
// If ReadIfReady is already canceled, do nothing.
if (!pending_read_if_ready_callback_)
return;
std::move(pending_read_if_ready_callback_).Run(result);
}
MockProxyClientSocket::MockProxyClientSocket( MockProxyClientSocket::MockProxyClientSocket(
std::unique_ptr<ClientSocketHandle> transport_socket, std::unique_ptr<ClientSocketHandle> transport_socket,
HttpAuthController* auth_controller, HttpAuthController* auth_controller,
......
...@@ -700,6 +700,10 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { ...@@ -700,6 +700,10 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
int ReadIfReadyImpl(IOBuffer* buf, int ReadIfReadyImpl(IOBuffer* buf,
int buf_len, int buf_len,
CompletionOnceCallback callback); CompletionOnceCallback callback);
// Helper method to run |pending_read_if_ready_callback_| if it is not null.
void RunReadIfReadyCallback(int result);
AddressList addresses_; AddressList addresses_;
SocketDataProvider* data_; SocketDataProvider* data_;
......
...@@ -185,6 +185,26 @@ int SOCKSClientSocket::Read(IOBuffer* buf, ...@@ -185,6 +185,26 @@ int SOCKSClientSocket::Read(IOBuffer* buf,
return rv; return rv;
} }
int SOCKSClientSocket::ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(user_callback_.is_null());
DCHECK(!callback.is_null());
// Pass |callback| directly instead of wrapping it with OnReadWriteComplete.
// This is to avoid setting |was_ever_used_| unless data is actually read.
int rv = transport_->socket()->ReadIfReady(buf, buf_len, std::move(callback));
if (rv > 0)
was_ever_used_ = true;
return rv;
}
int SOCKSClientSocket::CancelReadIfReady() {
return transport_->socket()->CancelReadIfReady();
}
// Write is called by the transport layer. This can only be done if the // Write is called by the transport layer. This can only be done if the
// SOCKS handshake is complete. // SOCKS handshake is complete.
int SOCKSClientSocket::Write( int SOCKSClientSocket::Write(
......
...@@ -63,6 +63,10 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { ...@@ -63,6 +63,10 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
int Read(IOBuffer* buf, int Read(IOBuffer* buf,
int buf_len, int buf_len,
CompletionOnceCallback callback) override; CompletionOnceCallback callback) override;
int ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int CancelReadIfReady() override;
int Write(IOBuffer* buf, int Write(IOBuffer* buf,
int buf_len, int buf_len,
CompletionOnceCallback callback, CompletionOnceCallback callback,
......
...@@ -82,22 +82,25 @@ std::unique_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( ...@@ -82,22 +82,25 @@ std::unique_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
NetLog* net_log) { NetLog* net_log) {
TestCompletionCallback callback; TestCompletionCallback callback;
data_.reset(new StaticSocketDataProvider(reads, writes)); data_.reset(new StaticSocketDataProvider(reads, writes));
tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); auto socket = std::make_unique<MockTCPClientSocket>(address_list_, net_log,
data_.get());
socket->set_enable_read_if_ready(true);
int rv = tcp_sock_->Connect(callback.callback()); int rv = socket->Connect(callback.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback.WaitForResult(); rv = callback.WaitForResult();
EXPECT_THAT(rv, IsOk()); EXPECT_THAT(rv, IsOk());
EXPECT_TRUE(tcp_sock_->IsConnected()); EXPECT_TRUE(socket->IsConnected());
std::unique_ptr<ClientSocketHandle> connection(new ClientSocketHandle); auto connection = std::make_unique<ClientSocketHandle>();
// |connection| takes ownership of |tcp_sock_|, but keep a // |connection| takes ownership of |socket|, but |tcp_socket_| keeps a
// non-owning pointer to it. // non-owning pointer to it.
connection->SetSocket(std::unique_ptr<StreamSocket>(tcp_sock_)); tcp_sock_ = socket.get();
return std::unique_ptr<SOCKSClientSocket>(new SOCKSClientSocket( connection->SetSocket(std::move(socket));
return std::make_unique<SOCKSClientSocket>(
std::move(connection), std::move(connection),
HostResolver::RequestInfo(HostPortPair(hostname, port)), DEFAULT_PRIORITY, HostResolver::RequestInfo(HostPortPair(hostname, port)), DEFAULT_PRIORITY,
host_resolver, TRAFFIC_ANNOTATION_FOR_TESTS)); host_resolver, TRAFFIC_ANNOTATION_FOR_TESTS);
} }
// Implementation of HostResolver that never completes its resolve request. // Implementation of HostResolver that never completes its resolve request.
...@@ -183,55 +186,104 @@ class HangingHostResolverWithCancel : public HostResolver { ...@@ -183,55 +186,104 @@ class HangingHostResolverWithCancel : public HostResolver {
// Tests a complete handshake and the disconnection. // Tests a complete handshake and the disconnection.
TEST_F(SOCKSClientSocketTest, CompleteHandshake) { TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
const std::string payload_write = "random data"; // Run the test twice. Once with ReadIfReady() and once with Read().
const std::string payload_read = "moar random data"; for (bool use_read_if_ready : {true, false}) {
const std::string payload_write = "random data";
const std::string payload_read = "moar random data";
MockWrite data_writes[] = { MockWrite data_writes[] = {
MockWrite(ASYNC, kSOCKS4OkRequestLocalHostPort80, MockWrite(ASYNC, kSOCKS4OkRequestLocalHostPort80,
kSOCKS4OkRequestLocalHostPort80Length), kSOCKS4OkRequestLocalHostPort80Length),
MockWrite(ASYNC, payload_write.data(), payload_write.size())}; MockWrite(ASYNC, payload_write.data(), payload_write.size())};
MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS4OkReply, kSOCKS4OkReplyLength),
MockRead(ASYNC, payload_read.data(), payload_read.size())};
TestNetLog log;
user_sock_ = BuildMockSocket(data_reads, data_writes, host_resolver_.get(),
"localhost", 80, &log);
// At this state the TCP connection is completed but not the SOCKS
// handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
EXPECT_FALSE(user_sock_->IsConnected());
int rv = user_sock_->Connect(callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
TestNetLogEntry::List entries;
log.GetEntries(&entries);
EXPECT_TRUE(
LogContainsBeginEvent(entries, 0, NetLogEventType::SOCKS_CONNECT));
EXPECT_FALSE(user_sock_->IsConnected());
rv = callback_.WaitForResult();
EXPECT_THAT(rv, IsOk());
EXPECT_TRUE(user_sock_->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(
LogContainsEndEvent(entries, -1, NetLogEventType::SOCKS_CONNECT));
scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
memcpy(buffer->data(), payload_write.data(), payload_write.size());
rv = user_sock_->Write(buffer.get(), payload_write.size(),
callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
buffer = new IOBuffer(payload_read.size());
if (use_read_if_ready) {
rv = user_sock_->ReadIfReady(buffer.get(), payload_read.size(),
callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
EXPECT_EQ(net::OK, rv);
rv = user_sock_->ReadIfReady(buffer.get(), payload_read.size(),
callback_.callback());
} else {
rv = user_sock_->Read(buffer.get(), payload_read.size(),
callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
}
EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
user_sock_->Disconnect();
EXPECT_FALSE(tcp_sock_->IsConnected());
EXPECT_FALSE(user_sock_->IsConnected());
}
}
TEST_F(SOCKSClientSocketTest, CancelPendingReadIfReady) {
const std::string payload_read = "random data";
MockWrite data_writes[] = {MockWrite(ASYNC, kSOCKS4OkRequestLocalHostPort80,
kSOCKS4OkRequestLocalHostPort80Length)};
MockRead data_reads[] = { MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS4OkReply, kSOCKS4OkReplyLength), MockRead(ASYNC, kSOCKS4OkReply, kSOCKS4OkReplyLength),
MockRead(ASYNC, payload_read.data(), payload_read.size())}; MockRead(ASYNC, payload_read.data(), payload_read.size())};
TestNetLog log;
user_sock_ = BuildMockSocket(data_reads, data_writes, host_resolver_.get(), user_sock_ = BuildMockSocket(data_reads, data_writes, host_resolver_.get(),
"localhost", 80, &log); "localhost", 80, nullptr);
// At this state the TCP connection is completed but not the SOCKS handshake. // At this state the TCP connection is completed but not the SOCKS
// handshake.
EXPECT_TRUE(tcp_sock_->IsConnected()); EXPECT_TRUE(tcp_sock_->IsConnected());
EXPECT_FALSE(user_sock_->IsConnected()); EXPECT_FALSE(user_sock_->IsConnected());
int rv = user_sock_->Connect(callback_.callback()); int rv = user_sock_->Connect(callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
TestNetLogEntry::List entries;
log.GetEntries(&entries);
EXPECT_TRUE(
LogContainsBeginEvent(entries, 0, NetLogEventType::SOCKS_CONNECT));
EXPECT_FALSE(user_sock_->IsConnected());
rv = callback_.WaitForResult(); rv = callback_.WaitForResult();
EXPECT_THAT(rv, IsOk()); EXPECT_THAT(rv, IsOk());
EXPECT_TRUE(user_sock_->IsConnected()); EXPECT_TRUE(user_sock_->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsEndEvent(entries, -1, NetLogEventType::SOCKS_CONNECT));
scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
memcpy(buffer->data(), payload_write.data(), payload_write.size());
rv = user_sock_->Write(buffer.get(), payload_write.size(),
callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
buffer = new IOBuffer(payload_read.size()); auto buffer = base::MakeRefCounted<IOBuffer>(payload_read.size());
rv = rv = user_sock_->ReadIfReady(buffer.get(), payload_read.size(),
user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback()); callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult(); rv = user_sock_->CancelReadIfReady();
EXPECT_EQ(static_cast<int>(payload_read.size()), rv); EXPECT_EQ(net::OK, rv);
EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
user_sock_->Disconnect(); user_sock_->Disconnect();
EXPECT_FALSE(tcp_sock_->IsConnected()); EXPECT_FALSE(tcp_sock_->IsConnected());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment