Commit f955ff0a authored by slan's avatar slan Committed by Commit bot

Handle non-HTTP/1.1 requests more gracefully in net::HttpServer.

Currently, HTTP/1.0 requests are causing http_server.cc to crash.
Code shouldn't crash when it is supplied with bad data. Simply
move into an error state and close the connection instead.

BUG=

Review-Url: https://codereview.chromium.org/2314073003
Cr-Commit-Position: refs/heads/master@{#419327}
parent 60378269
...@@ -232,6 +232,13 @@ int HttpServer::HandleReadResult(HttpConnection* connection, int rv) { ...@@ -232,6 +232,13 @@ int HttpServer::HandleReadResult(HttpConnection* connection, int rv) {
size_t pos = 0; size_t pos = 0;
if (!ParseHeaders(read_buf->StartOfBuffer(), read_buf->GetSize(), if (!ParseHeaders(read_buf->StartOfBuffer(), read_buf->GetSize(),
&request, &pos)) { &request, &pos)) {
// An error has occured. Close the connection.
Close(connection->id());
return ERR_CONNECTION_CLOSED;
} else if (!pos) {
// If pos is 0, all the data in read_buf has been consumed, but the
// headers have not been fully parsed yet. Continue parsing when more data
// rolls in.
break; break;
} }
...@@ -405,8 +412,10 @@ bool HttpServer::ParseHeaders(const char* data, ...@@ -405,8 +412,10 @@ bool HttpServer::ParseHeaders(const char* data,
buffer.clear(); buffer.clear();
break; break;
case ST_PROTO: case ST_PROTO:
// TODO(mbelshe): Deal better with parsing protocol. if (buffer != "HTTP/1.1") {
DCHECK(buffer == "HTTP/1.1"); LOG(ERROR) << "Cannot handle request with protocol: " << buffer;
next_state = ST_ERR;
}
buffer.clear(); buffer.clear();
break; break;
case ST_NAME: case ST_NAME:
...@@ -448,8 +457,10 @@ bool HttpServer::ParseHeaders(const char* data, ...@@ -448,8 +457,10 @@ bool HttpServer::ParseHeaders(const char* data,
} }
} }
} }
// No more characters, but we haven't finished parsing yet. // No more characters, but we haven't finished parsing yet. Signal this to
return false; // the caller by setting |pos| to zero.
pos = 0;
return true;
} }
HttpConnection* HttpServer::FindConnection(int connection_id) { HttpConnection* HttpServer::FindConnection(int connection_id) {
......
...@@ -96,7 +96,9 @@ class HttpServer { ...@@ -96,7 +96,9 @@ class HttpServer {
// Expects the raw data to be stored in recv_data_. If parsing is successful, // Expects the raw data to be stored in recv_data_. If parsing is successful,
// will remove the data parsed from recv_data_, leaving only the unused // will remove the data parsed from recv_data_, leaving only the unused
// recv data. // recv data. If all data has been consumed successfully, but the headers are
// not fully parsed, *pos will be set to zero. Returns false if an error is
// encountered while parsing, true otherwise.
bool ParseHeaders(const char* data, bool ParseHeaders(const char* data,
size_t data_len, size_t data_len,
HttpServerRequestInfo* info, HttpServerRequestInfo* info,
......
...@@ -132,6 +132,8 @@ class TestHttpClient { ...@@ -132,6 +132,8 @@ class TestHttpClient {
return true; return true;
} }
TCPClientSocket& socket() { return *socket_; }
private: private:
void OnConnect(const base::Closure& quit_loop, int result) { void OnConnect(const base::Closure& quit_loop, int result) {
connect_result_ = result; connect_result_ = result;
...@@ -198,7 +200,10 @@ class HttpServerTest : public testing::Test, ...@@ -198,7 +200,10 @@ class HttpServerTest : public testing::Test,
ASSERT_THAT(server_->GetLocalAddress(&server_address_), IsOk()); ASSERT_THAT(server_->GetLocalAddress(&server_address_), IsOk());
} }
void OnConnect(int connection_id) override {} void OnConnect(int connection_id) override {
DCHECK(connection_map_.find(connection_id) == connection_map_.end());
connection_map_[connection_id] = true;
}
void OnHttpRequest(int connection_id, void OnHttpRequest(int connection_id,
const HttpServerRequestInfo& info) override { const HttpServerRequestInfo& info) override {
...@@ -216,7 +221,10 @@ class HttpServerTest : public testing::Test, ...@@ -216,7 +221,10 @@ class HttpServerTest : public testing::Test,
NOTREACHED(); NOTREACHED();
} }
void OnClose(int connection_id) override {} void OnClose(int connection_id) override {
DCHECK(connection_map_.find(connection_id) != connection_map_.end());
connection_map_[connection_id] = false;
}
bool RunUntilRequestsReceived(size_t count) { bool RunUntilRequestsReceived(size_t count) {
quit_after_request_count_ = count; quit_after_request_count_ = count;
...@@ -243,11 +251,15 @@ class HttpServerTest : public testing::Test, ...@@ -243,11 +251,15 @@ class HttpServerTest : public testing::Test,
server_->HandleAcceptResult(OK); server_->HandleAcceptResult(OK);
} }
std::unordered_map<int, bool>& connection_map() { return connection_map_; }
protected: protected:
std::unique_ptr<HttpServer> server_; std::unique_ptr<HttpServer> server_;
IPEndPoint server_address_; IPEndPoint server_address_;
base::Closure run_loop_quit_func_; base::Closure run_loop_quit_func_;
std::vector<std::pair<HttpServerRequestInfo, int> > requests_; std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
std::unordered_map<int /* connection_id */, bool /* connected */>
connection_map_;
private: private:
size_t quit_after_request_count_; size_t quit_after_request_count_;
...@@ -472,6 +484,38 @@ TEST_F(HttpServerTest, SendRaw) { ...@@ -472,6 +484,38 @@ TEST_F(HttpServerTest, SendRaw) {
ASSERT_EQ(expected_response, response); ASSERT_EQ(expected_response, response);
} }
TEST_F(HttpServerTest, WrongProtocolRequest) {
const char* const kBadProtocolRequests[] = {
"GET /test HTTP/1.0\r\n\r\n",
"GET /test foo\r\n\r\n",
"GET /test \r\n\r\n",
};
for (size_t i = 0; i < arraysize(kBadProtocolRequests); ++i) {
TestHttpClient client;
ASSERT_THAT(client.ConnectAndWait(server_address_), IsOk());
client.Send(kBadProtocolRequests[i]);
ASSERT_FALSE(RunUntilRequestsReceived(1));
// Assert that the delegate was updated properly.
ASSERT_EQ(1u, connection_map().size());
ASSERT_FALSE(connection_map().begin()->second);
// Assert that the socket was opened...
ASSERT_TRUE(client.socket().WasEverUsed());
// ...then closed when the server disconnected. Verify that the socket was
// closed by checking that a Read() fails.
std::string response;
ASSERT_FALSE(client.Read(&response, 1u));
ASSERT_EQ(std::string(), response);
// Reset the state of the connection map.
connection_map().clear();
}
}
class MockStreamSocket : public StreamSocket { class MockStreamSocket : public StreamSocket {
public: public:
MockStreamSocket() MockStreamSocket()
...@@ -640,6 +684,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { ...@@ -640,6 +684,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
class CloseOnConnectHttpServerTest : public HttpServerTest { class CloseOnConnectHttpServerTest : public HttpServerTest {
public: public:
void OnConnect(int connection_id) override { void OnConnect(int connection_id) override {
HttpServerTest::OnConnect(connection_id);
connection_ids_.push_back(connection_id); connection_ids_.push_back(connection_id);
server_->Close(connection_id); server_->Close(connection_id);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment