Commit dbc13c5e authored by Yuwei Huang's avatar Yuwei Huang Committed by Commit Bot

[remoting][gRPC] Prevent task from running after scoped stream is deleted

Previously deleting the scoped stream will only prevent new tasks from
being scheduled, and it doesn't drop pending tasks. This is a potential
issue if the caller expect its method not to be called after the scoped
stream is deleted.

This CL fixes this by making the stream request check the scoped
stream's validity before running the task.

Bug: 927962
Change-Id: I6a43b21772b89231ee89a8def68ba40c2b50c821
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1574893Reviewed-by: default avatarJoe Downing <joedow@chromium.org>
Commit-Queue: Yuwei Huang <yuweih@chromium.org>
Cr-Commit-Position: refs/heads/master@{#652592}
parent c53f440c
...@@ -8,6 +8,18 @@ ...@@ -8,6 +8,18 @@
namespace remoting { namespace remoting {
namespace {
void RunTaskIfScopedStreamIsAlive(
base::WeakPtr<ScopedGrpcServerStream> scoped_stream,
base::OnceClosure task) {
if (scoped_stream) {
std::move(task).Run();
}
}
} // namespace
GrpcAsyncServerStreamingRequestBase::GrpcAsyncServerStreamingRequestBase( GrpcAsyncServerStreamingRequestBase::GrpcAsyncServerStreamingRequestBase(
std::unique_ptr<grpc::ClientContext> context, std::unique_ptr<grpc::ClientContext> context,
base::OnceCallback<void(const grpc::Status&)> on_channel_closed, base::OnceCallback<void(const grpc::Status&)> on_channel_closed,
...@@ -18,11 +30,19 @@ GrpcAsyncServerStreamingRequestBase::GrpcAsyncServerStreamingRequestBase( ...@@ -18,11 +30,19 @@ GrpcAsyncServerStreamingRequestBase::GrpcAsyncServerStreamingRequestBase(
on_channel_closed_ = std::move(on_channel_closed); on_channel_closed_ = std::move(on_channel_closed);
*scoped_stream = *scoped_stream =
std::make_unique<ScopedGrpcServerStream>(weak_factory_.GetWeakPtr()); std::make_unique<ScopedGrpcServerStream>(weak_factory_.GetWeakPtr());
scoped_stream_ = (*scoped_stream)->GetWeakPtr();
} }
GrpcAsyncServerStreamingRequestBase::~GrpcAsyncServerStreamingRequestBase() = GrpcAsyncServerStreamingRequestBase::~GrpcAsyncServerStreamingRequestBase() =
default; default;
void GrpcAsyncServerStreamingRequestBase::RunTask(base::OnceClosure task) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(run_task_callback_);
run_task_callback_.Run(base::BindOnce(&RunTaskIfScopedStreamIsAlive,
scoped_stream_, std::move(task)));
}
bool GrpcAsyncServerStreamingRequestBase::OnDequeue(bool operation_succeeded) { bool GrpcAsyncServerStreamingRequestBase::OnDequeue(bool operation_succeeded) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (state_ == State::CLOSED) { if (state_ == State::CLOSED) {
...@@ -45,10 +65,13 @@ bool GrpcAsyncServerStreamingRequestBase::OnDequeue(bool operation_succeeded) { ...@@ -45,10 +65,13 @@ bool GrpcAsyncServerStreamingRequestBase::OnDequeue(bool operation_succeeded) {
state_ = State::STREAMING; state_ = State::STREAMING;
return true; return true;
} }
DCHECK_EQ(State::STREAMING, state_); if (state_ == State::STREAMING) {
VLOG(1) << "Streaming call received message: " << this; VLOG(1) << "Streaming call received message: " << this;
ResolveIncomingMessage(); ResolveIncomingMessage();
return true; return true;
}
NOTREACHED();
return false;
} }
void GrpcAsyncServerStreamingRequestBase::Reenqueue(void* event_tag) { void GrpcAsyncServerStreamingRequestBase::Reenqueue(void* event_tag) {
...@@ -68,9 +91,6 @@ void GrpcAsyncServerStreamingRequestBase::Reenqueue(void* event_tag) { ...@@ -68,9 +91,6 @@ void GrpcAsyncServerStreamingRequestBase::Reenqueue(void* event_tag) {
void GrpcAsyncServerStreamingRequestBase::OnRequestCanceled() { void GrpcAsyncServerStreamingRequestBase::OnRequestCanceled() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (state_ == State::CLOSED) {
return;
}
state_ = State::CLOSED; state_ = State::CLOSED;
status_ = grpc::Status::CANCELLED; status_ = grpc::Status::CANCELLED;
weak_factory_.InvalidateWeakPtrs(); weak_factory_.InvalidateWeakPtrs();
...@@ -83,9 +103,7 @@ bool GrpcAsyncServerStreamingRequestBase::CanStartRequest() const { ...@@ -83,9 +103,7 @@ bool GrpcAsyncServerStreamingRequestBase::CanStartRequest() const {
void GrpcAsyncServerStreamingRequestBase::ResolveChannelClosed() { void GrpcAsyncServerStreamingRequestBase::ResolveChannelClosed() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(run_task_callback_); RunTask(base::BindOnce(std::move(on_channel_closed_), status_));
run_task_callback_.Run(
base::BindOnce(std::move(on_channel_closed_), status_));
} }
} // namespace remoting } // namespace remoting
...@@ -49,12 +49,18 @@ class GrpcAsyncServerStreamingRequestBase : public GrpcAsyncRequest { ...@@ -49,12 +49,18 @@ class GrpcAsyncServerStreamingRequestBase : public GrpcAsyncRequest {
CLOSED, CLOSED,
}; };
void set_run_task_callback(const RunTaskCallback& callback) {
run_task_callback_ = callback;
}
// Schedules a task with |run_task_callback_|. Drops it if the scoped stream
// has been deleted right before it is being executed.
void RunTask(base::OnceClosure task);
virtual void ResolveIncomingMessage() = 0; virtual void ResolveIncomingMessage() = 0;
virtual void WaitForIncomingMessage(void* event_tag) = 0; virtual void WaitForIncomingMessage(void* event_tag) = 0;
virtual void FinishStream(void* event_tag) = 0; virtual void FinishStream(void* event_tag) = 0;
RunTaskCallback run_task_callback_;
private: private:
// GrpcAsyncRequest implementations. // GrpcAsyncRequest implementations.
bool OnDequeue(bool operation_succeeded) override; bool OnDequeue(bool operation_succeeded) override;
...@@ -66,6 +72,9 @@ class GrpcAsyncServerStreamingRequestBase : public GrpcAsyncRequest { ...@@ -66,6 +72,9 @@ class GrpcAsyncServerStreamingRequestBase : public GrpcAsyncRequest {
base::OnceCallback<void(const grpc::Status&)> on_channel_closed_; base::OnceCallback<void(const grpc::Status&)> on_channel_closed_;
State state_ = State::STARTING; State state_ = State::STARTING;
RunTaskCallback run_task_callback_;
base::WeakPtr<ScopedGrpcServerStream> scoped_stream_;
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<GrpcAsyncServerStreamingRequestBase> weak_factory_; base::WeakPtrFactory<GrpcAsyncServerStreamingRequestBase> weak_factory_;
...@@ -104,13 +113,12 @@ class GrpcAsyncServerStreamingRequest ...@@ -104,13 +113,12 @@ class GrpcAsyncServerStreamingRequest
grpc::CompletionQueue* cq, grpc::CompletionQueue* cq,
void* event_tag) override { void* event_tag) override {
reader_ = std::move(create_reader_callback_).Run(cq, event_tag); reader_ = std::move(create_reader_callback_).Run(cq, event_tag);
run_task_callback_ = run_task_cb; set_run_task_callback(run_task_cb);
} }
// GrpcAsyncServerStreamingRequestBase implementations. // GrpcAsyncServerStreamingRequestBase implementations.
void ResolveIncomingMessage() override { void ResolveIncomingMessage() override {
DCHECK(run_task_callback_); RunTask(base::BindOnce(on_incoming_msg_, response_));
run_task_callback_.Run(base::BindOnce(on_incoming_msg_, response_));
} }
void WaitForIncomingMessage(void* event_tag) override { void WaitForIncomingMessage(void* event_tag) override {
......
...@@ -10,7 +10,7 @@ namespace remoting { ...@@ -10,7 +10,7 @@ namespace remoting {
ScopedGrpcServerStream::ScopedGrpcServerStream( ScopedGrpcServerStream::ScopedGrpcServerStream(
base::WeakPtr<GrpcAsyncServerStreamingRequestBase> request) base::WeakPtr<GrpcAsyncServerStreamingRequestBase> request)
: request_(request) {} : request_(request), weak_factory_(this) {}
ScopedGrpcServerStream::~ScopedGrpcServerStream() { ScopedGrpcServerStream::~ScopedGrpcServerStream() {
if (request_) { if (request_) {
...@@ -18,4 +18,8 @@ ScopedGrpcServerStream::~ScopedGrpcServerStream() { ...@@ -18,4 +18,8 @@ ScopedGrpcServerStream::~ScopedGrpcServerStream() {
} }
} }
base::WeakPtr<ScopedGrpcServerStream> ScopedGrpcServerStream::GetWeakPtr() {
return weak_factory_.GetWeakPtr();
}
} // namespace remoting } // namespace remoting
...@@ -20,8 +20,11 @@ class ScopedGrpcServerStream { ...@@ -20,8 +20,11 @@ class ScopedGrpcServerStream {
base::WeakPtr<GrpcAsyncServerStreamingRequestBase> request); base::WeakPtr<GrpcAsyncServerStreamingRequestBase> request);
virtual ~ScopedGrpcServerStream(); virtual ~ScopedGrpcServerStream();
base::WeakPtr<ScopedGrpcServerStream> GetWeakPtr();
private: private:
base::WeakPtr<GrpcAsyncServerStreamingRequestBase> request_; base::WeakPtr<GrpcAsyncServerStreamingRequestBase> request_;
base::WeakPtrFactory<ScopedGrpcServerStream> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(ScopedGrpcServerStream); DISALLOW_COPY_AND_ASSIGN(ScopedGrpcServerStream);
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment