Commit 60fedc79 authored by Junbo Ke's avatar Junbo Ke Committed by Commit Bot

Propagate the peer address to network::TCPServerSocket in the callback.

StreamSocket::GetPeerAddress returns ERR_SOCKET_NOT_CONNECTED when the
client socket has already disconnected before OnAcceptCompleted, which
causes network::server::HttpServer to exit the accept loop because rv
is not net::OK.

Bug: b/149013559
Test: net_unittests, manual test on device
Change-Id: I86173aa204a249768efe2d69cf3a6dc8287a96d9
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2057964
Commit-Queue: Junbo Ke <juke@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Cr-Commit-Position: refs/heads/master@{#742423}
parent 68e89bd7
......@@ -25,4 +25,13 @@ int ServerSocket::ListenWithAddressAndPort(const std::string& address_string,
return Listen(IPEndPoint(ip_address, port), backlog);
}
int ServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
net::CompletionOnceCallback callback,
net::IPEndPoint* peer_address) {
if (peer_address) {
*peer_address = IPEndPoint();
}
return Accept(socket, std::move(callback));
}
} // namespace net
......@@ -42,6 +42,13 @@ class NET_EXPORT ServerSocket {
virtual int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback) = 0;
// Accepts connection. Callback is called when new connection is accepted.
// Note: |peer_address| may or may not be populated depending on the
// implementation.
virtual int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback,
IPEndPoint* peer_address);
private:
DISALLOW_COPY_AND_ASSIGN(ServerSocket);
};
......
......@@ -62,6 +62,12 @@ int TCPServerSocket::GetLocalAddress(IPEndPoint* address) const {
int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback) {
return Accept(socket, std::move(callback), nullptr);
}
int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback,
IPEndPoint* peer_address) {
DCHECK(socket);
DCHECK(!callback.is_null());
......@@ -72,16 +78,16 @@ int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
// It is safe to use base::Unretained(this). |socket_| is owned by this class,
// and the callback won't be run after |socket_| is destroyed.
CompletionOnceCallback accept_callback =
base::BindOnce(&TCPServerSocket::OnAcceptCompleted,
base::Unretained(this), socket, std::move(callback));
CompletionOnceCallback accept_callback = base::BindOnce(
&TCPServerSocket::OnAcceptCompleted, base::Unretained(this), socket,
peer_address, std::move(callback));
int result = socket_->Accept(&accepted_socket_, &accepted_address_,
std::move(accept_callback));
if (result != ERR_IO_PENDING) {
// |accept_callback| won't be called so we need to run
// ConvertAcceptedSocket() ourselves in order to do the conversion from
// |accepted_socket_| to |socket|.
result = ConvertAcceptedSocket(result, socket);
result = ConvertAcceptedSocket(result, socket, peer_address);
} else {
pending_accept_ = true;
}
......@@ -95,12 +101,16 @@ void TCPServerSocket::DetachFromThread() {
int TCPServerSocket::ConvertAcceptedSocket(
int result,
std::unique_ptr<StreamSocket>* output_accepted_socket) {
std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address) {
// Make sure the TCPSocket object is destroyed in any case.
std::unique_ptr<TCPSocket> temp_accepted_socket(std::move(accepted_socket_));
if (result != OK)
return result;
if (output_accepted_address)
*output_accepted_address = accepted_address_;
output_accepted_socket->reset(
new TCPClientSocket(std::move(temp_accepted_socket), accepted_address_));
......@@ -109,9 +119,11 @@ int TCPServerSocket::ConvertAcceptedSocket(
void TCPServerSocket::OnAcceptCompleted(
std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address,
CompletionOnceCallback forward_callback,
int result) {
result = ConvertAcceptedSocket(result, output_accepted_socket);
result = ConvertAcceptedSocket(result, output_accepted_socket,
output_accepted_address);
pending_accept_ = false;
std::move(forward_callback).Run(result);
}
......
......@@ -41,8 +41,11 @@ class NET_EXPORT TCPServerSocket : public ServerSocket {
int GetLocalAddress(IPEndPoint* address) const override;
int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback) override;
int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback,
IPEndPoint* peer_address) override;
// Detachs from the current thread, to allow the socket to be transferred to
// Detaches from the current thread, to allow the socket to be transferred to
// a new thread. Should only be called when the object is no longer used by
// the old thread.
void DetachFromThread();
......@@ -54,9 +57,11 @@ class NET_EXPORT TCPServerSocket : public ServerSocket {
// set to NULL in any case.
int ConvertAcceptedSocket(
int result,
std::unique_ptr<StreamSocket>* output_accepted_socket);
std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address);
// Completion callback for calling TCPSocket::Accept().
void OnAcceptCompleted(std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address,
CompletionOnceCallback forward_callback,
int result);
......
......@@ -79,12 +79,17 @@ TEST_F(TCPServerSocketTest, Accept) {
TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket;
int result = socket_.Accept(&accepted_socket, accept_callback.callback());
IPEndPoint peer_address;
int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address);
result = accept_callback.GetResult(result);
ASSERT_THAT(result, IsOk());
ASSERT_TRUE(accepted_socket.get() != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address());
......@@ -98,8 +103,10 @@ TEST_F(TCPServerSocketTest, AcceptAsync) {
TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket;
IPEndPoint peer_address;
ASSERT_THAT(socket_.Accept(&accepted_socket, accept_callback.callback()),
ASSERT_THAT(socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
......@@ -112,20 +119,51 @@ TEST_F(TCPServerSocketTest, AcceptAsync) {
EXPECT_TRUE(accepted_socket != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address());
}
// Test Accept() when client disconnects right after trying to connect.
TEST_F(TCPServerSocketTest, AcceptClientDisconnectAfterConnect) {
ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket;
IPEndPoint peer_address;
TestCompletionCallback connect_callback;
TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
NetLogSource());
int connect_result = connecting_socket.Connect(connect_callback.callback());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
int accept_result = socket_.Accept(&accepted_socket,
accept_callback.callback(), &peer_address);
connecting_socket.Disconnect();
EXPECT_THAT(accept_callback.GetResult(accept_result), IsOk());
EXPECT_TRUE(accepted_socket != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
}
// Accept two connections simultaneously.
TEST_F(TCPServerSocketTest, Accept2Connections) {
ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket;
IPEndPoint peer_address;
ASSERT_EQ(ERR_IO_PENDING,
socket_.Accept(&accepted_socket, accept_callback.callback()));
socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address));
TestCompletionCallback connect_callback;
TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
......@@ -142,7 +180,9 @@ TEST_F(TCPServerSocketTest, Accept2Connections) {
TestCompletionCallback accept_callback2;
std::unique_ptr<StreamSocket> accepted_socket2;
int result = socket_.Accept(&accepted_socket2, accept_callback2.callback());
IPEndPoint peer_address2;
int result = socket_.Accept(&accepted_socket2, accept_callback2.callback(),
&peer_address2);
result = accept_callback2.GetResult(result);
ASSERT_THAT(result, IsOk());
......@@ -153,8 +193,10 @@ TEST_F(TCPServerSocketTest, Accept2Connections) {
EXPECT_TRUE(accepted_socket2 != nullptr);
EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
EXPECT_EQ(peer_address.address(), local_address_.address());
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address());
EXPECT_EQ(peer_address2.address(), local_address_.address());
EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
local_address_.address());
}
......@@ -172,12 +214,17 @@ TEST_F(TCPServerSocketTest, AcceptIPv6) {
TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket;
int result = socket_.Accept(&accepted_socket, accept_callback.callback());
IPEndPoint peer_address;
int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address);
result = accept_callback.GetResult(result);
ASSERT_THAT(result, IsOk());
ASSERT_TRUE(accepted_socket.get() != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address());
......@@ -195,11 +242,16 @@ TEST_F(TCPServerSocketTest, AcceptIO) {
TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket;
int result = socket_.Accept(&accepted_socket, accept_callback.callback());
IPEndPoint peer_address;
int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address);
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
ASSERT_TRUE(accepted_socket.get() != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address());
......
......@@ -92,12 +92,8 @@ void TCPServerSocket::OnAcceptCompleted(int result) {
auto pending_accept = std::move(pending_accepts_queue_.front());
pending_accepts_queue_.erase(pending_accepts_queue_.begin());
net::IPEndPoint peer_addr;
if (result == net::OK) {
DCHECK(accepted_socket_);
result = accepted_socket_->GetPeerAddress(&peer_addr);
}
if (result == net::OK) {
mojo::DataPipe send_pipe;
mojo::DataPipe receive_pipe;
mojo::PendingRemote<mojom::TCPConnectedSocket> socket;
......@@ -110,7 +106,7 @@ void TCPServerSocket::OnAcceptCompleted(int result) {
delegate_->OnAccept(std::move(connected_socket),
socket.InitWithNewPipeAndPassReceiver());
std::move(pending_accept->callback)
.Run(result, peer_addr, std::move(socket),
.Run(result, accepted_address_, std::move(socket),
std::move(receive_pipe.consumer_handle),
std::move(send_pipe.producer_handle));
} else {
......@@ -128,7 +124,8 @@ void TCPServerSocket::ProcessNextAccept() {
int result =
socket_->Accept(&accepted_socket_,
base::BindRepeating(&TCPServerSocket::OnAcceptCompleted,
base::Unretained(this)));
base::Unretained(this)),
&accepted_address_);
if (result == net::ERR_IO_PENDING)
return;
OnAcceptCompleted(result);
......
......@@ -88,6 +88,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) TCPServerSocket
int backlog_;
std::vector<std::unique_ptr<PendingAccept>> pending_accepts_queue_;
std::unique_ptr<net::StreamSocket> accepted_socket_;
net::IPEndPoint accepted_address_;
net::NetworkTrafficAnnotationTag traffic_annotation_;
base::WeakPtrFactory<TCPServerSocket> weak_factory_{this};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment