Commit 569e09ce authored by yhirano's avatar yhirano Committed by Commit bot

Implement handshake timeout on WebSocket

Previously, WebSocket has no handshake timeout, i.e. It waited unlimitedly
when the server accepts the TCP connection but doesn't send any handshake
response. This CL fixes the issue.

BUG=391263
R=ricea@chromium.org

Review URL: https://codereview.chromium.org/565573002

Cr-Commit-Position: refs/heads/master@{#294373}
parent 655e378c
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "base/memory/scoped_ptr.h" #include "base/memory/scoped_ptr.h"
#include "base/metrics/histogram.h" #include "base/metrics/histogram.h"
#include "base/metrics/sparse_histogram.h" #include "base/metrics/sparse_histogram.h"
#include "base/time/time.h"
#include "base/timer/timer.h"
#include "net/base/load_flags.h" #include "net/base/load_flags.h"
#include "net/http/http_request_headers.h" #include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h" #include "net/http/http_response_headers.h"
...@@ -27,6 +29,12 @@ ...@@ -27,6 +29,12 @@
namespace net { namespace net {
namespace { namespace {
// The timeout duration of WebSocket handshake.
// It is defined as the same value as the TCP connection timeout value in
// net/socket/websocket_transport_client_socket_pool.cc to make it hard for
// JavaScript programs to recognize the timeout cause.
const int kHandshakeTimeoutIntervalInSeconds = 240;
class StreamRequestImpl; class StreamRequestImpl;
class Delegate : public URLRequest::Delegate { class Delegate : public URLRequest::Delegate {
...@@ -112,22 +120,36 @@ class StreamRequestImpl : public WebSocketStreamRequest { ...@@ -112,22 +120,36 @@ class StreamRequestImpl : public WebSocketStreamRequest {
// and so terminates the handshake if it is incomplete. // and so terminates the handshake if it is incomplete.
virtual ~StreamRequestImpl() {} virtual ~StreamRequestImpl() {}
void Start() { void Start(scoped_ptr<base::Timer> timer) {
DCHECK(timer);
TimeDelta timeout(TimeDelta::FromSeconds(
kHandshakeTimeoutIntervalInSeconds));
timer_ = timer.Pass();
timer_->Start(FROM_HERE, timeout,
base::Bind(&StreamRequestImpl::OnTimeout,
base::Unretained(this)));
url_request_->Start(); url_request_->Start();
} }
void PerformUpgrade() { void PerformUpgrade() {
DCHECK(timer_);
timer_->Stop();
connect_delegate_->OnSuccess(create_helper_->Upgrade()); connect_delegate_->OnSuccess(create_helper_->Upgrade());
} }
void ReportFailure() { void ReportFailure() {
DCHECK(timer_);
timer_->Stop();
if (failure_message_.empty()) { if (failure_message_.empty()) {
switch (url_request_->status().status()) { switch (url_request_->status().status()) {
case URLRequestStatus::SUCCESS: case URLRequestStatus::SUCCESS:
case URLRequestStatus::IO_PENDING: case URLRequestStatus::IO_PENDING:
break; break;
case URLRequestStatus::CANCELED: case URLRequestStatus::CANCELED:
failure_message_ = "WebSocket opening handshake was canceled"; if (url_request_->status().error() == ERR_TIMED_OUT)
failure_message_ = "WebSocket opening handshake timed out";
else
failure_message_ = "WebSocket opening handshake was canceled";
break; break;
case URLRequestStatus::FAILED: case URLRequestStatus::FAILED:
failure_message_ = failure_message_ =
...@@ -154,6 +176,10 @@ class StreamRequestImpl : public WebSocketStreamRequest { ...@@ -154,6 +176,10 @@ class StreamRequestImpl : public WebSocketStreamRequest {
return connect_delegate_.get(); return connect_delegate_.get();
} }
void OnTimeout() {
url_request_->CancelWithError(ERR_TIMED_OUT);
}
private: private:
// |delegate_| needs to be declared before |url_request_| so that it gets // |delegate_| needs to be declared before |url_request_| so that it gets
// initialised first. // initialised first.
...@@ -170,6 +196,9 @@ class StreamRequestImpl : public WebSocketStreamRequest { ...@@ -170,6 +196,9 @@ class StreamRequestImpl : public WebSocketStreamRequest {
// The failure message supplied by WebSocketBasicHandshakeStream, if any. // The failure message supplied by WebSocketBasicHandshakeStream, if any.
std::string failure_message_; std::string failure_message_;
// A timer for handshake timeout.
scoped_ptr<base::Timer> timer_;
}; };
class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks { class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks {
...@@ -286,7 +315,7 @@ scoped_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream( ...@@ -286,7 +315,7 @@ scoped_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream(
origin, origin,
connect_delegate.Pass(), connect_delegate.Pass(),
create_helper.Pass())); create_helper.Pass()));
request->Start(); request->Start(scoped_ptr<base::Timer>(new base::Timer(false, false)));
return request.PassAs<WebSocketStreamRequest>(); return request.PassAs<WebSocketStreamRequest>();
} }
...@@ -297,14 +326,15 @@ scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamForTesting( ...@@ -297,14 +326,15 @@ scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamForTesting(
const url::Origin& origin, const url::Origin& origin,
URLRequestContext* url_request_context, URLRequestContext* url_request_context,
const BoundNetLog& net_log, const BoundNetLog& net_log,
scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) { scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
scoped_ptr<base::Timer> timer) {
scoped_ptr<StreamRequestImpl> request( scoped_ptr<StreamRequestImpl> request(
new StreamRequestImpl(socket_url, new StreamRequestImpl(socket_url,
url_request_context, url_request_context,
origin, origin,
connect_delegate.Pass(), connect_delegate.Pass(),
create_helper.Pass())); create_helper.Pass()));
request->Start(); request->Start(timer.Pass());
return request.PassAs<WebSocketStreamRequest>(); return request.PassAs<WebSocketStreamRequest>();
} }
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "base/metrics/statistics_recorder.h" #include "base/metrics/statistics_recorder.h"
#include "base/run_loop.h" #include "base/run_loop.h"
#include "base/strings/stringprintf.h" #include "base/strings/stringprintf.h"
#include "base/timer/mock_timer.h"
#include "base/timer/timer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/base/test_data_directory.h" #include "net/base/test_data_directory.h"
#include "net/http/http_request_headers.h" #include "net/http/http_request_headers.h"
...@@ -79,6 +81,13 @@ scoped_ptr<DeterministicSocketData> BuildNullSocketData() { ...@@ -79,6 +81,13 @@ scoped_ptr<DeterministicSocketData> BuildNullSocketData() {
return make_scoped_ptr(new DeterministicSocketData(NULL, 0, NULL, 0)); return make_scoped_ptr(new DeterministicSocketData(NULL, 0, NULL, 0));
} }
class MockWeakTimer : public base::MockTimer,
public base::SupportsWeakPtr<MockWeakTimer> {
public:
MockWeakTimer(bool retain_user_task, bool is_repeating)
: MockTimer(retain_user_task, is_repeating) {}
};
// A sub-class of WebSocketHandshakeStreamCreateHelper which always sets a // A sub-class of WebSocketHandshakeStreamCreateHelper which always sets a
// deterministic key to use in the WebSocket handshake. // deterministic key to use in the WebSocket handshake.
class DeterministicKeyWebSocketHandshakeStreamCreateHelper class DeterministicKeyWebSocketHandshakeStreamCreateHelper
...@@ -105,11 +114,12 @@ class WebSocketStreamCreateTest : public ::testing::Test { ...@@ -105,11 +114,12 @@ class WebSocketStreamCreateTest : public ::testing::Test {
const std::vector<std::string>& sub_protocols, const std::vector<std::string>& sub_protocols,
const std::string& origin, const std::string& origin,
const std::string& extra_request_headers, const std::string& extra_request_headers,
const std::string& response_body) { const std::string& response_body,
scoped_ptr<base::Timer> timer = scoped_ptr<base::Timer>()) {
url_request_context_host_.SetExpectations( url_request_context_host_.SetExpectations(
WebSocketStandardRequest(socket_path, origin, extra_request_headers), WebSocketStandardRequest(socket_path, origin, extra_request_headers),
response_body); response_body);
CreateAndConnectStream(socket_url, sub_protocols, origin); CreateAndConnectStream(socket_url, sub_protocols, origin, timer.Pass());
} }
// |extra_request_headers| and |extra_response_headers| must end in "\r\n" or // |extra_request_headers| and |extra_response_headers| must end in "\r\n" or
...@@ -119,23 +129,27 @@ class WebSocketStreamCreateTest : public ::testing::Test { ...@@ -119,23 +129,27 @@ class WebSocketStreamCreateTest : public ::testing::Test {
const std::vector<std::string>& sub_protocols, const std::vector<std::string>& sub_protocols,
const std::string& origin, const std::string& origin,
const std::string& extra_request_headers, const std::string& extra_request_headers,
const std::string& extra_response_headers) { const std::string& extra_response_headers,
scoped_ptr<base::Timer> timer =
scoped_ptr<base::Timer>()) {
CreateAndConnectCustomResponse( CreateAndConnectCustomResponse(
socket_url, socket_url,
socket_path, socket_path,
sub_protocols, sub_protocols,
origin, origin,
extra_request_headers, extra_request_headers,
WebSocketStandardResponse(extra_response_headers)); WebSocketStandardResponse(extra_response_headers),
timer.Pass());
} }
void CreateAndConnectRawExpectations( void CreateAndConnectRawExpectations(
const std::string& socket_url, const std::string& socket_url,
const std::vector<std::string>& sub_protocols, const std::vector<std::string>& sub_protocols,
const std::string& origin, const std::string& origin,
scoped_ptr<DeterministicSocketData> socket_data) { scoped_ptr<DeterministicSocketData> socket_data,
scoped_ptr<base::Timer> timer = scoped_ptr<base::Timer>()) {
AddRawExpectations(socket_data.Pass()); AddRawExpectations(socket_data.Pass());
CreateAndConnectStream(socket_url, sub_protocols, origin); CreateAndConnectStream(socket_url, sub_protocols, origin, timer.Pass());
} }
// Add additional raw expectations for sockets created before the final one. // Add additional raw expectations for sockets created before the final one.
...@@ -147,7 +161,8 @@ class WebSocketStreamCreateTest : public ::testing::Test { ...@@ -147,7 +161,8 @@ class WebSocketStreamCreateTest : public ::testing::Test {
// parameters. // parameters.
void CreateAndConnectStream(const std::string& socket_url, void CreateAndConnectStream(const std::string& socket_url,
const std::vector<std::string>& sub_protocols, const std::vector<std::string>& sub_protocols,
const std::string& origin) { const std::string& origin,
scoped_ptr<base::Timer> timer) {
for (size_t i = 0; i < ssl_data_.size(); ++i) { for (size_t i = 0; i < ssl_data_.size(); ++i) {
scoped_ptr<SSLSocketDataProvider> ssl_data(ssl_data_[i]); scoped_ptr<SSLSocketDataProvider> ssl_data(ssl_data_[i]);
ssl_data_[i] = NULL; ssl_data_[i] = NULL;
...@@ -166,7 +181,9 @@ class WebSocketStreamCreateTest : public ::testing::Test { ...@@ -166,7 +181,9 @@ class WebSocketStreamCreateTest : public ::testing::Test {
url::Origin(origin), url::Origin(origin),
url_request_context_host_.GetURLRequestContext(), url_request_context_host_.GetURLRequestContext(),
BoundNetLog(), BoundNetLog(),
connect_delegate.Pass()); connect_delegate.Pass(),
timer ? timer.Pass() : scoped_ptr<base::Timer>(
new base::Timer(false, false)));
} }
static void RunUntilIdle() { base::RunLoop().RunUntilIdle(); } static void RunUntilIdle() { base::RunLoop().RunUntilIdle(); }
...@@ -1098,6 +1115,67 @@ TEST_F(WebSocketStreamCreateTest, ConnectionTimeout) { ...@@ -1098,6 +1115,67 @@ TEST_F(WebSocketStreamCreateTest, ConnectionTimeout) {
failure_message()); failure_message());
} }
// The server doesn't respond to the opening handshake.
TEST_F(WebSocketStreamCreateTest, HandshakeTimeout) {
scoped_ptr<DeterministicSocketData> socket_data(BuildNullSocketData());
socket_data->set_connect_data(MockConnect(SYNCHRONOUS, ERR_IO_PENDING));
scoped_ptr<MockWeakTimer> timer(new MockWeakTimer(false, false));
base::WeakPtr<MockWeakTimer> weak_timer = timer->AsWeakPtr();
CreateAndConnectRawExpectations("ws://localhost/", NoSubProtocols(),
"http://localhost", socket_data.Pass(),
timer.PassAs<base::Timer>());
EXPECT_FALSE(has_failed());
ASSERT_TRUE(weak_timer.get());
EXPECT_TRUE(weak_timer->IsRunning());
weak_timer->Fire();
RunUntilIdle();
EXPECT_TRUE(has_failed());
EXPECT_EQ("WebSocket opening handshake timed out", failure_message());
ASSERT_TRUE(weak_timer.get());
EXPECT_FALSE(weak_timer->IsRunning());
}
// When the connection establishes the timer should be stopped.
TEST_F(WebSocketStreamCreateTest, HandshakeTimerOnSuccess) {
scoped_ptr<MockWeakTimer> timer(new MockWeakTimer(false, false));
base::WeakPtr<MockWeakTimer> weak_timer = timer->AsWeakPtr();
CreateAndConnectStandard(
"ws://localhost/", "/", NoSubProtocols(), "http://localhost", "", "",
timer.PassAs<base::Timer>());
ASSERT_TRUE(weak_timer);
EXPECT_TRUE(weak_timer->IsRunning());
RunUntilIdle();
EXPECT_FALSE(has_failed());
EXPECT_TRUE(stream_);
ASSERT_TRUE(weak_timer);
EXPECT_FALSE(weak_timer->IsRunning());
}
// When the connection fails the timer should be stopped.
TEST_F(WebSocketStreamCreateTest, HandshakeTimerOnFailure) {
scoped_ptr<DeterministicSocketData> socket_data(BuildNullSocketData());
socket_data->set_connect_data(
MockConnect(SYNCHRONOUS, ERR_CONNECTION_REFUSED));
scoped_ptr<MockWeakTimer> timer(new MockWeakTimer(false, false));
base::WeakPtr<MockWeakTimer> weak_timer = timer->AsWeakPtr();
CreateAndConnectRawExpectations("ws://localhost/", NoSubProtocols(),
"http://localhost", socket_data.Pass(),
timer.PassAs<base::Timer>());
ASSERT_TRUE(weak_timer.get());
EXPECT_TRUE(weak_timer->IsRunning());
RunUntilIdle();
EXPECT_TRUE(has_failed());
EXPECT_EQ("Error in connection establishment: net::ERR_CONNECTION_REFUSED",
failure_message());
ASSERT_TRUE(weak_timer.get());
EXPECT_FALSE(weak_timer->IsRunning());
}
// Cancellation during connect works. // Cancellation during connect works.
TEST_F(WebSocketStreamCreateTest, CancellationDuringConnect) { TEST_F(WebSocketStreamCreateTest, CancellationDuringConnect) {
scoped_ptr<DeterministicSocketData> socket_data(BuildNullSocketData()); scoped_ptr<DeterministicSocketData> socket_data(BuildNullSocketData());
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
class GURL; class GURL;
namespace base {
class Timer;
} // namespace base
namespace url { namespace url {
class Origin; class Origin;
} // namespace url } // namespace url
...@@ -37,8 +41,9 @@ class LinearCongruentialGenerator { ...@@ -37,8 +41,9 @@ class LinearCongruentialGenerator {
}; };
// Alternate version of WebSocketStream::CreateAndConnectStream() for testing // Alternate version of WebSocketStream::CreateAndConnectStream() for testing
// use only. The difference is the use of a |create_helper| argument in place of // use only. The differences are the use of a |create_helper| argument in place
// |requested_subprotocols|. Implemented in websocket_stream.cc. // of |requested_subprotocols| and taking |timer| as the handshake timeout
// timer. Implemented in websocket_stream.cc.
NET_EXPORT_PRIVATE extern scoped_ptr<WebSocketStreamRequest> NET_EXPORT_PRIVATE extern scoped_ptr<WebSocketStreamRequest>
CreateAndConnectStreamForTesting( CreateAndConnectStreamForTesting(
const GURL& socket_url, const GURL& socket_url,
...@@ -46,7 +51,8 @@ NET_EXPORT_PRIVATE extern scoped_ptr<WebSocketStreamRequest> ...@@ -46,7 +51,8 @@ NET_EXPORT_PRIVATE extern scoped_ptr<WebSocketStreamRequest>
const url::Origin& origin, const url::Origin& origin,
URLRequestContext* url_request_context, URLRequestContext* url_request_context,
const BoundNetLog& net_log, const BoundNetLog& net_log,
scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate); scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
scoped_ptr<base::Timer> timer);
// Generates a standard WebSocket handshake request. The challenge key used is // Generates a standard WebSocket handshake request. The challenge key used is
// "dGhlIHNhbXBsZSBub25jZQ==". Each header in |extra_headers| must be terminated // "dGhlIHNhbXBsZSBub25jZQ==". Each header in |extra_headers| must be terminated
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment