Commit e9605182 authored by Wez's avatar Wez Committed by Commit Bot

Delay Channel::OnError() in case of kDisconnected during Write().

Write() operations to a Channel can fail due to the peer having closed
it, while there are still messages waiting to be read from it. We must
therefore defer notifying the caller of the Channel::Error until we
observe end-of-stream via a readable notification, otherwise those
messages may be dropped (depending on whether the posted OnError task
is processed before or after a pending Channel-readable event).

Bug: 816620
Change-Id: I75bd34a48edf4022809d27ce49f9cfba7a5d4daf
Reviewed-on: https://chromium-review.googlesource.com/956932
Commit-Queue: Wez <wez@chromium.org>
Reviewed-by: default avatarKen Rockot <rockot@chromium.org>
Cr-Commit-Position: refs/heads/master@{#542634}
parent 91d8b186
...@@ -201,14 +201,14 @@ class ChannelFuchsia : public Channel, ...@@ -201,14 +201,14 @@ class ChannelFuchsia : public Channel,
StartOnIOThread(); StartOnIOThread();
} else { } else {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelFuchsia::StartOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelFuchsia::StartOnIOThread, this));
} }
} }
void ShutDownImpl() override { void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface. // Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelFuchsia::ShutDownOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelFuchsia::ShutDownOnIOThread, this));
} }
void Write(MessagePtr message) override { void Write(MessagePtr message) override {
...@@ -221,11 +221,11 @@ class ChannelFuchsia : public Channel, ...@@ -221,11 +221,11 @@ class ChannelFuchsia : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) { if (write_error) {
// Do not synchronously invoke OnError(). Write() may have been called by // Do not synchronously invoke OnWriteError(). Write() may have been
// the delegate and we don't want to re-enter it. // called by the delegate and we don't want to re-enter it.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, FROM_HERE, base::BindOnce(&ChannelFuchsia::OnWriteError, this,
base::Bind(&ChannelFuchsia::OnError, this, Error::kDisconnected)); Error::kDisconnected));
} }
} }
...@@ -410,6 +410,24 @@ class ChannelFuchsia : public Channel, ...@@ -410,6 +410,24 @@ class ChannelFuchsia : public Channel,
return true; return true;
} }
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
// If we can't write because the pipe is disconnected then continue
// reading to fetch any in-flight messages, relying on end-of-stream to
// signal the actual disconnection.
if (read_watch_) {
// TODO: When we add flow-control for writes, we also need to reset the
// write-watcher here.
return;
}
}
OnError(error);
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread. // Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_; scoped_refptr<Channel> self_;
......
...@@ -105,14 +105,14 @@ class ChannelPosix : public Channel, ...@@ -105,14 +105,14 @@ class ChannelPosix : public Channel,
StartOnIOThread(); StartOnIOThread();
} else { } else {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelPosix::StartOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelPosix::StartOnIOThread, this));
} }
} }
void ShutDownImpl() override { void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface. // Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelPosix::ShutDownOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelPosix::ShutDownOnIOThread, this));
} }
void Write(MessagePtr message) override { void Write(MessagePtr message) override {
...@@ -129,11 +129,11 @@ class ChannelPosix : public Channel, ...@@ -129,11 +129,11 @@ class ChannelPosix : public Channel,
} }
} }
if (write_error) { if (write_error) {
// Do not synchronously invoke OnError(). Write() may have been called by // Invoke OnWriteError() asynchronously on the IO thread, in case Write()
// the delegate and we don't want to re-enter it. // was called by the delegate, in which case we should not re-enter it.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, FROM_HERE, base::BindOnce(&ChannelPosix::OnWriteError, this,
base::Bind(&ChannelPosix::OnError, this, Error::kDisconnected)); Error::kDisconnected));
} }
} }
...@@ -243,7 +243,8 @@ class ChannelPosix : public Channel, ...@@ -243,7 +243,8 @@ class ChannelPosix : public Channel,
base::MessageLoopForIO::WATCH_WRITE, write_watcher_.get(), this); base::MessageLoopForIO::WATCH_WRITE, write_watcher_.get(), this);
} else { } else {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelPosix::WaitForWriteOnIOThread, this)); FROM_HERE,
base::BindOnce(&ChannelPosix::WaitForWriteOnIOThread, this));
} }
} }
...@@ -340,7 +341,7 @@ class ChannelPosix : public Channel, ...@@ -340,7 +341,7 @@ class ChannelPosix : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) if (write_error)
OnError(Error::kDisconnected); OnWriteError(Error::kDisconnected);
} }
// Attempts to write a message directly to the channel. If the full message // Attempts to write a message directly to the channel. If the full message
...@@ -524,6 +525,23 @@ class ChannelPosix : public Channel, ...@@ -524,6 +525,23 @@ class ChannelPosix : public Channel,
} }
#endif // defined(OS_MACOSX) #endif // defined(OS_MACOSX)
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
// If we can't write because the pipe is disconnected then continue
// reading to fetch any in-flight messages, relying on end-of-stream to
// signal the actual disconnection.
if (read_watcher_) {
write_watcher_.reset();
return;
}
}
OnError(error);
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread. // Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_; scoped_refptr<Channel> self_;
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "mojo/edk/system/channel.h" #include "mojo/edk/system/channel.h"
#include "base/bind.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "base/threading/thread.h"
#include "mojo/edk/embedder/platform_channel_pair.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -172,6 +176,101 @@ TEST(ChannelTest, OnReadNonLegacyMessage) { ...@@ -172,6 +176,101 @@ TEST(ChannelTest, OnReadNonLegacyMessage) {
channel_delegate.GetReceivedPayloadSize()); channel_delegate.GetReceivedPayloadSize());
} }
class ChannelTestShutdownAndWriteDelegate : public Channel::Delegate {
public:
ChannelTestShutdownAndWriteDelegate(
ScopedPlatformHandle handle,
scoped_refptr<base::TaskRunner> task_runner,
scoped_refptr<Channel> client_channel,
std::unique_ptr<base::Thread> client_thread,
base::RepeatingClosure quit_closure)
: quit_closure_(std::move(quit_closure)),
client_channel_(std::move(client_channel)),
client_thread_(std::move(client_thread)) {
channel_ = Channel::Create(
this, ConnectionParams(TransportProtocol::kLegacy, std::move(handle)),
std::move(task_runner));
channel_->Start();
}
~ChannelTestShutdownAndWriteDelegate() override { channel_->ShutDown(); }
// Channel::Delegate implementation
void OnChannelMessage(const void* payload,
size_t payload_size,
std::vector<ScopedPlatformHandle> handles) override {
++message_count_;
// If |client_channel_| exists then close it and its thread.
if (client_channel_) {
// Write a fresh message, making our channel readable again.
Channel::MessagePtr message = std::make_unique<Channel::Message>(1, 0);
client_thread_->task_runner()->PostTask(
FROM_HERE, base::BindOnce(&Channel::Write, client_channel_,
base::Passed(&message)));
// Close the channel and wait for it to shutdown.
client_channel_->ShutDown();
client_channel_ = nullptr;
client_thread_->Stop();
client_thread_ = nullptr;
}
// Write a message to the channel, to verify whether this triggers an
// OnChannelError callback before all messages were read.
Channel::MessagePtr message = std::make_unique<Channel::Message>(1, 0);
channel_->Write(std::move(message));
}
void OnChannelError(Channel::Error error) override {
EXPECT_EQ(2, message_count_);
quit_closure_.Run();
}
base::RepeatingClosure quit_closure_;
int message_count_ = 0;
scoped_refptr<Channel> channel_;
scoped_refptr<Channel> client_channel_;
std::unique_ptr<base::Thread> client_thread_;
};
TEST(ChannelTest, PeerShutdownDuringRead) {
base::MessageLoop message_loop(base::MessageLoop::TYPE_IO);
PlatformChannelPair channel_pair;
// Create a "client" Channel with one end of the pipe, and Start() it.
std::unique_ptr<base::Thread> client_thread =
std::make_unique<base::Thread>("clientio_thread");
client_thread->StartWithOptions(
base::Thread::Options(base::MessageLoop::TYPE_IO, 0));
scoped_refptr<Channel> client_channel =
Channel::Create(nullptr,
ConnectionParams(TransportProtocol::kLegacy,
channel_pair.PassClientHandle()),
client_thread->task_runner());
client_channel->Start();
// On the "client" IO thread, create and write a message.
Channel::MessagePtr message = std::make_unique<Channel::Message>(1, 0);
client_thread->task_runner()->PostTask(
FROM_HERE,
base::BindOnce(&Channel::Write, client_channel, base::Passed(&message)));
// Create a "server" Channel with the other end of the pipe, and process the
// messages from it. The |server_delegate| will ShutDown the client end of
// the pipe after the first message, and quit the RunLoop when OnChannelError
// is received.
base::RunLoop run_loop;
ChannelTestShutdownAndWriteDelegate server_delegate(
channel_pair.PassServerHandle(), message_loop.task_runner(),
std::move(client_channel), std::move(client_thread),
run_loop.QuitClosure());
run_loop.Run();
}
} // namespace } // namespace
} // namespace edk } // namespace edk
} // namespace mojo } // namespace mojo
...@@ -81,19 +81,17 @@ class ChannelWin : public Channel, ...@@ -81,19 +81,17 @@ class ChannelWin : public Channel,
handle_(std::move(handle)), handle_(std::move(handle)),
io_task_runner_(io_task_runner) { io_task_runner_(io_task_runner) {
CHECK(handle_.is_valid()); CHECK(handle_.is_valid());
wait_for_connect_ = handle_.get().needs_connection;
} }
void Start() override { void Start() override {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
} }
void ShutDownImpl() override { void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface. // Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
} }
void Write(MessagePtr message) override { void Write(MessagePtr message) override {
...@@ -109,11 +107,11 @@ class ChannelWin : public Channel, ...@@ -109,11 +107,11 @@ class ChannelWin : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) { if (write_error) {
// Do not synchronously invoke OnError(). Write() may have been called by // Do not synchronously invoke OnWriteError(). Write() may have been
// the delegate and we don't want to re-enter it. // called by the delegate and we don't want to re-enter it.
io_task_runner_->PostTask( io_task_runner_->PostTask(FROM_HERE,
FROM_HERE, base::BindOnce(&ChannelWin::OnWriteError, this,
base::Bind(&ChannelWin::OnError, this, Error::kDisconnected)); Error::kDisconnected));
} }
} }
...@@ -153,7 +151,7 @@ class ChannelWin : public Channel, ...@@ -153,7 +151,7 @@ class ChannelWin : public Channel,
base::MessageLoopForIO::current()->RegisterIOHandler( base::MessageLoopForIO::current()->RegisterIOHandler(
handle_.get().handle, this); handle_.get().handle, this);
if (wait_for_connect_) { if (handle_.get().needs_connection) {
BOOL ok = ConnectNamedPipe(handle_.get().handle, BOOL ok = ConnectNamedPipe(handle_.get().handle,
&connect_context_.overlapped); &connect_context_.overlapped);
if (ok) { if (ok) {
...@@ -165,12 +163,12 @@ class ChannelWin : public Channel, ...@@ -165,12 +163,12 @@ class ChannelWin : public Channel,
const DWORD err = GetLastError(); const DWORD err = GetLastError();
switch (err) { switch (err) {
case ERROR_PIPE_CONNECTED: case ERROR_PIPE_CONNECTED:
wait_for_connect_ = false;
break; break;
case ERROR_IO_PENDING: case ERROR_IO_PENDING:
AddRef(); is_connect_pending_ = this;
return; return;
case ERROR_NO_DATA: case ERROR_NO_DATA:
default:
OnError(Error::kConnectionFailed); OnError(Error::kConnectionFailed);
return; return;
} }
...@@ -201,7 +199,7 @@ class ChannelWin : public Channel, ...@@ -201,7 +199,7 @@ class ChannelWin : public Channel,
ignore_result(handle_.release()); ignore_result(handle_.release());
handle_.reset(); handle_.reset();
// May destroy the |this| if it was the last reference. // Allow |this| to be destroyed as soon as no IO is pending.
self_ = nullptr; self_ = nullptr;
} }
...@@ -217,10 +215,13 @@ class ChannelWin : public Channel, ...@@ -217,10 +215,13 @@ class ChannelWin : public Channel,
DWORD bytes_transfered, DWORD bytes_transfered,
DWORD error) override { DWORD error) override {
if (error != ERROR_SUCCESS) { if (error != ERROR_SUCCESS) {
OnError(Error::kDisconnected); if (context == &write_context_)
OnWriteError(Error::kDisconnected);
else
OnError(Error::kDisconnected);
} else if (context == &connect_context_) { } else if (context == &connect_context_) {
DCHECK(wait_for_connect_); DCHECK(is_connect_pending_);
wait_for_connect_ = false; scoped_refptr<ChannelWin> self(std::move(is_connect_pending_));
ReadMore(0); ReadMore(0);
base::AutoLock lock(write_lock_); base::AutoLock lock(write_lock_);
...@@ -229,12 +230,14 @@ class ChannelWin : public Channel, ...@@ -229,12 +230,14 @@ class ChannelWin : public Channel,
WriteNextNoLock(); WriteNextNoLock();
} }
} else if (context == &read_context_) { } else if (context == &read_context_) {
scoped_refptr<ChannelWin> self(std::move(is_read_pending_));
OnReadDone(static_cast<size_t>(bytes_transfered)); OnReadDone(static_cast<size_t>(bytes_transfered));
} else { } else {
CHECK(context == &write_context_); CHECK(context == &write_context_);
scoped_refptr<ChannelWin> self(std::move(is_write_pending_));
OnWriteDone(static_cast<size_t>(bytes_transfered)); OnWriteDone(static_cast<size_t>(bytes_transfered));
} }
Release(); // Balancing reference taken after ReadFile / WriteFile. // |this| may have been deleted by the time we reach here.
} }
void OnReadDone(size_t bytes_read) { void OnReadDone(size_t bytes_read) {
...@@ -276,7 +279,7 @@ class ChannelWin : public Channel, ...@@ -276,7 +279,7 @@ class ChannelWin : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) if (write_error)
OnError(Error::kDisconnected); OnWriteError(Error::kDisconnected);
} }
void ReadMore(size_t next_read_size_hint) { void ReadMore(size_t next_read_size_hint) {
...@@ -291,7 +294,7 @@ class ChannelWin : public Channel, ...@@ -291,7 +294,7 @@ class ChannelWin : public Channel,
&read_context_.overlapped); &read_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) { if (ok || GetLastError() == ERROR_IO_PENDING) {
AddRef(); // Will be balanced in OnIOCompleted is_read_pending_ = this;
} else { } else {
OnError(Error::kDisconnected); OnError(Error::kDisconnected);
} }
...@@ -308,7 +311,7 @@ class ChannelWin : public Channel, ...@@ -308,7 +311,7 @@ class ChannelWin : public Channel,
&write_context_.overlapped); &write_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) { if (ok || GetLastError() == ERROR_IO_PENDING) {
AddRef(); // Will be balanced in OnIOCompleted. is_write_pending_ = this;
return true; return true;
} }
return false; return false;
...@@ -320,6 +323,21 @@ class ChannelWin : public Channel, ...@@ -320,6 +323,21 @@ class ChannelWin : public Channel,
return WriteNoLock(outgoing_messages_.front()); return WriteNoLock(outgoing_messages_.front());
} }
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
// If we can't write because the pipe is disconnected then continue
// reading to fetch any in-flight messages, relying on end-of-stream to
// signal the actual disconnection.
if (is_read_pending_ || is_connect_pending_)
return;
}
OnError(error);
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread. // Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_; scoped_refptr<Channel> self_;
...@@ -329,16 +347,15 @@ class ChannelWin : public Channel, ...@@ -329,16 +347,15 @@ class ChannelWin : public Channel,
base::MessageLoopForIO::IOContext connect_context_; base::MessageLoopForIO::IOContext connect_context_;
base::MessageLoopForIO::IOContext read_context_; base::MessageLoopForIO::IOContext read_context_;
base::MessageLoopForIO::IOContext write_context_; base::MessageLoopForIO::IOContext write_context_;
scoped_refptr<ChannelWin> is_connect_pending_;
scoped_refptr<ChannelWin> is_read_pending_;
scoped_refptr<ChannelWin> is_write_pending_;
// Protects |reject_writes_| and |outgoing_messages_|. // Protects |delay_writes_|, |reject_writes_| and |outgoing_messages_|.
base::Lock write_lock_; base::Lock write_lock_;
base::circular_deque<MessageView> outgoing_messages_;
bool delay_writes_ = true; bool delay_writes_ = true;
bool reject_writes_ = false; bool reject_writes_ = false;
base::circular_deque<MessageView> outgoing_messages_;
bool wait_for_connect_;
bool leak_handle_ = false; bool leak_handle_ = false;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment