Commit 5063b5f2 authored by Helen Li's avatar Helen Li Committed by Commit Bot

Add ReadIfReady to SOCKSClientSocket

This is to support mojo sockets. With mojo sockets, we always try to read from
mojo data pipes, and we need a way to cancel pending reads without having
buffered data in socket subclasses.

Bug: 875855
Change-Id: I1b34c15ff9b848517185caa222cce9109497b78e
Reviewed-on: https://chromium-review.googlesource.com/1183996
Commit-Queue: Helen Li <xunjieli@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Cr-Commit-Position: refs/heads/master@{#585184}
parent 27d8383e
......@@ -1174,7 +1174,13 @@ int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
if (read_data_.mode == ASYNC) {
DCHECK(!callback.is_null());
read_data_.mode = SYNCHRONOUS;
RunCallbackAsync(std::move(callback), result);
pending_read_if_ready_callback_ = std::move(callback);
// base::Unretained() is safe here because RunCallbackAsync will wrap it
// with a callback associated with a weak ptr.
RunCallbackAsync(
base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
base::Unretained(this)),
result);
return ERR_IO_PENDING;
}
......@@ -1195,6 +1201,13 @@ int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
return result;
}
void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
// If ReadIfReady is already canceled, do nothing.
if (!pending_read_if_ready_callback_)
return;
std::move(pending_read_if_ready_callback_).Run(result);
}
MockProxyClientSocket::MockProxyClientSocket(
std::unique_ptr<ClientSocketHandle> transport_socket,
HttpAuthController* auth_controller,
......
......@@ -700,6 +700,10 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
int ReadIfReadyImpl(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback);
// Helper method to run |pending_read_if_ready_callback_| if it is not null.
void RunReadIfReadyCallback(int result);
AddressList addresses_;
SocketDataProvider* data_;
......
......@@ -185,6 +185,26 @@ int SOCKSClientSocket::Read(IOBuffer* buf,
return rv;
}
int SOCKSClientSocket::ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(user_callback_.is_null());
DCHECK(!callback.is_null());
// Pass |callback| directly instead of wrapping it with OnReadWriteComplete.
// This is to avoid setting |was_ever_used_| unless data is actually read.
int rv = transport_->socket()->ReadIfReady(buf, buf_len, std::move(callback));
if (rv > 0)
was_ever_used_ = true;
return rv;
}
int SOCKSClientSocket::CancelReadIfReady() {
return transport_->socket()->CancelReadIfReady();
}
// Write is called by the transport layer. This can only be done if the
// SOCKS handshake is complete.
int SOCKSClientSocket::Write(
......
......@@ -63,6 +63,10 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int CancelReadIfReady() override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
......
......@@ -82,22 +82,25 @@ std::unique_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
NetLog* net_log) {
TestCompletionCallback callback;
data_.reset(new StaticSocketDataProvider(reads, writes));
tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
auto socket = std::make_unique<MockTCPClientSocket>(address_list_, net_log,
data_.get());
socket->set_enable_read_if_ready(true);
int rv = tcp_sock_->Connect(callback.callback());
int rv = socket->Connect(callback.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback.WaitForResult();
EXPECT_THAT(rv, IsOk());
EXPECT_TRUE(tcp_sock_->IsConnected());
EXPECT_TRUE(socket->IsConnected());
std::unique_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
// |connection| takes ownership of |tcp_sock_|, but keep a
auto connection = std::make_unique<ClientSocketHandle>();
// |connection| takes ownership of |socket|, but |tcp_socket_| keeps a
// non-owning pointer to it.
connection->SetSocket(std::unique_ptr<StreamSocket>(tcp_sock_));
return std::unique_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
tcp_sock_ = socket.get();
connection->SetSocket(std::move(socket));
return std::make_unique<SOCKSClientSocket>(
std::move(connection),
HostResolver::RequestInfo(HostPortPair(hostname, port)), DEFAULT_PRIORITY,
host_resolver, TRAFFIC_ANNOTATION_FOR_TESTS));
host_resolver, TRAFFIC_ANNOTATION_FOR_TESTS);
}
// Implementation of HostResolver that never completes its resolve request.
......@@ -183,55 +186,104 @@ class HangingHostResolverWithCancel : public HostResolver {
// Tests a complete handshake and the disconnection.
TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
const std::string payload_write = "random data";
const std::string payload_read = "moar random data";
// Run the test twice. Once with ReadIfReady() and once with Read().
for (bool use_read_if_ready : {true, false}) {
const std::string payload_write = "random data";
const std::string payload_read = "moar random data";
MockWrite data_writes[] = {
MockWrite(ASYNC, kSOCKS4OkRequestLocalHostPort80,
kSOCKS4OkRequestLocalHostPort80Length),
MockWrite(ASYNC, payload_write.data(), payload_write.size())};
MockWrite data_writes[] = {
MockWrite(ASYNC, kSOCKS4OkRequestLocalHostPort80,
kSOCKS4OkRequestLocalHostPort80Length),
MockWrite(ASYNC, payload_write.data(), payload_write.size())};
MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS4OkReply, kSOCKS4OkReplyLength),
MockRead(ASYNC, payload_read.data(), payload_read.size())};
TestNetLog log;
user_sock_ = BuildMockSocket(data_reads, data_writes, host_resolver_.get(),
"localhost", 80, &log);
// At this state the TCP connection is completed but not the SOCKS
// handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
EXPECT_FALSE(user_sock_->IsConnected());
int rv = user_sock_->Connect(callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
TestNetLogEntry::List entries;
log.GetEntries(&entries);
EXPECT_TRUE(
LogContainsBeginEvent(entries, 0, NetLogEventType::SOCKS_CONNECT));
EXPECT_FALSE(user_sock_->IsConnected());
rv = callback_.WaitForResult();
EXPECT_THAT(rv, IsOk());
EXPECT_TRUE(user_sock_->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(
LogContainsEndEvent(entries, -1, NetLogEventType::SOCKS_CONNECT));
scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
memcpy(buffer->data(), payload_write.data(), payload_write.size());
rv = user_sock_->Write(buffer.get(), payload_write.size(),
callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
buffer = new IOBuffer(payload_read.size());
if (use_read_if_ready) {
rv = user_sock_->ReadIfReady(buffer.get(), payload_read.size(),
callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
EXPECT_EQ(net::OK, rv);
rv = user_sock_->ReadIfReady(buffer.get(), payload_read.size(),
callback_.callback());
} else {
rv = user_sock_->Read(buffer.get(), payload_read.size(),
callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
}
EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
user_sock_->Disconnect();
EXPECT_FALSE(tcp_sock_->IsConnected());
EXPECT_FALSE(user_sock_->IsConnected());
}
}
TEST_F(SOCKSClientSocketTest, CancelPendingReadIfReady) {
const std::string payload_read = "random data";
MockWrite data_writes[] = {MockWrite(ASYNC, kSOCKS4OkRequestLocalHostPort80,
kSOCKS4OkRequestLocalHostPort80Length)};
MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS4OkReply, kSOCKS4OkReplyLength),
MockRead(ASYNC, payload_read.data(), payload_read.size())};
TestNetLog log;
user_sock_ = BuildMockSocket(data_reads, data_writes, host_resolver_.get(),
"localhost", 80, &log);
"localhost", 80, nullptr);
// At this state the TCP connection is completed but not the SOCKS handshake.
// At this state the TCP connection is completed but not the SOCKS
// handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
EXPECT_FALSE(user_sock_->IsConnected());
int rv = user_sock_->Connect(callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
TestNetLogEntry::List entries;
log.GetEntries(&entries);
EXPECT_TRUE(
LogContainsBeginEvent(entries, 0, NetLogEventType::SOCKS_CONNECT));
EXPECT_FALSE(user_sock_->IsConnected());
rv = callback_.WaitForResult();
EXPECT_THAT(rv, IsOk());
EXPECT_TRUE(user_sock_->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsEndEvent(entries, -1, NetLogEventType::SOCKS_CONNECT));
scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
memcpy(buffer->data(), payload_write.data(), payload_write.size());
rv = user_sock_->Write(buffer.get(), payload_write.size(),
callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
buffer = new IOBuffer(payload_read.size());
rv =
user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
auto buffer = base::MakeRefCounted<IOBuffer>(payload_read.size());
rv = user_sock_->ReadIfReady(buffer.get(), payload_read.size(),
callback_.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback_.WaitForResult();
EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
rv = user_sock_->CancelReadIfReady();
EXPECT_EQ(net::OK, rv);
user_sock_->Disconnect();
EXPECT_FALSE(tcp_sock_->IsConnected());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment