Commit 3e64a8a0 authored by Wez's avatar Wez Committed by Commit Bot

Reland "Delay Channel::OnError() in case of kDisconnected during Write()."

This is a reland of e9605182

The original CL added a unit-test which created a single-byte
Channel::Message, without actually initializing that single-byte. This
caused the MSAN bots to (correctly) spot that uninitialized data was
being read during serialization.

Original change's description:
> Delay Channel::OnError() in case of kDisconnected during Write().
>
> Write() operations to a Channel can fail due to the peer having closed
> it, while there are still messages waiting to be read from it. We must
> therefore defer notifying the caller of the Channel::Error until we
> observe end-of-stream via a readable notification, otherwise those
> messages may be dropped (depending on whether the posted OnError task
> is processed before or after a pending Channel-readable event).
>
> Bug: 816620
> Change-Id: I75bd34a48edf4022809d27ce49f9cfba7a5d4daf
> Reviewed-on: https://chromium-review.googlesource.com/956932
> Commit-Queue: Wez <wez@chromium.org>
> Reviewed-by: Ken Rockot <rockot@chromium.org>
> Cr-Commit-Position: refs/heads/master@{#542634}

TBR: rockot
Bug: 816620
Change-Id: I1a1d6eb7fa712e50b3d9c86591878900f0aeb388
Reviewed-on: https://chromium-review.googlesource.com/959762Reviewed-by: default avatarWez <wez@chromium.org>
Commit-Queue: Wez <wez@chromium.org>
Cr-Commit-Position: refs/heads/master@{#542739}
parent f0b2b594
...@@ -201,14 +201,14 @@ class ChannelFuchsia : public Channel, ...@@ -201,14 +201,14 @@ class ChannelFuchsia : public Channel,
StartOnIOThread(); StartOnIOThread();
} else { } else {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelFuchsia::StartOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelFuchsia::StartOnIOThread, this));
} }
} }
void ShutDownImpl() override { void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface. // Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelFuchsia::ShutDownOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelFuchsia::ShutDownOnIOThread, this));
} }
void Write(MessagePtr message) override { void Write(MessagePtr message) override {
...@@ -221,11 +221,11 @@ class ChannelFuchsia : public Channel, ...@@ -221,11 +221,11 @@ class ChannelFuchsia : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) { if (write_error) {
// Do not synchronously invoke OnError(). Write() may have been called by // Do not synchronously invoke OnWriteError(). Write() may have been
// the delegate and we don't want to re-enter it. // called by the delegate and we don't want to re-enter it.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, FROM_HERE, base::BindOnce(&ChannelFuchsia::OnWriteError, this,
base::Bind(&ChannelFuchsia::OnError, this, Error::kDisconnected)); Error::kDisconnected));
} }
} }
...@@ -410,6 +410,24 @@ class ChannelFuchsia : public Channel, ...@@ -410,6 +410,24 @@ class ChannelFuchsia : public Channel,
return true; return true;
} }
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
// If we can't write because the pipe is disconnected then continue
// reading to fetch any in-flight messages, relying on end-of-stream to
// signal the actual disconnection.
if (read_watch_) {
// TODO: When we add flow-control for writes, we also need to reset the
// write-watcher here.
return;
}
}
OnError(error);
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread. // Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_; scoped_refptr<Channel> self_;
......
...@@ -105,14 +105,14 @@ class ChannelPosix : public Channel, ...@@ -105,14 +105,14 @@ class ChannelPosix : public Channel,
StartOnIOThread(); StartOnIOThread();
} else { } else {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelPosix::StartOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelPosix::StartOnIOThread, this));
} }
} }
void ShutDownImpl() override { void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface. // Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelPosix::ShutDownOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelPosix::ShutDownOnIOThread, this));
} }
void Write(MessagePtr message) override { void Write(MessagePtr message) override {
...@@ -129,11 +129,11 @@ class ChannelPosix : public Channel, ...@@ -129,11 +129,11 @@ class ChannelPosix : public Channel,
} }
} }
if (write_error) { if (write_error) {
// Do not synchronously invoke OnError(). Write() may have been called by // Invoke OnWriteError() asynchronously on the IO thread, in case Write()
// the delegate and we don't want to re-enter it. // was called by the delegate, in which case we should not re-enter it.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, FROM_HERE, base::BindOnce(&ChannelPosix::OnWriteError, this,
base::Bind(&ChannelPosix::OnError, this, Error::kDisconnected)); Error::kDisconnected));
} }
} }
...@@ -243,7 +243,8 @@ class ChannelPosix : public Channel, ...@@ -243,7 +243,8 @@ class ChannelPosix : public Channel,
base::MessageLoopForIO::WATCH_WRITE, write_watcher_.get(), this); base::MessageLoopForIO::WATCH_WRITE, write_watcher_.get(), this);
} else { } else {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelPosix::WaitForWriteOnIOThread, this)); FROM_HERE,
base::BindOnce(&ChannelPosix::WaitForWriteOnIOThread, this));
} }
} }
...@@ -340,7 +341,7 @@ class ChannelPosix : public Channel, ...@@ -340,7 +341,7 @@ class ChannelPosix : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) if (write_error)
OnError(Error::kDisconnected); OnWriteError(Error::kDisconnected);
} }
// Attempts to write a message directly to the channel. If the full message // Attempts to write a message directly to the channel. If the full message
...@@ -524,6 +525,23 @@ class ChannelPosix : public Channel, ...@@ -524,6 +525,23 @@ class ChannelPosix : public Channel,
} }
#endif // defined(OS_MACOSX) #endif // defined(OS_MACOSX)
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
// If we can't write because the pipe is disconnected then continue
// reading to fetch any in-flight messages, relying on end-of-stream to
// signal the actual disconnection.
if (read_watcher_) {
write_watcher_.reset();
return;
}
}
OnError(error);
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread. // Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_; scoped_refptr<Channel> self_;
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "mojo/edk/system/channel.h" #include "mojo/edk/system/channel.h"
#include "base/bind.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "base/threading/thread.h"
#include "mojo/edk/embedder/platform_channel_pair.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -172,6 +176,101 @@ TEST(ChannelTest, OnReadNonLegacyMessage) { ...@@ -172,6 +176,101 @@ TEST(ChannelTest, OnReadNonLegacyMessage) {
channel_delegate.GetReceivedPayloadSize()); channel_delegate.GetReceivedPayloadSize());
} }
class ChannelTestShutdownAndWriteDelegate : public Channel::Delegate {
public:
ChannelTestShutdownAndWriteDelegate(
ScopedPlatformHandle handle,
scoped_refptr<base::TaskRunner> task_runner,
scoped_refptr<Channel> client_channel,
std::unique_ptr<base::Thread> client_thread,
base::RepeatingClosure quit_closure)
: quit_closure_(std::move(quit_closure)),
client_channel_(std::move(client_channel)),
client_thread_(std::move(client_thread)) {
channel_ = Channel::Create(
this, ConnectionParams(TransportProtocol::kLegacy, std::move(handle)),
std::move(task_runner));
channel_->Start();
}
~ChannelTestShutdownAndWriteDelegate() override { channel_->ShutDown(); }
// Channel::Delegate implementation
void OnChannelMessage(const void* payload,
size_t payload_size,
std::vector<ScopedPlatformHandle> handles) override {
++message_count_;
// If |client_channel_| exists then close it and its thread.
if (client_channel_) {
// Write a fresh message, making our channel readable again.
Channel::MessagePtr message = CreateDefaultMessage(false);
client_thread_->task_runner()->PostTask(
FROM_HERE, base::BindOnce(&Channel::Write, client_channel_,
base::Passed(&message)));
// Close the channel and wait for it to shutdown.
client_channel_->ShutDown();
client_channel_ = nullptr;
client_thread_->Stop();
client_thread_ = nullptr;
}
// Write a message to the channel, to verify whether this triggers an
// OnChannelError callback before all messages were read.
Channel::MessagePtr message = CreateDefaultMessage(false);
channel_->Write(std::move(message));
}
void OnChannelError(Channel::Error error) override {
EXPECT_EQ(2, message_count_);
quit_closure_.Run();
}
base::RepeatingClosure quit_closure_;
int message_count_ = 0;
scoped_refptr<Channel> channel_;
scoped_refptr<Channel> client_channel_;
std::unique_ptr<base::Thread> client_thread_;
};
TEST(ChannelTest, PeerShutdownDuringRead) {
base::MessageLoop message_loop(base::MessageLoop::TYPE_IO);
PlatformChannelPair channel_pair;
// Create a "client" Channel with one end of the pipe, and Start() it.
std::unique_ptr<base::Thread> client_thread =
std::make_unique<base::Thread>("clientio_thread");
client_thread->StartWithOptions(
base::Thread::Options(base::MessageLoop::TYPE_IO, 0));
scoped_refptr<Channel> client_channel =
Channel::Create(nullptr,
ConnectionParams(TransportProtocol::kLegacy,
channel_pair.PassClientHandle()),
client_thread->task_runner());
client_channel->Start();
// On the "client" IO thread, create and write a message.
Channel::MessagePtr message = CreateDefaultMessage(false);
client_thread->task_runner()->PostTask(
FROM_HERE,
base::BindOnce(&Channel::Write, client_channel, base::Passed(&message)));
// Create a "server" Channel with the other end of the pipe, and process the
// messages from it. The |server_delegate| will ShutDown the client end of
// the pipe after the first message, and quit the RunLoop when OnChannelError
// is received.
base::RunLoop run_loop;
ChannelTestShutdownAndWriteDelegate server_delegate(
channel_pair.PassServerHandle(), message_loop.task_runner(),
std::move(client_channel), std::move(client_thread),
run_loop.QuitClosure());
run_loop.Run();
}
} // namespace } // namespace
} // namespace edk } // namespace edk
} // namespace mojo } // namespace mojo
...@@ -81,19 +81,17 @@ class ChannelWin : public Channel, ...@@ -81,19 +81,17 @@ class ChannelWin : public Channel,
handle_(std::move(handle)), handle_(std::move(handle)),
io_task_runner_(io_task_runner) { io_task_runner_(io_task_runner) {
CHECK(handle_.is_valid()); CHECK(handle_.is_valid());
wait_for_connect_ = handle_.get().needs_connection;
} }
void Start() override { void Start() override {
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
} }
void ShutDownImpl() override { void ShutDownImpl() override {
// Always shut down asynchronously when called through the public interface. // Always shut down asynchronously when called through the public interface.
io_task_runner_->PostTask( io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this)); FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
} }
void Write(MessagePtr message) override { void Write(MessagePtr message) override {
...@@ -109,11 +107,11 @@ class ChannelWin : public Channel, ...@@ -109,11 +107,11 @@ class ChannelWin : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) { if (write_error) {
// Do not synchronously invoke OnError(). Write() may have been called by // Do not synchronously invoke OnWriteError(). Write() may have been
// the delegate and we don't want to re-enter it. // called by the delegate and we don't want to re-enter it.
io_task_runner_->PostTask( io_task_runner_->PostTask(FROM_HERE,
FROM_HERE, base::BindOnce(&ChannelWin::OnWriteError, this,
base::Bind(&ChannelWin::OnError, this, Error::kDisconnected)); Error::kDisconnected));
} }
} }
...@@ -153,7 +151,7 @@ class ChannelWin : public Channel, ...@@ -153,7 +151,7 @@ class ChannelWin : public Channel,
base::MessageLoopForIO::current()->RegisterIOHandler( base::MessageLoopForIO::current()->RegisterIOHandler(
handle_.get().handle, this); handle_.get().handle, this);
if (wait_for_connect_) { if (handle_.get().needs_connection) {
BOOL ok = ConnectNamedPipe(handle_.get().handle, BOOL ok = ConnectNamedPipe(handle_.get().handle,
&connect_context_.overlapped); &connect_context_.overlapped);
if (ok) { if (ok) {
...@@ -165,12 +163,12 @@ class ChannelWin : public Channel, ...@@ -165,12 +163,12 @@ class ChannelWin : public Channel,
const DWORD err = GetLastError(); const DWORD err = GetLastError();
switch (err) { switch (err) {
case ERROR_PIPE_CONNECTED: case ERROR_PIPE_CONNECTED:
wait_for_connect_ = false;
break; break;
case ERROR_IO_PENDING: case ERROR_IO_PENDING:
AddRef(); is_connect_pending_ = this;
return; return;
case ERROR_NO_DATA: case ERROR_NO_DATA:
default:
OnError(Error::kConnectionFailed); OnError(Error::kConnectionFailed);
return; return;
} }
...@@ -201,7 +199,7 @@ class ChannelWin : public Channel, ...@@ -201,7 +199,7 @@ class ChannelWin : public Channel,
ignore_result(handle_.release()); ignore_result(handle_.release());
handle_.reset(); handle_.reset();
// May destroy the |this| if it was the last reference. // Allow |this| to be destroyed as soon as no IO is pending.
self_ = nullptr; self_ = nullptr;
} }
...@@ -217,10 +215,13 @@ class ChannelWin : public Channel, ...@@ -217,10 +215,13 @@ class ChannelWin : public Channel,
DWORD bytes_transfered, DWORD bytes_transfered,
DWORD error) override { DWORD error) override {
if (error != ERROR_SUCCESS) { if (error != ERROR_SUCCESS) {
OnError(Error::kDisconnected); if (context == &write_context_)
OnWriteError(Error::kDisconnected);
else
OnError(Error::kDisconnected);
} else if (context == &connect_context_) { } else if (context == &connect_context_) {
DCHECK(wait_for_connect_); DCHECK(is_connect_pending_);
wait_for_connect_ = false; scoped_refptr<ChannelWin> self(std::move(is_connect_pending_));
ReadMore(0); ReadMore(0);
base::AutoLock lock(write_lock_); base::AutoLock lock(write_lock_);
...@@ -229,12 +230,14 @@ class ChannelWin : public Channel, ...@@ -229,12 +230,14 @@ class ChannelWin : public Channel,
WriteNextNoLock(); WriteNextNoLock();
} }
} else if (context == &read_context_) { } else if (context == &read_context_) {
scoped_refptr<ChannelWin> self(std::move(is_read_pending_));
OnReadDone(static_cast<size_t>(bytes_transfered)); OnReadDone(static_cast<size_t>(bytes_transfered));
} else { } else {
CHECK(context == &write_context_); CHECK(context == &write_context_);
scoped_refptr<ChannelWin> self(std::move(is_write_pending_));
OnWriteDone(static_cast<size_t>(bytes_transfered)); OnWriteDone(static_cast<size_t>(bytes_transfered));
} }
Release(); // Balancing reference taken after ReadFile / WriteFile. // |this| may have been deleted by the time we reach here.
} }
void OnReadDone(size_t bytes_read) { void OnReadDone(size_t bytes_read) {
...@@ -276,7 +279,7 @@ class ChannelWin : public Channel, ...@@ -276,7 +279,7 @@ class ChannelWin : public Channel,
reject_writes_ = write_error = true; reject_writes_ = write_error = true;
} }
if (write_error) if (write_error)
OnError(Error::kDisconnected); OnWriteError(Error::kDisconnected);
} }
void ReadMore(size_t next_read_size_hint) { void ReadMore(size_t next_read_size_hint) {
...@@ -291,7 +294,7 @@ class ChannelWin : public Channel, ...@@ -291,7 +294,7 @@ class ChannelWin : public Channel,
&read_context_.overlapped); &read_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) { if (ok || GetLastError() == ERROR_IO_PENDING) {
AddRef(); // Will be balanced in OnIOCompleted is_read_pending_ = this;
} else { } else {
OnError(Error::kDisconnected); OnError(Error::kDisconnected);
} }
...@@ -308,7 +311,7 @@ class ChannelWin : public Channel, ...@@ -308,7 +311,7 @@ class ChannelWin : public Channel,
&write_context_.overlapped); &write_context_.overlapped);
if (ok || GetLastError() == ERROR_IO_PENDING) { if (ok || GetLastError() == ERROR_IO_PENDING) {
AddRef(); // Will be balanced in OnIOCompleted. is_write_pending_ = this;
return true; return true;
} }
return false; return false;
...@@ -320,6 +323,21 @@ class ChannelWin : public Channel, ...@@ -320,6 +323,21 @@ class ChannelWin : public Channel,
return WriteNoLock(outgoing_messages_.front()); return WriteNoLock(outgoing_messages_.front());
} }
void OnWriteError(Error error) {
DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
DCHECK(reject_writes_);
if (error == Error::kDisconnected) {
// If we can't write because the pipe is disconnected then continue
// reading to fetch any in-flight messages, relying on end-of-stream to
// signal the actual disconnection.
if (is_read_pending_ || is_connect_pending_)
return;
}
OnError(error);
}
// Keeps the Channel alive at least until explicit shutdown on the IO thread. // Keeps the Channel alive at least until explicit shutdown on the IO thread.
scoped_refptr<Channel> self_; scoped_refptr<Channel> self_;
...@@ -329,16 +347,15 @@ class ChannelWin : public Channel, ...@@ -329,16 +347,15 @@ class ChannelWin : public Channel,
base::MessageLoopForIO::IOContext connect_context_; base::MessageLoopForIO::IOContext connect_context_;
base::MessageLoopForIO::IOContext read_context_; base::MessageLoopForIO::IOContext read_context_;
base::MessageLoopForIO::IOContext write_context_; base::MessageLoopForIO::IOContext write_context_;
scoped_refptr<ChannelWin> is_connect_pending_;
scoped_refptr<ChannelWin> is_read_pending_;
scoped_refptr<ChannelWin> is_write_pending_;
// Protects |reject_writes_| and |outgoing_messages_|. // Protects |delay_writes_|, |reject_writes_| and |outgoing_messages_|.
base::Lock write_lock_; base::Lock write_lock_;
base::circular_deque<MessageView> outgoing_messages_;
bool delay_writes_ = true; bool delay_writes_ = true;
bool reject_writes_ = false; bool reject_writes_ = false;
base::circular_deque<MessageView> outgoing_messages_;
bool wait_for_connect_;
bool leak_handle_ = false; bool leak_handle_ = false;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment