Commit 60fedc79 authored by Junbo Ke's avatar Junbo Ke Committed by Commit Bot

Propagate the peer address to network::TCPServerSocket in the callback.

StreamSocket::GetPeerAddress returns ERR_SOCKET_NOT_CONNECTED when the
client socket has already disconnected before OnAcceptCompleted, which
causes network::server::HttpServer to exit the accept loop because rv
is not net::OK.

Bug: b/149013559
Test: net_unittests, manual test on device
Change-Id: I86173aa204a249768efe2d69cf3a6dc8287a96d9
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2057964
Commit-Queue: Junbo Ke <juke@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Cr-Commit-Position: refs/heads/master@{#742423}
parent 68e89bd7
...@@ -25,4 +25,13 @@ int ServerSocket::ListenWithAddressAndPort(const std::string& address_string, ...@@ -25,4 +25,13 @@ int ServerSocket::ListenWithAddressAndPort(const std::string& address_string,
return Listen(IPEndPoint(ip_address, port), backlog); return Listen(IPEndPoint(ip_address, port), backlog);
} }
int ServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
net::CompletionOnceCallback callback,
net::IPEndPoint* peer_address) {
if (peer_address) {
*peer_address = IPEndPoint();
}
return Accept(socket, std::move(callback));
}
} // namespace net } // namespace net
...@@ -42,6 +42,13 @@ class NET_EXPORT ServerSocket { ...@@ -42,6 +42,13 @@ class NET_EXPORT ServerSocket {
virtual int Accept(std::unique_ptr<StreamSocket>* socket, virtual int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback) = 0; CompletionOnceCallback callback) = 0;
// Accepts connection. Callback is called when new connection is accepted.
// Note: |peer_address| may or may not be populated depending on the
// implementation.
virtual int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback,
IPEndPoint* peer_address);
private: private:
DISALLOW_COPY_AND_ASSIGN(ServerSocket); DISALLOW_COPY_AND_ASSIGN(ServerSocket);
}; };
......
...@@ -62,6 +62,12 @@ int TCPServerSocket::GetLocalAddress(IPEndPoint* address) const { ...@@ -62,6 +62,12 @@ int TCPServerSocket::GetLocalAddress(IPEndPoint* address) const {
int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket, int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback) { CompletionOnceCallback callback) {
return Accept(socket, std::move(callback), nullptr);
}
int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback,
IPEndPoint* peer_address) {
DCHECK(socket); DCHECK(socket);
DCHECK(!callback.is_null()); DCHECK(!callback.is_null());
...@@ -72,16 +78,16 @@ int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket, ...@@ -72,16 +78,16 @@ int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
// It is safe to use base::Unretained(this). |socket_| is owned by this class, // It is safe to use base::Unretained(this). |socket_| is owned by this class,
// and the callback won't be run after |socket_| is destroyed. // and the callback won't be run after |socket_| is destroyed.
CompletionOnceCallback accept_callback = CompletionOnceCallback accept_callback = base::BindOnce(
base::BindOnce(&TCPServerSocket::OnAcceptCompleted, &TCPServerSocket::OnAcceptCompleted, base::Unretained(this), socket,
base::Unretained(this), socket, std::move(callback)); peer_address, std::move(callback));
int result = socket_->Accept(&accepted_socket_, &accepted_address_, int result = socket_->Accept(&accepted_socket_, &accepted_address_,
std::move(accept_callback)); std::move(accept_callback));
if (result != ERR_IO_PENDING) { if (result != ERR_IO_PENDING) {
// |accept_callback| won't be called so we need to run // |accept_callback| won't be called so we need to run
// ConvertAcceptedSocket() ourselves in order to do the conversion from // ConvertAcceptedSocket() ourselves in order to do the conversion from
// |accepted_socket_| to |socket|. // |accepted_socket_| to |socket|.
result = ConvertAcceptedSocket(result, socket); result = ConvertAcceptedSocket(result, socket, peer_address);
} else { } else {
pending_accept_ = true; pending_accept_ = true;
} }
...@@ -95,12 +101,16 @@ void TCPServerSocket::DetachFromThread() { ...@@ -95,12 +101,16 @@ void TCPServerSocket::DetachFromThread() {
int TCPServerSocket::ConvertAcceptedSocket( int TCPServerSocket::ConvertAcceptedSocket(
int result, int result,
std::unique_ptr<StreamSocket>* output_accepted_socket) { std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address) {
// Make sure the TCPSocket object is destroyed in any case. // Make sure the TCPSocket object is destroyed in any case.
std::unique_ptr<TCPSocket> temp_accepted_socket(std::move(accepted_socket_)); std::unique_ptr<TCPSocket> temp_accepted_socket(std::move(accepted_socket_));
if (result != OK) if (result != OK)
return result; return result;
if (output_accepted_address)
*output_accepted_address = accepted_address_;
output_accepted_socket->reset( output_accepted_socket->reset(
new TCPClientSocket(std::move(temp_accepted_socket), accepted_address_)); new TCPClientSocket(std::move(temp_accepted_socket), accepted_address_));
...@@ -109,9 +119,11 @@ int TCPServerSocket::ConvertAcceptedSocket( ...@@ -109,9 +119,11 @@ int TCPServerSocket::ConvertAcceptedSocket(
void TCPServerSocket::OnAcceptCompleted( void TCPServerSocket::OnAcceptCompleted(
std::unique_ptr<StreamSocket>* output_accepted_socket, std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address,
CompletionOnceCallback forward_callback, CompletionOnceCallback forward_callback,
int result) { int result) {
result = ConvertAcceptedSocket(result, output_accepted_socket); result = ConvertAcceptedSocket(result, output_accepted_socket,
output_accepted_address);
pending_accept_ = false; pending_accept_ = false;
std::move(forward_callback).Run(result); std::move(forward_callback).Run(result);
} }
......
...@@ -41,8 +41,11 @@ class NET_EXPORT TCPServerSocket : public ServerSocket { ...@@ -41,8 +41,11 @@ class NET_EXPORT TCPServerSocket : public ServerSocket {
int GetLocalAddress(IPEndPoint* address) const override; int GetLocalAddress(IPEndPoint* address) const override;
int Accept(std::unique_ptr<StreamSocket>* socket, int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback) override; CompletionOnceCallback callback) override;
int Accept(std::unique_ptr<StreamSocket>* socket,
CompletionOnceCallback callback,
IPEndPoint* peer_address) override;
// Detachs from the current thread, to allow the socket to be transferred to // Detaches from the current thread, to allow the socket to be transferred to
// a new thread. Should only be called when the object is no longer used by // a new thread. Should only be called when the object is no longer used by
// the old thread. // the old thread.
void DetachFromThread(); void DetachFromThread();
...@@ -54,9 +57,11 @@ class NET_EXPORT TCPServerSocket : public ServerSocket { ...@@ -54,9 +57,11 @@ class NET_EXPORT TCPServerSocket : public ServerSocket {
// set to NULL in any case. // set to NULL in any case.
int ConvertAcceptedSocket( int ConvertAcceptedSocket(
int result, int result,
std::unique_ptr<StreamSocket>* output_accepted_socket); std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address);
// Completion callback for calling TCPSocket::Accept(). // Completion callback for calling TCPSocket::Accept().
void OnAcceptCompleted(std::unique_ptr<StreamSocket>* output_accepted_socket, void OnAcceptCompleted(std::unique_ptr<StreamSocket>* output_accepted_socket,
IPEndPoint* output_accepted_address,
CompletionOnceCallback forward_callback, CompletionOnceCallback forward_callback,
int result); int result);
......
...@@ -79,12 +79,17 @@ TEST_F(TCPServerSocketTest, Accept) { ...@@ -79,12 +79,17 @@ TEST_F(TCPServerSocketTest, Accept) {
TestCompletionCallback accept_callback; TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket; std::unique_ptr<StreamSocket> accepted_socket;
int result = socket_.Accept(&accepted_socket, accept_callback.callback()); IPEndPoint peer_address;
int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address);
result = accept_callback.GetResult(result); result = accept_callback.GetResult(result);
ASSERT_THAT(result, IsOk()); ASSERT_THAT(result, IsOk());
ASSERT_TRUE(accepted_socket.get() != nullptr); ASSERT_TRUE(accepted_socket.get() != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface. // Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address()); local_address_.address());
...@@ -98,8 +103,10 @@ TEST_F(TCPServerSocketTest, AcceptAsync) { ...@@ -98,8 +103,10 @@ TEST_F(TCPServerSocketTest, AcceptAsync) {
TestCompletionCallback accept_callback; TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket; std::unique_ptr<StreamSocket> accepted_socket;
IPEndPoint peer_address;
ASSERT_THAT(socket_.Accept(&accepted_socket, accept_callback.callback()), ASSERT_THAT(socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address),
IsError(ERR_IO_PENDING)); IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback; TestCompletionCallback connect_callback;
...@@ -112,20 +119,51 @@ TEST_F(TCPServerSocketTest, AcceptAsync) { ...@@ -112,20 +119,51 @@ TEST_F(TCPServerSocketTest, AcceptAsync) {
EXPECT_TRUE(accepted_socket != nullptr); EXPECT_TRUE(accepted_socket != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface. // Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address()); local_address_.address());
} }
// Test Accept() when client disconnects right after trying to connect.
TEST_F(TCPServerSocketTest, AcceptClientDisconnectAfterConnect) {
ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket;
IPEndPoint peer_address;
TestCompletionCallback connect_callback;
TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
NetLogSource());
int connect_result = connecting_socket.Connect(connect_callback.callback());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
int accept_result = socket_.Accept(&accepted_socket,
accept_callback.callback(), &peer_address);
connecting_socket.Disconnect();
EXPECT_THAT(accept_callback.GetResult(accept_result), IsOk());
EXPECT_TRUE(accepted_socket != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
}
// Accept two connections simultaneously. // Accept two connections simultaneously.
TEST_F(TCPServerSocketTest, Accept2Connections) { TEST_F(TCPServerSocketTest, Accept2Connections) {
ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
TestCompletionCallback accept_callback; TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket; std::unique_ptr<StreamSocket> accepted_socket;
IPEndPoint peer_address;
ASSERT_EQ(ERR_IO_PENDING, ASSERT_EQ(ERR_IO_PENDING,
socket_.Accept(&accepted_socket, accept_callback.callback())); socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address));
TestCompletionCallback connect_callback; TestCompletionCallback connect_callback;
TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr, TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
...@@ -142,7 +180,9 @@ TEST_F(TCPServerSocketTest, Accept2Connections) { ...@@ -142,7 +180,9 @@ TEST_F(TCPServerSocketTest, Accept2Connections) {
TestCompletionCallback accept_callback2; TestCompletionCallback accept_callback2;
std::unique_ptr<StreamSocket> accepted_socket2; std::unique_ptr<StreamSocket> accepted_socket2;
int result = socket_.Accept(&accepted_socket2, accept_callback2.callback()); IPEndPoint peer_address2;
int result = socket_.Accept(&accepted_socket2, accept_callback2.callback(),
&peer_address2);
result = accept_callback2.GetResult(result); result = accept_callback2.GetResult(result);
ASSERT_THAT(result, IsOk()); ASSERT_THAT(result, IsOk());
...@@ -153,8 +193,10 @@ TEST_F(TCPServerSocketTest, Accept2Connections) { ...@@ -153,8 +193,10 @@ TEST_F(TCPServerSocketTest, Accept2Connections) {
EXPECT_TRUE(accepted_socket2 != nullptr); EXPECT_TRUE(accepted_socket2 != nullptr);
EXPECT_NE(accepted_socket.get(), accepted_socket2.get()); EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
EXPECT_EQ(peer_address.address(), local_address_.address());
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address()); local_address_.address());
EXPECT_EQ(peer_address2.address(), local_address_.address());
EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(), EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
local_address_.address()); local_address_.address());
} }
...@@ -172,12 +214,17 @@ TEST_F(TCPServerSocketTest, AcceptIPv6) { ...@@ -172,12 +214,17 @@ TEST_F(TCPServerSocketTest, AcceptIPv6) {
TestCompletionCallback accept_callback; TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket; std::unique_ptr<StreamSocket> accepted_socket;
int result = socket_.Accept(&accepted_socket, accept_callback.callback()); IPEndPoint peer_address;
int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address);
result = accept_callback.GetResult(result); result = accept_callback.GetResult(result);
ASSERT_THAT(result, IsOk()); ASSERT_THAT(result, IsOk());
ASSERT_TRUE(accepted_socket.get() != nullptr); ASSERT_TRUE(accepted_socket.get() != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface. // Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address()); local_address_.address());
...@@ -195,11 +242,16 @@ TEST_F(TCPServerSocketTest, AcceptIO) { ...@@ -195,11 +242,16 @@ TEST_F(TCPServerSocketTest, AcceptIO) {
TestCompletionCallback accept_callback; TestCompletionCallback accept_callback;
std::unique_ptr<StreamSocket> accepted_socket; std::unique_ptr<StreamSocket> accepted_socket;
int result = socket_.Accept(&accepted_socket, accept_callback.callback()); IPEndPoint peer_address;
int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
&peer_address);
ASSERT_THAT(accept_callback.GetResult(result), IsOk()); ASSERT_THAT(accept_callback.GetResult(result), IsOk());
ASSERT_TRUE(accepted_socket.get() != nullptr); ASSERT_TRUE(accepted_socket.get() != nullptr);
// |peer_address| should be correctly populated.
EXPECT_EQ(peer_address.address(), local_address_.address());
// Both sockets should be on the loopback network interface. // Both sockets should be on the loopback network interface.
EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
local_address_.address()); local_address_.address());
......
...@@ -92,12 +92,8 @@ void TCPServerSocket::OnAcceptCompleted(int result) { ...@@ -92,12 +92,8 @@ void TCPServerSocket::OnAcceptCompleted(int result) {
auto pending_accept = std::move(pending_accepts_queue_.front()); auto pending_accept = std::move(pending_accepts_queue_.front());
pending_accepts_queue_.erase(pending_accepts_queue_.begin()); pending_accepts_queue_.erase(pending_accepts_queue_.begin());
net::IPEndPoint peer_addr;
if (result == net::OK) { if (result == net::OK) {
DCHECK(accepted_socket_); DCHECK(accepted_socket_);
result = accepted_socket_->GetPeerAddress(&peer_addr);
}
if (result == net::OK) {
mojo::DataPipe send_pipe; mojo::DataPipe send_pipe;
mojo::DataPipe receive_pipe; mojo::DataPipe receive_pipe;
mojo::PendingRemote<mojom::TCPConnectedSocket> socket; mojo::PendingRemote<mojom::TCPConnectedSocket> socket;
...@@ -110,7 +106,7 @@ void TCPServerSocket::OnAcceptCompleted(int result) { ...@@ -110,7 +106,7 @@ void TCPServerSocket::OnAcceptCompleted(int result) {
delegate_->OnAccept(std::move(connected_socket), delegate_->OnAccept(std::move(connected_socket),
socket.InitWithNewPipeAndPassReceiver()); socket.InitWithNewPipeAndPassReceiver());
std::move(pending_accept->callback) std::move(pending_accept->callback)
.Run(result, peer_addr, std::move(socket), .Run(result, accepted_address_, std::move(socket),
std::move(receive_pipe.consumer_handle), std::move(receive_pipe.consumer_handle),
std::move(send_pipe.producer_handle)); std::move(send_pipe.producer_handle));
} else { } else {
...@@ -128,7 +124,8 @@ void TCPServerSocket::ProcessNextAccept() { ...@@ -128,7 +124,8 @@ void TCPServerSocket::ProcessNextAccept() {
int result = int result =
socket_->Accept(&accepted_socket_, socket_->Accept(&accepted_socket_,
base::BindRepeating(&TCPServerSocket::OnAcceptCompleted, base::BindRepeating(&TCPServerSocket::OnAcceptCompleted,
base::Unretained(this))); base::Unretained(this)),
&accepted_address_);
if (result == net::ERR_IO_PENDING) if (result == net::ERR_IO_PENDING)
return; return;
OnAcceptCompleted(result); OnAcceptCompleted(result);
......
...@@ -88,6 +88,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) TCPServerSocket ...@@ -88,6 +88,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) TCPServerSocket
int backlog_; int backlog_;
std::vector<std::unique_ptr<PendingAccept>> pending_accepts_queue_; std::vector<std::unique_ptr<PendingAccept>> pending_accepts_queue_;
std::unique_ptr<net::StreamSocket> accepted_socket_; std::unique_ptr<net::StreamSocket> accepted_socket_;
net::IPEndPoint accepted_address_;
net::NetworkTrafficAnnotationTag traffic_annotation_; net::NetworkTrafficAnnotationTag traffic_annotation_;
base::WeakPtrFactory<TCPServerSocket> weak_factory_{this}; base::WeakPtrFactory<TCPServerSocket> weak_factory_{this};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment