Commit 3599be80 authored by Shengfa Lin's avatar Shengfa Lin Committed by Commit Bot

[chromedriver] Bidi WebSocket connection(4)

1. Check path and session id from request and accept WebSocket connection
if valid
2. Only allow one connection per session
3. If connection closed, session id can be reused to establish connection
4. Add tests for accepting and rejecting connection
5. Reverted some changes in https://chromium-review.googlesource.com/c/chromium/src/+/1730860

Public design doc:
https://docs.google.com/document/d/1zixFBPtgFFwhc5pT1IfneFoW6yZfpvYdQRsNN4tbghk/

Previous CL chain(I, II and III):
I. https://chromium-review.googlesource.com/c/chromium/src/+/2382299
II. https://chromium-review.googlesource.com/c/chromium/src/+/2389053
II. https://chromium-review.googlesource.com/c/chromium/src/+/2391903

Bug: chromedriver:3588
Change-Id: Idbdd2c17e2b04f23c78c5a36f1f5585e94819d5b
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2406661
Commit-Queue: Shengfa Lin <shengfa@google.com>
Reviewed-by: default avatarYoichi Osato <yoichio@chromium.org>
Reviewed-by: default avatarJohn Chen <johnchen@chromium.org>
Cr-Commit-Position: refs/heads/master@{#810314}
parent d6814a3a
...@@ -215,6 +215,7 @@ source_set("lib") { ...@@ -215,6 +215,7 @@ source_set("lib") {
"command_listener_proxy.h", "command_listener_proxy.h",
"commands.cc", "commands.cc",
"commands.h", "commands.h",
"connection_session_map.h",
"constants/version.h", "constants/version.h",
"devtools_events_logger.cc", "devtools_events_logger.cc",
"devtools_events_logger.h", "devtools_events_logger.h",
......
...@@ -19,7 +19,7 @@ import websocket ...@@ -19,7 +19,7 @@ import websocket
class WebSocketCommands: class WebSocketCommands:
CREATE_WEBSOCKET = \ CREATE_WEBSOCKET = \
'/session/:sessionId/chromium/create_websocket' '/session/:sessionId'
SEND_OVER_WEBSOCKET = \ SEND_OVER_WEBSOCKET = \
'/session/:sessionId/chromium/send_command_from_websocket' '/session/:sessionId/chromium/send_command_from_websocket'
...@@ -37,3 +37,6 @@ class WebSocketConnection(object): ...@@ -37,3 +37,6 @@ class WebSocketConnection(object):
cmd_params['id'] = self._command_id cmd_params['id'] = self._command_id
self._command_id -= 1 self._command_id -= 1
self._websocket.send(json.dumps(cmd_params)) self._websocket.send(json.dumps(cmd_params))
def Close(self):
self._websocket.close();
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_TEST_CHROMEDRIVER_CONNECTION_SESSION_MAP_H_
#define CHROME_TEST_CHROMEDRIVER_CONNECTION_SESSION_MAP_H_
#include <string>
#include <unordered_map>
using ConnectionSessionMap = std::unordered_map<int, std::string>;
#endif // CHROME_TEST_CHROMEDRIVER_CONNECTION_SESSION_MAP_H_
...@@ -66,6 +66,13 @@ bool w3cMode(const std::string& session_id, ...@@ -66,6 +66,13 @@ bool w3cMode(const std::string& session_id,
return kW3CDefault; return kW3CDefault;
} }
net::HttpServerResponseInfo createWebSocketRejectResponse(
net::HttpStatusCode code,
const std::string& msg) {
net::HttpServerResponseInfo response(code);
response.AddHeader("X-WebSocket-Reject-Reason", msg);
return response;
}
} // namespace } // namespace
// WrapperURLLoaderFactory subclasses mojom::URLLoaderFactory as non-mojo, cross // WrapperURLLoaderFactory subclasses mojom::URLLoaderFactory as non-mojo, cross
...@@ -143,14 +150,17 @@ HttpHandler::HttpHandler( ...@@ -143,14 +150,17 @@ HttpHandler::HttpHandler(
const scoped_refptr<base::SingleThreadTaskRunner> cmd_task_runner, const scoped_refptr<base::SingleThreadTaskRunner> cmd_task_runner,
const std::string& url_base, const std::string& url_base,
int adb_port) int adb_port)
: quit_func_(quit_func), url_base_(url_base), received_shutdown_(false) { : quit_func_(quit_func),
io_task_runner_(io_task_runner),
url_base_(url_base),
received_shutdown_(false) {
#if defined(OS_MAC) #if defined(OS_MAC)
base::mac::ScopedNSAutoreleasePool autorelease_pool; base::mac::ScopedNSAutoreleasePool autorelease_pool;
#endif #endif
context_getter_ = new URLRequestContextGetter(io_task_runner); context_getter_ = new URLRequestContextGetter(io_task_runner_);
socket_factory_ = CreateSyncWebSocketFactory(context_getter_.get()); socket_factory_ = CreateSyncWebSocketFactory(context_getter_.get());
adb_.reset(new AdbImpl(io_task_runner, adb_port)); adb_ = std::make_unique<AdbImpl>(io_task_runner_, adb_port);
device_manager_.reset(new DeviceManager(adb_.get())); device_manager_ = std::make_unique<DeviceManager>(adb_.get());
url_loader_factory_owner_ = url_loader_factory_owner_ =
std::make_unique<network::TransitionalURLLoaderFactoryOwner>( std::make_unique<network::TransitionalURLLoaderFactoryOwner>(
context_getter_.get()); context_getter_.get());
...@@ -1296,10 +1306,67 @@ HttpHandler::PrepareStandardResponse( ...@@ -1296,10 +1306,67 @@ HttpHandler::PrepareStandardResponse(
return response; return response;
} }
void HttpHandler::OnWebSocketRequest(int connection_id, void HttpHandler::OnWebSocketRequest(HttpServer* http_server,
const net::HttpServerRequestInfo& info) {} int connection_id,
const net::HttpServerRequestInfo& info) {
std::string path = info.path;
std::vector<std::string> path_parts = base::SplitString(
path, "/", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
if (path_parts.size() != 2 || path_parts[0] != "session") {
std::string err_msg = "bad request received path " + path;
VLOG(0) << "HttpHandler WebSocketRequest error " << err_msg;
SendWebSocketRejectResponse(http_server, connection_id,
net::HTTP_BAD_REQUEST, err_msg);
return;
}
std::string session_id = path_parts[1];
auto it = session_connection_map_.find(session_id);
if (it == session_connection_map_.end()) {
std::string err_msg = "bad request invalid session id " + session_id;
VLOG(0) << "HttpHandler WebSocketRequest error " << err_msg;
SendWebSocketRejectResponse(http_server, connection_id,
net::HTTP_BAD_REQUEST, err_msg);
return;
} else if (it->second != -1) {
std::string err_msg = "bad request only one connection for session id " +
session_id + " is allowed";
VLOG(0) << "HttpHandler WebSocketRequest error " << err_msg;
SendWebSocketRejectResponse(http_server, connection_id,
net::HTTP_BAD_REQUEST, err_msg);
return;
} else {
session_connection_map_[session_id] = connection_id;
connection_session_map_[connection_id] = session_id;
io_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&HttpServer::AcceptWebSocket,
base::Unretained(http_server), connection_id, info));
}
}
void HttpHandler::OnClose(HttpServer* http_server, int connection_id) {
auto it = connection_session_map_.find(connection_id);
if (it == connection_session_map_.end()) {
return;
}
session_connection_map_[it->second] = -1;
connection_session_map_.erase(it);
}
void HttpHandler::OnClose(int connection_id) {} void HttpHandler::SendWebSocketRejectResponse(HttpServer* http_server,
int connection_id,
net::HttpStatusCode code,
const std::string& msg) {
io_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&HttpServer::SendResponse, base::Unretained(http_server),
connection_id,
createWebSocketRejectResponse(net::HTTP_BAD_REQUEST, msg),
TRAFFIC_ANNOTATION_FOR_TESTS));
}
namespace internal { namespace internal {
......
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
#include "base/threading/thread_checker.h" #include "base/threading/thread_checker.h"
#include "chrome/test/chromedriver/command.h" #include "chrome/test/chromedriver/command.h"
#include "chrome/test/chromedriver/commands.h" #include "chrome/test/chromedriver/commands.h"
#include "chrome/test/chromedriver/connection_session_map.h"
#include "chrome/test/chromedriver/element_commands.h" #include "chrome/test/chromedriver/element_commands.h"
#include "chrome/test/chromedriver/net/sync_websocket_factory.h" #include "chrome/test/chromedriver/net/sync_websocket_factory.h"
#include "chrome/test/chromedriver/session_commands.h" #include "chrome/test/chromedriver/session_commands.h"
#include "chrome/test/chromedriver/session_connection_map.h" #include "chrome/test/chromedriver/session_connection_map.h"
#include "chrome/test/chromedriver/session_thread_map.h" #include "chrome/test/chromedriver/session_thread_map.h"
#include "chrome/test/chromedriver/window_commands.h" #include "chrome/test/chromedriver/window_commands.h"
#include "net/http/http_status_code.h"
namespace base { namespace base {
class DictionaryValue; class DictionaryValue;
...@@ -127,13 +129,20 @@ class HttpHandler { ...@@ -127,13 +129,20 @@ class HttpHandler {
std::unique_ptr<base::Value> value, std::unique_ptr<base::Value> value,
const std::string& session_id); const std::string& session_id);
void OnWebSocketRequest(int connection_id, void OnWebSocketRequest(HttpServer* http_server,
int connection_id,
const net::HttpServerRequestInfo& info); const net::HttpServerRequestInfo& info);
void OnClose(int connection_id); void OnClose(HttpServer* http_server, int connection_id);
void SendWebSocketRejectResponse(HttpServer* http_server,
int connection_id,
net::HttpStatusCode code,
const std::string& msg);
base::ThreadChecker thread_checker_; base::ThreadChecker thread_checker_;
base::RepeatingClosure quit_func_; base::RepeatingClosure quit_func_;
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner_;
std::string url_base_; std::string url_base_;
bool received_shutdown_; bool received_shutdown_;
scoped_refptr<URLRequestContextGetter> context_getter_; scoped_refptr<URLRequestContextGetter> context_getter_;
...@@ -143,6 +152,7 @@ class HttpHandler { ...@@ -143,6 +152,7 @@ class HttpHandler {
SyncWebSocketFactory socket_factory_; SyncWebSocketFactory socket_factory_;
SessionThreadMap session_thread_map_; SessionThreadMap session_thread_map_;
SessionConnectionMap session_connection_map_; SessionConnectionMap session_connection_map_;
ConnectionSessionMap connection_session_map_;
std::unique_ptr<CommandMap> command_map_; std::unique_ptr<CommandMap> command_map_;
std::unique_ptr<Adb> adb_; std::unique_ptr<Adb> adb_;
std::unique_ptr<DeviceManager> device_manager_; std::unique_ptr<DeviceManager> device_manager_;
......
...@@ -133,69 +133,31 @@ void HttpServer::OnWebSocketRequest(int connection_id, ...@@ -133,69 +133,31 @@ void HttpServer::OnWebSocketRequest(int connection_id,
const net::HttpServerRequestInfo& info) { const net::HttpServerRequestInfo& info) {
cmd_runner_->PostTask( cmd_runner_->PostTask(
FROM_HERE, base::BindOnce(&HttpHandler::OnWebSocketRequest, handler_, FROM_HERE, base::BindOnce(&HttpHandler::OnWebSocketRequest, handler_,
connection_id, info)); this, connection_id, info));
std::string path = info.path;
std::string session_id;
if (!base::StartsWith(path, url_base_, base::CompareCase::SENSITIVE)) {
net::HttpServerResponseInfo response(net::HTTP_BAD_REQUEST);
response.SetBody("invalid websocket request url path", "text/plain");
server_->SendResponse(connection_id, response,
TRAFFIC_ANNOTATION_FOR_TESTS);
return;
}
path.erase(0, url_base_.length());
std::vector<std::string> path_parts =
base::SplitString(path, "/", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL);
std::vector<std::string> command_path_parts = base::SplitString(
kCreateWebSocketPath, "/", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL);
if (path_parts.size() != command_path_parts.size()) {
net::HttpServerResponseInfo response(net::HTTP_BAD_REQUEST);
response.SetBody("invalid websocket request url path", "text/plain");
server_->SendResponse(connection_id, response,
TRAFFIC_ANNOTATION_FOR_TESTS);
return;
}
for (size_t i = 0; i < path_parts.size(); ++i) {
if (command_path_parts[i][0] == ':') {
std::string name = command_path_parts[i];
name.erase(0, 1);
CHECK(name.length());
if (name == "sessionId")
session_id = path_parts[i];
} else if (command_path_parts[i] != path_parts[i]) {
net::HttpServerResponseInfo response(net::HTTP_BAD_REQUEST);
response.SetBody("invalid websocket request url path", "text/plain");
server_->SendResponse(connection_id, response,
TRAFFIC_ANNOTATION_FOR_TESTS);
return;
}
}
server_->AcceptWebSocket(connection_id, info, TRAFFIC_ANNOTATION_FOR_TESTS);
connection_to_session_map[connection_id] = session_id;
} }
void HttpServer::OnWebSocketMessage(int connection_id, std::string data) { void HttpServer::OnWebSocketMessage(int connection_id, std::string data) {
base::Optional<base::Value> parsed_data = base::JSONReader::Read(data); // TODO: Make use of WebSocket data
std::string path = url_base_ + kSendCommandFromWebSocket; VLOG(0) << "HttpServer::OnWebSocketMessage received: " << data;
base::ReplaceFirstSubstringAfterOffset(
&path, 0, ":sessionId", connection_to_session_map[connection_id]);
net::HttpServerRequestInfo request;
request.method = "post";
request.path = path;
request.data = data;
OnHttpRequest(connection_id, request);
} }
void HttpServer::OnClose(int connection_id) { void HttpServer::OnClose(int connection_id) {
cmd_runner_->PostTask(FROM_HERE, base::BindOnce(&HttpHandler::OnClose, cmd_runner_->PostTask(
handler_, connection_id)); FROM_HERE,
base::BindOnce(&HttpHandler::OnClose, handler_, this, connection_id));
}
void HttpServer::AcceptWebSocket(int connection_id,
const net::HttpServerRequestInfo& request) {
server_->AcceptWebSocket(connection_id, request,
TRAFFIC_ANNOTATION_FOR_TESTS);
}
void HttpServer::SendResponse(
int connection_id,
const net::HttpServerResponseInfo& response,
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
server_->SendResponse(connection_id, response, traffic_annotation);
} }
void HttpServer::OnResponse( void HttpServer::OnResponse(
......
...@@ -43,6 +43,13 @@ class HttpServer : public net::HttpServer::Delegate { ...@@ -43,6 +43,13 @@ class HttpServer : public net::HttpServer::Delegate {
void OnClose(int connection_id) override; void OnClose(int connection_id) override;
void AcceptWebSocket(int connection_id,
const net::HttpServerRequestInfo& request);
void SendResponse(int connection_id,
const net::HttpServerResponseInfo& response,
const net::NetworkTrafficAnnotationTag& traffic_annotation);
private: private:
void OnResponse(int connection_id, void OnResponse(int connection_id,
bool keep_alive, bool keep_alive,
......
...@@ -45,6 +45,7 @@ sys.path.remove(_PARENT_DIR) ...@@ -45,6 +45,7 @@ sys.path.remove(_PARENT_DIR)
sys.path.insert(1, _CLIENT_DIR) sys.path.insert(1, _CLIENT_DIR)
import chromedriver import chromedriver
import websocket_connection
import webelement import webelement
sys.path.remove(_CLIENT_DIR) sys.path.remove(_CLIENT_DIR)
...@@ -513,10 +514,14 @@ class ChromeDriverWebSocketTest(ChromeDriverBaseTestWithWebServer): ...@@ -513,10 +514,14 @@ class ChromeDriverWebSocketTest(ChromeDriverBaseTestWithWebServer):
def testDefaultSession(self): def testDefaultSession(self):
driver = self.CreateDriver() driver = self.CreateDriver()
self.assertFalse(driver.capabilities.has_key('webSocketUrl')) self.assertFalse(driver.capabilities.has_key('webSocketUrl'))
self.assertRaises(Exception, websocket_connection.WebSocketConnection,
_CHROMEDRIVER_SERVER_URL, driver.GetSessionId())
def testWebSocketUrlFalse(self): def testWebSocketUrlFalse(self):
driver = self.CreateDriver(web_socket_url=False) driver = self.CreateDriver(web_socket_url=False)
self.assertFalse(driver.capabilities.has_key('webSocketUrl')) self.assertFalse(driver.capabilities.has_key('webSocketUrl'))
self.assertRaises(Exception, websocket_connection.WebSocketConnection,
_CHROMEDRIVER_SERVER_URL, driver.GetSessionId())
def testWebSocketUrlTrue(self): def testWebSocketUrlTrue(self):
driver = self.CreateDriver(web_socket_url=True) driver = self.CreateDriver(web_socket_url=True)
...@@ -525,11 +530,37 @@ class ChromeDriverWebSocketTest(ChromeDriverBaseTestWithWebServer): ...@@ -525,11 +530,37 @@ class ChromeDriverWebSocketTest(ChromeDriverBaseTestWithWebServer):
self.assertEquals(driver.capabilities['webSocketUrl'], self.assertEquals(driver.capabilities['webSocketUrl'],
self.composeWebSocketUrl(_CHROMEDRIVER_SERVER_URL, self.composeWebSocketUrl(_CHROMEDRIVER_SERVER_URL,
driver.GetSessionId())) driver.GetSessionId()))
websocket = websocket_connection.WebSocketConnection(
_CHROMEDRIVER_SERVER_URL, driver.GetSessionId())
self.assertNotEqual(None, websocket)
def testWebSocketUrlInvalid(self): def testWebSocketUrlInvalid(self):
self.assertRaises(chromedriver.InvalidArgument, self.assertRaises(chromedriver.InvalidArgument,
self.CreateDriver, web_socket_url='Invalid') self.CreateDriver, web_socket_url='Invalid')
def testWebSocketOneConnectionPerSession(self):
driver = self.CreateDriver(web_socket_url=True)
websocket = websocket_connection.WebSocketConnection(
_CHROMEDRIVER_SERVER_URL, driver.GetSessionId())
self.assertNotEqual(None, websocket)
self.assertRaises(Exception, websocket_connection.WebSocketConnection,
_CHROMEDRIVER_SERVER_URL, driver.GetSessionId())
def testWebSocketInvalidSessionId(self):
driver = self.CreateDriver(web_socket_url=True)
self.assertRaises(Exception, websocket_connection.WebSocketConnection,
_CHROMEDRIVER_SERVER_URL, "random_session_id_123")
def testWebSocketClosedCanReconnect(self):
driver = self.CreateDriver(web_socket_url=True)
websocket = websocket_connection.WebSocketConnection(
_CHROMEDRIVER_SERVER_URL, driver.GetSessionId())
self.assertNotEqual(None, websocket)
websocket.Close()
websocket2 = websocket_connection.WebSocketConnection(
_CHROMEDRIVER_SERVER_URL, driver.GetSessionId())
self.assertNotEqual(None, websocket2)
class ChromeDriverTest(ChromeDriverBaseTestWithWebServer): class ChromeDriverTest(ChromeDriverBaseTestWithWebServer):
"""End to end tests for ChromeDriver.""" """End to end tests for ChromeDriver."""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment