Commit a04c0468 authored by sergeyu's avatar sergeyu Committed by Commit bot

Fix MessageReader to pass errors to the channel

Previously MessageReader was stopping reading after the first error,
but wasn't notifying the client about the problem. This results in some
errors (e.g. from SSL layer) being ignores while they should terminate
connection.

BUG=487451

Review URL: https://codereview.chromium.org/1143443003

Cr-Commit-Position: refs/heads/master@{#329780}
parent 915be3ac
...@@ -57,14 +57,17 @@ void ChannelDispatcherBase::OnChannelReady( ...@@ -57,14 +57,17 @@ void ChannelDispatcherBase::OnChannelReady(
channel_factory_ = nullptr; channel_factory_ = nullptr;
channel_ = socket.Pass(); channel_ = socket.Pass();
writer_.Init(channel_.get(), base::Bind(&ChannelDispatcherBase::OnWriteFailed, writer_.Init(channel_.get(),
base::Unretained(this))); base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
reader_.StartReading(channel_.get()); base::Unretained(this)));
reader_.StartReading(channel_.get(),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));
event_handler_->OnChannelInitialized(this); event_handler_->OnChannelInitialized(this);
} }
void ChannelDispatcherBase::OnWriteFailed(int error) { void ChannelDispatcherBase::OnReadWriteFailed(int error) {
event_handler_->OnChannelError(this, CHANNEL_CONNECTION_ERROR); event_handler_->OnChannelError(this, CHANNEL_CONNECTION_ERROR);
} }
......
...@@ -67,7 +67,7 @@ class ChannelDispatcherBase { ...@@ -67,7 +67,7 @@ class ChannelDispatcherBase {
private: private:
void OnChannelReady(scoped_ptr<net::StreamSocket> socket); void OnChannelReady(scoped_ptr<net::StreamSocket> socket);
void OnWriteFailed(int error); void OnReadWriteFailed(int error);
std::string channel_name_; std::string channel_name_;
StreamChannelFactory* channel_factory_; StreamChannelFactory* channel_factory_;
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/callback.h" #include "base/callback.h"
#include "base/callback_helpers.h"
#include "base/location.h" #include "base/location.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/stl_util.h" #include "base/stl_util.h"
...@@ -79,7 +80,7 @@ class ChannelMultiplexer::MuxChannel { ...@@ -79,7 +80,7 @@ class ChannelMultiplexer::MuxChannel {
scoped_ptr<net::StreamSocket> CreateSocket(); scoped_ptr<net::StreamSocket> CreateSocket();
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task); const base::Closure& done_task);
void OnWriteFailed(); void OnBaseChannelError(int error);
// Called by MuxSocket. // Called by MuxSocket.
void OnSocketDestroyed(); void OnSocketDestroyed();
...@@ -107,7 +108,7 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket, ...@@ -107,7 +108,7 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
~MuxSocket() override; ~MuxSocket() override;
void OnWriteComplete(); void OnWriteComplete();
void OnWriteFailed(); void OnBaseChannelError(int error);
void OnPacketReceived(); void OnPacketReceived();
// net::StreamSocket interface. // net::StreamSocket interface.
...@@ -168,6 +169,8 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket, ...@@ -168,6 +169,8 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
private: private:
MuxChannel* channel_; MuxChannel* channel_;
int base_channel_error_ = net::OK;
net::CompletionCallback read_callback_; net::CompletionCallback read_callback_;
scoped_refptr<net::IOBuffer> read_buffer_; scoped_refptr<net::IOBuffer> read_buffer_;
int read_buffer_size_; int read_buffer_size_;
...@@ -220,9 +223,9 @@ void ChannelMultiplexer::MuxChannel::OnIncomingPacket( ...@@ -220,9 +223,9 @@ void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
} }
} }
void ChannelMultiplexer::MuxChannel::OnWriteFailed() { void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
if (socket_) if (socket_)
socket_->OnWriteFailed(); socket_->OnBaseChannelError(error);
} }
void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
...@@ -276,6 +279,9 @@ int ChannelMultiplexer::MuxSocket::Read( ...@@ -276,6 +279,9 @@ int ChannelMultiplexer::MuxSocket::Read(
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
DCHECK(read_callback_.is_null()); DCHECK(read_callback_.is_null());
if (base_channel_error_ != net::OK)
return base_channel_error_;
int result = channel_->DoRead(buffer, buffer_len); int result = channel_->DoRead(buffer, buffer_len);
if (result == 0) { if (result == 0) {
read_buffer_ = buffer; read_buffer_ = buffer;
...@@ -290,6 +296,10 @@ int ChannelMultiplexer::MuxSocket::Write( ...@@ -290,6 +296,10 @@ int ChannelMultiplexer::MuxSocket::Write(
net::IOBuffer* buffer, int buffer_len, net::IOBuffer* buffer, int buffer_len,
const net::CompletionCallback& callback) { const net::CompletionCallback& callback) {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
DCHECK(write_callback_.is_null());
if (base_channel_error_ != net::OK)
return base_channel_error_;
scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
size_t size = std::min(kMaxPacketSize, buffer_len); size_t size = std::min(kMaxPacketSize, buffer_len);
...@@ -317,19 +327,28 @@ int ChannelMultiplexer::MuxSocket::Write( ...@@ -317,19 +327,28 @@ int ChannelMultiplexer::MuxSocket::Write(
void ChannelMultiplexer::MuxSocket::OnWriteComplete() { void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
write_pending_ = false; write_pending_ = false;
if (!write_callback_.is_null()) { if (!write_callback_.is_null())
net::CompletionCallback cb; base::ResetAndReturn(&write_callback_).Run(write_result_);
std::swap(cb, write_callback_);
cb.Run(write_result_);
}
} }
void ChannelMultiplexer::MuxSocket::OnWriteFailed() { void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
if (!write_callback_.is_null()) { base_channel_error_ = error;
net::CompletionCallback cb;
std::swap(cb, write_callback_); // Here only one of the read and write callbacks is called if both of them are
cb.Run(net::ERR_FAILED); // pending. Ideally both of them should be called in that case, but that would
// require the second one to be called asynchronously which would complicate
// this code. Channels handle read and write errors the same way (see
// ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
// callbacks is enough.
if (!read_callback_.is_null()) {
base::ResetAndReturn(&read_callback_).Run(error);
return;
} }
if (!write_callback_.is_null())
base::ResetAndReturn(&write_callback_).Run(error);
} }
void ChannelMultiplexer::MuxSocket::OnPacketReceived() { void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
...@@ -337,9 +356,7 @@ void ChannelMultiplexer::MuxSocket::OnPacketReceived() { ...@@ -337,9 +356,7 @@ void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
read_buffer_ = nullptr; read_buffer_ = nullptr;
DCHECK_GT(result, 0); DCHECK_GT(result, 0);
net::CompletionCallback cb; base::ResetAndReturn(&read_callback_).Run(result);
std::swap(cb, read_callback_);
cb.Run(result);
} }
} }
...@@ -403,9 +420,11 @@ void ChannelMultiplexer::OnBaseChannelReady( ...@@ -403,9 +420,11 @@ void ChannelMultiplexer::OnBaseChannelReady(
if (base_channel_.get()) { if (base_channel_.get()) {
// Initialize reader and writer. // Initialize reader and writer.
reader_.StartReading(base_channel_.get()); reader_.StartReading(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnBaseChannelError,
base::Unretained(this)));
writer_.Init(base_channel_.get(), writer_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnWriteFailed, base::Bind(&ChannelMultiplexer::OnBaseChannelError,
base::Unretained(this))); base::Unretained(this)));
} }
...@@ -447,20 +466,21 @@ ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( ...@@ -447,20 +466,21 @@ ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
} }
void ChannelMultiplexer::OnWriteFailed(int error) { void ChannelMultiplexer::OnBaseChannelError(int error) {
for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
it != channels_.end(); ++it) { it != channels_.end(); ++it) {
base::ThreadTaskRunnerHandle::Get()->PostTask( base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed, FROM_HERE,
weak_factory_.GetWeakPtr(), it->second->name())); base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
weak_factory_.GetWeakPtr(), it->second->name(), error));
} }
} }
void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) { void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
int error) {
std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
if (it != channels_.end()) { if (it != channels_.end())
it->second->OnWriteFailed(); it->second->OnBaseChannelError(error);
}
} }
void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
......
...@@ -44,11 +44,12 @@ class ChannelMultiplexer : public StreamChannelFactory { ...@@ -44,11 +44,12 @@ class ChannelMultiplexer : public StreamChannelFactory {
// Helper method used to create channels. // Helper method used to create channels.
MuxChannel* GetOrCreateChannel(const std::string& name); MuxChannel* GetOrCreateChannel(const std::string& name);
// Error handling callback for |writer_|. // Error handling callback for |reader_| and |writer_|.
void OnWriteFailed(int error); void OnBaseChannelError(int error);
// Failed write notifier, queued asynchronously by OnWriteFailed(). // Propagates base channel error to channel |name|, queued asynchronously by
void NotifyWriteFailed(const std::string& name); // OnBaseChannelError().
void NotifyBaseChannelError(const std::string& name, int error);
// Callback for |reader_; // Callback for |reader_;
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
......
...@@ -36,6 +36,7 @@ class ClientVideoDispatcherTest : public testing::Test, ...@@ -36,6 +36,7 @@ class ClientVideoDispatcherTest : public testing::Test,
protected: protected:
void OnVideoAck(scoped_ptr<VideoAck> ack, const base::Closure& done); void OnVideoAck(scoped_ptr<VideoAck> ack, const base::Closure& done);
void OnReadError(int error);
base::MessageLoop message_loop_; base::MessageLoop message_loop_;
...@@ -72,7 +73,9 @@ ClientVideoDispatcherTest::ClientVideoDispatcherTest() ...@@ -72,7 +73,9 @@ ClientVideoDispatcherTest::ClientVideoDispatcherTest()
DCHECK(initialized_); DCHECK(initialized_);
host_socket_.PairWith( host_socket_.PairWith(
session_.fake_channel_factory().GetFakeChannel(kVideoChannelName)); session_.fake_channel_factory().GetFakeChannel(kVideoChannelName));
reader_.StartReading(&host_socket_); reader_.StartReading(&host_socket_,
base::Bind(&ClientVideoDispatcherTest::OnReadError,
base::Unretained(this)));
writer_.Init(&host_socket_, BufferedSocketWriter::WriteFailedCallback()); writer_.Init(&host_socket_, BufferedSocketWriter::WriteFailedCallback());
} }
...@@ -101,6 +104,10 @@ void ClientVideoDispatcherTest::OnVideoAck(scoped_ptr<VideoAck> ack, ...@@ -101,6 +104,10 @@ void ClientVideoDispatcherTest::OnVideoAck(scoped_ptr<VideoAck> ack,
done.Run(); done.Run();
} }
void ClientVideoDispatcherTest::OnReadError(int error) {
LOG(FATAL) << "Unexpected read error: " << error;
}
// Verify that the client can receive video packets and acks are not sent for // Verify that the client can receive video packets and acks are not sent for
// VideoPackets that don't have frame_id field set. // VideoPackets that don't have frame_id field set.
TEST_F(ClientVideoDispatcherTest, WithoutAcks) { TEST_F(ClientVideoDispatcherTest, WithoutAcks) {
......
...@@ -38,10 +38,15 @@ void MessageReader::SetMessageReceivedCallback( ...@@ -38,10 +38,15 @@ void MessageReader::SetMessageReceivedCallback(
message_received_callback_ = callback; message_received_callback_ = callback;
} }
void MessageReader::StartReading(net::Socket* socket) { void MessageReader::StartReading(
net::Socket* socket,
const ReadFailedCallback& read_failed_callback) {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
DCHECK(socket); DCHECK(socket);
DCHECK(!read_failed_callback.is_null());
socket_ = socket; socket_ = socket;
read_failed_callback_ = read_failed_callback;
DoRead(); DoRead();
} }
...@@ -49,13 +54,16 @@ void MessageReader::DoRead() { ...@@ -49,13 +54,16 @@ void MessageReader::DoRead() {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
// Don't try to read again if there is another read pending or we // Don't try to read again if there is another read pending or we
// have messages that we haven't finished processing yet. // have messages that we haven't finished processing yet.
while (!closed_ && !read_pending_ && pending_messages_ == 0) { bool read_succeeded = true;
while (read_succeeded && !closed_ && !read_pending_ &&
pending_messages_ == 0) {
read_buffer_ = new net::IOBuffer(kReadBufferSize); read_buffer_ = new net::IOBuffer(kReadBufferSize);
int result = socket_->Read( int result = socket_->Read(
read_buffer_.get(), read_buffer_.get(),
kReadBufferSize, kReadBufferSize,
base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr())); base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr()));
HandleReadResult(result);
HandleReadResult(result, &read_succeeded);
} }
} }
...@@ -65,26 +73,34 @@ void MessageReader::OnRead(int result) { ...@@ -65,26 +73,34 @@ void MessageReader::OnRead(int result) {
read_pending_ = false; read_pending_ = false;
if (!closed_) { if (!closed_) {
HandleReadResult(result); bool read_succeeded;
DoRead(); HandleReadResult(result, &read_succeeded);
if (read_succeeded)
DoRead();
} }
} }
void MessageReader::HandleReadResult(int result) { void MessageReader::HandleReadResult(int result, bool* read_succeeded) {
DCHECK(CalledOnValidThread()); DCHECK(CalledOnValidThread());
if (closed_) if (closed_)
return; return;
*read_succeeded = true;
if (result > 0) { if (result > 0) {
OnDataReceived(read_buffer_.get(), result); OnDataReceived(read_buffer_.get(), result);
*read_succeeded = true;
} else if (result == net::ERR_IO_PENDING) { } else if (result == net::ERR_IO_PENDING) {
read_pending_ = true; read_pending_ = true;
} else { } else {
if (result != net::ERR_CONNECTION_CLOSED) { DCHECK_LT(result, 0);
LOG(ERROR) << "Read() returned error " << result;
}
// Stop reading after any error. // Stop reading after any error.
closed_ = true; closed_ = true;
*read_succeeded = false;
LOG(ERROR) << "Read() returned error " << result;
read_failed_callback_.Run(result);
} }
} }
......
...@@ -35,6 +35,7 @@ class MessageReader : public base::NonThreadSafe { ...@@ -35,6 +35,7 @@ class MessageReader : public base::NonThreadSafe {
public: public:
typedef base::Callback<void(scoped_ptr<CompoundBuffer>, const base::Closure&)> typedef base::Callback<void(scoped_ptr<CompoundBuffer>, const base::Closure&)>
MessageReceivedCallback; MessageReceivedCallback;
typedef base::Callback<void(int)> ReadFailedCallback;
MessageReader(); MessageReader();
virtual ~MessageReader(); virtual ~MessageReader();
...@@ -43,16 +44,19 @@ class MessageReader : public base::NonThreadSafe { ...@@ -43,16 +44,19 @@ class MessageReader : public base::NonThreadSafe {
void SetMessageReceivedCallback(const MessageReceivedCallback& callback); void SetMessageReceivedCallback(const MessageReceivedCallback& callback);
// Starts reading from |socket|. // Starts reading from |socket|.
void StartReading(net::Socket* socket); void StartReading(net::Socket* socket,
const ReadFailedCallback& read_failed_callback);
private: private:
void DoRead(); void DoRead();
void OnRead(int result); void OnRead(int result);
void HandleReadResult(int result); void HandleReadResult(int result, bool* read_succeeded);
void OnDataReceived(net::IOBuffer* data, int data_size); void OnDataReceived(net::IOBuffer* data, int data_size);
void RunCallback(scoped_ptr<CompoundBuffer> message); void RunCallback(scoped_ptr<CompoundBuffer> message);
void OnMessageDone(); void OnMessageDone();
ReadFailedCallback read_failed_callback_;
net::Socket* socket_; net::Socket* socket_;
// Set to true, when we have a socket read pending, and expecting // Set to true, when we have a socket read pending, and expecting
......
...@@ -76,7 +76,8 @@ class MessageReaderTest : public testing::Test { ...@@ -76,7 +76,8 @@ class MessageReaderTest : public testing::Test {
void InitReader() { void InitReader() {
reader_->SetMessageReceivedCallback( reader_->SetMessageReceivedCallback(
base::Bind(&MessageReaderTest::OnMessage, base::Unretained(this))); base::Bind(&MessageReaderTest::OnMessage, base::Unretained(this)));
reader_->StartReading(&socket_); reader_->StartReading(&socket_, base::Bind(&MessageReaderTest::OnReadError,
base::Unretained(this)));
} }
void AddMessage(const std::string& message) { void AddMessage(const std::string& message) {
...@@ -92,6 +93,11 @@ class MessageReaderTest : public testing::Test { ...@@ -92,6 +93,11 @@ class MessageReaderTest : public testing::Test {
return result == expected; return result == expected;
} }
void OnReadError(int error) {
read_error_ = error;
reader_.reset();
}
void OnMessage(scoped_ptr<CompoundBuffer> buffer, void OnMessage(scoped_ptr<CompoundBuffer> buffer,
const base::Closure& done_callback) { const base::Closure& done_callback) {
messages_.push_back(buffer.release()); messages_.push_back(buffer.release());
...@@ -102,6 +108,7 @@ class MessageReaderTest : public testing::Test { ...@@ -102,6 +108,7 @@ class MessageReaderTest : public testing::Test {
scoped_ptr<MessageReader> reader_; scoped_ptr<MessageReader> reader_;
FakeStreamSocket socket_; FakeStreamSocket socket_;
MockMessageReceivedCallback callback_; MockMessageReceivedCallback callback_;
int read_error_ = 0;
std::vector<CompoundBuffer*> messages_; std::vector<CompoundBuffer*> messages_;
bool in_callback_; bool in_callback_;
}; };
...@@ -281,13 +288,12 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) { ...@@ -281,13 +288,12 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) {
TEST_F(MessageReaderTest, ReadError) { TEST_F(MessageReaderTest, ReadError) {
socket_.AppendReadError(net::ERR_FAILED); socket_.AppendReadError(net::ERR_FAILED);
// Add a message. It should never be read after the error above. EXPECT_CALL(callback_, OnMessage(_)).Times(0);
AddMessage(kTestMessage1);
EXPECT_CALL(callback_, OnMessage(_))
.Times(0);
InitReader(); InitReader();
EXPECT_EQ(net::ERR_FAILED, read_error_);
EXPECT_FALSE(reader_);
} }
// Verify that we the OnMessage callback is not reentered. // Verify that we the OnMessage callback is not reentered.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment