Commit a6a4b6a1 authored by joedow's avatar joedow Committed by Commit bot

Removing 'AllowScopedIO' exception from SecurityKeyAuthHandlerLinux

This change updates the SecurityKeyAuthHandlerLinux class to run its
socket operations on a task_runner which allows blocking IO.

The first part of the change is plumbing through the file_task_runner
from the Me2Me host (which owns the context with the task runners)
down to the auth handler class which requires it.

The second part of the change was to move the file delete operation
to the file thread and then post back to the main thread.

I've also updated the unit tests for the affected class since it now
needs to handle synchronization between two threads.

BUG=591739
R=sergeyu@chromium.org
TBR=dpranke@chromium.org

Review-Url: https://codereview.chromium.org/2168303003
Cr-Commit-Position: refs/heads/master@{#407649}
parent 1528936c
...@@ -185,8 +185,6 @@ _BANNED_CPP_FUNCTIONS = ( ...@@ -185,8 +185,6 @@ _BANNED_CPP_FUNCTIONS = (
r"simple_platform_shared_buffer_posix\.cc$", r"simple_platform_shared_buffer_posix\.cc$",
r"^net[\\\/]disk_cache[\\\/]cache_util\.cc$", r"^net[\\\/]disk_cache[\\\/]cache_util\.cc$",
r"^net[\\\/]url_request[\\\/]test_url_fetcher_factory\.cc$", r"^net[\\\/]url_request[\\\/]test_url_fetcher_factory\.cc$",
r"^remoting[\\\/]host[\\\/]security_key[\\\/]"
"security_key_auth_handler_linux\.cc$",
r"^ui[\\\/]base[\\\/]material_design[\\\/]" r"^ui[\\\/]base[\\\/]material_design[\\\/]"
"material_design_controller\.cc$", "material_design_controller\.cc$",
r"^ui[\\\/]gl[\\\/]init[\\\/]gl_initializer_mac\.cc$", r"^ui[\\\/]gl[\\\/]init[\\\/]gl_initializer_mac\.cc$",
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/command_line.h" #include "base/command_line.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "base/single_thread_task_runner.h"
#include "build/build_config.h" #include "build/build_config.h"
#include "jingle/glue/thread_wrapper.h" #include "jingle/glue/thread_wrapper.h"
#include "remoting/base/constants.h" #include "remoting/base/constants.h"
#include "remoting/base/logging.h" #include "remoting/base/logging.h"
#include "remoting/host/chromoting_host_context.h"
#include "remoting/host/desktop_environment.h" #include "remoting/host/desktop_environment.h"
#include "remoting/host/host_config.h" #include "remoting/host/host_config.h"
#include "remoting/host/input_injector.h" #include "remoting/host/input_injector.h"
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
// This file implements a standalone host process for Me2Me. // This file implements a standalone host process for Me2Me.
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <cstdint>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -1461,7 +1461,8 @@ void HostProcess::StartHost() { ...@@ -1461,7 +1461,8 @@ void HostProcess::StartHost() {
context_->video_encode_task_runner())); context_->video_encode_task_runner()));
if (security_key_auth_policy_enabled_ && security_key_extension_supported_) { if (security_key_auth_policy_enabled_ && security_key_extension_supported_) {
host_->AddExtension(base::WrapUnique(new SecurityKeyExtension())); host_->AddExtension(base::WrapUnique(
new SecurityKeyExtension(context_->file_task_runner())));
} }
// TODO(simonmorris): Get the maximum session duration from a policy. // TODO(simonmorris): Get the maximum session duration from a policy.
......
...@@ -9,10 +9,12 @@ ...@@ -9,10 +9,12 @@
#include <string> #include <string>
#include "base/callback.h" #include "base/callback.h"
#include "base/memory/ref_counted.h"
#include "base/time/time.h" #include "base/time/time.h"
namespace base { namespace base {
class FilePath; class FilePath;
class SingleThreadTaskRunner;
} // namespace base } // namespace base
namespace remoting { namespace remoting {
...@@ -36,7 +38,8 @@ class SecurityKeyAuthHandler { ...@@ -36,7 +38,8 @@ class SecurityKeyAuthHandler {
// |client_session_details| will be valid until this instance is destroyed. // |client_session_details| will be valid until this instance is destroyed.
static std::unique_ptr<SecurityKeyAuthHandler> Create( static std::unique_ptr<SecurityKeyAuthHandler> Create(
ClientSessionDetails* client_session_details, ClientSessionDetails* client_session_details,
const SendMessageCallback& send_message_callback); const SendMessageCallback& send_message_callback,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner);
#if defined(OS_LINUX) #if defined(OS_LINUX)
// Specify the name of the socket to listen to security key requests on. // Specify the name of the socket to listen to security key requests on.
......
...@@ -8,7 +8,8 @@ namespace remoting { ...@@ -8,7 +8,8 @@ namespace remoting {
std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create( std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create(
ClientSessionDetails* client_session_details, ClientSessionDetails* client_session_details,
const SendMessageCallback& send_message_callback) { const SendMessageCallback& send_message_callback,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) {
return nullptr; return nullptr;
} }
......
...@@ -4,19 +4,26 @@ ...@@ -4,19 +4,26 @@
#include "remoting/host/security_key/security_key_auth_handler.h" #include "remoting/host/security_key/security_key_auth_handler.h"
#include <stdint.h>
#include <unistd.h> #include <unistd.h>
#include <cstdint>
#include <map>
#include <memory> #include <memory>
#include <string>
#include <utility>
#include "base/bind.h" #include "base/bind.h"
#include "base/callback.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h" #include "base/files/file_util.h"
#include "base/lazy_instance.h" #include "base/lazy_instance.h"
#include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/stl_util.h" #include "base/memory/ptr_util.h"
#include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h"
#include "base/single_thread_task_runner.h"
#include "base/threading/thread_checker.h" #include "base/threading/thread_checker.h"
#include "base/threading/thread_restrictions.h"
#include "base/values.h"
#include "net/base/completion_callback.h" #include "net/base/completion_callback.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/socket/stream_socket.h" #include "net/socket/stream_socket.h"
...@@ -53,11 +60,12 @@ namespace remoting { ...@@ -53,11 +60,12 @@ namespace remoting {
class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler { class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler {
public: public:
SecurityKeyAuthHandlerLinux(); explicit SecurityKeyAuthHandlerLinux(
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner);
~SecurityKeyAuthHandlerLinux() override; ~SecurityKeyAuthHandlerLinux() override;
private: private:
typedef std::map<int, SecurityKeySocket*> ActiveSockets; typedef std::map<int, std::unique_ptr<SecurityKeySocket>> ActiveSockets;
// SecurityKeyAuthHandler interface. // SecurityKeyAuthHandler interface.
void CreateSecurityKeyConnection() override; void CreateSecurityKeyConnection() override;
...@@ -69,6 +77,9 @@ class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler { ...@@ -69,6 +77,9 @@ class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler {
size_t GetActiveConnectionCountForTest() const override; size_t GetActiveConnectionCountForTest() const override;
void SetRequestTimeoutForTest(base::TimeDelta timeout) override; void SetRequestTimeoutForTest(base::TimeDelta timeout) override;
// Sets up the socket used for accepting new connections.
void CreateSocket();
// Starts listening for connection. // Starts listening for connection.
void DoAccept(); void DoAccept();
...@@ -101,22 +112,28 @@ class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler { ...@@ -101,22 +112,28 @@ class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler {
SendMessageCallback send_message_callback_; SendMessageCallback send_message_callback_;
// The last assigned security key connection id. // The last assigned security key connection id.
int last_connection_id_; int last_connection_id_ = 0;
// Sockets by connection id used to process gnubbyd requests. // Sockets by connection id used to process gnubbyd requests.
ActiveSockets active_sockets_; ActiveSockets active_sockets_;
// Used to perform blocking File IO.
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner_;
// Timeout used for a request. // Timeout used for a request.
base::TimeDelta request_timeout_; base::TimeDelta request_timeout_;
base::WeakPtrFactory<SecurityKeyAuthHandlerLinux> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(SecurityKeyAuthHandlerLinux); DISALLOW_COPY_AND_ASSIGN(SecurityKeyAuthHandlerLinux);
}; };
std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create( std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create(
ClientSessionDetails* client_session_details, ClientSessionDetails* client_session_details,
const SendMessageCallback& send_message_callback) { const SendMessageCallback& send_message_callback,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) {
std::unique_ptr<SecurityKeyAuthHandler> auth_handler( std::unique_ptr<SecurityKeyAuthHandler> auth_handler(
new SecurityKeyAuthHandlerLinux()); new SecurityKeyAuthHandlerLinux(file_task_runner));
auth_handler->SetSendMessageCallback(send_message_callback); auth_handler->SetSendMessageCallback(send_message_callback);
return auth_handler; return auth_handler;
} }
...@@ -126,31 +143,32 @@ void SecurityKeyAuthHandler::SetSecurityKeySocketName( ...@@ -126,31 +143,32 @@ void SecurityKeyAuthHandler::SetSecurityKeySocketName(
g_security_key_socket_name.Get() = security_key_socket_name; g_security_key_socket_name.Get() = security_key_socket_name;
} }
SecurityKeyAuthHandlerLinux::SecurityKeyAuthHandlerLinux() SecurityKeyAuthHandlerLinux::SecurityKeyAuthHandlerLinux(
: last_connection_id_(0), scoped_refptr<base::SingleThreadTaskRunner> file_task_runner)
: file_task_runner_(file_task_runner),
request_timeout_( request_timeout_(
base::TimeDelta::FromSeconds(kDefaultRequestTimeoutSeconds)) {} base::TimeDelta::FromSeconds(kDefaultRequestTimeoutSeconds)),
weak_factory_(this) {}
SecurityKeyAuthHandlerLinux::~SecurityKeyAuthHandlerLinux() { SecurityKeyAuthHandlerLinux::~SecurityKeyAuthHandlerLinux() {}
STLDeleteValues(&active_sockets_);
}
void SecurityKeyAuthHandlerLinux::CreateSecurityKeyConnection() { void SecurityKeyAuthHandlerLinux::CreateSecurityKeyConnection() {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(!g_security_key_socket_name.Get().empty()); DCHECK(!g_security_key_socket_name.Get().empty());
{ // We need to run the DeleteFile method on |file_task_runner_| as it is a
// DeleteFile() is a blocking operation, but so is creation of the unix // blocking function call which cannot be run on the main thread. Once
// socket below. Consider moving this class to a different thread if this // that task has completed, the main thread will be called back and we will
// causes any problems. See crbug.com/509807. // resume setting up our security key auth socket there.
// TODO(joedow): Since this code now runs as a host extension, we should file_task_runner_->PostTaskAndReply(
// perform our IO on a separate thread: crbug.com/591739 FROM_HERE, base::Bind(base::IgnoreResult(&base::DeleteFile),
base::ThreadRestrictions::ScopedAllowIO allow_io; g_security_key_socket_name.Get(), false),
base::Bind(&SecurityKeyAuthHandlerLinux::CreateSocket,
// If the file already exists, a socket in use error is returned. weak_factory_.GetWeakPtr()));
base::DeleteFile(g_security_key_socket_name.Get(), false); }
}
void SecurityKeyAuthHandlerLinux::CreateSocket() {
DCHECK(thread_checker_.CalledOnValidThread());
HOST_LOG << "Listening for security key requests on " HOST_LOG << "Listening for security key requests on "
<< g_security_key_socket_name.Get().value(); << g_security_key_socket_name.Get().value();
...@@ -212,6 +230,7 @@ void SecurityKeyAuthHandlerLinux::SetRequestTimeoutForTest( ...@@ -212,6 +230,7 @@ void SecurityKeyAuthHandlerLinux::SetRequestTimeoutForTest(
} }
void SecurityKeyAuthHandlerLinux::DoAccept() { void SecurityKeyAuthHandlerLinux::DoAccept() {
DCHECK(thread_checker_.CalledOnValidThread());
int result = auth_socket_->Accept( int result = auth_socket_->Accept(
&accept_socket_, base::Bind(&SecurityKeyAuthHandlerLinux::OnAccepted, &accept_socket_, base::Bind(&SecurityKeyAuthHandlerLinux::OnAccepted,
base::Unretained(this))); base::Unretained(this)));
...@@ -233,7 +252,7 @@ void SecurityKeyAuthHandlerLinux::OnAccepted(int result) { ...@@ -233,7 +252,7 @@ void SecurityKeyAuthHandlerLinux::OnAccepted(int result) {
std::move(accept_socket_), request_timeout_, std::move(accept_socket_), request_timeout_,
base::Bind(&SecurityKeyAuthHandlerLinux::RequestTimedOut, base::Bind(&SecurityKeyAuthHandlerLinux::RequestTimedOut,
base::Unretained(this), security_key_connection_id)); base::Unretained(this), security_key_connection_id));
active_sockets_[security_key_connection_id] = socket; active_sockets_[security_key_connection_id] = base::WrapUnique(socket);
socket->StartReadingRequest( socket->StartReadingRequest(
base::Bind(&SecurityKeyAuthHandlerLinux::OnReadComplete, base::Bind(&SecurityKeyAuthHandlerLinux::OnReadComplete,
base::Unretained(this), security_key_connection_id)); base::Unretained(this), security_key_connection_id));
...@@ -272,7 +291,6 @@ SecurityKeyAuthHandlerLinux::GetSocketForConnectionId( ...@@ -272,7 +291,6 @@ SecurityKeyAuthHandlerLinux::GetSocketForConnectionId(
void SecurityKeyAuthHandlerLinux::SendErrorAndCloseActiveSocket( void SecurityKeyAuthHandlerLinux::SendErrorAndCloseActiveSocket(
const ActiveSockets::const_iterator& iter) { const ActiveSockets::const_iterator& iter) {
iter->second->SendSshError(); iter->second->SendSshError();
delete iter->second;
active_sockets_.erase(iter); active_sockets_.erase(iter);
} }
......
...@@ -4,21 +4,22 @@ ...@@ -4,21 +4,22 @@
#include <stddef.h> #include <stddef.h>
#include <memory>
#include <string>
#include "base/bind.h"
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/files/scoped_temp_dir.h" #include "base/files/scoped_temp_dir.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/message_loop/message_loop.h" #include "base/message_loop/message_loop.h"
#include "base/run_loop.h" #include "base/run_loop.h"
#include "base/strings/stringprintf.h" #include "base/threading/thread.h"
#include "base/timer/mock_timer.h"
#include "base/values.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h" #include "net/base/test_completion_callback.h"
#include "net/socket/unix_domain_client_socket_posix.h" #include "net/socket/unix_domain_client_socket_posix.h"
#include "remoting/host/security_key/security_key_auth_handler.h" #include "remoting/host/security_key/security_key_auth_handler.h"
#include "remoting/host/security_key/security_key_socket.h" #include "remoting/host/security_key/security_key_socket.h"
#include "remoting/proto/internal.pb.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace remoting { namespace remoting {
...@@ -44,26 +45,45 @@ const unsigned char kRequestData[] = { ...@@ -44,26 +45,45 @@ const unsigned char kRequestData[] = {
0x5e, 0xa3, 0xbc, 0x02, 0x5b, 0xec, 0xe4, 0x4b, 0xae, 0x0e, 0xf2, 0xbd, 0x5e, 0xa3, 0xbc, 0x02, 0x5b, 0xec, 0xe4, 0x4b, 0xae, 0x0e, 0xf2, 0xbd,
0xc8, 0xaa}; 0xc8, 0xaa};
void RunUntilIdle() {
base::RunLoop run_loop;
run_loop.RunUntilIdle();
}
} // namespace } // namespace
class SecurityKeyAuthHandlerLinuxTest : public testing::Test { class SecurityKeyAuthHandlerLinuxTest : public testing::Test {
public: public:
SecurityKeyAuthHandlerLinuxTest() SecurityKeyAuthHandlerLinuxTest()
: run_loop_(new base::RunLoop()), last_connection_id_received_(-1) { : run_loop_(new base::RunLoop()),
file_thread_("SecurityKeyAuthHandlerLinuxTest_FileThread") {
EXPECT_TRUE(temp_dir_.CreateUniqueTempDir()); EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
socket_path_ = temp_dir_.path().Append(kSocketFilename); socket_path_ = temp_dir_.path().Append(kSocketFilename);
remoting::SecurityKeyAuthHandler::SetSecurityKeySocketName(socket_path_); remoting::SecurityKeyAuthHandler::SetSecurityKeySocketName(socket_path_);
EXPECT_TRUE(file_thread_.StartWithOptions(
base::Thread::Options(base::MessageLoop::TYPE_IO, 0)));
send_message_callback_ = send_message_callback_ =
base::Bind(&SecurityKeyAuthHandlerLinuxTest::SendMessageToClient, base::Bind(&SecurityKeyAuthHandlerLinuxTest::SendMessageToClient,
base::Unretained(this)); base::Unretained(this));
auth_handler_ = remoting::SecurityKeyAuthHandler::Create( auth_handler_ = remoting::SecurityKeyAuthHandler::Create(
nullptr, send_message_callback_); /*client_session_details=*/nullptr, send_message_callback_,
file_thread_.task_runner());
EXPECT_NE(auth_handler_.get(), nullptr);
} }
void WaitForSendMessageToClient() { void CreateSocketAndWait() {
ASSERT_EQ(0u, auth_handler_->GetActiveConnectionCountForTest());
auth_handler_->CreateSecurityKeyConnection();
ASSERT_TRUE(file_thread_.task_runner()->PostTaskAndReply(
FROM_HERE, base::Bind(&RunUntilIdle), run_loop_->QuitClosure()));
run_loop_->Run(); run_loop_->Run();
run_loop_.reset(new base::RunLoop); run_loop_.reset(new base::RunLoop);
ASSERT_EQ(0u, auth_handler_->GetActiveConnectionCountForTest());
} }
void SendMessageToClient(int connection_id, const std::string& data) { void SendMessageToClient(int connection_id, const std::string& data) {
...@@ -72,6 +92,11 @@ class SecurityKeyAuthHandlerLinuxTest : public testing::Test { ...@@ -72,6 +92,11 @@ class SecurityKeyAuthHandlerLinuxTest : public testing::Test {
run_loop_->Quit(); run_loop_->Quit();
} }
void WaitForSendMessageToClient() {
run_loop_->Run();
run_loop_.reset(new base::RunLoop);
}
void CheckHostDataMessage(int id, const std::string& expected_data) { void CheckHostDataMessage(int id, const std::string& expected_data) {
ASSERT_EQ(id, last_connection_id_received_); ASSERT_EQ(id, last_connection_id_received_);
ASSERT_EQ(expected_data.length(), last_message_received_.length()); ASSERT_EQ(expected_data.length(), last_message_received_.length());
...@@ -117,12 +142,14 @@ class SecurityKeyAuthHandlerLinuxTest : public testing::Test { ...@@ -117,12 +142,14 @@ class SecurityKeyAuthHandlerLinuxTest : public testing::Test {
base::MessageLoopForIO message_loop_; base::MessageLoopForIO message_loop_;
std::unique_ptr<base::RunLoop> run_loop_; std::unique_ptr<base::RunLoop> run_loop_;
base::Thread file_thread_;
// Object under test. // Object under test.
std::unique_ptr<SecurityKeyAuthHandler> auth_handler_; std::unique_ptr<SecurityKeyAuthHandler> auth_handler_;
SecurityKeyAuthHandler::SendMessageCallback send_message_callback_; SecurityKeyAuthHandler::SendMessageCallback send_message_callback_;
int last_connection_id_received_; int last_connection_id_received_ = -1;
std::string last_message_received_; std::string last_message_received_;
base::ScopedTempDir temp_dir_; base::ScopedTempDir temp_dir_;
...@@ -134,9 +161,7 @@ class SecurityKeyAuthHandlerLinuxTest : public testing::Test { ...@@ -134,9 +161,7 @@ class SecurityKeyAuthHandlerLinuxTest : public testing::Test {
}; };
TEST_F(SecurityKeyAuthHandlerLinuxTest, NotClosedAfterRequest) { TEST_F(SecurityKeyAuthHandlerLinuxTest, NotClosedAfterRequest) {
ASSERT_EQ(0u, auth_handler_->GetActiveConnectionCountForTest()); CreateSocketAndWait();
auth_handler_->CreateSecurityKeyConnection();
net::UnixDomainClientSocket client_socket(socket_path_.value(), false); net::UnixDomainClientSocket client_socket(socket_path_.value(), false);
net::TestCompletionCallback connect_callback; net::TestCompletionCallback connect_callback;
...@@ -156,9 +181,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, NotClosedAfterRequest) { ...@@ -156,9 +181,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, NotClosedAfterRequest) {
} }
TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoRequests) { TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoRequests) {
ASSERT_EQ(0u, auth_handler_->GetActiveConnectionCountForTest()); CreateSocketAndWait();
auth_handler_->CreateSecurityKeyConnection();
net::UnixDomainClientSocket client_socket(socket_path_.value(), false); net::UnixDomainClientSocket client_socket(socket_path_.value(), false);
net::TestCompletionCallback connect_callback; net::TestCompletionCallback connect_callback;
...@@ -186,9 +209,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoRequests) { ...@@ -186,9 +209,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoRequests) {
} }
TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoIndependentRequests) { TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoIndependentRequests) {
ASSERT_EQ(0u, auth_handler_->GetActiveConnectionCountForTest()); CreateSocketAndWait();
auth_handler_->CreateSecurityKeyConnection();
net::UnixDomainClientSocket client_socket(socket_path_.value(), false); net::UnixDomainClientSocket client_socket(socket_path_.value(), false);
net::TestCompletionCallback connect_callback; net::TestCompletionCallback connect_callback;
...@@ -221,8 +242,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoIndependentRequests) { ...@@ -221,8 +242,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, HandleTwoIndependentRequests) {
} }
TEST_F(SecurityKeyAuthHandlerLinuxTest, DidReadTimeout) { TEST_F(SecurityKeyAuthHandlerLinuxTest, DidReadTimeout) {
ASSERT_EQ(0u, auth_handler_->GetActiveConnectionCountForTest()); CreateSocketAndWait();
auth_handler_->CreateSecurityKeyConnection();
net::UnixDomainClientSocket client_socket(socket_path_.value(), false); net::UnixDomainClientSocket client_socket(socket_path_.value(), false);
net::TestCompletionCallback connect_callback; net::TestCompletionCallback connect_callback;
...@@ -233,8 +253,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, DidReadTimeout) { ...@@ -233,8 +253,7 @@ TEST_F(SecurityKeyAuthHandlerLinuxTest, DidReadTimeout) {
} }
TEST_F(SecurityKeyAuthHandlerLinuxTest, ClientErrorMessageDelivered) { TEST_F(SecurityKeyAuthHandlerLinuxTest, ClientErrorMessageDelivered) {
ASSERT_EQ(0u, auth_handler_->GetActiveConnectionCountForTest()); CreateSocketAndWait();
auth_handler_->CreateSecurityKeyConnection();
net::UnixDomainClientSocket client_socket(socket_path_.value(), false); net::UnixDomainClientSocket client_socket(socket_path_.value(), false);
net::TestCompletionCallback connect_callback; net::TestCompletionCallback connect_callback;
......
...@@ -10,7 +10,8 @@ namespace remoting { ...@@ -10,7 +10,8 @@ namespace remoting {
std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create( std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create(
ClientSessionDetails* client_session_details, ClientSessionDetails* client_session_details,
const SendMessageCallback& send_message_callback) { const SendMessageCallback& send_message_callback,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) {
return nullptr; return nullptr;
} }
......
...@@ -133,7 +133,8 @@ class SecurityKeyAuthHandlerWin : public SecurityKeyAuthHandler, ...@@ -133,7 +133,8 @@ class SecurityKeyAuthHandlerWin : public SecurityKeyAuthHandler,
std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create( std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create(
ClientSessionDetails* client_session_details, ClientSessionDetails* client_session_details,
const SendMessageCallback& send_message_callback) { const SendMessageCallback& send_message_callback,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) {
std::unique_ptr<SecurityKeyAuthHandler> auth_handler( std::unique_ptr<SecurityKeyAuthHandler> auth_handler(
new SecurityKeyAuthHandlerWin(client_session_details)); new SecurityKeyAuthHandlerWin(client_session_details));
auth_handler->SetSendMessageCallback(send_message_callback); auth_handler->SetSendMessageCallback(send_message_callback);
......
...@@ -122,7 +122,8 @@ SecurityKeyAuthHandlerWinTest::SecurityKeyAuthHandlerWinTest() ...@@ -122,7 +122,8 @@ SecurityKeyAuthHandlerWinTest::SecurityKeyAuthHandlerWinTest()
auth_handler_ = remoting::SecurityKeyAuthHandler::Create( auth_handler_ = remoting::SecurityKeyAuthHandler::Create(
&mock_client_session_details_, &mock_client_session_details_,
base::Bind(&SecurityKeyAuthHandlerWinTest::SendMessageToClient, base::Bind(&SecurityKeyAuthHandlerWinTest::SendMessageToClient,
base::Unretained(this))); base::Unretained(this)),
/*file_task_runner=*/nullptr);
} }
SecurityKeyAuthHandlerWinTest::~SecurityKeyAuthHandlerWinTest() {} SecurityKeyAuthHandlerWinTest::~SecurityKeyAuthHandlerWinTest() {}
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "remoting/host/security_key/security_key_extension.h" #include "remoting/host/security_key/security_key_extension.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "base/memory/ref_counted.h"
#include "base/single_thread_task_runner.h"
#include "remoting/host/security_key/security_key_extension_session.h" #include "remoting/host/security_key/security_key_extension_session.h"
namespace { namespace {
...@@ -15,7 +17,9 @@ const char kCapability[] = ""; ...@@ -15,7 +17,9 @@ const char kCapability[] = "";
namespace remoting { namespace remoting {
SecurityKeyExtension::SecurityKeyExtension() {} SecurityKeyExtension::SecurityKeyExtension(
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner)
: file_task_runner_(file_task_runner) {}
SecurityKeyExtension::~SecurityKeyExtension() {} SecurityKeyExtension::~SecurityKeyExtension() {}
...@@ -31,7 +35,7 @@ SecurityKeyExtension::CreateExtensionSession( ...@@ -31,7 +35,7 @@ SecurityKeyExtension::CreateExtensionSession(
// extension will only send messages through the initial // extension will only send messages through the initial
// |client_stub| and |details| with the current design. // |client_stub| and |details| with the current design.
return base::WrapUnique( return base::WrapUnique(
new SecurityKeyExtensionSession(details, client_stub)); new SecurityKeyExtensionSession(details, client_stub, file_task_runner_));
} }
} // namespace remoting } // namespace remoting
...@@ -9,8 +9,13 @@ ...@@ -9,8 +9,13 @@
#include <string> #include <string>
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "remoting/host/host_extension.h" #include "remoting/host/host_extension.h"
namespace base {
class SingleThreadTaskRunner;
} // namespace base
namespace remoting { namespace remoting {
class ClientSessionDetails; class ClientSessionDetails;
...@@ -19,7 +24,8 @@ class HostExtensionSession; ...@@ -19,7 +24,8 @@ class HostExtensionSession;
// SecurityKeyExtension extends HostExtension to enable Security Key support. // SecurityKeyExtension extends HostExtension to enable Security Key support.
class SecurityKeyExtension : public HostExtension { class SecurityKeyExtension : public HostExtension {
public: public:
SecurityKeyExtension(); explicit SecurityKeyExtension(
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner);
~SecurityKeyExtension() override; ~SecurityKeyExtension() override;
// HostExtension interface. // HostExtension interface.
...@@ -29,6 +35,9 @@ class SecurityKeyExtension : public HostExtension { ...@@ -29,6 +35,9 @@ class SecurityKeyExtension : public HostExtension {
protocol::ClientStub* client_stub) override; protocol::ClientStub* client_stub) override;
private: private:
// Allows underlying auth handler to perform blocking file IO.
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner_;
DISALLOW_COPY_AND_ASSIGN(SecurityKeyExtension); DISALLOW_COPY_AND_ASSIGN(SecurityKeyExtension);
}; };
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "base/json/json_writer.h" #include "base/json/json_writer.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/single_thread_task_runner.h"
#include "base/values.h" #include "base/values.h"
#include "remoting/base/logging.h" #include "remoting/base/logging.h"
#include "remoting/host/client_session_details.h" #include "remoting/host/client_session_details.h"
...@@ -62,14 +63,16 @@ namespace remoting { ...@@ -62,14 +63,16 @@ namespace remoting {
SecurityKeyExtensionSession::SecurityKeyExtensionSession( SecurityKeyExtensionSession::SecurityKeyExtensionSession(
ClientSessionDetails* client_session_details, ClientSessionDetails* client_session_details,
protocol::ClientStub* client_stub) protocol::ClientStub* client_stub,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner)
: client_stub_(client_stub) { : client_stub_(client_stub) {
DCHECK(client_stub_); DCHECK(client_stub_);
security_key_auth_handler_ = remoting::SecurityKeyAuthHandler::Create( security_key_auth_handler_ = remoting::SecurityKeyAuthHandler::Create(
client_session_details, client_session_details,
base::Bind(&SecurityKeyExtensionSession::SendMessageToClient, base::Bind(&SecurityKeyExtensionSession::SendMessageToClient,
base::Unretained(this))); base::Unretained(this)),
file_task_runner);
} }
SecurityKeyExtensionSession::~SecurityKeyExtensionSession() {} SecurityKeyExtensionSession::~SecurityKeyExtensionSession() {}
......
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "base/threading/thread_checker.h" #include "base/threading/thread_checker.h"
#include "remoting/host/host_extension_session.h" #include "remoting/host/host_extension_session.h"
namespace base { namespace base {
class DictionaryValue; class DictionaryValue;
} class SingleThreadTaskRunner;
} // namespace base
namespace remoting { namespace remoting {
...@@ -29,8 +31,10 @@ class ClientStub; ...@@ -29,8 +31,10 @@ class ClientStub;
// A HostExtensionSession implementation that enables Security Key support. // A HostExtensionSession implementation that enables Security Key support.
class SecurityKeyExtensionSession : public HostExtensionSession { class SecurityKeyExtensionSession : public HostExtensionSession {
public: public:
SecurityKeyExtensionSession(ClientSessionDetails* client_session_details, SecurityKeyExtensionSession(
protocol::ClientStub* client_stub); ClientSessionDetails* client_session_details,
protocol::ClientStub* client_stub,
scoped_refptr<base::SingleThreadTaskRunner> file_task_runner);
~SecurityKeyExtensionSession() override; ~SecurityKeyExtensionSession() override;
// HostExtensionSession interface. // HostExtensionSession interface.
......
...@@ -169,7 +169,9 @@ class SecurityKeyExtensionSessionTest : public testing::Test { ...@@ -169,7 +169,9 @@ class SecurityKeyExtensionSessionTest : public testing::Test {
SecurityKeyExtensionSessionTest::SecurityKeyExtensionSessionTest() SecurityKeyExtensionSessionTest::SecurityKeyExtensionSessionTest()
: security_key_extension_session_( : security_key_extension_session_(
new SecurityKeyExtensionSession(&client_details_, &client_stub_)) { new SecurityKeyExtensionSession(&client_details_,
&client_stub_,
/*file_task_runner=*/nullptr)) {
// We want to retain ownership of mock object so we can use it to inject // We want to retain ownership of mock object so we can use it to inject
// events into the extension session. The mock object should not be used // events into the extension session. The mock object should not be used
// once |security_key_extension_session_| is destroyed. // once |security_key_extension_session_| is destroyed.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment