Commit ffaa6833 authored by Yuwei Huang's avatar Yuwei Huang Committed by Commit Bot

[remoting] Add stream request support to ProtobufHttpStatus

* Add stream request support to ProtobufHttpStatus, which includes some
  refactoring in the code.
* Make ProtobufHttpClient directly own the requests, to make lifetime
  managing less painful.

Bug: 1103416
Change-Id: I9c1f62acb4e02cd968387bc4cc34db8dd6268c25
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2299536
Commit-Queue: Yuwei Huang <yuweih@chromium.org>
Reviewed-by: default avatarJoe Downing <joedow@chromium.org>
Cr-Commit-Position: refs/heads/master@{#790112}
parent 526a61db
...@@ -33,10 +33,16 @@ source_set("base") { ...@@ -33,10 +33,16 @@ source_set("base") {
"protobuf_http_client.h", "protobuf_http_client.h",
"protobuf_http_request.cc", "protobuf_http_request.cc",
"protobuf_http_request.h", "protobuf_http_request.h",
"protobuf_http_request_base.cc",
"protobuf_http_request_base.h",
"protobuf_http_request_config.cc",
"protobuf_http_request_config.h",
"protobuf_http_status.cc", "protobuf_http_status.cc",
"protobuf_http_status.h", "protobuf_http_status.h",
"protobuf_http_stream_parser.cc", "protobuf_http_stream_parser.cc",
"protobuf_http_stream_parser.h", "protobuf_http_stream_parser.h",
"protobuf_http_stream_request.cc",
"protobuf_http_stream_request.h",
"rate_counter.cc", "rate_counter.cc",
"rate_counter.h", "rate_counter.h",
"result.h", "result.h",
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
#include "net/base/load_flags.h" #include "net/base/load_flags.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "remoting/base/oauth_token_getter.h" #include "remoting/base/oauth_token_getter.h"
#include "remoting/base/protobuf_http_request.h" #include "remoting/base/protobuf_http_request_base.h"
#include "remoting/base/protobuf_http_request_config.h"
#include "remoting/base/protobuf_http_status.h" #include "remoting/base/protobuf_http_status.h"
#include "services/network/public/cpp/resource_request.h" #include "services/network/public/cpp/resource_request.h"
#include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/shared_url_loader_factory.h"
...@@ -19,7 +20,6 @@ ...@@ -19,7 +20,6 @@
namespace { namespace {
constexpr char kAuthorizationHeaderFormat[] = "Authorization: Bearer %s"; constexpr char kAuthorizationHeaderFormat[] = "Authorization: Bearer %s";
constexpr int kMaxResponseSizeKb = 512;
} // namespace } // namespace
...@@ -33,15 +33,15 @@ ProtobufHttpClient::ProtobufHttpClient( ...@@ -33,15 +33,15 @@ ProtobufHttpClient::ProtobufHttpClient(
token_getter_(token_getter), token_getter_(token_getter),
url_loader_factory_(url_loader_factory) {} url_loader_factory_(url_loader_factory) {}
ProtobufHttpClient::~ProtobufHttpClient() = default; ProtobufHttpClient::~ProtobufHttpClient() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
void ProtobufHttpClient::ExecuteRequest( void ProtobufHttpClient::ExecuteRequest(
std::unique_ptr<ProtobufHttpRequest> request) { std::unique_ptr<ProtobufHttpRequestBase> request) {
DCHECK(request->request_message); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(!request->path.empty());
DCHECK(request->response_callback_);
if (!request->authenticated) { if (!request->config().authenticated) {
DoExecuteRequest(std::move(request), OAuthTokenGetter::Status::SUCCESS, {}, DoExecuteRequest(std::move(request), OAuthTokenGetter::Status::SUCCESS, {},
{}); {});
return; return;
...@@ -54,23 +54,37 @@ void ProtobufHttpClient::ExecuteRequest( ...@@ -54,23 +54,37 @@ void ProtobufHttpClient::ExecuteRequest(
} }
void ProtobufHttpClient::CancelPendingRequests() { void ProtobufHttpClient::CancelPendingRequests() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
weak_factory_.InvalidateWeakPtrs(); weak_factory_.InvalidateWeakPtrs();
pending_requests_.clear();
}
bool ProtobufHttpClient::HasPendingRequests() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return !pending_requests_.empty();
} }
void ProtobufHttpClient::DoExecuteRequest( void ProtobufHttpClient::DoExecuteRequest(
std::unique_ptr<ProtobufHttpRequest> request, std::unique_ptr<ProtobufHttpRequestBase> request,
OAuthTokenGetter::Status status, OAuthTokenGetter::Status status,
const std::string& user_email, const std::string& user_email,
const std::string& access_token) { const std::string& access_token) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (status != OAuthTokenGetter::Status::SUCCESS) { if (status != OAuthTokenGetter::Status::SUCCESS) {
LOG(ERROR) << "Failed to fetch access token. Status: " << status; std::string error_message =
request->OnResponse( base::StringPrintf("Failed to fetch access token. Status: %d", status);
ProtobufHttpStatus(net::HttpStatusCode::HTTP_UNAUTHORIZED), nullptr); LOG(ERROR) << error_message;
request->OnAuthFailed(ProtobufHttpStatus(
ProtobufHttpStatus::Code::UNAUTHENTICATED, error_message));
return; return;
} }
auto resource_request = std::make_unique<network::ResourceRequest>(); auto resource_request = std::make_unique<network::ResourceRequest>();
resource_request->url = GURL("https://" + server_endpoint_ + request->path); resource_request->url =
GURL("https://" + server_endpoint_ + request->config().path);
resource_request->load_flags = resource_request->load_flags =
net::LOAD_BYPASS_CACHE | net::LOAD_DISABLE_CACHE; net::LOAD_BYPASS_CACHE | net::LOAD_DISABLE_CACHE;
resource_request->credentials_mode = network::mojom::CredentialsMode::kOmit; resource_request->credentials_mode = network::mojom::CredentialsMode::kOmit;
...@@ -85,35 +99,28 @@ void ProtobufHttpClient::DoExecuteRequest( ...@@ -85,35 +99,28 @@ void ProtobufHttpClient::DoExecuteRequest(
std::unique_ptr<network::SimpleURLLoader> send_url_loader = std::unique_ptr<network::SimpleURLLoader> send_url_loader =
network::SimpleURLLoader::Create(std::move(resource_request), network::SimpleURLLoader::Create(std::move(resource_request),
request->traffic_annotation); request->config().traffic_annotation);
send_url_loader->SetTimeoutDuration(request->timeout_duration); base::TimeDelta timeout_duration = request->GetRequestTimeoutDuration();
if (!timeout_duration.is_zero()) {
send_url_loader->SetTimeoutDuration(request->GetRequestTimeoutDuration());
}
send_url_loader->AttachStringForUpload( send_url_loader->AttachStringForUpload(
request->request_message->SerializeAsString(), "application/x-protobuf"); request->config().request_message->SerializeAsString(),
send_url_loader->DownloadToString( "application/x-protobuf");
url_loader_factory_.get(), auto* unowned_request = request.get();
base::BindOnce(&ProtobufHttpClient::OnResponse, base::OnceClosure invalidator = base::BindOnce(
weak_factory_.GetWeakPtr(), std::move(request), &ProtobufHttpClient::CancelRequest, weak_factory_.GetWeakPtr(),
std::move(send_url_loader)), pending_requests_.insert(pending_requests_.end(), std::move(request)));
kMaxResponseSizeKb); unowned_request->StartRequest(url_loader_factory_.get(),
std::move(send_url_loader),
std::move(invalidator));
} }
void ProtobufHttpClient::OnResponse( void ProtobufHttpClient::CancelRequest(
std::unique_ptr<ProtobufHttpRequest> request, const PendingRequestListIterator& request_iterator) {
std::unique_ptr<network::SimpleURLLoader> url_loader, DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
std::unique_ptr<std::string> response_body) {
net::Error net_error = static_cast<net::Error>(url_loader->NetError()); pending_requests_.erase(request_iterator);
if (net_error == net::Error::ERR_HTTP_RESPONSE_CODE_FAILURE &&
(!url_loader->ResponseInfo() || !url_loader->ResponseInfo()->headers)) {
LOG(ERROR) << "Can't find response header.";
net_error = net::Error::ERR_INVALID_RESPONSE;
}
ProtobufHttpStatus status =
(net_error == net::Error::ERR_HTTP_RESPONSE_CODE_FAILURE ||
net_error == net::Error::OK)
? ProtobufHttpStatus(static_cast<net::HttpStatusCode>(
url_loader->ResponseInfo()->headers->response_code()))
: ProtobufHttpStatus(net_error);
request->OnResponse(status, std::move(response_body));
} }
} // namespace remoting } // namespace remoting
...@@ -5,22 +5,23 @@ ...@@ -5,22 +5,23 @@
#ifndef REMOTING_BASE_PROTOBUF_HTTP_CLIENT_H_ #ifndef REMOTING_BASE_PROTOBUF_HTTP_CLIENT_H_
#define REMOTING_BASE_PROTOBUF_HTTP_CLIENT_H_ #define REMOTING_BASE_PROTOBUF_HTTP_CLIENT_H_
#include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include "base/memory/scoped_refptr.h" #include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "net/traffic_annotation/network_traffic_annotation.h" #include "net/traffic_annotation/network_traffic_annotation.h"
#include "remoting/base/oauth_token_getter.h" #include "remoting/base/oauth_token_getter.h"
namespace network { namespace network {
class SharedURLLoaderFactory; class SharedURLLoaderFactory;
class SimpleURLLoader;
} // namespace network } // namespace network
namespace remoting { namespace remoting {
struct ProtobufHttpRequest; class ProtobufHttpRequestBase;
// Helper class for executing REST/Protobuf requests over HTTP. // Helper class for executing REST/Protobuf requests over HTTP.
class ProtobufHttpClient final { class ProtobufHttpClient final {
...@@ -38,26 +39,36 @@ class ProtobufHttpClient final { ...@@ -38,26 +39,36 @@ class ProtobufHttpClient final {
// Executes a unary request. Caller will not be notified of the result if // Executes a unary request. Caller will not be notified of the result if
// CancelPendingRequests() is called or |this| is destroyed. // CancelPendingRequests() is called or |this| is destroyed.
void ExecuteRequest(std::unique_ptr<ProtobufHttpRequest> request); void ExecuteRequest(std::unique_ptr<ProtobufHttpRequestBase> request);
// Tries to cancel all pending requests. Note that this prevents request // Cancel all pending requests.
// callbacks from being called but does not necessarily stop pending requests
// from being sent.
void CancelPendingRequests(); void CancelPendingRequests();
// Indicates whether the client has any pending requests.
bool HasPendingRequests() const;
private: private:
void DoExecuteRequest(std::unique_ptr<ProtobufHttpRequest> request, using PendingRequestList =
std::list<std::unique_ptr<ProtobufHttpRequestBase>>;
// std::list iterators are stable, so they survive list editing and only
// become invalidated when underlying element is deleted.
using PendingRequestListIterator = PendingRequestList::iterator;
void DoExecuteRequest(std::unique_ptr<ProtobufHttpRequestBase> request,
OAuthTokenGetter::Status status, OAuthTokenGetter::Status status,
const std::string& user_email, const std::string& user_email,
const std::string& access_token); const std::string& access_token);
void OnResponse(std::unique_ptr<ProtobufHttpRequest> request, void CancelRequest(const PendingRequestListIterator& request_iterator);
std::unique_ptr<network::SimpleURLLoader> url_loader,
std::unique_ptr<std::string> response_body);
std::string server_endpoint_; std::string server_endpoint_;
OAuthTokenGetter* token_getter_; OAuthTokenGetter* token_getter_;
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_; scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
PendingRequestList pending_requests_;
SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<ProtobufHttpClient> weak_factory_{this}; base::WeakPtrFactory<ProtobufHttpClient> weak_factory_{this};
}; };
......
...@@ -11,10 +11,14 @@ ...@@ -11,10 +11,14 @@
#include "base/test/gmock_callback_support.h" #include "base/test/gmock_callback_support.h"
#include "base/test/mock_callback.h" #include "base/test/mock_callback.h"
#include "base/test/task_environment.h" #include "base/test/task_environment.h"
#include "net/http/http_status_code.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "remoting/base/protobuf_http_client_messages.pb.h"
#include "remoting/base/protobuf_http_client_test_messages.pb.h" #include "remoting/base/protobuf_http_client_test_messages.pb.h"
#include "remoting/base/protobuf_http_request.h" #include "remoting/base/protobuf_http_request.h"
#include "remoting/base/protobuf_http_request_config.h"
#include "remoting/base/protobuf_http_status.h" #include "remoting/base/protobuf_http_status.h"
#include "remoting/base/protobuf_http_stream_request.h"
#include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h" #include "services/network/test/test_url_loader_factory.h"
...@@ -25,15 +29,21 @@ namespace remoting { ...@@ -25,15 +29,21 @@ namespace remoting {
namespace { namespace {
using protobufhttpclient::StreamBody;
using protobufhttpclienttest::EchoRequest; using protobufhttpclienttest::EchoRequest;
using protobufhttpclienttest::EchoResponse; using protobufhttpclienttest::EchoResponse;
using ::base::test::RunOnceCallback; using ::base::test::RunOnceCallback;
using ::testing::_; using ::testing::_;
using ::testing::InSequence;
using MockEchoResponseCallback = using EchoResponseCallback =
base::MockCallback<base::OnceCallback<void(const ProtobufHttpStatus&, ProtobufHttpRequest::ResponseCallback<EchoResponse>;
std::unique_ptr<EchoResponse>)>>; using MockEchoResponseCallback = base::MockCallback<EchoResponseCallback>;
using MockEchoMessageCallback = base::MockCallback<
ProtobufHttpStreamRequest::MessageCallback<EchoResponse>>;
using MockStreamClosedCallback =
base::MockCallback<ProtobufHttpStreamRequest::StreamClosedCallback>;
constexpr char kTestServerEndpoint[] = "test.com"; constexpr char kTestServerEndpoint[] = "test.com";
constexpr char kTestRpcPath[] = "/v1/echo:echo"; constexpr char kTestRpcPath[] = "/v1/echo:echo";
...@@ -48,10 +58,14 @@ MATCHER_P(HasErrorCode, error_code, "") { ...@@ -48,10 +58,14 @@ MATCHER_P(HasErrorCode, error_code, "") {
return arg.error_code() == error_code; return arg.error_code() == error_code;
} }
MATCHER(IsResponseText, "") { MATCHER(IsDefaultResponseText, "") {
return arg->text() == kResponseText; return arg->text() == kResponseText;
} }
MATCHER_P(IsResponseText, response_text, "") {
return arg->text() == response_text;
}
MATCHER(IsNullResponse, "") { MATCHER(IsNullResponse, "") {
return arg.get() == nullptr; return arg.get() == nullptr;
} }
...@@ -62,30 +76,68 @@ class MockOAuthTokenGetter : public OAuthTokenGetter { ...@@ -62,30 +76,68 @@ class MockOAuthTokenGetter : public OAuthTokenGetter {
MOCK_METHOD0(InvalidateCache, void()); MOCK_METHOD0(InvalidateCache, void());
}; };
std::unique_ptr<ProtobufHttpRequest> CreateDefaultTestRequest() { EchoResponseCallback DoNothingResponse() {
auto request = return base::DoNothing::Once<const ProtobufHttpStatus&,
std::make_unique<ProtobufHttpRequest>(TRAFFIC_ANNOTATION_FOR_TESTS); std::unique_ptr<EchoResponse>>();
}
std::unique_ptr<ProtobufHttpRequestConfig> CreateDefaultRequestConfig() {
auto request_message = std::make_unique<EchoRequest>(); auto request_message = std::make_unique<EchoRequest>();
request_message->set_text(kRequestText); request_message->set_text(kRequestText);
request->request_message = std::move(request_message); auto request_config =
request->SetResponseCallback( std::make_unique<ProtobufHttpRequestConfig>(TRAFFIC_ANNOTATION_FOR_TESTS);
base::DoNothing::Once<const ProtobufHttpStatus&, request_config->request_message = std::move(request_message);
std::unique_ptr<EchoResponse>>()); request_config->path = kTestRpcPath;
request->path = kTestRpcPath; return request_config;
}
std::unique_ptr<ProtobufHttpRequest> CreateDefaultTestRequest() {
auto request =
std::make_unique<ProtobufHttpRequest>(CreateDefaultRequestConfig());
request->SetResponseCallback(DoNothingResponse());
return request; return request;
} }
std::string CreateDefaultResponseContent() { std::unique_ptr<ProtobufHttpStreamRequest> CreateDefaultTestStreamRequest() {
auto request =
std::make_unique<ProtobufHttpStreamRequest>(CreateDefaultRequestConfig());
request->SetStreamReadyCallback(base::DoNothing::Once());
request->SetStreamClosedCallback(
base::DoNothing::Once<const ProtobufHttpStatus&>());
request->SetMessageCallback(
base::DoNothing::Repeatedly<std::unique_ptr<EchoResponse>>());
return request;
}
std::string CreateSerializedEchoResponse(
const std::string& text = kResponseText) {
EchoResponse response; EchoResponse response;
response.set_text(kResponseText); response.set_text(text);
return response.SerializeAsString(); return response.SerializeAsString();
} }
std::string CreateSerializedStreamBodyWithText(
const std::string& text = kResponseText) {
StreamBody stream_body;
stream_body.add_messages(CreateSerializedEchoResponse(text));
return stream_body.SerializeAsString();
}
std::string CreateSerializedStreamBodyWithStatusCode(
ProtobufHttpStatus::Code status_code) {
StreamBody stream_body;
stream_body.mutable_status()->set_code(static_cast<int32_t>(status_code));
return stream_body.SerializeAsString();
}
} // namespace } // namespace
class ProtobufHttpClientTest : public testing::Test { class ProtobufHttpClientTest : public testing::Test {
protected: protected:
base::test::SingleThreadTaskEnvironment task_environment_; void ExpectCallWithToken(bool success);
base::test::SingleThreadTaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
MockOAuthTokenGetter mock_token_getter_; MockOAuthTokenGetter mock_token_getter_;
network::TestURLLoaderFactory test_url_loader_factory_; network::TestURLLoaderFactory test_url_loader_factory_;
scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_ = scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_ =
...@@ -95,16 +147,24 @@ class ProtobufHttpClientTest : public testing::Test { ...@@ -95,16 +147,24 @@ class ProtobufHttpClientTest : public testing::Test {
test_shared_loader_factory_}; test_shared_loader_factory_};
}; };
void ProtobufHttpClientTest::ExpectCallWithToken(bool success) {
EXPECT_CALL(mock_token_getter_, CallWithToken(_))
.WillOnce(RunOnceCallback<0>(success
? OAuthTokenGetter::Status::SUCCESS
: OAuthTokenGetter::Status::AUTH_ERROR,
"", success ? kFakeAccessToken : ""));
}
// Unary request tests.
TEST_F(ProtobufHttpClientTest, SendRequestAndDecodeResponse) { TEST_F(ProtobufHttpClientTest, SendRequestAndDecodeResponse) {
base::RunLoop run_loop; base::RunLoop run_loop;
EXPECT_CALL(mock_token_getter_, CallWithToken(_)) ExpectCallWithToken(/* success= */ true);
.WillOnce(RunOnceCallback<0>(OAuthTokenGetter::Status::SUCCESS, "",
kFakeAccessToken));
MockEchoResponseCallback response_callback; MockEchoResponseCallback response_callback;
EXPECT_CALL(response_callback, EXPECT_CALL(response_callback, Run(HasErrorCode(ProtobufHttpStatus::Code::OK),
Run(HasErrorCode(ProtobufHttpStatus::Code::OK), IsResponseText())) IsDefaultResponseText()))
.WillOnce([&]() { run_loop.Quit(); }); .WillOnce([&]() { run_loop.Quit(); });
auto request = CreateDefaultTestRequest(); auto request = CreateDefaultTestRequest();
...@@ -129,16 +189,20 @@ TEST_F(ProtobufHttpClientTest, SendRequestAndDecodeResponse) { ...@@ -129,16 +189,20 @@ TEST_F(ProtobufHttpClientTest, SendRequestAndDecodeResponse) {
// Respond. // Respond.
test_url_loader_factory_.AddResponse(kTestFullUrl, test_url_loader_factory_.AddResponse(kTestFullUrl,
CreateDefaultResponseContent()); CreateSerializedEchoResponse());
run_loop.Run(); run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
} }
TEST_F(ProtobufHttpClientTest, TEST_F(ProtobufHttpClientTest,
SendUnauthenticatedRequest_TokenGetterNotCalled) { SendUnauthenticatedRequest_TokenGetterNotCalled) {
EXPECT_CALL(mock_token_getter_, CallWithToken(_)).Times(0); EXPECT_CALL(mock_token_getter_, CallWithToken(_)).Times(0);
auto request = CreateDefaultTestRequest(); auto request_config = CreateDefaultRequestConfig();
request->authenticated = false; request_config->authenticated = false;
auto request =
std::make_unique<ProtobufHttpRequest>(std::move(request_config));
request->SetResponseCallback(DoNothingResponse());
client_.ExecuteRequest(std::move(request)); client_.ExecuteRequest(std::move(request));
// Verify that the request is sent with no auth header. // Verify that the request is sent with no auth header.
...@@ -153,9 +217,7 @@ TEST_F(ProtobufHttpClientTest, ...@@ -153,9 +217,7 @@ TEST_F(ProtobufHttpClientTest,
FailedToFetchAuthToken_RejectsWithUnauthorizedError) { FailedToFetchAuthToken_RejectsWithUnauthorizedError) {
base::RunLoop run_loop; base::RunLoop run_loop;
EXPECT_CALL(mock_token_getter_, CallWithToken(_)) ExpectCallWithToken(/* success= */ false);
.WillOnce(
RunOnceCallback<0>(OAuthTokenGetter::Status::AUTH_ERROR, "", ""));
MockEchoResponseCallback response_callback; MockEchoResponseCallback response_callback;
EXPECT_CALL(response_callback, EXPECT_CALL(response_callback,
...@@ -168,14 +230,13 @@ TEST_F(ProtobufHttpClientTest, ...@@ -168,14 +230,13 @@ TEST_F(ProtobufHttpClientTest,
client_.ExecuteRequest(std::move(request)); client_.ExecuteRequest(std::move(request));
run_loop.Run(); run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
} }
TEST_F(ProtobufHttpClientTest, FailedToParseResponse_GetsInvalidResponseError) { TEST_F(ProtobufHttpClientTest, FailedToParseResponse_GetsInvalidResponseError) {
base::RunLoop run_loop; base::RunLoop run_loop;
EXPECT_CALL(mock_token_getter_, CallWithToken(_)) ExpectCallWithToken(/* success= */ true);
.WillOnce(RunOnceCallback<0>(OAuthTokenGetter::Status::SUCCESS, "",
kFakeAccessToken));
MockEchoResponseCallback response_callback; MockEchoResponseCallback response_callback;
EXPECT_CALL( EXPECT_CALL(
...@@ -190,13 +251,13 @@ TEST_F(ProtobufHttpClientTest, FailedToParseResponse_GetsInvalidResponseError) { ...@@ -190,13 +251,13 @@ TEST_F(ProtobufHttpClientTest, FailedToParseResponse_GetsInvalidResponseError) {
// Respond. // Respond.
test_url_loader_factory_.AddResponse(kTestFullUrl, "Invalid content"); test_url_loader_factory_.AddResponse(kTestFullUrl, "Invalid content");
run_loop.Run(); run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
} }
TEST_F(ProtobufHttpClientTest, ServerRespondsWithError) { TEST_F(ProtobufHttpClientTest, ServerRespondsWithError) {
base::RunLoop run_loop; base::RunLoop run_loop;
EXPECT_CALL(mock_token_getter_, CallWithToken(_)) ExpectCallWithToken(/* success= */ true);
.WillOnce(RunOnceCallback<0>(OAuthTokenGetter::Status::SUCCESS, "", ""));
MockEchoResponseCallback response_callback; MockEchoResponseCallback response_callback;
EXPECT_CALL(response_callback, EXPECT_CALL(response_callback,
...@@ -211,9 +272,11 @@ TEST_F(ProtobufHttpClientTest, ServerRespondsWithError) { ...@@ -211,9 +272,11 @@ TEST_F(ProtobufHttpClientTest, ServerRespondsWithError) {
test_url_loader_factory_.AddResponse(kTestFullUrl, "", test_url_loader_factory_.AddResponse(kTestFullUrl, "",
net::HttpStatusCode::HTTP_UNAUTHORIZED); net::HttpStatusCode::HTTP_UNAUTHORIZED);
run_loop.Run(); run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
} }
TEST_F(ProtobufHttpClientTest, CancelPendingRequests_CallbackNotCalled) { TEST_F(ProtobufHttpClientTest,
CancelPendingRequestsBeforeTokenCallback_CallbackNotCalled) {
base::RunLoop run_loop; base::RunLoop run_loop;
OAuthTokenGetter::TokenCallback token_callback; OAuthTokenGetter::TokenCallback token_callback;
...@@ -222,7 +285,10 @@ TEST_F(ProtobufHttpClientTest, CancelPendingRequests_CallbackNotCalled) { ...@@ -222,7 +285,10 @@ TEST_F(ProtobufHttpClientTest, CancelPendingRequests_CallbackNotCalled) {
token_callback = std::move(callback); token_callback = std::move(callback);
}); });
MockEchoResponseCallback not_called_response_callback;
auto request = CreateDefaultTestRequest(); auto request = CreateDefaultTestRequest();
request->SetResponseCallback(not_called_response_callback.Get());
client_.ExecuteRequest(std::move(request)); client_.ExecuteRequest(std::move(request));
client_.CancelPendingRequests(); client_.CancelPendingRequests();
ASSERT_TRUE(token_callback); ASSERT_TRUE(token_callback);
...@@ -231,6 +297,179 @@ TEST_F(ProtobufHttpClientTest, CancelPendingRequests_CallbackNotCalled) { ...@@ -231,6 +297,179 @@ TEST_F(ProtobufHttpClientTest, CancelPendingRequests_CallbackNotCalled) {
// Verify no request. // Verify no request.
ASSERT_FALSE(test_url_loader_factory_.IsPending(kTestFullUrl)); ASSERT_FALSE(test_url_loader_factory_.IsPending(kTestFullUrl));
ASSERT_FALSE(client_.HasPendingRequests());
}
TEST_F(ProtobufHttpClientTest,
CancelPendingRequestsAfterTokenCallback_CallbackNotCalled) {
base::RunLoop run_loop;
ExpectCallWithToken(/* success= */ true);
client_.ExecuteRequest(CreateDefaultTestRequest());
// Respond.
ASSERT_TRUE(test_url_loader_factory_.IsPending(kTestFullUrl));
ASSERT_EQ(1, test_url_loader_factory_.NumPending());
client_.CancelPendingRequests();
test_url_loader_factory_.AddResponse(kTestFullUrl,
CreateSerializedEchoResponse());
run_loop.RunUntilIdle();
ASSERT_FALSE(client_.HasPendingRequests());
}
TEST_F(ProtobufHttpClientTest, RequestTimeout_ReturnsDeadlineExceeded) {
base::RunLoop run_loop;
ExpectCallWithToken(/* success= */ true);
MockEchoResponseCallback response_callback;
EXPECT_CALL(response_callback,
Run(HasErrorCode(ProtobufHttpStatus::Code::DEADLINE_EXCEEDED),
IsNullResponse()))
.WillOnce([&]() { run_loop.Quit(); });
auto request = CreateDefaultTestRequest();
request->SetTimeoutDuration(base::TimeDelta::FromSeconds(15));
request->SetResponseCallback(response_callback.Get());
client_.ExecuteRequest(std::move(request));
ASSERT_TRUE(test_url_loader_factory_.IsPending(kTestFullUrl));
ASSERT_EQ(1, test_url_loader_factory_.NumPending());
task_environment_.FastForwardBy(base::TimeDelta::FromSeconds(16));
run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
}
// Stream request tests.
TEST_F(ProtobufHttpClientTest, StartStreamRequestAndDecodeMessages) {
base::MockOnceClosure stream_ready_callback;
MockEchoMessageCallback message_callback;
MockStreamClosedCallback stream_closed_callback;
{
InSequence s;
ExpectCallWithToken(/* success= */ true);
EXPECT_CALL(stream_ready_callback, Run());
EXPECT_CALL(message_callback, Run(IsResponseText("response text 1")));
EXPECT_CALL(message_callback, Run(IsResponseText("response text 2")));
EXPECT_CALL(stream_closed_callback,
Run(HasErrorCode(ProtobufHttpStatus::Code::CANCELLED)));
}
auto request = CreateDefaultTestStreamRequest();
request->SetStreamReadyCallback(stream_ready_callback.Get());
request->SetMessageCallback(message_callback.Get());
request->SetStreamClosedCallback(stream_closed_callback.Get());
network::SimpleURLLoaderStreamConsumer* stream_consumer = request.get();
client_.ExecuteRequest(std::move(request));
ASSERT_TRUE(test_url_loader_factory_.IsPending(kTestFullUrl));
ASSERT_EQ(1, test_url_loader_factory_.NumPending());
// TestURLLoaderFactory can't simulate streaming, so we invoke the request
// directly.
stream_consumer->OnDataReceived(
CreateSerializedStreamBodyWithText("response text 1"),
base::DoNothing::Once());
stream_consumer->OnDataReceived(
CreateSerializedStreamBodyWithText("response text 2"),
base::DoNothing::Once());
stream_consumer->OnDataReceived(CreateSerializedStreamBodyWithStatusCode(
ProtobufHttpStatus::Code::CANCELLED),
base::DoNothing::Once());
ASSERT_FALSE(client_.HasPendingRequests());
}
TEST_F(ProtobufHttpClientTest, InvalidStreamData_Ignored) {
base::RunLoop run_loop;
base::MockOnceClosure stream_ready_callback;
MockEchoMessageCallback not_called_message_callback;
MockStreamClosedCallback stream_closed_callback;
{
InSequence s;
ExpectCallWithToken(/* success= */ true);
EXPECT_CALL(stream_ready_callback, Run());
EXPECT_CALL(stream_closed_callback,
Run(HasErrorCode(ProtobufHttpStatus::Code::OK)))
.WillOnce([&]() { run_loop.Quit(); });
}
auto request = CreateDefaultTestStreamRequest();
request->SetStreamReadyCallback(stream_ready_callback.Get());
request->SetMessageCallback(not_called_message_callback.Get());
request->SetStreamClosedCallback(stream_closed_callback.Get());
client_.ExecuteRequest(std::move(request));
ASSERT_TRUE(test_url_loader_factory_.IsPending(kTestFullUrl));
ASSERT_EQ(1, test_url_loader_factory_.NumPending());
test_url_loader_factory_.AddResponse(kTestFullUrl, "Invalid stream data",
net::HttpStatusCode::HTTP_OK);
run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
}
TEST_F(ProtobufHttpClientTest, SendHttpStatusOnly_StreamClosesWithHttpStatus) {
base::RunLoop run_loop;
base::MockOnceClosure stream_ready_callback;
MockStreamClosedCallback stream_closed_callback;
{
InSequence s;
ExpectCallWithToken(/* success= */ true);
EXPECT_CALL(stream_closed_callback,
Run(HasErrorCode(ProtobufHttpStatus::Code::UNAUTHENTICATED)))
.WillOnce([&]() { run_loop.Quit(); });
}
auto request = CreateDefaultTestStreamRequest();
request->SetStreamReadyCallback(stream_ready_callback.Get());
request->SetStreamClosedCallback(stream_closed_callback.Get());
client_.ExecuteRequest(std::move(request));
ASSERT_TRUE(test_url_loader_factory_.IsPending(kTestFullUrl));
ASSERT_EQ(1, test_url_loader_factory_.NumPending());
test_url_loader_factory_.AddResponse(kTestFullUrl, /* response_body= */ "",
net::HttpStatusCode::HTTP_UNAUTHORIZED);
run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
}
TEST_F(ProtobufHttpClientTest, SendStreamStatusAndHttpStatus_StreamStatusWins) {
base::RunLoop run_loop;
base::MockOnceClosure stream_ready_callback;
MockStreamClosedCallback stream_closed_callback;
{
InSequence s;
ExpectCallWithToken(/* success= */ true);
EXPECT_CALL(stream_ready_callback, Run());
EXPECT_CALL(stream_closed_callback,
Run(HasErrorCode(ProtobufHttpStatus::Code::CANCELLED)))
.WillOnce([&]() { run_loop.Quit(); });
}
auto request = CreateDefaultTestStreamRequest();
request->SetStreamReadyCallback(stream_ready_callback.Get());
request->SetStreamClosedCallback(stream_closed_callback.Get());
client_.ExecuteRequest(std::move(request));
ASSERT_TRUE(test_url_loader_factory_.IsPending(kTestFullUrl));
ASSERT_EQ(1, test_url_loader_factory_.NumPending());
test_url_loader_factory_.AddResponse(kTestFullUrl,
CreateSerializedStreamBodyWithStatusCode(
ProtobufHttpStatus::Code::CANCELLED),
net::HttpStatusCode::HTTP_OK);
run_loop.Run();
ASSERT_FALSE(client_.HasPendingRequests());
} }
} // namespace remoting } // namespace remoting
...@@ -4,19 +4,52 @@ ...@@ -4,19 +4,52 @@
#include "remoting/base/protobuf_http_request.h" #include "remoting/base/protobuf_http_request.h"
#include "remoting/base/protobuf_http_request_config.h"
#include "services/network/public/cpp/simple_url_loader.h"
#include "third_party/protobuf/src/google/protobuf/message_lite.h"
namespace remoting { namespace remoting {
namespace {
constexpr int kMaxResponseSizeKb = 512;
} // namespace
ProtobufHttpRequest::ProtobufHttpRequest( ProtobufHttpRequest::ProtobufHttpRequest(
const net::NetworkTrafficAnnotationTag& traffic_annotation) std::unique_ptr<ProtobufHttpRequestConfig> config)
: traffic_annotation(traffic_annotation) {} : ProtobufHttpRequestBase(std::move(config)) {}
ProtobufHttpRequest::~ProtobufHttpRequest() = default; ProtobufHttpRequest::~ProtobufHttpRequest() = default;
void ProtobufHttpRequest::SetTimeoutDuration(base::TimeDelta timeout_duration) {
timeout_duration_ = timeout_duration;
}
void ProtobufHttpRequest::OnAuthFailed(const ProtobufHttpStatus& status) {
std::move(response_callback_).Run(status);
}
void ProtobufHttpRequest::StartRequestInternal(
network::mojom::URLLoaderFactory* loader_factory) {
DCHECK(response_callback_);
// Safe to use unretained as callback will not be called once |url_loader_| is
// deleted.
url_loader_->DownloadToString(
loader_factory,
base::BindOnce(&ProtobufHttpRequest::OnResponse, base::Unretained(this)),
kMaxResponseSizeKb);
}
base::TimeDelta ProtobufHttpRequest::GetRequestTimeoutDuration() const {
return timeout_duration_;
}
void ProtobufHttpRequest::OnResponse( void ProtobufHttpRequest::OnResponse(
const ProtobufHttpStatus& status,
std::unique_ptr<std::string> response_body) { std::unique_ptr<std::string> response_body) {
ProtobufHttpStatus status = GetUrlLoaderStatus();
std::move(response_callback_) std::move(response_callback_)
.Run(status.ok() ? ParseResponse(std::move(response_body)) : status); .Run(status.ok() ? ParseResponse(std::move(response_body)) : status);
std::move(invalidator_).Run();
} }
ProtobufHttpStatus ProtobufHttpRequest::ParseResponse( ProtobufHttpStatus ProtobufHttpRequest::ParseResponse(
......
...@@ -5,35 +5,34 @@ ...@@ -5,35 +5,34 @@
#ifndef REMOTING_BASE_PROTOBUF_HTTP_REQUEST_H_ #ifndef REMOTING_BASE_PROTOBUF_HTTP_REQUEST_H_
#define REMOTING_BASE_PROTOBUF_HTTP_REQUEST_H_ #define REMOTING_BASE_PROTOBUF_HTTP_REQUEST_H_
#include <memory>
#include <string>
#include "base/bind.h" #include "base/bind.h"
#include "base/callback.h" #include "base/callback.h"
#include "base/time/time.h" #include "remoting/base/protobuf_http_request_base.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "remoting/base/protobuf_http_status.h" namespace google {
#include "third_party/protobuf/src/google/protobuf/message_lite.h" namespace protobuf {
class MessageLite;
} // namespace protobuf
} // namespace google
namespace remoting { namespace remoting {
// A simple unary request. Caller needs to set all public members and call // A simple unary request.
// SetResponseCallback() before passing it to ProtobufHttpClient. class ProtobufHttpRequest final : public ProtobufHttpRequestBase {
struct ProtobufHttpRequest final { public:
template <typename ResponseType> template <typename ResponseType>
using ResponseCallback = using ResponseCallback =
base::OnceCallback<void(const ProtobufHttpStatus& status, base::OnceCallback<void(const ProtobufHttpStatus& status,
std::unique_ptr<ResponseType> response)>; std::unique_ptr<ResponseType> response)>;
explicit ProtobufHttpRequest( explicit ProtobufHttpRequest(
const net::NetworkTrafficAnnotationTag& traffic_annotation); std::unique_ptr<ProtobufHttpRequestConfig> config);
~ProtobufHttpRequest(); ~ProtobufHttpRequest() override;
const net::NetworkTrafficAnnotationTag traffic_annotation; // Sets the amount of time to wait before giving up on a given network request
std::unique_ptr<google::protobuf::MessageLite> request_message; // and considering it an error. The default value is 30s. Set it to zero to
std::string path; // disable timeout.
bool authenticated = true; void SetTimeoutDuration(base::TimeDelta timeout_duration);
base::TimeDelta timeout_duration = base::TimeDelta::FromSeconds(30);
// Sets the response callback. |ResponseType| needs to be a protobuf message // Sets the response callback. |ResponseType| needs to be a protobuf message
// type. // type.
...@@ -54,15 +53,19 @@ struct ProtobufHttpRequest final { ...@@ -54,15 +53,19 @@ struct ProtobufHttpRequest final {
} }
private: private:
friend class ProtobufHttpClient; // ProtobufHttpRequestBase implementations.
void OnAuthFailed(const ProtobufHttpStatus& status) override;
void StartRequestInternal(
network::mojom::URLLoaderFactory* loader_factory) override;
base::TimeDelta GetRequestTimeoutDuration() const override;
// To be called by ProtobufHttpClient. void OnResponse(std::unique_ptr<std::string> response_body);
void OnResponse(const ProtobufHttpStatus& status,
std::unique_ptr<std::string> response_body);
// Parses |response_body| and writes it to |response_message_|. // Parses |response_body| and writes it to |response_message_|.
ProtobufHttpStatus ParseResponse(std::unique_ptr<std::string> response_body); ProtobufHttpStatus ParseResponse(std::unique_ptr<std::string> response_body);
base::TimeDelta timeout_duration_ = base::TimeDelta::FromSeconds(30);
// This is owned by |response_callback_|. // This is owned by |response_callback_|.
google::protobuf::MessageLite* response_message_; google::protobuf::MessageLite* response_message_;
base::OnceCallback<void(const ProtobufHttpStatus&)> response_callback_; base::OnceCallback<void(const ProtobufHttpStatus&)> response_callback_;
......
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/base/protobuf_http_request_base.h"
#include "net/base/net_errors.h"
#include "remoting/base/protobuf_http_request_config.h"
#include "services/network/public/cpp/simple_url_loader.h"
namespace remoting {
ProtobufHttpRequestBase::ProtobufHttpRequestBase(
std::unique_ptr<ProtobufHttpRequestConfig> config)
: config_(std::move(config)) {
config_->Validate();
}
ProtobufHttpRequestBase::~ProtobufHttpRequestBase() {
#if DCHECK_IS_ON()
DCHECK(request_deadline_.is_null() ||
request_deadline_ >= base::TimeTicks::Now())
<< "The request must have been deleted before the deadline.";
#endif // DCHECK_IS_ON()
}
ProtobufHttpStatus ProtobufHttpRequestBase::GetUrlLoaderStatus() const {
net::Error net_error = static_cast<net::Error>(url_loader_->NetError());
if (net_error == net::Error::ERR_HTTP_RESPONSE_CODE_FAILURE &&
(!url_loader_->ResponseInfo() || !url_loader_->ResponseInfo()->headers)) {
LOG(ERROR) << "Can't find response header.";
net_error = net::Error::ERR_INVALID_RESPONSE;
}
return (net_error == net::Error::ERR_HTTP_RESPONSE_CODE_FAILURE ||
net_error == net::Error::OK)
? ProtobufHttpStatus(static_cast<net::HttpStatusCode>(
url_loader_->ResponseInfo()->headers->response_code()))
: ProtobufHttpStatus(net_error);
}
void ProtobufHttpRequestBase::StartRequest(
network::mojom::URLLoaderFactory* loader_factory,
std::unique_ptr<network::SimpleURLLoader> url_loader,
base::OnceClosure invalidator) {
DCHECK(!url_loader_);
DCHECK(!invalidator_);
url_loader_ = std::move(url_loader);
invalidator_ = std::move(invalidator);
StartRequestInternal(loader_factory);
#if DCHECK_IS_ON()
base::TimeDelta timeout_duration = GetRequestTimeoutDuration();
if (!timeout_duration.is_zero()) {
// Add a 500ms fuzz to account for task dispatching delay and other stuff.
request_deadline_ = base::TimeTicks::Now() + timeout_duration +
base::TimeDelta::FromMilliseconds(500);
}
#endif // DCHECK_IS_ON()
}
} // namespace remoting
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef REMOTING_BASE_PROTOBUF_HTTP_REQUEST_BASE_H_
#define REMOTING_BASE_PROTOBUF_HTTP_REQUEST_BASE_H_
#include <memory>
#include <string>
#include "base/callback.h"
#include "base/dcheck_is_on.h"
#include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "remoting/base/protobuf_http_status.h"
namespace network {
namespace mojom {
class URLLoaderFactory;
} // namespace mojom
class SimpleURLLoader;
} // namespace network
namespace remoting {
class ProtobufHttpClient;
struct ProtobufHttpRequestConfig;
// Base request class for unary and server streaming requests.
class ProtobufHttpRequestBase {
public:
explicit ProtobufHttpRequestBase(
std::unique_ptr<ProtobufHttpRequestConfig> config);
virtual ~ProtobufHttpRequestBase();
const ProtobufHttpRequestConfig& config() const { return *config_; }
protected:
virtual void OnAuthFailed(const ProtobufHttpStatus& status) = 0;
virtual void StartRequestInternal(
network::mojom::URLLoaderFactory* loader_factory) = 0;
// Returns a deadline for when the request has to be finished. Returns zero
// if the request doesn't timeout. This is generally only useful for unary
// requests.
virtual base::TimeDelta GetRequestTimeoutDuration() const = 0;
// Returns the http status from |url_loader_|. Only useful when |url_loader_|
// informs that the request has been completed.
ProtobufHttpStatus GetUrlLoaderStatus() const;
std::unique_ptr<network::SimpleURLLoader> url_loader_;
// Subclass should run this closure whenever its lifetime ends, e.g. response
// is received or stream is closed. This will delete |this| from the parent
// ProtobufHttpClient.
base::OnceClosure invalidator_;
private:
friend class ProtobufHttpClient;
// Called by ProtobufHttpClient.
void StartRequest(network::mojom::URLLoaderFactory* loader_factory,
std::unique_ptr<network::SimpleURLLoader> url_loader,
base::OnceClosure invalidator);
std::unique_ptr<ProtobufHttpRequestConfig> config_;
#if DCHECK_IS_ON()
base::TimeTicks request_deadline_;
#endif // DCHECK_IS_ON()
};
} // namespace remoting
#endif // REMOTING_BASE_PROTOBUF_HTTP_REQUEST_BASE_H_
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/base/protobuf_http_request_config.h"
#include "third_party/protobuf/src/google/protobuf/message_lite.h"
namespace remoting {
ProtobufHttpRequestConfig::ProtobufHttpRequestConfig(
const net::NetworkTrafficAnnotationTag& traffic_annotation)
: traffic_annotation(traffic_annotation) {}
ProtobufHttpRequestConfig::~ProtobufHttpRequestConfig() = default;
void ProtobufHttpRequestConfig::Validate() const {
DCHECK(request_message);
DCHECK(!path.empty());
}
} // namespace remoting
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef REMOTING_BASE_PROTOBUF_HTTP_REQUEST_CONFIG_H_
#define REMOTING_BASE_PROTOBUF_HTTP_REQUEST_CONFIG_H_
#include <memory>
#include <string>
#include "net/traffic_annotation/network_traffic_annotation.h"
namespace google {
namespace protobuf {
class MessageLite;
} // namespace protobuf
} // namespace google
namespace remoting {
// Common configurations for unary and stream protobuf http requests. Caller
// needs to set all fields in this struct.
struct ProtobufHttpRequestConfig {
explicit ProtobufHttpRequestConfig(
const net::NetworkTrafficAnnotationTag& traffic_annotation);
~ProtobufHttpRequestConfig();
// Runs DCHECK's on the fields to make sure all fields have been set up.
void Validate() const;
const net::NetworkTrafficAnnotationTag traffic_annotation;
std::unique_ptr<google::protobuf::MessageLite> request_message;
std::string path;
bool authenticated = true;
};
} // namespace remoting
#endif // REMOTING_BASE_PROTOBUF_HTTP_REQUEST_CONFIG_H_
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/base/protobuf_http_stream_request.h"
#include "base/bind.h"
#include "base/logging.h"
#include "remoting/base/protobuf_http_client.h"
#include "remoting/base/protobuf_http_request_config.h"
#include "remoting/base/protobuf_http_status.h"
#include "remoting/base/protobuf_http_stream_parser.h"
#include "services/network/public/cpp/simple_url_loader.h"
#include "third_party/protobuf/src/google/protobuf/message_lite.h"
namespace remoting {
ProtobufHttpStreamRequest::ProtobufHttpStreamRequest(
std::unique_ptr<ProtobufHttpRequestConfig> config)
: ProtobufHttpRequestBase(std::move(config)) {}
ProtobufHttpStreamRequest::~ProtobufHttpStreamRequest() = default;
void ProtobufHttpStreamRequest::SetStreamReadyCallback(
base::OnceClosure callback) {
stream_ready_callback_ = std::move(callback);
}
void ProtobufHttpStreamRequest::SetStreamClosedCallback(
StreamClosedCallback callback) {
stream_closed_callback_ = std::move(callback);
}
void ProtobufHttpStreamRequest::OnMessage(const std::string& message) {
std::unique_ptr<google::protobuf::MessageLite> protobuf_message(
default_message_->New());
if (protobuf_message->ParseFromString(message)) {
message_callback_.Run(std::move(protobuf_message));
} else {
LOG(ERROR) << "Failed to parse a stream message.";
}
}
void ProtobufHttpStreamRequest::OnStreamClosed(
const ProtobufHttpStatus& status) {
DCHECK(stream_closed_callback_);
DCHECK(invalidator_);
std::move(stream_closed_callback_).Run(status);
std::move(invalidator_).Run();
}
void ProtobufHttpStreamRequest::OnAuthFailed(const ProtobufHttpStatus& status) {
OnStreamClosed(status);
}
void ProtobufHttpStreamRequest::StartRequestInternal(
network::mojom::URLLoaderFactory* loader_factory) {
DCHECK(default_message_);
DCHECK(stream_ready_callback_);
DCHECK(stream_closed_callback_);
DCHECK(message_callback_);
// Safe to use unretained, as callbacks won't be called after |stream_parser_|
// is deleted.
stream_parser_ = std::make_unique<ProtobufHttpStreamParser>(
base::BindRepeating(&ProtobufHttpStreamRequest::OnMessage,
base::Unretained(this)),
base::BindRepeating(&ProtobufHttpStreamRequest::OnStreamClosed,
base::Unretained(this)));
url_loader_->DownloadAsStream(loader_factory, this);
}
base::TimeDelta ProtobufHttpStreamRequest::GetRequestTimeoutDuration() const {
return base::TimeDelta();
}
void ProtobufHttpStreamRequest::OnDataReceived(base::StringPiece string_piece,
base::OnceClosure resume) {
if (stream_ready_callback_) {
std::move(stream_ready_callback_).Run();
}
DCHECK(stream_parser_);
stream_parser_->Append(string_piece);
std::move(resume).Run();
}
void ProtobufHttpStreamRequest::OnComplete(bool success) {
OnStreamClosed(success ? ProtobufHttpStatus::OK : GetUrlLoaderStatus());
}
void ProtobufHttpStreamRequest::OnRetry(base::OnceClosure start_retry) {
NOTIMPLEMENTED();
}
} // namespace remoting
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef REMOTING_BASE_PROTOBUF_HTTP_STREAM_REQUEST_H_
#define REMOTING_BASE_PROTOBUF_HTTP_STREAM_REQUEST_H_
#include "base/bind.h"
#include "base/callback.h"
#include "base/memory/weak_ptr.h"
#include "remoting/base/protobuf_http_request_base.h"
#include "services/network/public/cpp/simple_url_loader_stream_consumer.h"
namespace google {
namespace protobuf {
class MessageLite;
} // namespace protobuf
} // namespace google
namespace remoting {
class ProtobufHttpClient;
class ProtobufHttpStatus;
class ProtobufHttpStreamParser;
// A server streaming request.
class ProtobufHttpStreamRequest final
: public ProtobufHttpRequestBase,
public network::SimpleURLLoaderStreamConsumer {
public:
template <typename MessageType>
using MessageCallback =
base::RepeatingCallback<void(std::unique_ptr<MessageType> message)>;
using StreamClosedCallback =
base::OnceCallback<void(const ProtobufHttpStatus& status)>;
explicit ProtobufHttpStreamRequest(
std::unique_ptr<ProtobufHttpRequestConfig> config);
~ProtobufHttpStreamRequest() override;
// Sets a callback that gets called when the stream is ready to receive data.
void SetStreamReadyCallback(base::OnceClosure callback);
// Sets a callback that gets called when the stream is closed.
void SetStreamClosedCallback(StreamClosedCallback callback);
// Sets the callback to be called every time a new message is received.
// |MessageType| needs to be a protobuf message type.
template <typename MessageType>
void SetMessageCallback(const MessageCallback<MessageType>& callback) {
default_message_ = &MessageType::default_instance();
message_callback_ = base::BindRepeating(
[](MessageCallback<MessageType> callback,
std::unique_ptr<google::protobuf::MessageLite> generic_message) {
std::move(callback).Run(std::unique_ptr<MessageType>(
static_cast<MessageType*>(generic_message.release())));
},
callback);
}
// TODO(yuweih): Consider adding an option to set timeout for
// |stream_ready_callback_|.
private:
friend class ProtobufHttpClient;
// ProtobufHttpStreamParser callbacks.
void OnMessage(const std::string& message);
void OnStreamClosed(const ProtobufHttpStatus& status);
// ProtobufHttpRequestBase implementations.
void OnAuthFailed(const ProtobufHttpStatus& status) override;
void StartRequestInternal(
network::mojom::URLLoaderFactory* loader_factory) override;
base::TimeDelta GetRequestTimeoutDuration() const override;
// network::SimpleURLLoaderStreamConsumer implementations.
void OnDataReceived(base::StringPiece string_piece,
base::OnceClosure resume) override;
void OnComplete(bool success) override;
void OnRetry(base::OnceClosure start_retry) override;
// Used to create new response message instances.
const google::protobuf::MessageLite* default_message_;
std::unique_ptr<ProtobufHttpStreamParser> stream_parser_;
base::OnceClosure stream_ready_callback_;
StreamClosedCallback stream_closed_callback_;
base::RepeatingCallback<void(std::unique_ptr<google::protobuf::MessageLite>)>
message_callback_;
};
} // namespace remoting
#endif // REMOTING_BASE_PROTOBUF_HTTP_STREAM_REQUEST_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment