Refactor unix domain socket.

This is a pre-requisite for http server refactoring,
https://codereview.chromium.org/296053012/.

1) Define UnixDomainClientSocket and UnixDomainServerSocket utilizing
   SocketLibevent.
2) Rename UnixDomainSocket to UnixDomainListenSocket to reduce confusion.
3) unittests

BUG=371906

Review URL: https://codereview.chromium.org/376323002

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@285148 0039d316-1c4b-4281-b951-d872f2087c98
parent 960cf134
......@@ -19,7 +19,7 @@
#include "content/public/browser/web_contents.h"
#include "content/public/common/user_agent.h"
#include "jni/AwDevToolsServer_jni.h"
#include "net/socket/unix_domain_socket_posix.h"
#include "net/socket/unix_domain_listen_socket_posix.h"
using content::DevToolsAgentHost;
using content::RenderViewHost;
......@@ -176,7 +176,7 @@ void AwDevToolsServer::Start() {
return;
protocol_handler_ = content::DevToolsHttpHandler::Start(
new net::UnixDomainSocketWithAbstractNamespaceFactory(
new net::deprecated::UnixDomainListenSocketWithAbstractNamespaceFactory(
base::StringPrintf(kSocketNameFormat, getpid()),
"",
base::Bind(&content::CanUserConnectToDevTools)),
......
......@@ -40,7 +40,7 @@
#include "content/public/common/user_agent.h"
#include "grit/browser_resources.h"
#include "jni/DevToolsServer_jni.h"
#include "net/socket/unix_domain_socket_posix.h"
#include "net/socket/unix_domain_listen_socket_posix.h"
#include "net/url_request/url_request_context_getter.h"
#include "ui/base/resource/resource_bundle.h"
......@@ -366,12 +366,13 @@ class DevToolsServerDelegate : public content::DevToolsHttpHandlerDelegate {
std::string* name) OVERRIDE {
*name = base::StringPrintf(
kTetheringSocketName, getpid(), ++last_tethering_socket_);
return net::UnixDomainSocket::CreateAndListenWithAbstractNamespace(
*name,
"",
delegate,
base::Bind(&content::CanUserConnectToDevTools))
.PassAs<net::StreamListenSocket>();
return net::deprecated::UnixDomainListenSocket::
CreateAndListenWithAbstractNamespace(
*name,
"",
delegate,
base::Bind(&content::CanUserConnectToDevTools))
.PassAs<net::StreamListenSocket>();
}
private:
......@@ -423,7 +424,7 @@ void DevToolsServer::Start() {
return;
protocol_handler_ = content::DevToolsHttpHandler::Start(
new net::UnixDomainSocketWithAbstractNamespaceFactory(
new net::deprecated::UnixDomainListenSocketWithAbstractNamespaceFactory(
socket_name_,
base::StringPrintf("%s_%d", socket_name_.c_str(), getpid()),
base::Bind(&content::CanUserConnectToDevTools)),
......
......@@ -29,7 +29,7 @@
#if defined(OS_ANDROID)
#include "content/public/browser/android/devtools_auth.h"
#include "net/socket/unix_domain_socket_posix.h"
#include "net/socket/unix_domain_listen_socket_posix.h"
#endif
using content::DevToolsAgentHost;
......@@ -52,8 +52,9 @@ net::StreamListenSocketFactory* CreateSocketFactory() {
socket_name = command_line.GetSwitchValueASCII(
switches::kRemoteDebuggingSocketName);
}
return new net::UnixDomainSocketWithAbstractNamespaceFactory(
socket_name, "", base::Bind(&content::CanUserConnectToDevTools));
return new net::deprecated::
UnixDomainListenSocketWithAbstractNamespaceFactory(
socket_name, "", base::Bind(&content::CanUserConnectToDevTools));
#else
// See if the user specified a port on the command line (useful for
// automation). If not, use an ephemeral port by specifying 0.
......
......@@ -1297,7 +1297,9 @@ test("net_unittests") {
# The following tests are disabled because they don't apply to
# iOS.
# OS is not "linux" or "freebsd" or "openbsd".
"socket/unix_domain_socket_posix_unittest.cc",
"socket/unix_domain_client_socket_posix_unittest.cc",
"socket/unix_domain_listen_socket_posix_unittest.cc",
"socket/unix_domain_server_socket_posix_unittest.cc",
# See bug http://crbug.com/344533.
"disk_cache/blockfile/index_table_v3_unittest.cc",
......
......@@ -784,7 +784,9 @@
# The following tests are disabled because they don't apply to
# iOS.
# OS is not "linux" or "freebsd" or "openbsd".
'socket/unix_domain_socket_posix_unittest.cc',
'socket/unix_domain_client_socket_posix_unittest.cc',
'socket/unix_domain_listen_socket_posix_unittest.cc',
'socket/unix_domain_server_socket_posix_unittest.cc',
# See bug http://crbug.com/344533.
'disk_cache/blockfile/index_table_v3_unittest.cc',
......
......@@ -990,8 +990,12 @@
'socket/tcp_socket_win.h',
'socket/transport_client_socket_pool.cc',
'socket/transport_client_socket_pool.h',
'socket/unix_domain_socket_posix.cc',
'socket/unix_domain_socket_posix.h',
'socket/unix_domain_client_socket_posix.cc',
'socket/unix_domain_client_socket_posix.h',
'socket/unix_domain_listen_socket_posix.cc',
'socket/unix_domain_listen_socket_posix.h',
'socket/unix_domain_server_socket_posix.cc',
'socket/unix_domain_server_socket_posix.h',
'socket/websocket_endpoint_lock_manager.cc',
'socket/websocket_endpoint_lock_manager.h',
'socket/websocket_transport_client_socket_pool.cc',
......@@ -1571,7 +1575,9 @@
'socket/transport_client_socket_pool_test_util.h',
'socket/transport_client_socket_pool_unittest.cc',
'socket/transport_client_socket_unittest.cc',
'socket/unix_domain_socket_posix_unittest.cc',
'socket/unix_domain_client_socket_posix_unittest.cc',
'socket/unix_domain_listen_socket_posix_unittest.cc',
'socket/unix_domain_server_socket_posix_unittest.cc',
'socket/websocket_endpoint_lock_manager_unittest.cc',
'socket/websocket_transport_client_socket_pool_unittest.cc',
'socket_stream/socket_stream_metrics_unittest.cc',
......
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/socket/unix_domain_client_socket_posix.h"
#include <sys/socket.h>
#include <sys/un.h>
#include "base/logging.h"
#include "base/posix/eintr_wrapper.h"
#include "net/base/net_errors.h"
#include "net/base/net_util.h"
#include "net/socket/socket_libevent.h"
namespace net {
UnixDomainClientSocket::UnixDomainClientSocket(const std::string& socket_path,
bool use_abstract_namespace)
: socket_path_(socket_path),
use_abstract_namespace_(use_abstract_namespace) {
}
UnixDomainClientSocket::UnixDomainClientSocket(
scoped_ptr<SocketLibevent> socket)
: use_abstract_namespace_(false),
socket_(socket.Pass()) {
}
UnixDomainClientSocket::~UnixDomainClientSocket() {
Disconnect();
}
// static
bool UnixDomainClientSocket::FillAddress(const std::string& socket_path,
bool use_abstract_namespace,
SockaddrStorage* address) {
struct sockaddr_un* socket_addr =
reinterpret_cast<struct sockaddr_un*>(address->addr);
size_t path_max = address->addr_len - offsetof(struct sockaddr_un, sun_path);
// Non abstract namespace pathname should be null-terminated. Abstract
// namespace pathname must start with '\0'. So, the size is always greater
// than socket_path size by 1.
size_t path_size = socket_path.size() + 1;
if (path_size > path_max)
return false;
memset(socket_addr, 0, address->addr_len);
socket_addr->sun_family = AF_UNIX;
address->addr_len = path_size + offsetof(struct sockaddr_un, sun_path);
if (!use_abstract_namespace) {
memcpy(socket_addr->sun_path, socket_path.c_str(), socket_path.size());
return true;
}
#if defined(OS_ANDROID) || defined(OS_LINUX)
// Convert the path given into abstract socket name. It must start with
// the '\0' character, so we are adding it. |addr_len| must specify the
// length of the structure exactly, as potentially the socket name may
// have '\0' characters embedded (although we don't support this).
// Note that addr.sun_path is already zero initialized.
memcpy(socket_addr->sun_path + 1, socket_path.c_str(), socket_path.size());
return true;
#else
return false;
#endif
}
int UnixDomainClientSocket::Connect(const CompletionCallback& callback) {
DCHECK(!socket_);
if (socket_path_.empty())
return ERR_ADDRESS_INVALID;
SockaddrStorage address;
if (!FillAddress(socket_path_, use_abstract_namespace_, &address))
return ERR_ADDRESS_INVALID;
socket_.reset(new SocketLibevent);
int rv = socket_->Open(AF_UNIX);
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv != OK)
return rv;
return socket_->Connect(address, callback);
}
void UnixDomainClientSocket::Disconnect() {
socket_.reset();
}
bool UnixDomainClientSocket::IsConnected() const {
return socket_ && socket_->IsConnected();
}
bool UnixDomainClientSocket::IsConnectedAndIdle() const {
return socket_ && socket_->IsConnectedAndIdle();
}
int UnixDomainClientSocket::GetPeerAddress(IPEndPoint* address) const {
NOTIMPLEMENTED();
return ERR_NOT_IMPLEMENTED;
}
int UnixDomainClientSocket::GetLocalAddress(IPEndPoint* address) const {
NOTIMPLEMENTED();
return ERR_NOT_IMPLEMENTED;
}
const BoundNetLog& UnixDomainClientSocket::NetLog() const {
return net_log_;
}
void UnixDomainClientSocket::SetSubresourceSpeculation() {
}
void UnixDomainClientSocket::SetOmniboxSpeculation() {
}
bool UnixDomainClientSocket::WasEverUsed() const {
return true; // We don't care.
}
bool UnixDomainClientSocket::UsingTCPFastOpen() const {
return false;
}
bool UnixDomainClientSocket::WasNpnNegotiated() const {
return false;
}
NextProto UnixDomainClientSocket::GetNegotiatedProtocol() const {
return kProtoUnknown;
}
bool UnixDomainClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
return false;
}
int UnixDomainClientSocket::Read(IOBuffer* buf, int buf_len,
const CompletionCallback& callback) {
DCHECK(socket_);
return socket_->Read(buf, buf_len, callback);
}
int UnixDomainClientSocket::Write(IOBuffer* buf, int buf_len,
const CompletionCallback& callback) {
DCHECK(socket_);
return socket_->Write(buf, buf_len, callback);
}
int UnixDomainClientSocket::SetReceiveBufferSize(int32 size) {
NOTIMPLEMENTED();
return ERR_NOT_IMPLEMENTED;
}
int UnixDomainClientSocket::SetSendBufferSize(int32 size) {
NOTIMPLEMENTED();
return ERR_NOT_IMPLEMENTED;
}
} // namespace net
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_
#define NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_
#include <string>
#include "base/basictypes.h"
#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
#include "net/base/net_log.h"
#include "net/socket/stream_socket.h"
namespace net {
class SocketLibevent;
struct SockaddrStorage;
// A client socket that uses unix domain socket as the transport layer.
class NET_EXPORT UnixDomainClientSocket : public StreamSocket {
public:
// Builds a client socket with |socket_path|. The caller should call Connect()
// to connect to a server socket.
UnixDomainClientSocket(const std::string& socket_path,
bool use_abstract_namespace);
// Builds a client socket with socket libevent which is already connected.
// UnixDomainServerSocket uses this after it accepts a connection.
explicit UnixDomainClientSocket(scoped_ptr<SocketLibevent> socket);
virtual ~UnixDomainClientSocket();
// Fills |address| with |socket_path| and its length. For Android or Linux
// platform, this supports abstract namespaces.
static bool FillAddress(const std::string& socket_path,
bool use_abstract_namespace,
SockaddrStorage* address);
// StreamSocket implementation.
virtual int Connect(const CompletionCallback& callback) OVERRIDE;
virtual void Disconnect() OVERRIDE;
virtual bool IsConnected() const OVERRIDE;
virtual bool IsConnectedAndIdle() const OVERRIDE;
virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
virtual const BoundNetLog& NetLog() const OVERRIDE;
virtual void SetSubresourceSpeculation() OVERRIDE;
virtual void SetOmniboxSpeculation() OVERRIDE;
virtual bool WasEverUsed() const OVERRIDE;
virtual bool UsingTCPFastOpen() const OVERRIDE;
virtual bool WasNpnNegotiated() const OVERRIDE;
virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
// Socket implementation.
virtual int Read(IOBuffer* buf, int buf_len,
const CompletionCallback& callback) OVERRIDE;
virtual int Write(IOBuffer* buf, int buf_len,
const CompletionCallback& callback) OVERRIDE;
virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
virtual int SetSendBufferSize(int32 size) OVERRIDE;
private:
const std::string socket_path_;
const bool use_abstract_namespace_;
scoped_ptr<SocketLibevent> socket_;
// This net log is just to comply StreamSocket::NetLog(). It throws away
// everything.
BoundNetLog net_log_;
DISALLOW_COPY_AND_ASSIGN(UnixDomainClientSocket);
};
} // namespace net
#endif // NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_
This diff is collapsed.
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/socket/unix_domain_socket_posix.h"
#include <cstring>
#include <string>
#include "net/socket/unix_domain_listen_socket_posix.h"
#include <errno.h>
#include <sys/socket.h>
......@@ -14,6 +11,9 @@
#include <sys/un.h>
#include <unistd.h>
#include <cstring>
#include <string>
#include "base/bind.h"
#include "base/callback.h"
#include "base/posix/eintr_wrapper.h"
......@@ -22,57 +22,65 @@
#include "net/base/net_errors.h"
#include "net/base/net_util.h"
#include "net/socket/socket_descriptor.h"
#include "net/socket/unix_domain_client_socket_posix.h"
namespace net {
namespace deprecated {
namespace {
bool NoAuthenticationCallback(uid_t, gid_t) {
return true;
}
int CreateAndBind(const std::string& socket_path,
bool use_abstract_namespace,
SocketDescriptor* socket_fd) {
DCHECK(socket_fd);
bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) {
#if defined(OS_LINUX) || defined(OS_ANDROID)
struct ucred user_cred;
socklen_t len = sizeof(user_cred);
if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1)
return false;
*user_id = user_cred.uid;
*group_id = user_cred.gid;
#else
if (getpeereid(socket, user_id, group_id) == -1)
return false;
#endif
return true;
}
SockaddrStorage address;
if (!UnixDomainClientSocket::FillAddress(socket_path,
use_abstract_namespace,
&address)) {
return ERR_ADDRESS_INVALID;
}
} // namespace
SocketDescriptor fd = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
if (fd == kInvalidSocket)
return errno ? MapSystemError(errno) : ERR_UNEXPECTED;
// static
UnixDomainSocket::AuthCallback UnixDomainSocket::NoAuthentication() {
return base::Bind(NoAuthenticationCallback);
if (bind(fd, address.addr, address.addr_len) < 0) {
int rv = MapSystemError(errno);
close(fd);
PLOG(ERROR) << "Could not bind unix domain socket to " << socket_path
<< (use_abstract_namespace ? " (with abstract namespace)" : "");
return rv;
}
*socket_fd = fd;
return OK;
}
} // namespace
// static
scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenInternal(
scoped_ptr<UnixDomainListenSocket>
UnixDomainListenSocket::CreateAndListenInternal(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback,
bool use_abstract_namespace) {
SocketDescriptor s = CreateAndBind(path, use_abstract_namespace);
if (s == kInvalidSocket && !fallback_path.empty())
s = CreateAndBind(fallback_path, use_abstract_namespace);
if (s == kInvalidSocket)
return scoped_ptr<UnixDomainSocket>();
scoped_ptr<UnixDomainSocket> sock(
new UnixDomainSocket(s, del, auth_callback));
SocketDescriptor socket_fd = kInvalidSocket;
int rv = CreateAndBind(path, use_abstract_namespace, &socket_fd);
if (rv != OK && !fallback_path.empty())
rv = CreateAndBind(fallback_path, use_abstract_namespace, &socket_fd);
if (rv != OK)
return scoped_ptr<UnixDomainListenSocket>();
scoped_ptr<UnixDomainListenSocket> sock(
new UnixDomainListenSocket(socket_fd, del, auth_callback));
sock->Listen();
return sock.Pass();
}
// static
scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
scoped_ptr<UnixDomainListenSocket> UnixDomainListenSocket::CreateAndListen(
const std::string& path,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback) {
......@@ -81,8 +89,8 @@ scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
// static
scoped_ptr<UnixDomainSocket>
UnixDomainSocket::CreateAndListenWithAbstractNamespace(
scoped_ptr<UnixDomainListenSocket>
UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
......@@ -92,105 +100,70 @@ UnixDomainSocket::CreateAndListenWithAbstractNamespace(
}
#endif
UnixDomainSocket::UnixDomainSocket(
UnixDomainListenSocket::UnixDomainListenSocket(
SocketDescriptor s,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback)
: StreamListenSocket(s, del),
auth_callback_(auth_callback) {}
UnixDomainSocket::~UnixDomainSocket() {}
// static
SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path,
bool use_abstract_namespace) {
sockaddr_un addr;
static const size_t kPathMax = sizeof(addr.sun_path);
if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax)
return kInvalidSocket;
const SocketDescriptor s = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
if (s == kInvalidSocket)
return kInvalidSocket;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
socklen_t addr_len;
if (use_abstract_namespace) {
// Convert the path given into abstract socket name. It must start with
// the '\0' character, so we are adding it. |addr_len| must specify the
// length of the structure exactly, as potentially the socket name may
// have '\0' characters embedded (although we don't support this).
// Note that addr.sun_path is already zero initialized.
memcpy(addr.sun_path + 1, path.c_str(), path.size());
addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
} else {
memcpy(addr.sun_path, path.c_str(), path.size());
addr_len = sizeof(sockaddr_un);
}
if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) {
LOG(ERROR) << "Could not bind unix domain socket to " << path;
if (use_abstract_namespace)
LOG(ERROR) << " (with abstract namespace enabled)";
if (IGNORE_EINTR(close(s)) < 0)
LOG(ERROR) << "close() error";
return kInvalidSocket;
}
return s;
}
UnixDomainListenSocket::~UnixDomainListenSocket() {}
void UnixDomainSocket::Accept() {
void UnixDomainListenSocket::Accept() {
SocketDescriptor conn = StreamListenSocket::AcceptSocket();
if (conn == kInvalidSocket)
return;
uid_t user_id;
gid_t group_id;
if (!GetPeerIds(conn, &user_id, &group_id) ||
if (!UnixDomainServerSocket::GetPeerIds(conn, &user_id, &group_id) ||
!auth_callback_.Run(user_id, group_id)) {
if (IGNORE_EINTR(close(conn)) < 0)
LOG(ERROR) << "close() error";
return;
}
scoped_ptr<UnixDomainSocket> sock(
new UnixDomainSocket(conn, socket_delegate_, auth_callback_));
scoped_ptr<UnixDomainListenSocket> sock(
new UnixDomainListenSocket(conn, socket_delegate_, auth_callback_));
// It's up to the delegate to AddRef if it wants to keep it around.
sock->WatchSocket(WAITING_READ);
socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>());
}
UnixDomainSocketFactory::UnixDomainSocketFactory(
UnixDomainListenSocketFactory::UnixDomainListenSocketFactory(
const std::string& path,
const UnixDomainSocket::AuthCallback& auth_callback)
const UnixDomainListenSocket::AuthCallback& auth_callback)
: path_(path),
auth_callback_(auth_callback) {}
UnixDomainSocketFactory::~UnixDomainSocketFactory() {}
UnixDomainListenSocketFactory::~UnixDomainListenSocketFactory() {}
scoped_ptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
scoped_ptr<StreamListenSocket> UnixDomainListenSocketFactory::CreateAndListen(
StreamListenSocket::Delegate* delegate) const {
return UnixDomainSocket::CreateAndListen(
return UnixDomainListenSocket::CreateAndListen(
path_, delegate, auth_callback_).PassAs<StreamListenSocket>();
}
#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
UnixDomainSocketWithAbstractNamespaceFactory::
UnixDomainSocketWithAbstractNamespaceFactory(
UnixDomainListenSocketWithAbstractNamespaceFactory::
UnixDomainListenSocketWithAbstractNamespaceFactory(
const std::string& path,
const std::string& fallback_path,
const UnixDomainSocket::AuthCallback& auth_callback)
: UnixDomainSocketFactory(path, auth_callback),
const UnixDomainListenSocket::AuthCallback& auth_callback)
: UnixDomainListenSocketFactory(path, auth_callback),
fallback_path_(fallback_path) {}
UnixDomainSocketWithAbstractNamespaceFactory::
~UnixDomainSocketWithAbstractNamespaceFactory() {}
UnixDomainListenSocketWithAbstractNamespaceFactory::
~UnixDomainListenSocketWithAbstractNamespaceFactory() {}
scoped_ptr<StreamListenSocket>
UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen(
UnixDomainListenSocketWithAbstractNamespaceFactory::CreateAndListen(
StreamListenSocket::Delegate* delegate) const {
return UnixDomainSocket::CreateAndListenWithAbstractNamespace(
return UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
path_, fallback_path_, delegate, auth_callback_)
.PassAs<StreamListenSocket>();
}
#endif
} // namespace deprecated
} // namespace net
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
#define NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
#ifndef NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_
#define NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_
#include <string>
#include "base/basictypes.h"
#include "base/callback_forward.h"
#include "base/compiler_specific.h"
#include "base/macros.h"
#include "build/build_config.h"
#include "net/base/net_export.h"
#include "net/socket/stream_listen_socket.h"
#include "net/socket/unix_domain_server_socket_posix.h"
#if defined(OS_ANDROID) || defined(OS_LINUX)
// Feature only supported on Linux currently. This lets the Unix Domain Socket
......@@ -21,25 +23,18 @@
#endif
namespace net {
namespace deprecated {
// Unix Domain Socket Implementation. Supports abstract namespaces on Linux.
class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
class NET_EXPORT UnixDomainListenSocket : public StreamListenSocket {
public:
virtual ~UnixDomainSocket();
typedef UnixDomainServerSocket::AuthCallback AuthCallback;
// Callback that returns whether the already connected client, identified by
// its process |user_id| and |group_id|, is allowed to keep the connection
// open. Note that the socket is closed immediately in case the callback
// returns false.
typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback;
virtual ~UnixDomainListenSocket();
// Returns an authentication callback that always grants access for
// convenience in case you don't want to use authentication.
static AuthCallback NoAuthentication();
// Note that the returned UnixDomainSocket instance does not take ownership of
// |del|.
static scoped_ptr<UnixDomainSocket> CreateAndListen(
// Note that the returned UnixDomainListenSocket instance does not take
// ownership of |del|.
static scoped_ptr<UnixDomainListenSocket> CreateAndListen(
const std::string& path,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback);
......@@ -48,7 +43,8 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
// Same as above except that the created socket uses the abstract namespace
// which is a Linux-only feature. If |fallback_path| is not empty,
// make the second attempt with the provided fallback name.
static scoped_ptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace(
static scoped_ptr<UnixDomainListenSocket>
CreateAndListenWithAbstractNamespace(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
......@@ -56,35 +52,34 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
#endif
private:
UnixDomainSocket(SocketDescriptor s,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback);
UnixDomainListenSocket(SocketDescriptor s,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback);
static scoped_ptr<UnixDomainSocket> CreateAndListenInternal(
static scoped_ptr<UnixDomainListenSocket> CreateAndListenInternal(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback,
bool use_abstract_namespace);
static SocketDescriptor CreateAndBind(const std::string& path,
bool use_abstract_namespace);
// StreamListenSocket:
virtual void Accept() OVERRIDE;
AuthCallback auth_callback_;
DISALLOW_COPY_AND_ASSIGN(UnixDomainSocket);
DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocket);
};
// Factory that can be used to instantiate UnixDomainSocket.
class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory {
// Factory that can be used to instantiate UnixDomainListenSocket.
class NET_EXPORT UnixDomainListenSocketFactory
: public StreamListenSocketFactory {
public:
// Note that this class does not take ownership of the provided delegate.
UnixDomainSocketFactory(const std::string& path,
const UnixDomainSocket::AuthCallback& auth_callback);
virtual ~UnixDomainSocketFactory();
UnixDomainListenSocketFactory(
const std::string& path,
const UnixDomainListenSocket::AuthCallback& auth_callback);
virtual ~UnixDomainListenSocketFactory();
// StreamListenSocketFactory:
virtual scoped_ptr<StreamListenSocket> CreateAndListen(
......@@ -92,35 +87,36 @@ class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory {
protected:
const std::string path_;
const UnixDomainSocket::AuthCallback auth_callback_;
const UnixDomainListenSocket::AuthCallback auth_callback_;
private:
DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketFactory);
DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketFactory);
};
#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
// Use this factory to instantiate UnixDomainSocket using the abstract
// Use this factory to instantiate UnixDomainListenSocket using the abstract
// namespace feature (only supported on Linux).
class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory
: public UnixDomainSocketFactory {
class NET_EXPORT UnixDomainListenSocketWithAbstractNamespaceFactory
: public UnixDomainListenSocketFactory {
public:
UnixDomainSocketWithAbstractNamespaceFactory(
UnixDomainListenSocketWithAbstractNamespaceFactory(
const std::string& path,
const std::string& fallback_path,
const UnixDomainSocket::AuthCallback& auth_callback);
virtual ~UnixDomainSocketWithAbstractNamespaceFactory();
const UnixDomainListenSocket::AuthCallback& auth_callback);
virtual ~UnixDomainListenSocketWithAbstractNamespaceFactory();
// UnixDomainSocketFactory:
// UnixDomainListenSocketFactory:
virtual scoped_ptr<StreamListenSocket> CreateAndListen(
StreamListenSocket::Delegate* delegate) const OVERRIDE;
private:
std::string fallback_path_;
DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketWithAbstractNamespaceFactory);
DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketWithAbstractNamespaceFactory);
};
#endif
} // namespace deprecated
} // namespace net
#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
#endif // NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/socket/unix_domain_listen_socket_posix.h"
#include <errno.h>
#include <fcntl.h>
#include <poll.h>
......@@ -21,6 +23,7 @@
#include "base/compiler_specific.h"
#include "base/file_util.h"
#include "base/files/file_path.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/message_loop/message_loop.h"
......@@ -30,16 +33,16 @@
#include "base/threading/platform_thread.h"
#include "base/threading/thread.h"
#include "net/socket/socket_descriptor.h"
#include "net/socket/unix_domain_socket_posix.h"
#include "testing/gtest/include/gtest/gtest.h"
using std::queue;
using std::string;
namespace net {
namespace deprecated {
namespace {
const char kSocketFilename[] = "unix_domain_socket_for_testing";
const char kSocketFilename[] = "socket_for_testing";
const char kInvalidSocketPath[] = "/invalid/path";
const char kMsg[] = "hello";
......@@ -52,16 +55,6 @@ enum EventType {
EVENT_READ,
};
string MakeSocketPath(const string& socket_file_name) {
base::FilePath temp_dir;
base::GetTempDir(&temp_dir);
return temp_dir.Append(socket_file_name).value();
}
string MakeSocketPath() {
return MakeSocketPath(kSocketFilename);
}
class EventManager : public base::RefCounted<EventManager> {
public:
EventManager() : condition_(&mutex_) {}
......@@ -151,41 +144,46 @@ bool UserCanConnectCallback(
return allow_user;
}
class UnixDomainSocketTestHelper : public testing::Test {
class UnixDomainListenSocketTestHelper : public testing::Test {
public:
void CreateAndListen() {
socket_ = UnixDomainSocket::CreateAndListen(
socket_ = UnixDomainListenSocket::CreateAndListen(
file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
socket_delegate_->OnListenCompleted();
}
protected:
UnixDomainSocketTestHelper(const string& path, bool allow_user)
: file_path_(path),
allow_user_(allow_user) {}
UnixDomainListenSocketTestHelper(const string& path_str, bool allow_user)
: allow_user_(allow_user) {
file_path_ = base::FilePath(path_str);
if (!file_path_.IsAbsolute()) {
EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
file_path_ = GetTempSocketPath(file_path_.value());
}
// Beware that if path_str is an absolute path, this class doesn't delete
// the file. It must be an invalid path and cannot be created by unittests.
}
base::FilePath GetTempSocketPath(const std::string socket_name) {
DCHECK(temp_dir_.IsValid());
return temp_dir_.path().Append(socket_name);
}
virtual void SetUp() OVERRIDE {
event_manager_ = new EventManager();
socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
DeleteSocketFile();
}
virtual void TearDown() OVERRIDE {
DeleteSocketFile();
socket_.reset();
socket_delegate_.reset();
event_manager_ = NULL;
}
UnixDomainSocket::AuthCallback MakeAuthCallback() {
UnixDomainListenSocket::AuthCallback MakeAuthCallback() {
return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
}
void DeleteSocketFile() {
ASSERT_FALSE(file_path_.empty());
base::DeleteFile(file_path_, false /* not recursive */);
}
SocketDescriptor CreateClientSocket() {
const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
if (sock < 0) {
......@@ -199,7 +197,8 @@ class UnixDomainSocketTestHelper : public testing::Test {
strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
addr_len = sizeof(sockaddr_un);
if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
LOG(ERROR) << "connect() error";
LOG(ERROR) << "connect() error: " << strerror(errno)
<< ": path=" << file_path_.value();
return kInvalidSocket;
}
return sock;
......@@ -212,43 +211,48 @@ class UnixDomainSocketTestHelper : public testing::Test {
thread->StartWithOptions(options);
thread->message_loop()->PostTask(
FROM_HERE,
base::Bind(&UnixDomainSocketTestHelper::CreateAndListen,
base::Bind(&UnixDomainListenSocketTestHelper::CreateAndListen,
base::Unretained(this)));
return thread.Pass();
}
const base::FilePath file_path_;
base::ScopedTempDir temp_dir_;
base::FilePath file_path_;
const bool allow_user_;
scoped_refptr<EventManager> event_manager_;
scoped_ptr<TestListenSocketDelegate> socket_delegate_;
scoped_ptr<UnixDomainSocket> socket_;
scoped_ptr<UnixDomainListenSocket> socket_;
};
class UnixDomainSocketTest : public UnixDomainSocketTestHelper {
class UnixDomainListenSocketTest : public UnixDomainListenSocketTestHelper {
protected:
UnixDomainSocketTest()
: UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
UnixDomainListenSocketTest()
: UnixDomainListenSocketTestHelper(kSocketFilename,
true /* allow user */) {}
};
class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper {
class UnixDomainListenSocketTestWithInvalidPath
: public UnixDomainListenSocketTestHelper {
protected:
UnixDomainSocketTestWithInvalidPath()
: UnixDomainSocketTestHelper(kInvalidSocketPath, true) {}
UnixDomainListenSocketTestWithInvalidPath()
: UnixDomainListenSocketTestHelper(kInvalidSocketPath, true) {}
};
class UnixDomainSocketTestWithForbiddenUser
: public UnixDomainSocketTestHelper {
class UnixDomainListenSocketTestWithForbiddenUser
: public UnixDomainListenSocketTestHelper {
protected:
UnixDomainSocketTestWithForbiddenUser()
: UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
UnixDomainListenSocketTestWithForbiddenUser()
: UnixDomainListenSocketTestHelper(kSocketFilename,
false /* forbid user */) {}
};
TEST_F(UnixDomainSocketTest, CreateAndListen) {
TEST_F(UnixDomainListenSocketTest, CreateAndListen) {
CreateAndListen();
EXPECT_FALSE(socket_.get() == NULL);
}
TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
TEST_F(UnixDomainListenSocketTestWithInvalidPath,
CreateAndListenWithInvalidPath) {
CreateAndListen();
EXPECT_TRUE(socket_.get() == NULL);
}
......@@ -256,35 +260,35 @@ TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
// Test with an invalid path to make sure that the socket is not backed by a
// file.
TEST_F(UnixDomainSocketTestWithInvalidPath,
TEST_F(UnixDomainListenSocketTestWithInvalidPath,
CreateAndListenWithAbstractNamespace) {
socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
EXPECT_FALSE(socket_.get() == NULL);
}
TEST_F(UnixDomainSocketTest, TestFallbackName) {
scoped_ptr<UnixDomainSocket> existing_socket =
UnixDomainSocket::CreateAndListenWithAbstractNamespace(
TEST_F(UnixDomainListenSocketTest, TestFallbackName) {
scoped_ptr<UnixDomainListenSocket> existing_socket =
UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
EXPECT_FALSE(existing_socket.get() == NULL);
// First, try to bind socket with the same name with no fallback name.
socket_ =
UnixDomainSocket::CreateAndListenWithAbstractNamespace(
UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
EXPECT_TRUE(socket_.get() == NULL);
// Now with a fallback name.
const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2";
socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
const char kFallbackSocketName[] = "socket_for_testing_2";
socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(),
MakeSocketPath(kFallbackSocketName),
GetTempSocketPath(kFallbackSocketName).value(),
socket_delegate_.get(),
MakeAuthCallback());
EXPECT_FALSE(socket_.get() == NULL);
}
#endif
TEST_F(UnixDomainSocketTest, TestWithClient) {
TEST_F(UnixDomainListenSocketTest, TestWithClient) {
const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
EventType event = event_manager_->WaitForEvent();
ASSERT_EQ(EVENT_LISTEN, event);
......@@ -311,7 +315,7 @@ TEST_F(UnixDomainSocketTest, TestWithClient) {
ASSERT_EQ(EVENT_CLOSE, event);
}
TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
TEST_F(UnixDomainListenSocketTestWithForbiddenUser, TestWithForbiddenUser) {
const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
EventType event = event_manager_->WaitForEvent();
ASSERT_EQ(EVENT_LISTEN, event);
......@@ -335,4 +339,5 @@ TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
}
} // namespace
} // namespace deprecated
} // namespace net
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/socket/unix_domain_server_socket_posix.h"
#include <errno.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include "base/logging.h"
#include "net/base/net_errors.h"
#include "net/socket/socket_libevent.h"
#include "net/socket/unix_domain_client_socket_posix.h"
namespace net {
UnixDomainServerSocket::UnixDomainServerSocket(
const AuthCallback& auth_callback,
bool use_abstract_namespace)
: auth_callback_(auth_callback),
use_abstract_namespace_(use_abstract_namespace) {
DCHECK(!auth_callback_.is_null());
}
UnixDomainServerSocket::~UnixDomainServerSocket() {
}
// static
bool UnixDomainServerSocket::GetPeerIds(SocketDescriptor socket,
uid_t* user_id,
gid_t* group_id) {
#if defined(OS_LINUX) || defined(OS_ANDROID)
struct ucred user_cred;
socklen_t len = sizeof(user_cred);
if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) < 0)
return false;
*user_id = user_cred.uid;
*group_id = user_cred.gid;
return true;
#else
return getpeereid(socket, user_id, group_id) == 0;
#endif
}
int UnixDomainServerSocket::Listen(const IPEndPoint& address, int backlog) {
NOTIMPLEMENTED();
return ERR_NOT_IMPLEMENTED;
}
int UnixDomainServerSocket::ListenWithAddressAndPort(
const std::string& unix_domain_path,
int port_unused,
int backlog) {
DCHECK(!listen_socket_);
SockaddrStorage address;
if (!UnixDomainClientSocket::FillAddress(unix_domain_path,
use_abstract_namespace_,
&address)) {
return ERR_ADDRESS_INVALID;
}
listen_socket_.reset(new SocketLibevent);
int rv = listen_socket_->Open(AF_UNIX);
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv != OK)
return rv;
rv = listen_socket_->Bind(address);
DCHECK_NE(ERR_IO_PENDING, rv);
if (rv != OK) {
PLOG(ERROR)
<< "Could not bind unix domain socket to " << unix_domain_path
<< (use_abstract_namespace_ ? " (with abstract namespace)" : "");
return rv;
}
return listen_socket_->Listen(backlog);
}
int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const {
NOTIMPLEMENTED();
return ERR_NOT_IMPLEMENTED;
}
int UnixDomainServerSocket::Accept(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback) {
DCHECK(socket);
DCHECK(!callback.is_null());
DCHECK(listen_socket_);
DCHECK(!accept_socket_);
while (true) {
int rv = listen_socket_->Accept(
&accept_socket_,
base::Bind(&UnixDomainServerSocket::AcceptCompleted,
base::Unretained(this), socket, callback));
if (rv != OK)
return rv;
if (AuthenticateAndGetStreamSocket(socket))
return OK;
// Accept another socket because authentication error should be transparent
// to the caller.
}
}
void UnixDomainServerSocket::AcceptCompleted(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback,
int rv) {
if (rv != OK) {
callback.Run(rv);
return;
}
if (AuthenticateAndGetStreamSocket(socket)) {
callback.Run(OK);
return;
}
// Accept another socket because authentication error should be transparent
// to the caller.
rv = Accept(socket, callback);
if (rv != ERR_IO_PENDING)
callback.Run(rv);
}
bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket(
scoped_ptr<StreamSocket>* socket) {
DCHECK(accept_socket_);
uid_t user_id;
gid_t group_id;
if (!GetPeerIds(accept_socket_->socket_fd(), &user_id, &group_id) ||
!auth_callback_.Run(user_id, group_id)) {
accept_socket_.reset();
return false;
}
socket->reset(new UnixDomainClientSocket(accept_socket_.Pass()));
return true;
}
} // namespace net
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_SOCKET_UNIX_DOMAIN_SERVER_SOCKET_POSIX_H_
#define NET_SOCKET_UNIX_DOMAIN_SERVER_SOCKET_POSIX_H_
#include <sys/types.h>
#include <string>
#include "base/basictypes.h"
#include "base/callback.h"
#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
#include "net/base/net_export.h"
#include "net/socket/server_socket.h"
#include "net/socket/socket_descriptor.h"
namespace net {
class SocketLibevent;
// Unix Domain Server Socket Implementation. Supports abstract namespaces on
// Linux and Android.
class NET_EXPORT UnixDomainServerSocket : public ServerSocket {
public:
// Callback that returns whether the already connected client, identified by
// its process |user_id| and |group_id|, is allowed to keep the connection
// open. Note that the socket is closed immediately in case the callback
// returns false.
typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback;
UnixDomainServerSocket(const AuthCallback& auth_callack,
bool use_abstract_namespace);
virtual ~UnixDomainServerSocket();
// Gets UID and GID of peer to check permissions.
static bool GetPeerIds(SocketDescriptor socket_fd,
uid_t* user_id,
gid_t* group_id);
// ServerSocket implementation.
virtual int Listen(const IPEndPoint& address, int backlog) OVERRIDE;
virtual int ListenWithAddressAndPort(const std::string& unix_domain_path,
int port_unused,
int backlog) OVERRIDE;
virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
virtual int Accept(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback) OVERRIDE;
private:
void AcceptCompleted(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback,
int rv);
bool AuthenticateAndGetStreamSocket(scoped_ptr<StreamSocket>* socket);
scoped_ptr<SocketLibevent> listen_socket_;
const AuthCallback auth_callback_;
const bool use_abstract_namespace_;
scoped_ptr<SocketLibevent> accept_socket_;
DISALLOW_COPY_AND_ASSIGN(UnixDomainServerSocket);
};
} // namespace net
#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/socket/unix_domain_server_socket_posix.h"
#include <vector>
#include "base/bind.h"
#include "base/files/file_path.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/scoped_ptr.h"
#include "base/run_loop.h"
#include "base/stl_util.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/unix_domain_client_socket_posix.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
namespace {
const char kSocketFilename[] = "socket_for_testing";
const char kInvalidSocketPath[] = "/invalid/path";
bool UserCanConnectCallback(bool allow_user, uid_t uid, gid_t gid) {
// Here peers are running in same process.
EXPECT_EQ(getuid(), uid);
EXPECT_EQ(getgid(), gid);
return allow_user;
}
UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
return base::Bind(&UserCanConnectCallback, allow_user);
}
class UnixDomainServerSocketTest : public testing::Test {
protected:
UnixDomainServerSocketTest() {
EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
socket_path_ = temp_dir_.path().Append(kSocketFilename).value();
}
base::ScopedTempDir temp_dir_;
std::string socket_path_;
};
TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPath) {
const bool kUseAbstractNamespace = false;
UnixDomainServerSocket server_socket(CreateAuthCallback(true),
kUseAbstractNamespace);
EXPECT_EQ(ERR_FILE_NOT_FOUND,
server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1));
}
TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPathWithAbstractNamespace) {
const bool kUseAbstractNamespace = true;
UnixDomainServerSocket server_socket(CreateAuthCallback(true),
kUseAbstractNamespace);
#if defined(OS_ANDROID) || defined(OS_LINUX)
EXPECT_EQ(OK,
server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1));
#else
EXPECT_EQ(ERR_ADDRESS_INVALID,
server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1));
#endif
}
TEST_F(UnixDomainServerSocketTest, AcceptWithForbiddenUser) {
const bool kUseAbstractNamespace = false;
UnixDomainServerSocket server_socket(CreateAuthCallback(false),
kUseAbstractNamespace);
EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
scoped_ptr<StreamSocket> accepted_socket;
TestCompletionCallback accept_callback;
EXPECT_EQ(ERR_IO_PENDING,
server_socket.Accept(&accepted_socket, accept_callback.callback()));
EXPECT_FALSE(accepted_socket);
UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
EXPECT_FALSE(client_socket.IsConnected());
// Connect() will return OK before the server rejects the connection.
TestCompletionCallback connect_callback;
int rv = client_socket.Connect(connect_callback.callback());
if (rv == ERR_IO_PENDING) {
rv = connect_callback.WaitForResult();
} else {
EXPECT_TRUE(client_socket.IsConnected());
}
EXPECT_EQ(OK, rv);
// Cannot use accept_callback.WaitForResult() because authentication error is
// invisible to the caller.
base::RunLoop().RunUntilIdle();
// Server disconnects the connection.
EXPECT_FALSE(client_socket.IsConnected());
// But, server didn't create the accepted socket.
EXPECT_FALSE(accepted_socket);
const int read_buffer_size = 10;
scoped_refptr<IOBuffer> read_buffer(new IOBuffer(read_buffer_size));
TestCompletionCallback read_callback;
EXPECT_EQ(0, /* EOF */
client_socket.Read(read_buffer, read_buffer_size,
read_callback.callback()));
}
// Normal cases including read/write are tested by UnixDomainClientSocketTest.
} // namespace
} // namespace net
......@@ -14,7 +14,7 @@
#include "base/lazy_instance.h"
#include "base/stl_util.h"
#include "base/values.h"
#include "net/socket/unix_domain_socket_posix.h"
#include "net/socket/unix_domain_listen_socket_posix.h"
#include "remoting/base/logging.h"
#include "remoting/host/gnubby_socket.h"
#include "remoting/proto/control.pb.h"
......@@ -263,7 +263,7 @@ void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
HOST_LOG << "Listening for gnubby requests on "
<< g_gnubby_socket_name.Get().value();
auth_socket_ = net::UnixDomainSocket::CreateAndListen(
auth_socket_ = net::deprecated::UnixDomainListenSocket::CreateAndListen(
g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid));
if (!auth_socket_.get()) {
LOG(ERROR) << "Failed to open socket for gnubby requests";
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment