Commit 8750ad24 authored by Andrey Kosyakov's avatar Andrey Kosyakov Committed by Commit Bot

DevTools: introduce CrossThreadProtocolCallback

... to assure DevTools protocol callbacks are always invoked (and destroyed) on the
correct thread.

Change-Id: Ie94c67d754138161ba22192bb76891f4021ad280
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1775501
Commit-Queue: Andrey Kosyakov <caseq@chromium.org>
Reviewed-by: default avatarAlexei Filippov <alph@chromium.org>
Cr-Commit-Position: refs/heads/master@{#691793}
parent 2eb51578
...@@ -673,6 +673,7 @@ jumbo_source_set("browser") { ...@@ -673,6 +673,7 @@ jumbo_source_set("browser") {
"data_url_loader_factory.h", "data_url_loader_factory.h",
"devtools/browser_devtools_agent_host.cc", "devtools/browser_devtools_agent_host.cc",
"devtools/browser_devtools_agent_host.h", "devtools/browser_devtools_agent_host.h",
"devtools/cross_thread_protocol_callback.h",
"devtools/devtools_agent_host_impl.cc", "devtools/devtools_agent_host_impl.cc",
"devtools/devtools_agent_host_impl.h", "devtools/devtools_agent_host_impl.h",
"devtools/devtools_background_services_context_impl.cc", "devtools/devtools_background_services_context_impl.cc",
......
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CONTENT_BROWSER_DEVTOOLS_CROSS_THREAD_PROTOCOL_CALLBACK_H_
#define CONTENT_BROWSER_DEVTOOLS_CROSS_THREAD_PROTOCOL_CALLBACK_H_
#include <memory>
#include "base/task/post_task.h"
#include "content/public/browser/browser_thread.h"
namespace content {
// A wrapper for a DevTools protocol method callback that assures the
// underlying callback is called on the correct thread. Use this to pass
// the protocol callback to methods handling DevTools commands on threads
// other than UI.
template <typename ProtocolCallback>
class CrossThreadProtocolCallback {
public:
explicit CrossThreadProtocolCallback(
std::unique_ptr<ProtocolCallback> callback)
: callback_(std::move(callback)) {}
CrossThreadProtocolCallback(CrossThreadProtocolCallback&& r) = default;
template <typename... Args>
void sendSuccess(Args&&... args) {
base::PostTask(
FROM_HERE, {BrowserThread::UI},
base::BindOnce(&ProtocolCallback::sendSuccess, std::move(callback_),
std::forward<Args>(args)...));
}
void sendFailure(protocol::DispatchResponse response) {
base::PostTask(FROM_HERE, {BrowserThread::UI},
base::BindOnce(&ProtocolCallback::sendFailure,
std::move(callback_), std::move(response)));
}
~CrossThreadProtocolCallback() {
BrowserThread::DeleteSoon({BrowserThread::UI}, FROM_HERE,
std::move(callback_));
}
private:
std::unique_ptr<ProtocolCallback> callback_;
};
template <typename ProtocolCallback>
CrossThreadProtocolCallback<ProtocolCallback> WrapForAnotherThread(
std::unique_ptr<ProtocolCallback> callback) {
return CrossThreadProtocolCallback<ProtocolCallback>(std::move(callback));
}
} // namespace content
#endif // CONTENT_BROWSER_DEVTOOLS_CROSS_THREAD_PROTOCOL_CALLBACK_H_
\ No newline at end of file
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "base/task/post_task.h" #include "base/task/post_task.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "base/unguessable_token.h" #include "base/unguessable_token.h"
#include "content/browser/devtools/cross_thread_protocol_callback.h"
#include "content/browser/devtools/protocol/network.h" #include "content/browser/devtools/protocol/network.h"
#include "content/browser/devtools/protocol/network_handler.h" #include "content/browser/devtools/protocol/network_handler.h"
#include "content/browser/frame_host/frame_tree_node.h" #include "content/browser/frame_host/frame_tree_node.h"
...@@ -37,7 +38,7 @@ using RequestInterceptedCallback = ...@@ -37,7 +38,7 @@ using RequestInterceptedCallback =
DevToolsNetworkInterceptor::RequestInterceptedCallback; DevToolsNetworkInterceptor::RequestInterceptedCallback;
using ContinueInterceptedRequestCallback = using ContinueInterceptedRequestCallback =
DevToolsNetworkInterceptor::ContinueInterceptedRequestCallback; DevToolsNetworkInterceptor::ContinueInterceptedRequestCallback;
using GetResponseBodyForInterceptionCallback = using GetResponseBodyCallback =
DevToolsNetworkInterceptor::GetResponseBodyForInterceptionCallback; DevToolsNetworkInterceptor::GetResponseBodyForInterceptionCallback;
using TakeResponseBodyPipeCallback = using TakeResponseBodyPipeCallback =
DevToolsNetworkInterceptor::TakeResponseBodyPipeCallback; DevToolsNetworkInterceptor::TakeResponseBodyPipeCallback;
...@@ -74,15 +75,14 @@ class BodyReader : public mojo::DataPipeDrainer::Client { ...@@ -74,15 +75,14 @@ class BodyReader : public mojo::DataPipeDrainer::Client {
void StartReading(mojo::ScopedDataPipeConsumerHandle body); void StartReading(mojo::ScopedDataPipeConsumerHandle body);
void AddCallback( using GetBodyCallback = CrossThreadProtocolCallback<GetResponseBodyCallback>;
std::unique_ptr<GetResponseBodyForInterceptionCallback> callback) { void AddCallback(GetBodyCallback callback) {
callbacks_.push_back(std::move(callback));
if (data_complete_) { if (data_complete_) {
DCHECK_EQ(1UL, callbacks_.size()); DCHECK(callbacks_.empty());
base::PostTask(FROM_HERE, {BrowserThread::UI}, callback.sendSuccess(encoded_body_, true);
base::BindOnce(&BodyReader::DispatchBodyOnUI, return;
std::move(callbacks_), encoded_body_));
} }
callbacks_.push_back(std::move(callback));
} }
bool data_complete() const { return data_complete_; } bool data_complete() const { return data_complete_; }
...@@ -93,19 +93,12 @@ class BodyReader : public mojo::DataPipeDrainer::Client { ...@@ -93,19 +93,12 @@ class BodyReader : public mojo::DataPipeDrainer::Client {
} }
void CancelWithError(std::string error) { void CancelWithError(std::string error) {
base::PostTask(FROM_HERE, {BrowserThread::UI}, for (auto& cb : callbacks_)
base::BindOnce(&BodyReader::DispatchErrorOnUI, cb.sendFailure(Response::Error(error));
std::move(callbacks_), std::move(error))); callbacks_.clear();
} }
private: private:
using CallbackVector =
std::vector<std::unique_ptr<GetResponseBodyForInterceptionCallback>>;
static void DispatchBodyOnUI(const CallbackVector& callbacks,
const std::string& body);
static void DispatchErrorOnUI(const CallbackVector& callbacks,
const std::string& error);
void OnDataAvailable(const void* data, size_t num_bytes) override { void OnDataAvailable(const void* data, size_t num_bytes) override {
DCHECK(!data_complete_); DCHECK(!data_complete_);
body_->data().append( body_->data().append(
...@@ -115,7 +108,7 @@ class BodyReader : public mojo::DataPipeDrainer::Client { ...@@ -115,7 +108,7 @@ class BodyReader : public mojo::DataPipeDrainer::Client {
void OnDataComplete() override; void OnDataComplete() override;
std::unique_ptr<mojo::DataPipeDrainer> body_pipe_drainer_; std::unique_ptr<mojo::DataPipeDrainer> body_pipe_drainer_;
CallbackVector callbacks_; std::vector<GetBodyCallback> callbacks_;
base::OnceClosure download_complete_callback_; base::OnceClosure download_complete_callback_;
scoped_refptr<base::RefCountedString> body_; scoped_refptr<base::RefCountedString> body_;
std::string encoded_body_; std::string encoded_body_;
...@@ -136,26 +129,12 @@ void BodyReader::OnDataComplete() { ...@@ -136,26 +129,12 @@ void BodyReader::OnDataComplete() {
body_pipe_drainer_.reset(); body_pipe_drainer_.reset();
// TODO(caseq): only encode if necessary. // TODO(caseq): only encode if necessary.
base::Base64Encode(body_->data(), &encoded_body_); base::Base64Encode(body_->data(), &encoded_body_);
base::PostTask(FROM_HERE, {BrowserThread::UI}, for (auto& cb : callbacks_)
base::BindOnce(&BodyReader::DispatchBodyOnUI, cb.sendSuccess(encoded_body_, true);
std::move(callbacks_), encoded_body_)); callbacks_.clear();
std::move(download_complete_callback_).Run(); std::move(download_complete_callback_).Run();
} }
// static
void BodyReader::DispatchBodyOnUI(const CallbackVector& callbacks,
const std::string& encoded_body) {
for (const auto& cb : callbacks)
cb->sendSuccess(encoded_body, true);
}
// static
void BodyReader::DispatchErrorOnUI(const CallbackVector& callbacks,
const std::string& error) {
for (const auto& cb : callbacks)
cb->sendFailure(Response::Error(error));
}
struct ResponseMetadata { struct ResponseMetadata {
ResponseMetadata() = default; ResponseMetadata() = default;
explicit ResponseMetadata(const network::ResourceResponseHead& head) explicit ResponseMetadata(const network::ResourceResponseHead& head)
...@@ -193,11 +172,11 @@ class InterceptionJob : public network::mojom::URLLoaderClient, ...@@ -193,11 +172,11 @@ class InterceptionJob : public network::mojom::URLLoaderClient,
mojo::PendingRemote<network::mojom::CookieManager> cookie_manager); mojo::PendingRemote<network::mojom::CookieManager> cookie_manager);
void GetResponseBody( void GetResponseBody(
std::unique_ptr<GetResponseBodyForInterceptionCallback> callback); CrossThreadProtocolCallback<GetResponseBodyCallback> callback);
void TakeResponseBodyPipe(TakeResponseBodyPipeCallback callback); void TakeResponseBodyPipe(TakeResponseBodyPipeCallback callback);
void ContinueInterceptedRequest( void ContinueInterceptedRequest(
std::unique_ptr<Modifications> modifications, std::unique_ptr<Modifications> modifications,
std::unique_ptr<ContinueInterceptedRequestCallback> callback); CrossThreadProtocolCallback<ContinueInterceptedRequestCallback> callback);
void Detach(); void Detach();
void OnAuthRequest( void OnAuthRequest(
...@@ -389,7 +368,7 @@ class DevToolsURLLoaderInterceptor::Impl ...@@ -389,7 +368,7 @@ class DevToolsURLLoaderInterceptor::Impl
void GetResponseBody( void GetResponseBody(
const std::string& interception_id, const std::string& interception_id,
std::unique_ptr<GetResponseBodyForInterceptionCallback> callback) { CrossThreadProtocolCallback<GetResponseBodyCallback> callback) {
if (InterceptionJob* job = FindJob(interception_id, &callback)) if (InterceptionJob* job = FindJob(interception_id, &callback))
job->GetResponseBody(std::move(callback)); job->GetResponseBody(std::move(callback));
} }
...@@ -413,7 +392,8 @@ class DevToolsURLLoaderInterceptor::Impl ...@@ -413,7 +392,8 @@ class DevToolsURLLoaderInterceptor::Impl
void ContinueInterceptedRequest( void ContinueInterceptedRequest(
const std::string& interception_id, const std::string& interception_id,
std::unique_ptr<Modifications> modifications, std::unique_ptr<Modifications> modifications,
std::unique_ptr<ContinueInterceptedRequestCallback> callback) { CrossThreadProtocolCallback<ContinueInterceptedRequestCallback>
callback) {
if (InterceptionJob* job = FindJob(interception_id, &callback)) { if (InterceptionJob* job = FindJob(interception_id, &callback)) {
job->ContinueInterceptedRequest(std::move(modifications), job->ContinueInterceptedRequest(std::move(modifications),
std::move(callback)); std::move(callback));
...@@ -425,14 +405,12 @@ class DevToolsURLLoaderInterceptor::Impl ...@@ -425,14 +405,12 @@ class DevToolsURLLoaderInterceptor::Impl
template <typename Callback> template <typename Callback>
InterceptionJob* FindJob(const std::string& id, InterceptionJob* FindJob(const std::string& id,
std::unique_ptr<Callback>* callback) { CrossThreadProtocolCallback<Callback>* callback) {
auto it = jobs_.find(id); auto it = jobs_.find(id);
if (it != jobs_.end()) if (it != jobs_.end())
return it->second; return it->second;
base::PostTask(FROM_HERE, {BrowserThread::UI}, callback->sendFailure(
base::BindOnce(&Callback::sendFailure, std::move(*callback), protocol::Response::InvalidParams("Invalid InterceptionId."));
protocol::Response::InvalidParams(
"Invalid InterceptionId.")));
return nullptr; return nullptr;
} }
...@@ -617,11 +595,11 @@ void DevToolsURLLoaderInterceptor::SetPatterns( ...@@ -617,11 +595,11 @@ void DevToolsURLLoaderInterceptor::SetPatterns(
void DevToolsURLLoaderInterceptor::GetResponseBody( void DevToolsURLLoaderInterceptor::GetResponseBody(
const std::string& interception_id, const std::string& interception_id,
std::unique_ptr<GetResponseBodyForInterceptionCallback> callback) { std::unique_ptr<GetResponseBodyCallback> callback) {
base::PostTask( base::PostTask(FROM_HERE, {BrowserThread::IO},
FROM_HERE, {BrowserThread::IO}, base::BindOnce(&Impl::GetResponseBody,
base::BindOnce(&Impl::GetResponseBody, base::Unretained(impl_.get()), base::Unretained(impl_.get()), interception_id,
interception_id, std::move(callback))); WrapForAnotherThread(std::move(callback))));
} }
void DevToolsURLLoaderInterceptor::TakeResponseBodyPipe( void DevToolsURLLoaderInterceptor::TakeResponseBodyPipe(
...@@ -640,7 +618,8 @@ void DevToolsURLLoaderInterceptor::ContinueInterceptedRequest( ...@@ -640,7 +618,8 @@ void DevToolsURLLoaderInterceptor::ContinueInterceptedRequest(
base::PostTask(FROM_HERE, {BrowserThread::IO}, base::PostTask(FROM_HERE, {BrowserThread::IO},
base::BindOnce(&Impl::ContinueInterceptedRequest, base::BindOnce(&Impl::ContinueInterceptedRequest,
base::Unretained(impl_.get()), interception_id, base::Unretained(impl_.get()), interception_id,
std::move(modifications), std::move(callback))); std::move(modifications),
WrapForAnotherThread(std::move(callback))));
} }
bool DevToolsURLLoaderInterceptor::CreateProxyForInterception( bool DevToolsURLLoaderInterceptor::CreateProxyForInterception(
...@@ -754,14 +733,10 @@ bool InterceptionJob::CanGetResponseBody(std::string* error_reason) { ...@@ -754,14 +733,10 @@ bool InterceptionJob::CanGetResponseBody(std::string* error_reason) {
} }
void InterceptionJob::GetResponseBody( void InterceptionJob::GetResponseBody(
std::unique_ptr<GetResponseBodyForInterceptionCallback> callback) { CrossThreadProtocolCallback<GetResponseBodyCallback> callback) {
std::string error_reason; std::string error_reason;
if (!CanGetResponseBody(&error_reason)) { if (!CanGetResponseBody(&error_reason)) {
base::PostTask( callback.sendFailure(Response::Error(std::move(error_reason)));
FROM_HERE, {BrowserThread::UI},
base::BindOnce(&GetResponseBodyForInterceptionCallback::sendFailure,
std::move(callback),
Response::Error(std::move(error_reason))));
return; return;
} }
if (!body_reader_) { if (!body_reader_) {
...@@ -794,16 +769,13 @@ void InterceptionJob::TakeResponseBodyPipe( ...@@ -794,16 +769,13 @@ void InterceptionJob::TakeResponseBodyPipe(
void InterceptionJob::ContinueInterceptedRequest( void InterceptionJob::ContinueInterceptedRequest(
std::unique_ptr<Modifications> modifications, std::unique_ptr<Modifications> modifications,
std::unique_ptr<ContinueInterceptedRequestCallback> callback) { CrossThreadProtocolCallback<ContinueInterceptedRequestCallback> callback) {
Response response = InnerContinueRequest(std::move(modifications)); Response response = InnerContinueRequest(std::move(modifications));
// |this| may be destroyed at this point. // |this| may be destroyed at this point.
bool success = response.isSuccess(); if (response.isSuccess())
base::OnceClosure task = callback.sendSuccess();
success ? base::BindOnce(&ContinueInterceptedRequestCallback::sendSuccess, else
std::move(callback)) callback.sendFailure(std::move(response));
: base::BindOnce(&ContinueInterceptedRequestCallback::sendFailure,
std::move(callback), std::move(response));
base::PostTask(FROM_HERE, {BrowserThread::UI}, std::move(task));
} }
void InterceptionJob::Detach() { void InterceptionJob::Detach() {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment