Commit 1c28679e authored by Adam Rice's avatar Adam Rice Committed by Commit Bot

Add completion callbacks to WebSocket Send

Add a callback argument to WebSocket Send() for string and ArrayBuffer.
The callback will be called when the send completes, if it happens
asynchronously and the callback is non-null.

The callback is not used yet. It will be used by the WebSocketStream
implementation.

The methods return SENT_SYNCHRONOUSLY if they were able to send the data
over mojo synchronously, as in that case the callback will not be
called.

The Blob Send method hasn't been changed, as it won't be used by
WebSocketStream.

See design doc for WebSocketStream at
https://docs.google.com/document/d/1XuxEshh5VYBYm1qRVKordTamCOsR-uGQBCYFcHXP4L0/edit

BUG=983030

Change-Id: Ia7f3147a49e8bc55f8ead63e611a395124118959
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1718445
Commit-Queue: Adam Rice <ricea@chromium.org>
Reviewed-by: default avatarYoichi Osato <yoichio@chromium.org>
Reviewed-by: default avatarYutaka Hirano <yhirano@chromium.org>
Cr-Commit-Position: refs/heads/master@{#682209}
parent 1255648c
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "third_party/blink/renderer/modules/websockets/dom_websocket.h" #include "third_party/blink/renderer/modules/websockets/dom_websocket.h"
#include "base/callback.h"
#include "base/feature_list.h" #include "base/feature_list.h"
#include "base/location.h" #include "base/location.h"
#include "third_party/blink/public/common/features.h" #include "third_party/blink/public/common/features.h"
...@@ -478,7 +479,7 @@ void DOMWebSocket::send(const String& message, ...@@ -478,7 +479,7 @@ void DOMWebSocket::send(const String& message,
DCHECK(channel_); DCHECK(channel_);
buffered_amount_ += encoded_message.length(); buffered_amount_ += encoded_message.length();
channel_->Send(encoded_message); channel_->Send(encoded_message, base::OnceClosure());
} }
void DOMWebSocket::send(DOMArrayBuffer* binary_data, void DOMWebSocket::send(DOMArrayBuffer* binary_data,
...@@ -500,7 +501,8 @@ void DOMWebSocket::send(DOMArrayBuffer* binary_data, ...@@ -500,7 +501,8 @@ void DOMWebSocket::send(DOMArrayBuffer* binary_data,
binary_data->ByteLength()); binary_data->ByteLength());
DCHECK(channel_); DCHECK(channel_);
buffered_amount_ += binary_data->ByteLength(); buffered_amount_ += binary_data->ByteLength();
channel_->Send(*binary_data, 0, binary_data->ByteLength()); channel_->Send(*binary_data, 0, binary_data->ByteLength(),
base::OnceClosure());
} }
void DOMWebSocket::send(NotShared<DOMArrayBufferView> array_buffer_view, void DOMWebSocket::send(NotShared<DOMArrayBufferView> array_buffer_view,
...@@ -523,7 +525,7 @@ void DOMWebSocket::send(NotShared<DOMArrayBufferView> array_buffer_view, ...@@ -523,7 +525,7 @@ void DOMWebSocket::send(NotShared<DOMArrayBufferView> array_buffer_view,
buffered_amount_ += array_buffer_view.View()->byteLength(); buffered_amount_ += array_buffer_view.View()->byteLength();
channel_->Send(*array_buffer_view.View()->buffer(), channel_->Send(*array_buffer_view.View()->buffer(),
array_buffer_view.View()->byteOffset(), array_buffer_view.View()->byteOffset(),
array_buffer_view.View()->byteLength()); array_buffer_view.View()->byteLength(), base::OnceClosure());
} }
void DOMWebSocket::send(Blob* binary_data, ExceptionState& exception_state) { void DOMWebSocket::send(Blob* binary_data, ExceptionState& exception_state) {
......
...@@ -49,8 +49,14 @@ class MockWebSocketChannel : public WebSocketChannel { ...@@ -49,8 +49,14 @@ class MockWebSocketChannel : public WebSocketChannel {
~MockWebSocketChannel() override = default; ~MockWebSocketChannel() override = default;
MOCK_METHOD2(Connect, bool(const KURL&, const String&)); MOCK_METHOD2(Connect, bool(const KURL&, const String&));
MOCK_METHOD1(Send, void(const std::string&)); MOCK_METHOD2(Send,
MOCK_METHOD3(Send, void(const DOMArrayBuffer&, unsigned, unsigned)); WebSocketChannel::SendResult(const std::string&,
base::OnceClosure));
MOCK_METHOD4(Send,
WebSocketChannel::SendResult(const DOMArrayBuffer&,
unsigned,
unsigned,
base::OnceClosure));
MOCK_METHOD1(SendMock, void(BlobDataHandle*)); MOCK_METHOD1(SendMock, void(BlobDataHandle*));
void Send(scoped_refptr<BlobDataHandle> handle) override { void Send(scoped_refptr<BlobDataHandle> handle) override {
SendMock(handle.get()); SendMock(handle.get());
...@@ -705,7 +711,7 @@ TEST(DOMWebSocketTest, sendStringSuccess) { ...@@ -705,7 +711,7 @@ TEST(DOMWebSocketTest, sendStringSuccess) {
EXPECT_CALL(websocket_scope.Channel(), EXPECT_CALL(websocket_scope.Channel(),
Connect(KURL("ws://example.com/"), String())) Connect(KURL("ws://example.com/"), String()))
.WillOnce(Return(true)); .WillOnce(Return(true));
EXPECT_CALL(websocket_scope.Channel(), Send(std::string("hello"))); EXPECT_CALL(websocket_scope.Channel(), Send(std::string("hello"), _));
} }
websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(), websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(),
scope.GetExceptionState()); scope.GetExceptionState());
...@@ -728,7 +734,7 @@ TEST(DOMWebSocketTest, sendNonLatin1String) { ...@@ -728,7 +734,7 @@ TEST(DOMWebSocketTest, sendNonLatin1String) {
Connect(KURL("ws://example.com/"), String())) Connect(KURL("ws://example.com/"), String()))
.WillOnce(Return(true)); .WillOnce(Return(true));
EXPECT_CALL(websocket_scope.Channel(), EXPECT_CALL(websocket_scope.Channel(),
Send(std::string("\xe7\x8b\x90\xe0\xa4\x94"))); Send(std::string("\xe7\x8b\x90\xe0\xa4\x94"), _));
} }
websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(), websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(),
scope.GetExceptionState()); scope.GetExceptionState());
...@@ -829,7 +835,7 @@ TEST(DOMWebSocketTest, sendArrayBufferSuccess) { ...@@ -829,7 +835,7 @@ TEST(DOMWebSocketTest, sendArrayBufferSuccess) {
EXPECT_CALL(websocket_scope.Channel(), EXPECT_CALL(websocket_scope.Channel(),
Connect(KURL("ws://example.com/"), String())) Connect(KURL("ws://example.com/"), String()))
.WillOnce(Return(true)); .WillOnce(Return(true));
EXPECT_CALL(websocket_scope.Channel(), Send(Ref(*view->buffer()), 0, 8)); EXPECT_CALL(websocket_scope.Channel(), Send(Ref(*view->buffer()), 0, 8, _));
} }
websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(), websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(),
scope.GetExceptionState()); scope.GetExceptionState());
...@@ -854,8 +860,8 @@ TEST(DOMWebSocketTest, bufferedAmountUpdated) { ...@@ -854,8 +860,8 @@ TEST(DOMWebSocketTest, bufferedAmountUpdated) {
EXPECT_CALL(websocket_scope.Channel(), EXPECT_CALL(websocket_scope.Channel(),
Connect(KURL("ws://example.com/"), String())) Connect(KURL("ws://example.com/"), String()))
.WillOnce(Return(true)); .WillOnce(Return(true));
EXPECT_CALL(websocket_scope.Channel(), Send(std::string("hello"))); EXPECT_CALL(websocket_scope.Channel(), Send(std::string("hello"), _));
EXPECT_CALL(websocket_scope.Channel(), Send(std::string("world"))); EXPECT_CALL(websocket_scope.Channel(), Send(std::string("world"), _));
} }
websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(), websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(),
scope.GetExceptionState()); scope.GetExceptionState());
...@@ -884,7 +890,7 @@ TEST(DOMWebSocketTest, bufferedAmountUpdatedBeforeOnMessage) { ...@@ -884,7 +890,7 @@ TEST(DOMWebSocketTest, bufferedAmountUpdatedBeforeOnMessage) {
EXPECT_CALL(websocket_scope.Channel(), EXPECT_CALL(websocket_scope.Channel(),
Connect(KURL("ws://example.com/"), String())) Connect(KURL("ws://example.com/"), String()))
.WillOnce(Return(true)); .WillOnce(Return(true));
EXPECT_CALL(websocket_scope.Channel(), Send(std::string("hello"))); EXPECT_CALL(websocket_scope.Channel(), Send(std::string("hello"), _));
} }
websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(), websocket_scope.Socket().Connect("ws://example.com/", Vector<String>(),
scope.GetExceptionState()); scope.GetExceptionState());
......
...@@ -31,7 +31,10 @@ ...@@ -31,7 +31,10 @@
#include "third_party/blink/renderer/modules/websockets/web_pepper_socket_impl.h" #include "third_party/blink/renderer/modules/websockets/web_pepper_socket_impl.h"
#include <stddef.h> #include <stddef.h>
#include <memory> #include <memory>
#include "base/callback.h"
#include "third_party/blink/public/mojom/devtools/console_message.mojom-blink.h" #include "third_party/blink/public/mojom/devtools/console_message.mojom-blink.h"
#include "third_party/blink/public/platform/web_url.h" #include "third_party/blink/public/platform/web_url.h"
#include "third_party/blink/public/web/web_array_buffer.h" #include "third_party/blink/public/web/web_array_buffer.h"
...@@ -93,7 +96,7 @@ bool WebPepperSocketImpl::SendText(const WebString& message) { ...@@ -93,7 +96,7 @@ bool WebPepperSocketImpl::SendText(const WebString& message) {
if (is_closing_or_closed_) if (is_closing_or_closed_)
return true; return true;
private_->Send(encoded_message); private_->Send(encoded_message, base::OnceClosure());
return true; return true;
} }
...@@ -111,7 +114,8 @@ bool WebPepperSocketImpl::SendArrayBuffer( ...@@ -111,7 +114,8 @@ bool WebPepperSocketImpl::SendArrayBuffer(
return true; return true;
DOMArrayBuffer* array_buffer = web_array_buffer; DOMArrayBuffer* array_buffer = web_array_buffer;
private_->Send(*array_buffer, 0, array_buffer->ByteLength()); private_->Send(*array_buffer, 0, array_buffer->ByteLength(),
base::OnceClosure());
return true; return true;
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#define THIRD_PARTY_BLINK_RENDERER_MODULES_WEBSOCKETS_WEBSOCKET_CHANNEL_H_ #define THIRD_PARTY_BLINK_RENDERER_MODULES_WEBSOCKETS_WEBSOCKET_CHANNEL_H_
#include <memory> #include <memory>
#include "base/callback_forward.h"
#include "base/macros.h" #include "base/macros.h"
#include "third_party/blink/public/mojom/devtools/console_message.mojom-blink.h" #include "third_party/blink/public/mojom/devtools/console_message.mojom-blink.h"
#include "third_party/blink/renderer/bindings/core/v8/source_location.h" #include "third_party/blink/renderer/bindings/core/v8/source_location.h"
...@@ -48,6 +49,8 @@ class KURL; ...@@ -48,6 +49,8 @@ class KURL;
class MODULES_EXPORT WebSocketChannel class MODULES_EXPORT WebSocketChannel
: public GarbageCollectedFinalized<WebSocketChannel> { : public GarbageCollectedFinalized<WebSocketChannel> {
public: public:
enum class SendResult { SENT_SYNCHRONOUSLY, CALLBACK_WILL_BE_CALLED };
WebSocketChannel() = default; WebSocketChannel() = default;
enum CloseEventCode { enum CloseEventCode {
...@@ -70,10 +73,15 @@ class MODULES_EXPORT WebSocketChannel ...@@ -70,10 +73,15 @@ class MODULES_EXPORT WebSocketChannel
}; };
virtual bool Connect(const KURL&, const String& protocol) = 0; virtual bool Connect(const KURL&, const String& protocol) = 0;
virtual void Send(const std::string&) = 0; virtual SendResult Send(const std::string&,
virtual void Send(const DOMArrayBuffer&, base::OnceClosure completion_callback) = 0;
unsigned byte_offset, virtual SendResult Send(const DOMArrayBuffer&,
unsigned byte_length) = 0; unsigned byte_offset,
unsigned byte_length,
base::OnceClosure completion_callback) = 0;
// Blobs are always sent asynchronously. No callers currently need completion
// callbacks for Blobs, so they are not implemented.
virtual void Send(scoped_refptr<BlobDataHandle>) = 0; virtual void Send(scoped_refptr<BlobDataHandle>) = 0;
// Do not call |Send| after calling this method. // Do not call |Send| after calling this method.
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <memory> #include <memory>
#include "base/callback.h"
#include "base/location.h" #include "base/location.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "mojo/public/cpp/bindings/interface_request.h" #include "mojo/public/cpp/bindings/interface_request.h"
...@@ -105,9 +106,9 @@ class WebSocketChannelImpl::BlobLoader final ...@@ -105,9 +106,9 @@ class WebSocketChannelImpl::BlobLoader final
class WebSocketChannelImpl::Message class WebSocketChannelImpl::Message
: public GarbageCollectedFinalized<WebSocketChannelImpl::Message> { : public GarbageCollectedFinalized<WebSocketChannelImpl::Message> {
public: public:
explicit Message(const std::string&); Message(const std::string&, base::OnceClosure completion_callback);
explicit Message(scoped_refptr<BlobDataHandle>); explicit Message(scoped_refptr<BlobDataHandle>);
explicit Message(DOMArrayBuffer*); Message(DOMArrayBuffer*, base::OnceClosure completion_callback);
// Close message // Close message
Message(uint16_t code, const String& reason); Message(uint16_t code, const String& reason);
...@@ -120,6 +121,7 @@ class WebSocketChannelImpl::Message ...@@ -120,6 +121,7 @@ class WebSocketChannelImpl::Message
Member<DOMArrayBuffer> array_buffer; Member<DOMArrayBuffer> array_buffer;
uint16_t code; uint16_t code;
String reason; String reason;
base::OnceClosure completion_callback;
}; };
WebSocketChannelImpl::BlobLoader::BlobLoader( WebSocketChannelImpl::BlobLoader::BlobLoader(
...@@ -288,18 +290,28 @@ bool WebSocketChannelImpl::Connect(const KURL& url, const String& protocol) { ...@@ -288,18 +290,28 @@ bool WebSocketChannelImpl::Connect(const KURL& url, const String& protocol) {
return true; return true;
} }
void WebSocketChannelImpl::Send(const std::string& message) { WebSocketChannel::SendResult WebSocketChannelImpl::Send(
const std::string& message,
base::OnceClosure completion_callback) {
NETWORK_DVLOG(1) << this << " Send(" << message << ") (std::string argument)"; NETWORK_DVLOG(1) << this << " Send(" << message << ") (std::string argument)";
probe::DidSendWebSocketMessage(execution_context_, identifier_, probe::DidSendWebSocketMessage(execution_context_, identifier_,
WebSocketOpCode::kOpCodeText, true, WebSocketOpCode::kOpCodeText, true,
message.c_str(), message.length()); message.c_str(), message.length());
if (messages_.empty() && if (messages_.empty() &&
MaybeSendSynchronously(WebSocketHandle::kMessageTypeText, message)) { MaybeSendSynchronously(WebSocketHandle::kMessageTypeText, message)) {
return; return SendResult::SENT_SYNCHRONOUSLY;
} }
messages_.push_back(MakeGarbageCollected<Message>(message)); messages_.push_back(
MakeGarbageCollected<Message>(message, std::move(completion_callback)));
ProcessSendQueue(); ProcessSendQueue();
// If we managed to flush this message synchronously after all, it would mean
// that the callback was fired re-entrantly, which would be bad.
DCHECK(!messages_.empty());
return SendResult::CALLBACK_WILL_BE_CALLED;
} }
void WebSocketChannelImpl::Send( void WebSocketChannelImpl::Send(
...@@ -318,9 +330,11 @@ void WebSocketChannelImpl::Send( ...@@ -318,9 +330,11 @@ void WebSocketChannelImpl::Send(
ProcessSendQueue(); ProcessSendQueue();
} }
void WebSocketChannelImpl::Send(const DOMArrayBuffer& buffer, WebSocketChannel::SendResult WebSocketChannelImpl::Send(
unsigned byte_offset, const DOMArrayBuffer& buffer,
unsigned byte_length) { unsigned byte_offset,
unsigned byte_length,
base::OnceClosure completion_callback) {
NETWORK_DVLOG(1) << this << " Send(" << buffer.Data() << ", " << byte_offset NETWORK_DVLOG(1) << this << " Send(" << buffer.Data() << ", " << byte_offset
<< ", " << byte_length << ") " << ", " << byte_length << ") "
<< "(DOMArrayBuffer argument)"; << "(DOMArrayBuffer argument)";
...@@ -332,13 +346,21 @@ void WebSocketChannelImpl::Send(const DOMArrayBuffer& buffer, ...@@ -332,13 +346,21 @@ void WebSocketChannelImpl::Send(const DOMArrayBuffer& buffer,
WebSocketHandle::kMessageTypeBinary, WebSocketHandle::kMessageTypeBinary,
base::make_span(static_cast<const char*>(buffer.Data()) + byte_offset, base::make_span(static_cast<const char*>(buffer.Data()) + byte_offset,
byte_length))) { byte_length))) {
return; return SendResult::SENT_SYNCHRONOUSLY;
} }
// buffer.Slice copies its contents. // buffer.Slice copies its contents.
messages_.push_back(MakeGarbageCollected<Message>( messages_.push_back(MakeGarbageCollected<Message>(
buffer.Slice(byte_offset, byte_offset + byte_length))); buffer.Slice(byte_offset, byte_offset + byte_length),
std::move(completion_callback)));
ProcessSendQueue(); ProcessSendQueue();
// If we managed to flush this message synchronously after all, it would mean
// that the callback was fired re-entrantly, which would be bad.
DCHECK(!messages_.empty());
return SendResult::CALLBACK_WILL_BE_CALLED;
} }
void WebSocketChannelImpl::Close(int code, const String& reason) { void WebSocketChannelImpl::Close(int code, const String& reason) {
...@@ -394,15 +416,21 @@ void WebSocketChannelImpl::Disconnect() { ...@@ -394,15 +416,21 @@ void WebSocketChannelImpl::Disconnect() {
identifier_ = 0; identifier_ = 0;
} }
WebSocketChannelImpl::Message::Message(const std::string& text) WebSocketChannelImpl::Message::Message(const std::string& text,
: type(kMessageTypeText), text(text) {} base::OnceClosure completion_callback)
: type(kMessageTypeText),
text(text),
completion_callback(std::move(completion_callback)) {}
WebSocketChannelImpl::Message::Message( WebSocketChannelImpl::Message::Message(
scoped_refptr<BlobDataHandle> blob_data_handle) scoped_refptr<BlobDataHandle> blob_data_handle)
: type(kMessageTypeBlob), blob_data_handle(std::move(blob_data_handle)) {} : type(kMessageTypeBlob), blob_data_handle(std::move(blob_data_handle)) {}
WebSocketChannelImpl::Message::Message(DOMArrayBuffer* array_buffer) WebSocketChannelImpl::Message::Message(DOMArrayBuffer* array_buffer,
: type(kMessageTypeArrayBuffer), array_buffer(array_buffer) {} base::OnceClosure completion_callback)
: type(kMessageTypeArrayBuffer),
array_buffer(array_buffer),
completion_callback(std::move(completion_callback)) {}
WebSocketChannelImpl::Message::Message(uint16_t code, const String& reason) WebSocketChannelImpl::Message::Message(uint16_t code, const String& reason)
: type(kMessageTypeClose), code(code), reason(reason) {} : type(kMessageTypeClose), code(code), reason(reason) {}
...@@ -431,6 +459,10 @@ void WebSocketChannelImpl::SendInternal( ...@@ -431,6 +459,10 @@ void WebSocketChannelImpl::SendInternal(
sent_size_of_top_message_ += size; sent_size_of_top_message_ += size;
if (final) { if (final) {
base::OnceClosure completion_callback =
std::move(messages_.front()->completion_callback);
if (!completion_callback.is_null())
std::move(completion_callback).Run();
messages_.pop_front(); messages_.pop_front();
sent_size_of_top_message_ = 0; sent_size_of_top_message_ = 0;
} }
...@@ -797,7 +829,8 @@ void WebSocketChannelImpl::DidFinishLoadingBlob(DOMArrayBuffer* buffer) { ...@@ -797,7 +829,8 @@ void WebSocketChannelImpl::DidFinishLoadingBlob(DOMArrayBuffer* buffer) {
DCHECK_GT(messages_.size(), 0u); DCHECK_GT(messages_.size(), 0u);
DCHECK_EQ(messages_.front()->type, kMessageTypeBlob); DCHECK_EQ(messages_.front()->type, kMessageTypeBlob);
// We replace it with the loaded blob. // We replace it with the loaded blob.
messages_.front() = MakeGarbageCollected<Message>(buffer); messages_.front() =
MakeGarbageCollected<Message>(buffer, base::OnceClosure());
ProcessSendQueue(); ProcessSendQueue();
} }
......
...@@ -91,10 +91,12 @@ class MODULES_EXPORT WebSocketChannelImpl final : public WebSocketChannel { ...@@ -91,10 +91,12 @@ class MODULES_EXPORT WebSocketChannelImpl final : public WebSocketChannel {
// WebSocketChannel functions. // WebSocketChannel functions.
bool Connect(const KURL&, const String& protocol) override; bool Connect(const KURL&, const String& protocol) override;
void Send(const std::string& message) override; SendResult Send(const std::string& message,
void Send(const DOMArrayBuffer&, base::OnceClosure completion_callback) override;
unsigned byte_offset, SendResult Send(const DOMArrayBuffer&,
unsigned byte_length) override; unsigned byte_offset,
unsigned byte_length,
base::OnceClosure completion_callback) override;
void Send(scoped_refptr<BlobDataHandle>) override; void Send(scoped_refptr<BlobDataHandle>) override;
// Start closing handshake. Use the CloseEventCodeNotSpecified for the code // Start closing handshake. Use the CloseEventCodeNotSpecified for the code
// argument to omit payload. // argument to omit payload.
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <stdint.h> #include <stdint.h>
#include <memory> #include <memory>
#include "base/callback.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -21,6 +23,7 @@ ...@@ -21,6 +23,7 @@
#include "third_party/blink/renderer/modules/websockets/websocket_handle.h" #include "third_party/blink/renderer/modules/websockets/websocket_handle.h"
#include "third_party/blink/renderer/platform/heap/handle.h" #include "third_party/blink/renderer/platform/heap/handle.h"
#include "third_party/blink/renderer/platform/weborigin/kurl.h" #include "third_party/blink/renderer/platform/weborigin/kurl.h"
#include "third_party/blink/renderer/platform/wtf/functional.h"
#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h" #include "third_party/blink/renderer/platform/wtf/text/wtf_string.h"
#include "third_party/blink/renderer/platform/wtf/vector.h" #include "third_party/blink/renderer/platform/wtf/vector.h"
#include "third_party/blink/renderer/platform/wtf/wtf_size_t.h" #include "third_party/blink/renderer/platform/wtf/wtf_size_t.h"
...@@ -181,6 +184,26 @@ class WebSocketChannelImplTest : public PageTestBase { ...@@ -181,6 +184,26 @@ class WebSocketChannelImplTest : public PageTestBase {
static const uint64_t kDefaultReceiveQuotaThreshold = 1 << 15; static const uint64_t kDefaultReceiveQuotaThreshold = 1 << 15;
}; };
class CallTrackingClosure {
public:
CallTrackingClosure() = default;
base::OnceClosure Closure() {
// This use of base::Unretained is safe because nothing can call the
// callback once the test has finished.
return WTF::Bind(&CallTrackingClosure::Called, base::Unretained(this));
}
bool WasCalled() const { return was_called_; }
private:
void Called() { was_called_ = true; }
bool was_called_ = false;
DISALLOW_COPY_AND_ASSIGN(CallTrackingClosure);
};
MATCHER_P2(MemEq, MATCHER_P2(MemEq,
p, p,
len, len,
...@@ -243,9 +266,9 @@ TEST_F(WebSocketChannelImplTest, sendText) { ...@@ -243,9 +266,9 @@ TEST_F(WebSocketChannelImplTest, sendText) {
ChannelImpl()->AddSendFlowControlQuota(Handle(), 16); ChannelImpl()->AddSendFlowControlQuota(Handle(), 16);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber()); EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
Channel()->Send("foo"); Channel()->Send("foo", base::OnceClosure());
Channel()->Send("bar"); Channel()->Send("bar", base::OnceClosure());
Channel()->Send("baz"); Channel()->Send("baz", base::OnceClosure());
EXPECT_EQ(9ul, sum_of_consumed_buffered_amount_); EXPECT_EQ(9ul, sum_of_consumed_buffered_amount_);
} }
...@@ -276,9 +299,10 @@ TEST_F(WebSocketChannelImplTest, sendTextContinuation) { ...@@ -276,9 +299,10 @@ TEST_F(WebSocketChannelImplTest, sendTextContinuation) {
ChannelImpl()->AddSendFlowControlQuota(Handle(), 16); ChannelImpl()->AddSendFlowControlQuota(Handle(), 16);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber()); EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
Channel()->Send("0123456789abcdefg"); Channel()->Send("0123456789abcdefg", base::OnceClosure());
Channel()->Send("hijk"); Channel()->Send("hijk", base::OnceClosure());
Channel()->Send("lmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"); Channel()->Send("lmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
base::OnceClosure());
checkpoint.Call(1); checkpoint.Call(1);
ChannelImpl()->AddSendFlowControlQuota(Handle(), 16); ChannelImpl()->AddSendFlowControlQuota(Handle(), 16);
checkpoint.Call(2); checkpoint.Call(2);
...@@ -301,7 +325,7 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInVector) { ...@@ -301,7 +325,7 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInVector) {
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber()); EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
DOMArrayBuffer* foo_buffer = DOMArrayBuffer::Create("foo", 3); DOMArrayBuffer* foo_buffer = DOMArrayBuffer::Create("foo", 3);
Channel()->Send(*foo_buffer, 0, 3); Channel()->Send(*foo_buffer, 0, 3, base::OnceClosure());
EXPECT_EQ(3ul, sum_of_consumed_buffered_amount_); EXPECT_EQ(3ul, sum_of_consumed_buffered_amount_);
} }
...@@ -325,10 +349,10 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferPartial) { ...@@ -325,10 +349,10 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferPartial) {
DOMArrayBuffer* foobar_buffer = DOMArrayBuffer::Create("foobar", 6); DOMArrayBuffer* foobar_buffer = DOMArrayBuffer::Create("foobar", 6);
DOMArrayBuffer* qbazux_buffer = DOMArrayBuffer::Create("qbazux", 6); DOMArrayBuffer* qbazux_buffer = DOMArrayBuffer::Create("qbazux", 6);
Channel()->Send(*foobar_buffer, 0, 3); Channel()->Send(*foobar_buffer, 0, 3, base::OnceClosure());
Channel()->Send(*foobar_buffer, 3, 3); Channel()->Send(*foobar_buffer, 3, 3, base::OnceClosure());
Channel()->Send(*qbazux_buffer, 1, 3); Channel()->Send(*qbazux_buffer, 1, 3, base::OnceClosure());
Channel()->Send(*qbazux_buffer, 2, 1); Channel()->Send(*qbazux_buffer, 2, 1, base::OnceClosure());
EXPECT_EQ(10ul, sum_of_consumed_buffered_amount_); EXPECT_EQ(10ul, sum_of_consumed_buffered_amount_);
} }
...@@ -352,19 +376,19 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferWithNullBytes) { ...@@ -352,19 +376,19 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferWithNullBytes) {
{ {
DOMArrayBuffer* b = DOMArrayBuffer::Create("\0ar", 3); DOMArrayBuffer* b = DOMArrayBuffer::Create("\0ar", 3);
Channel()->Send(*b, 0, 3); Channel()->Send(*b, 0, 3, base::OnceClosure());
} }
{ {
DOMArrayBuffer* b = DOMArrayBuffer::Create("b\0z", 3); DOMArrayBuffer* b = DOMArrayBuffer::Create("b\0z", 3);
Channel()->Send(*b, 0, 3); Channel()->Send(*b, 0, 3, base::OnceClosure());
} }
{ {
DOMArrayBuffer* b = DOMArrayBuffer::Create("qu\0", 3); DOMArrayBuffer* b = DOMArrayBuffer::Create("qu\0", 3);
Channel()->Send(*b, 0, 3); Channel()->Send(*b, 0, 3, base::OnceClosure());
} }
{ {
DOMArrayBuffer* b = DOMArrayBuffer::Create("\0\0\0", 3); DOMArrayBuffer* b = DOMArrayBuffer::Create("\0\0\0", 3);
Channel()->Send(*b, 0, 3); Channel()->Send(*b, 0, 3, base::OnceClosure());
} }
EXPECT_EQ(12ul, sum_of_consumed_buffered_amount_); EXPECT_EQ(12ul, sum_of_consumed_buffered_amount_);
...@@ -379,7 +403,7 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferNonLatin1UTF8) { ...@@ -379,7 +403,7 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferNonLatin1UTF8) {
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber()); EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
DOMArrayBuffer* b = DOMArrayBuffer::Create("\xe7\x8b\x90", 3); DOMArrayBuffer* b = DOMArrayBuffer::Create("\xe7\x8b\x90", 3);
Channel()->Send(*b, 0, 3); Channel()->Send(*b, 0, 3, base::OnceClosure());
EXPECT_EQ(3ul, sum_of_consumed_buffered_amount_); EXPECT_EQ(3ul, sum_of_consumed_buffered_amount_);
} }
...@@ -393,7 +417,7 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferNonUTF8) { ...@@ -393,7 +417,7 @@ TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferNonUTF8) {
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber()); EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
DOMArrayBuffer* b = DOMArrayBuffer::Create("\x80\xff\xe7", 3); DOMArrayBuffer* b = DOMArrayBuffer::Create("\x80\xff\xe7", 3);
Channel()->Send(*b, 0, 3); Channel()->Send(*b, 0, 3, base::OnceClosure());
EXPECT_EQ(3ul, sum_of_consumed_buffered_amount_); EXPECT_EQ(3ul, sum_of_consumed_buffered_amount_);
} }
...@@ -421,7 +445,7 @@ TEST_F(WebSocketChannelImplTest, ...@@ -421,7 +445,7 @@ TEST_F(WebSocketChannelImplTest,
"\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b" "\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b\x90\xe7\x8b"
"\x90", "\x90",
18); 18);
Channel()->Send(*b, 0, 18); Channel()->Send(*b, 0, 18, base::OnceClosure());
checkpoint.Call(1); checkpoint.Call(1);
ChannelImpl()->AddSendFlowControlQuota(Handle(), 16); ChannelImpl()->AddSendFlowControlQuota(Handle(), 16);
...@@ -429,6 +453,118 @@ TEST_F(WebSocketChannelImplTest, ...@@ -429,6 +453,118 @@ TEST_F(WebSocketChannelImplTest,
EXPECT_EQ(18ul, sum_of_consumed_buffered_amount_); EXPECT_EQ(18ul, sum_of_consumed_buffered_amount_);
} }
TEST_F(WebSocketChannelImplTest, sendTextSync) {
Connect();
{
InSequence s;
EXPECT_CALL(*Handle(), Send(_, _, _, _));
}
ChannelImpl()->AddSendFlowControlQuota(Handle(), 5);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
CallTrackingClosure closure;
EXPECT_EQ(WebSocketChannel::SendResult::SENT_SYNCHRONOUSLY,
Channel()->Send("hello", closure.Closure()));
EXPECT_FALSE(closure.WasCalled());
}
TEST_F(WebSocketChannelImplTest, sendTextAsyncBecauseQuota) {
Connect();
{
InSequence s;
EXPECT_CALL(*Handle(), Send(_, _, _, _));
EXPECT_CALL(*Handle(), Send(_, _, _, _));
}
ChannelImpl()->AddSendFlowControlQuota(Handle(), 4);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
CallTrackingClosure closure;
EXPECT_EQ(WebSocketChannel::SendResult::CALLBACK_WILL_BE_CALLED,
Channel()->Send("hello", closure.Closure()));
EXPECT_FALSE(closure.WasCalled());
ChannelImpl()->AddSendFlowControlQuota(Handle(), 1);
EXPECT_TRUE(closure.WasCalled());
}
TEST_F(WebSocketChannelImplTest, sendTextAsyncBecauseQueue) {
Connect();
{
InSequence s;
EXPECT_CALL(*Handle(), Send(_, _, _, _));
EXPECT_CALL(*Handle(), Send(_, _, _, _));
EXPECT_CALL(*Handle(), Send(_, _, _, _));
}
ChannelImpl()->AddSendFlowControlQuota(Handle(), 8);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
// Ideally we'd use a Blob to block the queue in this test, but setting up a
// working blob environment in a unit-test is complicated, so just block
// behind a larger string instead.
Channel()->Send("0123456789", base::OnceClosure());
CallTrackingClosure closure;
EXPECT_EQ(WebSocketChannel::SendResult::CALLBACK_WILL_BE_CALLED,
Channel()->Send("hello", closure.Closure()));
EXPECT_FALSE(closure.WasCalled());
ChannelImpl()->AddSendFlowControlQuota(Handle(), 7);
EXPECT_TRUE(closure.WasCalled());
}
TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferSync) {
Connect();
{
InSequence s;
EXPECT_CALL(*Handle(), Send(_, _, _, _));
}
ChannelImpl()->AddSendFlowControlQuota(Handle(), 5);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
CallTrackingClosure closure;
const auto* b = DOMArrayBuffer::Create("hello", 5);
EXPECT_EQ(WebSocketChannel::SendResult::SENT_SYNCHRONOUSLY,
Channel()->Send(*b, 0, 5, closure.Closure()));
EXPECT_FALSE(closure.WasCalled());
}
TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferAsyncBecauseQuota) {
Connect();
{
InSequence s;
EXPECT_CALL(*Handle(), Send(_, _, _, _));
EXPECT_CALL(*Handle(), Send(_, _, _, _));
}
ChannelImpl()->AddSendFlowControlQuota(Handle(), 4);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
CallTrackingClosure closure;
const auto* b = DOMArrayBuffer::Create("hello", 5);
EXPECT_EQ(WebSocketChannel::SendResult::CALLBACK_WILL_BE_CALLED,
Channel()->Send(*b, 0, 5, closure.Closure()));
EXPECT_FALSE(closure.WasCalled());
ChannelImpl()->AddSendFlowControlQuota(Handle(), 1);
EXPECT_TRUE(closure.WasCalled());
}
TEST_F(WebSocketChannelImplTest, sendBinaryInArrayBufferAsyncBecauseQueue) {
Connect();
{
InSequence s;
EXPECT_CALL(*Handle(), Send(_, _, _, _));
EXPECT_CALL(*Handle(), Send(_, _, _, _));
EXPECT_CALL(*Handle(), Send(_, _, _, _));
}
ChannelImpl()->AddSendFlowControlQuota(Handle(), 8);
EXPECT_CALL(*ChannelClient(), DidConsumeBufferedAmount(_)).Times(AnyNumber());
Channel()->Send("0123456789", base::OnceClosure());
CallTrackingClosure closure;
const auto* b = DOMArrayBuffer::Create("hello", 5);
EXPECT_EQ(WebSocketChannel::SendResult::CALLBACK_WILL_BE_CALLED,
Channel()->Send(*b, 0, 5, closure.Closure()));
EXPECT_FALSE(closure.WasCalled());
ChannelImpl()->AddSendFlowControlQuota(Handle(), 7);
EXPECT_TRUE(closure.WasCalled());
}
// FIXME: Add tests for WebSocketChannel::send(scoped_refptr<BlobDataHandle>) // FIXME: Add tests for WebSocketChannel::send(scoped_refptr<BlobDataHandle>)
TEST_F(WebSocketChannelImplTest, receiveText) { TEST_F(WebSocketChannelImplTest, receiveText) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment