Commit fb1021a7 authored by Ken MacKay's avatar Ken MacKay Committed by Commit Bot

[Chromecast] Make FakeStreamSocket more realistic

 * Don't allow pending reads to complete within a write call.
 * Signal EOS when the other end of the socket disconnects.

Merge-With: eureka-internal/381432
Change-Id: Ia091361fc0c72da52c76c9fcf98b7c24a04d9f13
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2129996Reviewed-by: default avatarYuchen Liu <yucliu@chromium.org>
Commit-Queue: Kenneth MacKay <kmackay@chromium.org>
Cr-Commit-Position: refs/heads/master@{#755031}
parent 615014e4
...@@ -8,8 +8,12 @@ ...@@ -8,8 +8,12 @@
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include "base/bind.h"
#include "base/callback_helpers.h" #include "base/callback_helpers.h"
#include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/memory/weak_ptr.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/socket/next_proto.h" #include "net/socket/next_proto.h"
...@@ -31,8 +35,11 @@ class SocketBuffer { ...@@ -31,8 +35,11 @@ class SocketBuffer {
int Read(char* data, size_t len, net::CompletionOnceCallback callback) { int Read(char* data, size_t len, net::CompletionOnceCallback callback) {
DCHECK(data); DCHECK(data);
DCHECK_GT(len, 0u); DCHECK_GT(len, 0u);
DCHECK(!callback.is_null()); DCHECK(callback);
if (data_.empty()) { if (data_.empty()) {
if (eos_) {
return 0;
}
pending_read_data_ = data; pending_read_data_ = data;
pending_read_len_ = len; pending_read_len_ = len;
pending_read_callback_ = std::move(callback); pending_read_callback_ = std::move(callback);
...@@ -51,7 +58,15 @@ class SocketBuffer { ...@@ -51,7 +58,15 @@ class SocketBuffer {
int result = ReadInternal(pending_read_data_, pending_read_len_); int result = ReadInternal(pending_read_data_, pending_read_len_);
pending_read_data_ = nullptr; pending_read_data_ = nullptr;
pending_read_len_ = 0; pending_read_len_ = 0;
std::move(pending_read_callback_).Run(result); PostReadCallback(std::move(pending_read_callback_), result);
}
}
// Called when the remote end of the fake connection disconnects.
void ReceiveEOS() {
eos_ = true;
if (pending_read_callback_ && data_.empty()) {
PostReadCallback(std::move(pending_read_callback_), 0);
} }
} }
...@@ -65,10 +80,25 @@ class SocketBuffer { ...@@ -65,10 +80,25 @@ class SocketBuffer {
return len; return len;
} }
void PostReadCallback(net::CompletionOnceCallback callback, int result) {
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&SocketBuffer::CallReadCallback,
weak_factory_.GetWeakPtr(),
std::move(callback), result));
}
// Need a member function to asynchronously call the read callback, so we
// can use weak ptr.
void CallReadCallback(net::CompletionOnceCallback callback, int result) {
std::move(callback).Run(result);
}
std::vector<char> data_; std::vector<char> data_;
char* pending_read_data_; char* pending_read_data_;
size_t pending_read_len_; size_t pending_read_len_;
net::CompletionOnceCallback pending_read_callback_; net::CompletionOnceCallback pending_read_callback_;
bool eos_ = false;
base::WeakPtrFactory<SocketBuffer> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(SocketBuffer); DISALLOW_COPY_AND_ASSIGN(SocketBuffer);
}; };
...@@ -82,7 +112,7 @@ FakeStreamSocket::FakeStreamSocket(const net::IPEndPoint& local_address) ...@@ -82,7 +112,7 @@ FakeStreamSocket::FakeStreamSocket(const net::IPEndPoint& local_address)
FakeStreamSocket::~FakeStreamSocket() { FakeStreamSocket::~FakeStreamSocket() {
if (peer_) { if (peer_) {
peer_->peer_ = nullptr; peer_->RemoteDisconnected();
} }
} }
...@@ -91,6 +121,11 @@ void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) { ...@@ -91,6 +121,11 @@ void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) {
peer_ = peer; peer_ = peer;
} }
void FakeStreamSocket::RemoteDisconnected() {
peer_ = nullptr;
buffer_->ReceiveEOS();
}
void FakeStreamSocket::SetBadSenderMode(bool bad_sender) { void FakeStreamSocket::SetBadSenderMode(bool bad_sender) {
bad_sender_mode_ = bad_sender; bad_sender_mode_ = bad_sender;
} }
......
...@@ -61,6 +61,8 @@ class FakeStreamSocket : public net::StreamSocket { ...@@ -61,6 +61,8 @@ class FakeStreamSocket : public net::StreamSocket {
void ApplySocketTag(const net::SocketTag& tag) override; void ApplySocketTag(const net::SocketTag& tag) override;
private: private:
void RemoteDisconnected();
const net::IPEndPoint local_address_; const net::IPEndPoint local_address_;
const std::unique_ptr<SocketBuffer> buffer_; const std::unique_ptr<SocketBuffer> buffer_;
FakeStreamSocket* peer_; FakeStreamSocket* peer_;
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/test/task_environment.h"
#include "chromecast/net/fake_stream_socket.h" #include "chromecast/net/fake_stream_socket.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
...@@ -45,6 +46,7 @@ class FakeStreamSocketTest : public ::testing::Test { ...@@ -45,6 +46,7 @@ class FakeStreamSocketTest : public ::testing::Test {
socket_2_(endpoint_2_) {} socket_2_(endpoint_2_) {}
~FakeStreamSocketTest() override {} ~FakeStreamSocketTest() override {}
base::test::TaskEnvironment task_environment_;
net::IPEndPoint endpoint_1_; net::IPEndPoint endpoint_1_;
FakeStreamSocket socket_1_; FakeStreamSocket socket_1_;
net::IPEndPoint endpoint_2_; net::IPEndPoint endpoint_2_;
...@@ -65,6 +67,7 @@ TEST_F(FakeStreamSocketTest, GetPeerAddressWithoutPeer) { ...@@ -65,6 +67,7 @@ TEST_F(FakeStreamSocketTest, GetPeerAddressWithoutPeer) {
TEST_F(FakeStreamSocketTest, GetPeerAddressWithPeer) { TEST_F(FakeStreamSocketTest, GetPeerAddressWithPeer) {
socket_1_.SetPeer(&socket_2_); socket_1_.SetPeer(&socket_2_);
socket_2_.SetPeer(&socket_1_);
net::IPEndPoint peer_address; net::IPEndPoint peer_address;
ASSERT_EQ(net::OK, socket_1_.GetPeerAddress(&peer_address)); ASSERT_EQ(net::OK, socket_1_.GetPeerAddress(&peer_address));
EXPECT_EQ(endpoint_2_, peer_address); EXPECT_EQ(endpoint_2_, peer_address);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment