Commit a76c901a authored by Clark DuVall's avatar Clark DuVall Committed by Commit Bot

Add FlushAsyncForTesting to InterfacePtr API

This will be used in
URLLoaderFactoryGetter::FlushNetworkInterfaceForTesting to prevent
nested IO message loops.

Hit this in
https://chromium-review.googlesource.com/c/chromium/src/+/1139048.

Bug: 857577
Change-Id: I1b6a53f39737ff0a6d4681b4256991569b91a2d5
Reviewed-on: https://chromium-review.googlesource.com/1142198
Commit-Queue: Clark DuVall <cduvall@chromium.org>
Reviewed-by: default avatarKen Rockot <rockot@chromium.org>
Cr-Commit-Position: refs/heads/master@{#576350}
parent 0e7f1e0f
...@@ -113,6 +113,7 @@ class MOJO_CPP_BINDINGS_EXPORT InterfaceEndpointClient ...@@ -113,6 +113,7 @@ class MOJO_CPP_BINDINGS_EXPORT InterfaceEndpointClient
void QueryVersion(const base::Callback<void(uint32_t)>& callback); void QueryVersion(const base::Callback<void(uint32_t)>& callback);
void RequireVersion(uint32_t version); void RequireVersion(uint32_t version);
void FlushForTesting(); void FlushForTesting();
void FlushAsyncForTesting(base::OnceClosure callback);
private: private:
// Maps from the id of a response to the MessageReceiver that handles the // Maps from the id of a response to the MessageReceiver that handles the
......
...@@ -127,6 +127,12 @@ class InterfacePtr { ...@@ -127,6 +127,12 @@ class InterfacePtr {
// stimulus. // stimulus.
void FlushForTesting() { internal_state_.FlushForTesting(); } void FlushForTesting() { internal_state_.FlushForTesting(); }
// Same as |FlushForTesting()| but will call |callback| when the flush is
// complete.
void FlushAsyncForTesting(base::OnceClosure callback) {
internal_state_.FlushAsyncForTesting(std::move(callback));
}
// Closes the bound message pipe, if any. // Closes the bound message pipe, if any.
void reset() { void reset() {
State doomed; State doomed;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "base/callback_helpers.h" #include "base/callback_helpers.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/run_loop.h" #include "base/run_loop.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/serialization.h"
#include "mojo/public/cpp/bindings/lib/validation_util.h" #include "mojo/public/cpp/bindings/lib/validation_util.h"
#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/bindings/message.h"
...@@ -39,12 +40,12 @@ bool ValidateControlResponse(Message* message) { ...@@ -39,12 +40,12 @@ bool ValidateControlResponse(Message* message) {
} }
using RunCallback = using RunCallback =
base::Callback<void(interface_control::RunResponseMessageParamsPtr)>; base::OnceCallback<void(interface_control::RunResponseMessageParamsPtr)>;
class RunResponseForwardToCallback : public MessageReceiver { class RunResponseForwardToCallback : public MessageReceiver {
public: public:
explicit RunResponseForwardToCallback(const RunCallback& callback) explicit RunResponseForwardToCallback(RunCallback callback)
: callback_(callback) {} : callback_(std::move(callback)) {}
bool Accept(Message* message) override; bool Accept(Message* message) override;
private: private:
...@@ -65,13 +66,13 @@ bool RunResponseForwardToCallback::Accept(Message* message) { ...@@ -65,13 +66,13 @@ bool RunResponseForwardToCallback::Accept(Message* message) {
Deserialize<interface_control::RunResponseMessageParamsDataView>( Deserialize<interface_control::RunResponseMessageParamsDataView>(
params, &params_ptr, &context); params, &params_ptr, &context);
callback_.Run(std::move(params_ptr)); std::move(callback_).Run(std::move(params_ptr));
return true; return true;
} }
void SendRunMessage(MessageReceiverWithResponder* receiver, void SendRunMessage(MessageReceiverWithResponder* receiver,
interface_control::RunInputPtr input_ptr, interface_control::RunInputPtr input_ptr,
const RunCallback& callback) { RunCallback callback) {
auto params_ptr = interface_control::RunMessageParams::New(); auto params_ptr = interface_control::RunMessageParams::New();
params_ptr->input = std::move(input_ptr); params_ptr->input = std::move(input_ptr);
Message message(interface_control::kRunMessageId, Message message(interface_control::kRunMessageId,
...@@ -81,7 +82,7 @@ void SendRunMessage(MessageReceiverWithResponder* receiver, ...@@ -81,7 +82,7 @@ void SendRunMessage(MessageReceiverWithResponder* receiver,
Serialize<interface_control::RunMessageParamsDataView>( Serialize<interface_control::RunMessageParamsDataView>(
params_ptr, message.payload_buffer(), &params, &context); params_ptr, message.payload_buffer(), &params, &context);
std::unique_ptr<MessageReceiver> responder = std::unique_ptr<MessageReceiver> responder =
std::make_unique<RunResponseForwardToCallback>(callback); std::make_unique<RunResponseForwardToCallback>(std::move(callback));
ignore_result(receiver->AcceptWithResponder(&message, std::move(responder))); ignore_result(receiver->AcceptWithResponder(&message, std::move(responder)));
} }
...@@ -115,9 +116,9 @@ void RunVersionCallback( ...@@ -115,9 +116,9 @@ void RunVersionCallback(
callback.Run(version); callback.Run(version);
} }
void RunClosure(const base::Closure& callback, void RunClosure(base::OnceClosure callback,
interface_control::RunResponseMessageParamsPtr run_response) { interface_control::RunResponseMessageParamsPtr run_response) {
callback.Run(); std::move(callback).Run();
} }
} // namespace } // namespace
...@@ -133,7 +134,7 @@ void ControlMessageProxy::QueryVersion( ...@@ -133,7 +134,7 @@ void ControlMessageProxy::QueryVersion(
auto input_ptr = interface_control::RunInput::New(); auto input_ptr = interface_control::RunInput::New();
input_ptr->set_query_version(interface_control::QueryVersion::New()); input_ptr->set_query_version(interface_control::QueryVersion::New());
SendRunMessage(receiver_, std::move(input_ptr), SendRunMessage(receiver_, std::move(input_ptr),
base::Bind(&RunVersionCallback, callback)); base::BindOnce(&RunVersionCallback, callback));
} }
void ControlMessageProxy::RequireVersion(uint32_t version) { void ControlMessageProxy::RequireVersion(uint32_t version) {
...@@ -145,29 +146,38 @@ void ControlMessageProxy::RequireVersion(uint32_t version) { ...@@ -145,29 +146,38 @@ void ControlMessageProxy::RequireVersion(uint32_t version) {
} }
void ControlMessageProxy::FlushForTesting() { void ControlMessageProxy::FlushForTesting() {
if (encountered_error_) base::RunLoop run_loop(base::RunLoop::Type::kNestableTasksAllowed);
FlushAsyncForTesting(run_loop.QuitClosure());
run_loop.Run();
}
void ControlMessageProxy::FlushAsyncForTesting(base::OnceClosure callback) {
if (encountered_error_) {
base::SequencedTaskRunnerHandle::Get()->PostTask(FROM_HERE,
std::move(callback));
return; return;
}
auto input_ptr = interface_control::RunInput::New(); auto input_ptr = interface_control::RunInput::New();
input_ptr->set_flush_for_testing(interface_control::FlushForTesting::New()); input_ptr->set_flush_for_testing(interface_control::FlushForTesting::New());
base::RunLoop run_loop(base::RunLoop::Type::kNestableTasksAllowed); DCHECK(!pending_flush_callback_);
run_loop_quit_closure_ = run_loop.QuitClosure(); pending_flush_callback_ = std::move(callback);
SendRunMessage( SendRunMessage(
receiver_, std::move(input_ptr), receiver_, std::move(input_ptr),
base::Bind(&RunClosure, base::BindOnce(
base::Bind(&ControlMessageProxy::RunFlushForTestingClosure, &RunClosure,
base::BindOnce(&ControlMessageProxy::RunFlushForTestingClosure,
base::Unretained(this)))); base::Unretained(this))));
run_loop.Run();
} }
void ControlMessageProxy::RunFlushForTestingClosure() { void ControlMessageProxy::RunFlushForTestingClosure() {
DCHECK(!run_loop_quit_closure_.is_null()); DCHECK(!pending_flush_callback_.is_null());
base::ResetAndReturn(&run_loop_quit_closure_).Run(); std::move(pending_flush_callback_).Run();
} }
void ControlMessageProxy::OnConnectionError() { void ControlMessageProxy::OnConnectionError() {
encountered_error_ = true; encountered_error_ = true;
if (!run_loop_quit_closure_.is_null()) if (!pending_flush_callback_.is_null())
RunFlushForTestingClosure(); RunFlushForTestingClosure();
} }
......
...@@ -29,6 +29,7 @@ class MOJO_CPP_BINDINGS_EXPORT ControlMessageProxy { ...@@ -29,6 +29,7 @@ class MOJO_CPP_BINDINGS_EXPORT ControlMessageProxy {
void RequireVersion(uint32_t version); void RequireVersion(uint32_t version);
void FlushForTesting(); void FlushForTesting();
void FlushAsyncForTesting(base::OnceClosure callback);
void OnConnectionError(); void OnConnectionError();
private: private:
...@@ -38,7 +39,7 @@ class MOJO_CPP_BINDINGS_EXPORT ControlMessageProxy { ...@@ -38,7 +39,7 @@ class MOJO_CPP_BINDINGS_EXPORT ControlMessageProxy {
MessageReceiverWithResponder* receiver_; MessageReceiverWithResponder* receiver_;
bool encountered_error_ = false; bool encountered_error_ = false;
base::Closure run_loop_quit_closure_; base::OnceClosure pending_flush_callback_;
DISALLOW_COPY_AND_ASSIGN(ControlMessageProxy); DISALLOW_COPY_AND_ASSIGN(ControlMessageProxy);
}; };
......
...@@ -347,6 +347,10 @@ void InterfaceEndpointClient::FlushForTesting() { ...@@ -347,6 +347,10 @@ void InterfaceEndpointClient::FlushForTesting() {
control_message_proxy_.FlushForTesting(); control_message_proxy_.FlushForTesting();
} }
void InterfaceEndpointClient::FlushAsyncForTesting(base::OnceClosure callback) {
control_message_proxy_.FlushAsyncForTesting(std::move(callback));
}
void InterfaceEndpointClient::InitControllerIfNecessary() { void InterfaceEndpointClient::InitControllerIfNecessary() {
if (controller_ || handle_.pending_association()) if (controller_ || handle_.pending_association())
return; return;
......
...@@ -131,6 +131,11 @@ class InterfacePtrState : public InterfacePtrStateBase { ...@@ -131,6 +131,11 @@ class InterfacePtrState : public InterfacePtrStateBase {
endpoint_client()->FlushForTesting(); endpoint_client()->FlushForTesting();
} }
void FlushAsyncForTesting(base::OnceClosure callback) {
ConfigureProxyIfNecessary();
endpoint_client()->FlushAsyncForTesting(std::move(callback));
}
void CloseWithReason(uint32_t custom_reason, const std::string& description) { void CloseWithReason(uint32_t custom_reason, const std::string& description) {
ConfigureProxyIfNecessary(); ConfigureProxyIfNecessary();
endpoint_client()->CloseWithReason(custom_reason, description); endpoint_client()->CloseWithReason(custom_reason, description);
......
...@@ -754,6 +754,27 @@ TEST_P(InterfacePtrTest, FlushForTesting) { ...@@ -754,6 +754,27 @@ TEST_P(InterfacePtrTest, FlushForTesting) {
EXPECT_EQ(10.0, calculator_ui.GetOutput()); EXPECT_EQ(10.0, calculator_ui.GetOutput());
} }
TEST_P(InterfacePtrTest, FlushAsyncForTesting) {
math::CalculatorPtr calc;
MathCalculatorImpl calc_impl(MakeRequest(&calc));
calc.set_connection_error_handler(base::BindOnce(&Fail));
MathCalculatorUI calculator_ui(std::move(calc));
calculator_ui.Add(2.0, base::DoNothing());
base::RunLoop run_loop;
calculator_ui.GetInterfacePtr().FlushAsyncForTesting(run_loop.QuitClosure());
run_loop.Run();
EXPECT_EQ(2.0, calculator_ui.GetOutput());
calculator_ui.Multiply(5.0, base::DoNothing());
base::RunLoop run_loop2;
calculator_ui.GetInterfacePtr().FlushAsyncForTesting(run_loop2.QuitClosure());
run_loop2.Run();
EXPECT_EQ(10.0, calculator_ui.GetOutput());
}
void SetBool(bool* value) { void SetBool(bool* value) {
*value = true; *value = true;
} }
...@@ -768,6 +789,21 @@ TEST_P(InterfacePtrTest, FlushForTestingWithClosedPeer) { ...@@ -768,6 +789,21 @@ TEST_P(InterfacePtrTest, FlushForTestingWithClosedPeer) {
calc.FlushForTesting(); calc.FlushForTesting();
} }
TEST_P(InterfacePtrTest, FlushAsyncForTestingWithClosedPeer) {
math::CalculatorPtr calc;
MakeRequest(&calc);
bool called = false;
calc.set_connection_error_handler(base::BindOnce(&SetBool, &called));
base::RunLoop run_loop;
calc.FlushAsyncForTesting(run_loop.QuitClosure());
run_loop.Run();
EXPECT_TRUE(called);
base::RunLoop run_loop2;
calc.FlushAsyncForTesting(run_loop2.QuitClosure());
run_loop2.Run();
}
TEST_P(InterfacePtrTest, ConnectionErrorWithReason) { TEST_P(InterfacePtrTest, ConnectionErrorWithReason) {
math::CalculatorPtr calc; math::CalculatorPtr calc;
MathCalculatorImpl calc_impl(MakeRequest(&calc)); MathCalculatorImpl calc_impl(MakeRequest(&calc));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment