Commit 74222e77 authored by Helen Li's avatar Helen Li Committed by Commit Bot

Fix gcm::SocketOutputStream to handle partial writes

gcm::SocketOutputStream is not keeping track of partial writes correctly.
It will always try to write |next_pos_| bytes.
This CL adds a regression test.

Bug: 866635
Change-Id: Iccfc4c88a9247ff151073d3691e20ba052f082b0
Reviewed-on: https://chromium-review.googlesource.com/1147475
Commit-Queue: Helen Li <xunjieli@chromium.org>
Reviewed-by: default avatarPeter Beverloo <peter@chromium.org>
Reviewed-by: default avatarMaks Orlovich <morlovich@chromium.org>
Cr-Commit-Position: refs/heads/master@{#577531}
parent e88a3b6a
...@@ -216,9 +216,7 @@ SocketOutputStream::SocketOutputStream( ...@@ -216,9 +216,7 @@ SocketOutputStream::SocketOutputStream(
net::StreamSocket* socket, net::StreamSocket* socket,
const net::NetworkTrafficAnnotationTag& traffic_annotation) const net::NetworkTrafficAnnotationTag& traffic_annotation)
: socket_(socket), : socket_(socket),
io_buffer_(new net::IOBuffer(kDefaultBufferSize)), io_buffer_(new net::IOBufferWithSize(kDefaultBufferSize)),
write_buffer_(
new net::DrainableIOBuffer(io_buffer_.get(), kDefaultBufferSize)),
next_pos_(0), next_pos_(0),
last_error_(net::OK), last_error_(net::OK),
traffic_annotation_(traffic_annotation), traffic_annotation_(traffic_annotation),
...@@ -232,12 +230,12 @@ SocketOutputStream::~SocketOutputStream() { ...@@ -232,12 +230,12 @@ SocketOutputStream::~SocketOutputStream() {
bool SocketOutputStream::Next(void** data, int* size) { bool SocketOutputStream::Next(void** data, int* size) {
DCHECK_NE(GetState(), CLOSED); DCHECK_NE(GetState(), CLOSED);
DCHECK_NE(GetState(), FLUSHING); DCHECK_NE(GetState(), FLUSHING);
if (next_pos_ == write_buffer_->size()) if (next_pos_ == io_buffer_->size())
return false; return false;
*data = write_buffer_->data() + next_pos_; *data = io_buffer_->data() + next_pos_;
*size = write_buffer_->size() - next_pos_; *size = io_buffer_->size() - next_pos_;
next_pos_ = write_buffer_->size(); next_pos_ = io_buffer_->size();
return true; return true;
} }
...@@ -259,15 +257,21 @@ int64_t SocketOutputStream::ByteCount() const { ...@@ -259,15 +257,21 @@ int64_t SocketOutputStream::ByteCount() const {
net::Error SocketOutputStream::Flush(const base::Closure& callback) { net::Error SocketOutputStream::Flush(const base::Closure& callback) {
DCHECK_EQ(GetState(), READY); DCHECK_EQ(GetState(), READY);
if (!write_buffer_) {
write_buffer_ = base::MakeRefCounted<net::DrainableIOBuffer>(
io_buffer_.get(), next_pos_);
}
if (!socket_->IsConnected()) { if (!socket_->IsConnected()) {
LOG(ERROR) << "Socket was disconnected, closing output stream"; LOG(ERROR) << "Socket was disconnected, closing output stream";
last_error_ = net::ERR_CONNECTION_CLOSED; last_error_ = net::ERR_CONNECTION_CLOSED;
return net::OK; return net::OK;
} }
DVLOG(1) << "Flushing " << next_pos_ << " bytes into socket."; DVLOG(1) << "Flushing " << write_buffer_->BytesRemaining()
<< " bytes into socket.";
int result = int result =
socket_->Write(write_buffer_.get(), next_pos_, socket_->Write(write_buffer_.get(), write_buffer_->BytesRemaining(),
base::Bind(&SocketOutputStream::FlushCompletionCallback, base::Bind(&SocketOutputStream::FlushCompletionCallback,
weak_ptr_factory_.GetWeakPtr(), callback), weak_ptr_factory_.GetWeakPtr(), callback),
traffic_annotation_); traffic_annotation_);
...@@ -322,17 +326,17 @@ void SocketOutputStream::FlushCompletionCallback( ...@@ -322,17 +326,17 @@ void SocketOutputStream::FlushCompletionCallback(
DCHECK_GT(result, net::OK); DCHECK_GT(result, net::OK);
last_error_ = net::OK; last_error_ = net::OK;
write_buffer_->DidConsume(result);
if (write_buffer_->BytesConsumed() + result < next_pos_) { if (write_buffer_->BytesRemaining() > 0) {
DVLOG(1) << "Partial flush complete. Retrying."; DVLOG(1) << "Partial flush complete. Retrying.";
// Only a partial write was completed. Flush again to finish the write. // Only a partial write was completed. Flush again to finish the write.
write_buffer_->DidConsume(result);
Flush(callback); Flush(callback);
return; return;
} }
DVLOG(1) << "Socket flush complete."; DVLOG(1) << "Socket flush complete.";
write_buffer_->SetOffset(0); write_buffer_ = nullptr;
next_pos_ = 0; next_pos_ = 0;
if (!callback.is_null()) if (!callback.is_null())
callback.Run(); callback.Run();
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
namespace net { namespace net {
class DrainableIOBuffer; class DrainableIOBuffer;
class IOBuffer; class IOBuffer;
class IOBufferWithSize;
class StreamSocket; class StreamSocket;
} // namespace net } // namespace net
...@@ -186,10 +187,10 @@ class GCM_EXPORT SocketOutputStream ...@@ -186,10 +187,10 @@ class GCM_EXPORT SocketOutputStream
// Internal net components. // Internal net components.
net::StreamSocket* const socket_; net::StreamSocket* const socket_;
const scoped_refptr<net::IOBuffer> io_buffer_; const scoped_refptr<net::IOBufferWithSize> io_buffer_;
// IOBuffer implementation that wraps the data within |io_buffer_| that hasn't // IOBuffer implementation that wraps the data within |io_buffer_| that hasn't
// been written to the socket yet. // been written to the socket yet.
const scoped_refptr<net::DrainableIOBuffer> write_buffer_; scoped_refptr<net::DrainableIOBuffer> write_buffer_;
// Starting position of the data within |io_buffer_| to consume on subsequent // Starting position of the data within |io_buffer_| to consume on subsequent
// Next(..) call. 0 <= write_buffer_.BytesConsumed() <= next_pos_ // Next(..) call. 0 <= write_buffer_.BytesConsumed() <= next_pos_
......
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include <stdint.h> #include <stdint.h>
#include <memory> #include <memory>
#include <string>
#include <utility>
#include <vector>
#include "base/bind.h" #include "base/bind.h"
#include "base/macros.h" #include "base/macros.h"
...@@ -32,6 +35,77 @@ const int kReadData2Size = arraysize(kReadData2) - 1; ...@@ -32,6 +35,77 @@ const int kReadData2Size = arraysize(kReadData2) - 1;
const char kWriteData[] = "write_data"; const char kWriteData[] = "write_data";
const int kWriteDataSize = arraysize(kWriteData) - 1; const int kWriteDataSize = arraysize(kWriteData) - 1;
// A net::StreamSocket that returns a partial write only for the first time.
class FirstWritePartialSocket : public net::StreamSocket {
public:
FirstWritePartialSocket() {}
~FirstWritePartialSocket() override {}
// Returns the data that is actually written to the socket.
const std::string& actual_data_written() { return actual_data_written_; }
// net::Socket implementation.
int Write(
net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation) override {
// Make the first write as a partial write.
if (!write_invoked_) {
write_invoked_ = true;
actual_data_written_.append(buf->data(), buf_len / 2);
return buf_len / 2;
}
// For subsequent writes, write everything that caller has passed to us.
actual_data_written_.append(buf->data(), buf_len);
return buf_len;
}
int Read(net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback) override {
return net::ERR_IO_PENDING;
}
int ReadIfReady(net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback) override {
return net::ERR_IO_PENDING;
}
int CancelReadIfReady() override { return net::OK; }
int SetReceiveBufferSize(int32_t size) override { return net::OK; }
int SetSendBufferSize(int32_t size) override { return net::OK; }
// net::StreamSocket implementation.
int Connect(net::CompletionOnceCallback callback) override { return net::OK; }
void Disconnect() override {}
bool IsConnected() const override { return true; }
bool IsConnectedAndIdle() const override { return true; };
int GetPeerAddress(net::IPEndPoint* address) const override {
return net::OK;
}
int GetLocalAddress(net::IPEndPoint* address) const override {
return net::OK;
}
const net::NetLogWithSource& NetLog() const override { return net_log_; }
bool WasEverUsed() const override { return true; }
bool WasAlpnNegotiated() const override { return false; }
net::NextProto GetNegotiatedProtocol() const override {
return net::kProtoUnknown;
}
bool GetSSLInfo(net::SSLInfo* ssl_info) override { return false; }
void GetConnectionAttempts(net::ConnectionAttempts* out) const override {}
void ClearConnectionAttempts() override {}
void AddConnectionAttempts(const net::ConnectionAttempts& attempts) override {
}
int64_t GetTotalReceivedBytes() const override { return 0; }
void ApplySocketTag(const net::SocketTag& tag) override {}
private:
net::NetLogWithSource net_log_;
std::string actual_data_written_;
// Whether Write() has been invoked before.
bool write_invoked_ = false;
};
class GCMSocketStreamTest : public testing::Test { class GCMSocketStreamTest : public testing::Test {
public: public:
GCMSocketStreamTest(); GCMSocketStreamTest();
...@@ -59,6 +133,10 @@ class GCMSocketStreamTest : public testing::Test { ...@@ -59,6 +133,10 @@ class GCMSocketStreamTest : public testing::Test {
SocketOutputStream* output_stream() { return socket_output_stream_.get(); } SocketOutputStream* output_stream() { return socket_output_stream_.get(); }
net::StreamSocket* socket() { return socket_.get(); } net::StreamSocket* socket() { return socket_.get(); }
void set_socket_output_stream(std::unique_ptr<SocketOutputStream> stream) {
socket_output_stream_ = std::move(stream);
}
private: private:
void OpenConnection(); void OpenConnection();
void ResetInputStream(); void ResetInputStream();
...@@ -329,6 +407,17 @@ TEST_F(GCMSocketStreamTest, WritePartial) { ...@@ -329,6 +407,17 @@ TEST_F(GCMSocketStreamTest, WritePartial) {
kWriteDataSize))); kWriteDataSize)));
} }
// Regression test for crbug.com/866635.
TEST_F(GCMSocketStreamTest, WritePartialWithLengthChecking) {
auto socket = std::make_unique<FirstWritePartialSocket>();
auto socket_output_stream = std::make_unique<SocketOutputStream>(
socket.get(), TRAFFIC_ANNOTATION_FOR_TESTS);
set_socket_output_stream(std::move(socket_output_stream));
ASSERT_EQ(kWriteDataSize,
DoOutputStreamWrite(base::StringPiece(kWriteData, kWriteDataSize)));
EXPECT_EQ(kWriteData, socket->actual_data_written());
}
// Write a message completely asynchronously (returns IO_PENDING before // Write a message completely asynchronously (returns IO_PENDING before
// finishing the write in two go's). // finishing the write in two go's).
TEST_F(GCMSocketStreamTest, WriteNone) { TEST_F(GCMSocketStreamTest, WriteNone) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment