Commit f4f56e28 authored by Bence Béky's avatar Bence Béky Committed by Commit Bot

Use CompletionOnceCallback in WebSocketStream.

Bug: 807724
Change-Id: I71ddcf48ef586baa4d784a76f3d281988e6c2fc4
Reviewed-on: https://chromium-review.googlesource.com/1128180Reviewed-by: default avatarAdam Rice <ricea@chromium.org>
Commit-Queue: Bence Béky <bnc@chromium.org>
Cr-Commit-Position: refs/heads/master@{#575521}
parent 67b6e214
......@@ -115,55 +115,21 @@ WebSocketBasicStream::~WebSocketBasicStream() { Close(); }
int WebSocketBasicStream::ReadFrames(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) {
DCHECK(frames->empty());
// If there is data left over after parsing the HTTP headers, attempt to parse
// it as WebSocket frames.
if (http_read_buffer_.get()) {
DCHECK_GE(http_read_buffer_->offset(), 0);
// We cannot simply copy the data into read_buffer_, as it might be too
// large.
scoped_refptr<GrowableIOBuffer> buffered_data;
buffered_data.swap(http_read_buffer_);
DCHECK(!http_read_buffer_);
std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
if (!parser_.Decode(buffered_data->StartOfBuffer(),
buffered_data->offset(),
&frame_chunks))
return WebSocketErrorToNetError(parser_.websocket_error());
if (!frame_chunks.empty()) {
int result = ConvertChunksToFrames(&frame_chunks, frames);
if (result != ERR_IO_PENDING)
return result;
}
}
CompletionOnceCallback callback) {
read_callback_ = std::move(callback);
// Run until socket stops giving us data or we get some frames.
while (true) {
// base::Unretained(this) here is safe because net::Socket guarantees not to
// call any callbacks after Disconnect(), which we call from the
// destructor. The caller of ReadFrames() is required to keep |frames|
// valid.
int result = connection_->Read(
read_buffer_.get(), read_buffer_->size(),
base::Bind(&WebSocketBasicStream::OnReadComplete,
base::Unretained(this), base::Unretained(frames), callback));
if (result == ERR_IO_PENDING)
return result;
result = HandleReadResult(result, frames);
if (result != ERR_IO_PENDING)
return result;
DCHECK(frames->empty());
}
return ReadEverything(frames);
}
int WebSocketBasicStream::WriteFrames(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) {
CompletionOnceCallback callback) {
// This function always concatenates all frames into a single buffer.
// TODO(ricea): Investigate whether it would be better in some cases to
// perform multiple writes with smaller buffers.
//
write_callback_ = std::move(callback);
// First calculate the size of the buffer we need to allocate.
int total_size = CalculateSerializedSizeAndTurnOnMaskBit(frames);
auto combined_buffer = base::MakeRefCounted<IOBufferWithSize>(total_size);
......@@ -196,7 +162,7 @@ int WebSocketBasicStream::WriteFrames(
<< remaining_size << " bytes left over.";
auto drainable_buffer = base::MakeRefCounted<DrainableIOBuffer>(
combined_buffer.get(), total_size);
return WriteEverything(drainable_buffer, callback);
return WriteEverything(drainable_buffer);
}
void WebSocketBasicStream::Close() {
......@@ -225,17 +191,68 @@ WebSocketBasicStream::CreateWebSocketBasicStreamForTesting(
return stream;
}
int WebSocketBasicStream::ReadEverything(
std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
DCHECK(frames->empty());
// If there is data left over after parsing the HTTP headers, attempt to parse
// it as WebSocket frames.
if (http_read_buffer_.get()) {
DCHECK_GE(http_read_buffer_->offset(), 0);
// We cannot simply copy the data into read_buffer_, as it might be too
// large.
scoped_refptr<GrowableIOBuffer> buffered_data;
buffered_data.swap(http_read_buffer_);
DCHECK(!http_read_buffer_);
std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
if (!parser_.Decode(buffered_data->StartOfBuffer(), buffered_data->offset(),
&frame_chunks))
return WebSocketErrorToNetError(parser_.websocket_error());
if (!frame_chunks.empty()) {
int result = ConvertChunksToFrames(&frame_chunks, frames);
if (result != ERR_IO_PENDING)
return result;
}
}
// Run until socket stops giving us data or we get some frames.
while (true) {
// base::Unretained(this) here is safe because net::Socket guarantees not to
// call any callbacks after Disconnect(), which we call from the destructor.
// The caller of ReadEverything() is required to keep |frames| valid.
int result = connection_->Read(
read_buffer_.get(), read_buffer_->size(),
base::BindOnce(&WebSocketBasicStream::OnReadComplete,
base::Unretained(this), base::Unretained(frames)));
if (result == ERR_IO_PENDING)
return result;
result = HandleReadResult(result, frames);
if (result != ERR_IO_PENDING)
return result;
DCHECK(frames->empty());
}
}
void WebSocketBasicStream::OnReadComplete(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
int result) {
result = HandleReadResult(result, frames);
if (result == ERR_IO_PENDING)
result = ReadEverything(frames);
if (result != ERR_IO_PENDING)
std::move(read_callback_).Run(result);
}
int WebSocketBasicStream::WriteEverything(
const scoped_refptr<DrainableIOBuffer>& buffer,
const CompletionCallback& callback) {
const scoped_refptr<DrainableIOBuffer>& buffer) {
while (buffer->BytesRemaining() > 0) {
// The use of base::Unretained() here is safe because on destruction we
// disconnect the socket, preventing any further callbacks.
int result =
connection_->Write(buffer.get(), buffer->BytesRemaining(),
base::Bind(&WebSocketBasicStream::OnWriteComplete,
base::Unretained(this), buffer, callback),
kTrafficAnnotation);
int result = connection_->Write(
buffer.get(), buffer->BytesRemaining(),
base::BindOnce(&WebSocketBasicStream::OnWriteComplete,
base::Unretained(this), buffer),
kTrafficAnnotation);
if (result > 0) {
UMA_HISTOGRAM_COUNTS_100000("Net.WebSocket.DataUse.Upstream", result);
buffer->DidConsume(result);
......@@ -248,11 +265,10 @@ int WebSocketBasicStream::WriteEverything(
void WebSocketBasicStream::OnWriteComplete(
const scoped_refptr<DrainableIOBuffer>& buffer,
const CompletionCallback& callback,
int result) {
if (result < 0) {
DCHECK_NE(ERR_IO_PENDING, result);
callback.Run(result);
std::move(write_callback_).Run(result);
return;
}
......@@ -260,9 +276,9 @@ void WebSocketBasicStream::OnWriteComplete(
UMA_HISTOGRAM_COUNTS_100000("Net.WebSocket.DataUse.Upstream", result);
buffer->DidConsume(result);
result = WriteEverything(buffer, callback);
result = WriteEverything(buffer);
if (result != ERR_IO_PENDING)
callback.Run(result);
std::move(write_callback_).Run(result);
}
int WebSocketBasicStream::HandleReadResult(
......@@ -441,15 +457,4 @@ void WebSocketBasicStream::AddToIncompleteControlFrameBody(
incomplete_control_frame_body_->set_offset(new_offset);
}
void WebSocketBasicStream::OnReadComplete(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback,
int result) {
result = HandleReadResult(result, frames);
if (result == ERR_IO_PENDING)
result = ReadFrames(frames, callback);
if (result != ERR_IO_PENDING)
callback.Run(result);
}
} // namespace net
......@@ -10,7 +10,6 @@
#include <vector>
#include "base/memory/scoped_refptr.h"
#include "net/base/completion_callback.h"
#include "net/base/completion_once_callback.h"
#include "net/base/net_export.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
......@@ -68,10 +67,10 @@ class NET_EXPORT_PRIVATE WebSocketBasicStream : public WebSocketStream {
// WebSocketStream implementation.
int ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) override;
CompletionOnceCallback callback) override;
int WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) override;
CompletionOnceCallback callback) override;
void Close() override;
......@@ -91,14 +90,24 @@ class NET_EXPORT_PRIVATE WebSocketBasicStream : public WebSocketStream {
WebSocketMaskingKeyGeneratorFunction key_generator_function);
private:
// Returns OK or calls |callback| when the |buffer| is fully drained or
// something has failed.
int WriteEverything(const scoped_refptr<DrainableIOBuffer>& buffer,
const CompletionCallback& callback);
// Reads until socket read returns asynchronously or returns error.
// If returns ERR_IO_PENDING, then |read_callback_| will be called with result
// later.
int ReadEverything(std::vector<std::unique_ptr<WebSocketFrame>>* frames);
// Wraps the |callback| to continue writing until everything has been written.
// Called when a read completes. Parses the result, tries to read more.
// Might call |read_callback_|.
void OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
int result);
// Writes until |buffer| is fully drained (in which case returns OK) or a
// socket write returns asynchronously or returns an error. If returns
// ERR_IO_PENDING, then |write_callback_| will be called with result later.
int WriteEverything(const scoped_refptr<DrainableIOBuffer>& buffer);
// Called when a write completes. Tries to write more.
// Might call |write_callback_|.
void OnWriteComplete(const scoped_refptr<DrainableIOBuffer>& buffer,
const CompletionCallback& callback,
int result);
// Attempts to parse the output of a read as WebSocket frames. On success,
......@@ -138,12 +147,6 @@ class NET_EXPORT_PRIVATE WebSocketBasicStream : public WebSocketStream {
void AddToIncompleteControlFrameBody(
const scoped_refptr<IOBufferWithSize>& data_buffer);
// Called when a read completes. Parses the result and (unless no complete
// header has been received) calls |callback|.
void OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback,
int result);
// Storage for pending reads. All active WebSockets spend all the time with a
// call to ReadFrames() pending, so there is no benefit in trying to share
// this between sockets.
......@@ -185,6 +188,10 @@ class NET_EXPORT_PRIVATE WebSocketBasicStream : public WebSocketStream {
// use a Callback here because a function pointer is faster and good enough
// for our purposes.
WebSocketMaskingKeyGeneratorFunction generate_websocket_masking_key_;
// User callback saved for asynchronous writes and reads.
CompletionOnceCallback write_callback_;
CompletionOnceCallback read_callback_;
};
} // namespace net
......
This diff is collapsed.
......@@ -14,7 +14,6 @@
#include "base/bind.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/websockets/websocket_deflate_parameters.h"
......@@ -63,30 +62,28 @@ WebSocketDeflateStream::~WebSocketDeflateStream() = default;
int WebSocketDeflateStream::ReadFrames(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) {
CompletionOnceCallback callback) {
read_callback_ = std::move(callback);
int result = stream_->ReadFrames(
frames,
base::Bind(&WebSocketDeflateStream::OnReadComplete,
base::Unretained(this),
base::Unretained(frames),
callback));
frames, base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
base::Unretained(this), base::Unretained(frames)));
if (result < 0)
return result;
DCHECK_EQ(OK, result);
DCHECK(!frames->empty());
return InflateAndReadIfNecessary(frames, callback);
return InflateAndReadIfNecessary(frames);
}
int WebSocketDeflateStream::WriteFrames(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) {
CompletionOnceCallback callback) {
int result = Deflate(frames);
if (result != OK)
return result;
if (frames->empty())
return OK;
return stream_->WriteFrames(frames, callback);
return stream_->WriteFrames(frames, std::move(callback));
}
void WebSocketDeflateStream::Close() { stream_->Close(); }
......@@ -101,17 +98,16 @@ std::string WebSocketDeflateStream::GetExtensions() const {
void WebSocketDeflateStream::OnReadComplete(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback,
int result) {
if (result != OK) {
frames->clear();
callback.Run(result);
std::move(read_callback_).Run(result);
return;
}
int r = InflateAndReadIfNecessary(frames, callback);
int r = InflateAndReadIfNecessary(frames);
if (r != ERR_IO_PENDING)
callback.Run(r);
std::move(read_callback_).Run(r);
}
int WebSocketDeflateStream::Deflate(
......@@ -375,18 +371,15 @@ int WebSocketDeflateStream::Inflate(
}
int WebSocketDeflateStream::InflateAndReadIfNecessary(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) {
std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
int result = Inflate(frames);
while (result == ERR_IO_PENDING) {
DCHECK(frames->empty());
result = stream_->ReadFrames(
frames,
base::Bind(&WebSocketDeflateStream::OnReadComplete,
base::Unretained(this),
base::Unretained(frames),
callback));
base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
base::Unretained(this), base::Unretained(frames)));
if (result < 0)
break;
DCHECK_EQ(OK, result);
......
......@@ -12,7 +12,7 @@
#include <vector>
#include "base/macros.h"
#include "net/base/completion_callback.h"
#include "net/base/completion_once_callback.h"
#include "net/base/net_export.h"
#include "net/websockets/websocket_deflater.h"
#include "net/websockets/websocket_frame.h"
......@@ -47,9 +47,9 @@ class NET_EXPORT_PRIVATE WebSocketDeflateStream : public WebSocketStream {
// WebSocketStream functions.
int ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) override;
CompletionOnceCallback callback) override;
int WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) override;
CompletionOnceCallback callback) override;
void Close() override;
std::string GetSubProtocol() const override;
std::string GetExtensions() const override;
......@@ -70,7 +70,6 @@ class NET_EXPORT_PRIVATE WebSocketDeflateStream : public WebSocketStream {
// Handles asynchronous completion of ReadFrames() call on |stream_|.
void OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback,
int result);
// This function deflates |frames| and stores the result to |frames| itself.
......@@ -89,8 +88,7 @@ class NET_EXPORT_PRIVATE WebSocketDeflateStream : public WebSocketStream {
int Inflate(std::vector<std::unique_ptr<WebSocketFrame>>* frames);
int InflateAndReadIfNecessary(
std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback);
std::vector<std::unique_ptr<WebSocketFrame>>* frames);
const std::unique_ptr<WebSocketStream> stream_;
WebSocketDeflater deflater_;
......@@ -101,6 +99,9 @@ class NET_EXPORT_PRIVATE WebSocketDeflateStream : public WebSocketStream {
WebSocketFrameHeader::OpCode current_writing_opcode_;
std::unique_ptr<WebSocketDeflatePredictor> predictor_;
// User callback saved for asynchronous reads.
CompletionOnceCallback read_callback_;
DISALLOW_COPY_AND_ASSIGN(WebSocketDeflateStream);
};
......
......@@ -13,7 +13,6 @@
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
#include "base/test/fuzzed_data_provider.h"
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/websockets/websocket_deflate_parameters.h"
......@@ -48,7 +47,7 @@ class WebSocketFuzzedStream final : public WebSocketStream {
: fuzzed_data_provider_(fuzzed_data_provider) {}
int ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) override {
CompletionOnceCallback callback) override {
if (fuzzed_data_provider_->remaining_bytes() < MIN_BYTES_TO_CREATE_A_FRAME)
return ERR_CONNECTION_CLOSED;
while (fuzzed_data_provider_->remaining_bytes() > 0)
......@@ -57,7 +56,7 @@ class WebSocketFuzzedStream final : public WebSocketStream {
}
int WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) override {
CompletionOnceCallback callback) override {
return ERR_FILE_NOT_FOUND;
}
......@@ -117,7 +116,7 @@ void WebSocketDeflateStreamFuzz(const uint8_t* data, size_t size) {
std::make_unique<WebSocketFuzzedStream>(&fuzzed_data_provider),
parameters, std::make_unique<WebSocketDeflatePredictorImpl>());
std::vector<std::unique_ptr<net::WebSocketFrame>> frames;
deflate_stream.ReadFrames(&frames, CompletionCallback());
deflate_stream.ReadFrames(&frames, CompletionOnceCallback());
}
} // namespace
......
......@@ -14,7 +14,7 @@
#include "base/memory/scoped_refptr.h"
#include "base/optional.h"
#include "base/time/time.h"
#include "net/base/completion_callback.h"
#include "net/base/completion_once_callback.h"
#include "net/base/net_export.h"
#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_handshake_request_info.h"
......@@ -223,7 +223,7 @@ class NET_EXPORT_PRIVATE WebSocketStream {
// set correctly. If the reserved header bits are set incorrectly, it is okay
// to leave it to the caller to report the error.
virtual int ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) = 0;
CompletionOnceCallback callback) = 0;
// Writes WebSocket frame data.
//
......@@ -240,7 +240,7 @@ class NET_EXPORT_PRIVATE WebSocketStream {
// this. This generally means returning to the event loop immediately after
// calling the callback.
virtual int WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
const CompletionCallback& callback) = 0;
CompletionOnceCallback callback) = 0;
// Closes the stream. All pending I/O operations (if any) are cancelled
// at this point, so |frames| can be freed.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment