Commit ca100d59 authored by cmasone's avatar cmasone Committed by Commit bot

Raw SocketDescriptor variant of UnixDomainServerSocket::Accept

The Mojo code on CrOS needs to accept inbound connections
on a unix domain socket, and then 'promote' the resulting
sockets to Mojo MessagePipes. This really requires access
to the underying file descriptor, so provide a mechanism
to accept a connection and get back a SocketDescriptor.

BUG=407782
TEST=UnixDomain*SocketTest
R=mmenke@chromium.org

Review URL: https://codereview.chromium.org/509133002

Cr-Commit-Position: refs/heads/master@{#293172}
parent 292bdba7
......@@ -106,6 +106,13 @@ int SocketLibevent::AdoptConnectedSocket(SocketDescriptor socket,
return OK;
}
SocketDescriptor SocketLibevent::ReleaseConnectedSocket() {
StopWatchingAndCleanUp();
SocketDescriptor socket_fd = socket_fd_;
socket_fd_ = kInvalidSocket;
return socket_fd;
}
int SocketLibevent::Bind(const SockaddrStorage& address) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK_NE(kInvalidSocket, socket_fd_);
......@@ -326,38 +333,13 @@ bool SocketLibevent::HasPeerAddress() const {
void SocketLibevent::Close() {
DCHECK(thread_checker_.CalledOnValidThread());
bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = read_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = write_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
StopWatchingAndCleanUp();
if (socket_fd_ != kInvalidSocket) {
if (IGNORE_EINTR(close(socket_fd_)) < 0)
PLOG(ERROR) << "close() returned an error, errno=" << errno;
socket_fd_ = kInvalidSocket;
}
if (!accept_callback_.is_null()) {
accept_socket_ = NULL;
accept_callback_.Reset();
}
if (!read_callback_.is_null()) {
read_buf_ = NULL;
read_buf_len_ = 0;
read_callback_.Reset();
}
if (!write_callback_.is_null()) {
write_buf_ = NULL;
write_buf_len_ = 0;
write_callback_.Reset();
}
waiting_connect_ = false;
peer_address_.reset();
}
void SocketLibevent::OnFileCanReadWithoutBlocking(int fd) {
......@@ -468,4 +450,33 @@ void SocketLibevent::WriteCompleted() {
base::ResetAndReturn(&write_callback_).Run(rv);
}
void SocketLibevent::StopWatchingAndCleanUp() {
bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = read_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
ok = write_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
if (!accept_callback_.is_null()) {
accept_socket_ = NULL;
accept_callback_.Reset();
}
if (!read_callback_.is_null()) {
read_buf_ = NULL;
read_buf_len_ = 0;
read_callback_.Reset();
}
if (!write_callback_.is_null()) {
write_buf_ = NULL;
write_buf_len_ = 0;
write_callback_.Reset();
}
waiting_connect_ = false;
peer_address_.reset();
}
} // namespace net
......@@ -23,7 +23,8 @@ class IPEndPoint;
// Socket class to provide asynchronous read/write operations on top of the
// posix socket api. It supports AF_INET, AF_INET6, and AF_UNIX addresses.
class SocketLibevent : public base::MessageLoopForIO::Watcher {
class NET_EXPORT_PRIVATE SocketLibevent
: public base::MessageLoopForIO::Watcher {
public:
SocketLibevent();
virtual ~SocketLibevent();
......@@ -34,6 +35,8 @@ class SocketLibevent : public base::MessageLoopForIO::Watcher {
// Takes ownership of |socket|.
int AdoptConnectedSocket(SocketDescriptor socket,
const SockaddrStorage& peer_address);
// Releases ownership of |socket_fd_| to caller.
SocketDescriptor ReleaseConnectedSocket();
int Bind(const SockaddrStorage& address);
......@@ -93,6 +96,8 @@ class SocketLibevent : public base::MessageLoopForIO::Watcher {
int DoWrite(IOBuffer* buf, int buf_len);
void WriteCompleted();
void StopWatchingAndCleanUp();
SocketDescriptor socket_fd_;
base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_;
......
......@@ -159,4 +159,13 @@ int UnixDomainClientSocket::SetSendBufferSize(int32 size) {
return ERR_NOT_IMPLEMENTED;
}
SocketDescriptor UnixDomainClientSocket::ReleaseConnectedSocket() {
DCHECK(socket_);
DCHECK(socket_->IsConnected());
SocketDescriptor socket_fd = socket_->ReleaseConnectedSocket();
socket_.reset();
return socket_fd;
}
} // namespace net
......@@ -13,6 +13,7 @@
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
#include "net/base/net_log.h"
#include "net/socket/socket_descriptor.h"
#include "net/socket/stream_socket.h"
namespace net {
......@@ -63,6 +64,11 @@ class NET_EXPORT UnixDomainClientSocket : public StreamSocket {
virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
virtual int SetSendBufferSize(int32 size) OVERRIDE;
// Releases ownership of underlying SocketDescriptor to caller.
// Internal state is reset so that this object can be used again.
// Socket must be connected in order to release it.
SocketDescriptor ReleaseConnectedSocket();
private:
const std::string socket_path_;
const bool use_abstract_namespace_;
......
......@@ -10,9 +10,11 @@
#include "base/files/file_path.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/scoped_ptr.h"
#include "base/posix/eintr_wrapper.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/socket_libevent.h"
#include "net/socket/unix_domain_server_socket_posix.h"
#include "testing/gtest/include/gtest/gtest.h"
......@@ -148,6 +150,54 @@ TEST_F(UnixDomainClientSocketTest, Connect) {
EXPECT_TRUE(accepted_socket->IsConnected());
}
TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
const bool kUseAbstractNamespace = false;
UnixDomainServerSocket server_socket(CreateAuthCallback(true),
kUseAbstractNamespace);
EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
SocketDescriptor accepted_socket_fd = kInvalidSocket;
TestCompletionCallback accept_callback;
EXPECT_EQ(ERR_IO_PENDING,
server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
accept_callback.callback()));
EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
EXPECT_FALSE(client_socket.IsConnected());
EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
EXPECT_TRUE(client_socket.IsConnected());
// Server has not yet been notified of the connection.
EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
EXPECT_EQ(OK, accept_callback.WaitForResult());
EXPECT_NE(kInvalidSocket, accepted_socket_fd);
SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
EXPECT_NE(kInvalidSocket, client_socket_fd);
// Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
// to be sure it hasn't gotten accidentally closed.
SockaddrStorage addr;
ASSERT_TRUE(UnixDomainClientSocket::FillAddress(socket_path_, false, &addr));
scoped_ptr<SocketLibevent> adopter(new SocketLibevent);
adopter->AdoptConnectedSocket(client_socket_fd, addr);
UnixDomainClientSocket rewrapped_socket(adopter.Pass());
EXPECT_TRUE(rewrapped_socket.IsConnected());
// Try to read data.
const int kReadDataSize = 10;
scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
TestCompletionCallback read_callback;
EXPECT_EQ(ERR_IO_PENDING,
rewrapped_socket.Read(
read_buffer.get(), kReadDataSize, read_callback.callback()));
EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
}
TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
const bool kUseAbstractNamespace = true;
......
......@@ -16,6 +16,21 @@
namespace net {
namespace {
// Intended for use as SetterCallbacks in Accept() helper methods.
void SetStreamSocket(scoped_ptr<StreamSocket>* socket,
scoped_ptr<SocketLibevent> accepted_socket) {
socket->reset(new UnixDomainClientSocket(accepted_socket.Pass()));
}
void SetSocketDescriptor(SocketDescriptor* socket,
scoped_ptr<SocketLibevent> accepted_socket) {
*socket = accepted_socket->ReleaseConnectedSocket();
}
} // anonymous namespace
UnixDomainServerSocket::UnixDomainServerSocket(
const AuthCallback& auth_callback,
bool use_abstract_namespace)
......@@ -95,6 +110,23 @@ int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const {
int UnixDomainServerSocket::Accept(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback) {
DCHECK(socket);
SetterCallback setter_callback = base::Bind(&SetStreamSocket, socket);
return DoAccept(setter_callback, callback);
}
int UnixDomainServerSocket::AcceptSocketDescriptor(
SocketDescriptor* socket,
const CompletionCallback& callback) {
DCHECK(socket);
SetterCallback setter_callback = base::Bind(&SetSocketDescriptor, socket);
return DoAccept(setter_callback, callback);
}
int UnixDomainServerSocket::DoAccept(const SetterCallback& setter_callback,
const CompletionCallback& callback) {
DCHECK(!setter_callback.is_null());
DCHECK(!callback.is_null());
DCHECK(listen_socket_);
DCHECK(!accept_socket_);
......@@ -103,38 +135,41 @@ int UnixDomainServerSocket::Accept(scoped_ptr<StreamSocket>* socket,
int rv = listen_socket_->Accept(
&accept_socket_,
base::Bind(&UnixDomainServerSocket::AcceptCompleted,
base::Unretained(this), socket, callback));
base::Unretained(this),
setter_callback,
callback));
if (rv != OK)
return rv;
if (AuthenticateAndGetStreamSocket(socket))
if (AuthenticateAndGetStreamSocket(setter_callback))
return OK;
// Accept another socket because authentication error should be transparent
// to the caller.
}
}
void UnixDomainServerSocket::AcceptCompleted(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback,
int rv) {
void UnixDomainServerSocket::AcceptCompleted(
const SetterCallback& setter_callback,
const CompletionCallback& callback,
int rv) {
if (rv != OK) {
callback.Run(rv);
return;
}
if (AuthenticateAndGetStreamSocket(socket)) {
if (AuthenticateAndGetStreamSocket(setter_callback)) {
callback.Run(OK);
return;
}
// Accept another socket because authentication error should be transparent
// to the caller.
rv = Accept(socket, callback);
rv = DoAccept(setter_callback, callback);
if (rv != ERR_IO_PENDING)
callback.Run(rv);
}
bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket(
scoped_ptr<StreamSocket>* socket) {
const SetterCallback& setter_callback) {
DCHECK(accept_socket_);
Credentials credentials;
......@@ -144,7 +179,7 @@ bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket(
return false;
}
socket->reset(new UnixDomainClientSocket(accept_socket_.Pass()));
setter_callback.Run(accept_socket_.Pass());
return true;
}
......
......@@ -59,11 +59,23 @@ class NET_EXPORT UnixDomainServerSocket : public ServerSocket {
virtual int Accept(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback) OVERRIDE;
// Accepts an incoming connection on |listen_socket_|, but passes back
// a raw SocketDescriptor instead of a StreamSocket.
int AcceptSocketDescriptor(SocketDescriptor* socket_descriptor,
const CompletionCallback& callback);
private:
void AcceptCompleted(scoped_ptr<StreamSocket>* socket,
// A callback to wrap the setting of the out-parameter to Accept().
// This allows the internal machinery of that call to be implemented in
// a manner that's agnostic to the caller's desired output.
typedef base::Callback<void(scoped_ptr<SocketLibevent>)> SetterCallback;
int DoAccept(const SetterCallback& setter_callback,
const CompletionCallback& callback);
void AcceptCompleted(const SetterCallback& setter_callback,
const CompletionCallback& callback,
int rv);
bool AuthenticateAndGetStreamSocket(scoped_ptr<StreamSocket>* socket);
bool AuthenticateAndGetStreamSocket(const SetterCallback& setter_callback);
scoped_ptr<SocketLibevent> listen_socket_;
const AuthCallback auth_callback_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment