Commit d07eb404 authored by Roger Tawa's avatar Roger Tawa Committed by Commit Bot

Refactor BinaryUploadService::Request to make getting data size async.

Bug: 999143, 999141
Change-Id: I395da7e76649e46bea9cc69bdb0009168833f90e
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1793822Reviewed-by: default avatarDaniel Rubery <drubery@chromium.org>
Commit-Queue: Roger Tawa <rogerta@chromium.org>
Cr-Commit-Position: refs/heads/master@{#695020}
parent 71943ac0
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
namespace safe_browsing { namespace safe_browsing {
namespace { namespace {
const size_t kMaxUploadSizeBytes = 50 * 1024 * 1024; // 50 MB
const int kScanningTimeoutSeconds = 5 * 60; // 5 minutes const int kScanningTimeoutSeconds = 5 * 60; // 5 minutes
const char kSbBinaryUploadUrl[] = const char kSbBinaryUploadUrl[] =
"https://safebrowsing.google.com/safebrowsing/uploads/webprotect"; "https://safebrowsing.google.com/safebrowsing/uploads/webprotect";
...@@ -55,12 +54,6 @@ void BinaryUploadService::UploadForDeepScanning( ...@@ -55,12 +54,6 @@ void BinaryUploadService::UploadForDeepScanning(
Request* raw_request = request.get(); Request* raw_request = request.get();
active_requests_[raw_request] = std::move(request); active_requests_[raw_request] = std::move(request);
if (raw_request->GetFileSize() > kMaxUploadSizeBytes) {
FinishRequest(raw_request, Result::FILE_TOO_LARGE,
DeepScanningClientResponse());
return;
}
if (!binary_fcm_service_) { if (!binary_fcm_service_) {
FinishRequest(raw_request, Result::FAILED_TO_GET_TOKEN, FinishRequest(raw_request, Result::FAILED_TO_GET_TOKEN,
DeepScanningClientResponse()); DeepScanningClientResponse());
...@@ -96,16 +89,22 @@ void BinaryUploadService::OnGetInstanceID(Request* request, ...@@ -96,16 +89,22 @@ void BinaryUploadService::OnGetInstanceID(Request* request,
} }
request->set_fcm_token(instance_id); request->set_fcm_token(instance_id);
request->GetFileContents( request->GetRequestData(base::BindOnce(&BinaryUploadService::OnGetRequestData,
base::BindOnce(&BinaryUploadService::OnGetFileContents, weakptr_factory_.GetWeakPtr(),
weakptr_factory_.GetWeakPtr(), request)); request));
} }
void BinaryUploadService::OnGetFileContents(Request* request, void BinaryUploadService::OnGetRequestData(Request* request,
const std::string& file_contents) { Result result,
const Request::Data& data) {
if (!IsActive(request)) if (!IsActive(request))
return; return;
if (result != Result::SUCCESS) {
FinishRequest(request, result, DeepScanningClientResponse());
return;
}
net::NetworkTrafficAnnotationTag traffic_annotation = net::NetworkTrafficAnnotationTag traffic_annotation =
net::DefineNetworkTrafficAnnotation("safe_browsing_binary_upload", R"( net::DefineNetworkTrafficAnnotation("safe_browsing_binary_upload", R"(
semantics { semantics {
...@@ -149,7 +148,7 @@ void BinaryUploadService::OnGetFileContents(Request* request, ...@@ -149,7 +148,7 @@ void BinaryUploadService::OnGetFileContents(Request* request,
base::Base64Encode(metadata, &metadata); base::Base64Encode(metadata, &metadata);
auto upload_request = MultipartUploadRequest::Create( auto upload_request = MultipartUploadRequest::Create(
url_loader_factory_, GURL(kSbBinaryUploadUrl), metadata, file_contents, url_loader_factory_, GURL(kSbBinaryUploadUrl), metadata, data.contents,
traffic_annotation, traffic_annotation,
base::BindOnce(&BinaryUploadService::OnUploadComplete, base::BindOnce(&BinaryUploadService::OnUploadComplete,
weakptr_factory_.GetWeakPtr(), request)); weakptr_factory_.GetWeakPtr(), request));
...@@ -255,6 +254,8 @@ void BinaryUploadService::FinishRequest(Request* request, ...@@ -255,6 +254,8 @@ void BinaryUploadService::FinishRequest(Request* request,
} }
} }
BinaryUploadService::Request::Data::Data() = default;
BinaryUploadService::Request::Request(Callback callback) BinaryUploadService::Request::Request(Callback callback)
: callback_(std::move(callback)) {} : callback_(std::move(callback)) {}
......
...@@ -23,6 +23,9 @@ namespace safe_browsing { ...@@ -23,6 +23,9 @@ namespace safe_browsing {
// and asynchronously retrieving a verdict. // and asynchronously retrieving a verdict.
class BinaryUploadService { class BinaryUploadService {
public: public:
// The maximum size of data that can be uploaded via this service.
constexpr static size_t kMaxUploadSizeBytes = 50 * 1024 * 1024; // 50 MB
BinaryUploadService( BinaryUploadService(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
Profile* profile); Profile* profile);
...@@ -72,16 +75,22 @@ class BinaryUploadService { ...@@ -72,16 +75,22 @@ class BinaryUploadService {
Request(Request&&) = delete; Request(Request&&) = delete;
Request& operator=(Request&&) = delete; Request& operator=(Request&&) = delete;
// Asynchronously returns the file contents to upload. // Structure of data returned in the callback to GetRequestData().
// TODO(drubery): This could allocate up to 50MB of memory for a large file struct Data {
// upload. We should see how often that causes errors, and possibly Data();
// implement some sort of streaming interface so we don't use so much std::string contents;
// memory. };
virtual void GetFileContents(
base::OnceCallback<void(const std::string&)> callback) = 0;
// Returns the content size. // Asynchronously returns the file contents to upload.
virtual size_t GetFileSize() = 0; // TODO(drubery): This could allocate up to kMaxUploadSizeBytes of memory
// for a large file upload. We should see how often that causes errors,
// and possibly implement some sort of streaming interface so we don't use
// so much memory.
//
// |result| is set to SUCCESS if getting the request data succeeded or
// some value describing the error.
using DataCallback = base::OnceCallback<void(Result, const Data&)>;
virtual void GetRequestData(DataCallback callback) = 0;
// Returns the metadata to upload, as a DeepScanningClientRequest. // Returns the metadata to upload, as a DeepScanningClientRequest.
const DeepScanningClientRequest& deep_scanning_request() const { const DeepScanningClientRequest& deep_scanning_request() const {
...@@ -119,7 +128,9 @@ class BinaryUploadService { ...@@ -119,7 +128,9 @@ class BinaryUploadService {
void OnGetInstanceID(Request* request, const std::string& token); void OnGetInstanceID(Request* request, const std::string& token);
void OnGetFileContents(Request* request, const std::string& file_contents); void OnGetRequestData(Request* request,
Result result,
const Request::Data& data);
void OnUploadComplete(Request* request, void OnUploadComplete(Request* request,
bool success, bool success,
......
...@@ -28,9 +28,7 @@ class MockRequest : public BinaryUploadService::Request { ...@@ -28,9 +28,7 @@ class MockRequest : public BinaryUploadService::Request {
public: public:
explicit MockRequest(BinaryUploadService::Callback callback) explicit MockRequest(BinaryUploadService::Callback callback)
: BinaryUploadService::Request(std::move(callback)) {} : BinaryUploadService::Request(std::move(callback)) {}
MOCK_METHOD1(GetFileContents, MOCK_METHOD1(GetRequestData, void(DataCallback));
void(base::OnceCallback<void(const std::string&)>));
MOCK_METHOD0(GetFileSize, size_t());
}; };
class FakeMultipartUploadRequest : public MultipartUploadRequest { class FakeMultipartUploadRequest : public MultipartUploadRequest {
...@@ -143,11 +141,13 @@ class BinaryUploadServiceTest : public testing::Test { ...@@ -143,11 +141,13 @@ class BinaryUploadServiceTest : public testing::Test {
*target_response = response; *target_response = response;
}, },
scanning_result, scanning_response)); scanning_result, scanning_response));
ON_CALL(*request, GetFileSize()).WillByDefault(Return(strlen("contents"))); ON_CALL(*request, GetRequestData(_))
ON_CALL(*request, GetFileContents(_))
.WillByDefault( .WillByDefault(
Invoke([](base::OnceCallback<void(const std::string&)> callback) { Invoke([](BinaryUploadService::Request::DataCallback callback) {
std::move(callback).Run("contents"); BinaryUploadService::Request::Data data;
data.contents = "contents";
std::move(callback).Run(BinaryUploadService::Result::SUCCESS,
data);
})); }));
return request; return request;
} }
...@@ -163,9 +163,16 @@ TEST_F(BinaryUploadServiceTest, FailsForLargeFile) { ...@@ -163,9 +163,16 @@ TEST_F(BinaryUploadServiceTest, FailsForLargeFile) {
BinaryUploadService::Result scanning_result; BinaryUploadService::Result scanning_result;
DeepScanningClientResponse scanning_response; DeepScanningClientResponse scanning_response;
ExpectInstanceID("valid id");
std::unique_ptr<MockRequest> request = std::unique_ptr<MockRequest> request =
MakeRequest(&scanning_result, &scanning_response); MakeRequest(&scanning_result, &scanning_response);
ON_CALL(*request, GetFileSize()).WillByDefault(Return(100 * 1024 * 1024)); ON_CALL(*request, GetRequestData(_))
.WillByDefault(
Invoke([](BinaryUploadService::Request::DataCallback callback) {
BinaryUploadService::Request::Data data;
std::move(callback).Run(BinaryUploadService::Result::FILE_TOO_LARGE,
data);
}));
service_->UploadForDeepScanning(std::move(request)); service_->UploadForDeepScanning(std::move(request));
content::RunAllTasksUntilIdle(); content::RunAllTasksUntilIdle();
......
...@@ -43,7 +43,6 @@ DownloadItemRequest::DownloadItemRequest(download::DownloadItem* item, ...@@ -43,7 +43,6 @@ DownloadItemRequest::DownloadItemRequest(download::DownloadItem* item,
BinaryUploadService::Callback callback) BinaryUploadService::Callback callback)
: Request(std::move(callback)), : Request(std::move(callback)),
item_(item), item_(item),
download_item_renamed_(false),
weakptr_factory_(this) { weakptr_factory_(this) {
item_->AddObserver(this); item_->AddObserver(this);
} }
...@@ -53,43 +52,39 @@ DownloadItemRequest::~DownloadItemRequest() { ...@@ -53,43 +52,39 @@ DownloadItemRequest::~DownloadItemRequest() {
item_->RemoveObserver(this); item_->RemoveObserver(this);
} }
void DownloadItemRequest::GetFileContents( void DownloadItemRequest::GetRequestData(DataCallback callback) {
base::OnceCallback<void(const std::string&)> callback) {
if (item_ == nullptr) { if (item_ == nullptr) {
std::move(callback).Run(""); std::move(callback).Run(BinaryUploadService::Result::UNKNOWN, Data());
return; return;
} }
pending_callbacks_.push_back(std::move(callback)); if (static_cast<size_t>(item_->GetTotalBytes()) >
BinaryUploadService::kMaxUploadSizeBytes) {
std::move(callback).Run(BinaryUploadService::Result::FILE_TOO_LARGE,
Data());
return;
}
if (is_data_valid_) {
std::move(callback).Run(BinaryUploadService::Result::SUCCESS, data_);
return;
}
if (download_item_renamed_) pending_callbacks_.push_back(std::move(callback));
RunPendingGetFileContentsCallbacks();
} }
void DownloadItemRequest::RunPendingGetFileContentsCallbacks() { void DownloadItemRequest::RunPendingGetFileContentsCallbacks() {
for (auto it = pending_callbacks_.begin(); it != pending_callbacks_.end(); for (auto it = pending_callbacks_.begin(); it != pending_callbacks_.end();
it++) { it++) {
base::PostTaskAndReplyWithResult( std::move(*it).Run(BinaryUploadService::Result::SUCCESS, data_);
FROM_HERE,
{base::ThreadPool(), base::TaskPriority::USER_VISIBLE,
base::MayBlock()},
base::BindOnce(&GetFileContentsBlocking, item_->GetFullPath()),
base::BindOnce(&DownloadItemRequest::OnGotFileContents,
weakptr_factory_.GetWeakPtr(), std::move(*it)));
} }
pending_callbacks_.clear(); pending_callbacks_.clear();
} }
size_t DownloadItemRequest::GetFileSize() {
return item_ == nullptr ? 0 : item_->GetTotalBytes();
}
void DownloadItemRequest::OnDownloadUpdated(download::DownloadItem* download) { void DownloadItemRequest::OnDownloadUpdated(download::DownloadItem* download) {
if (download == item_ && item_->GetFullPath() == item_->GetTargetFilePath()) { if (download == item_ && item_->GetFullPath() == item_->GetTargetFilePath())
download_item_renamed_ = true; ReadFile();
RunPendingGetFileContentsCallbacks();
}
} }
void DownloadItemRequest::OnDownloadDestroyed( void DownloadItemRequest::OnDownloadDestroyed(
...@@ -98,10 +93,19 @@ void DownloadItemRequest::OnDownloadDestroyed( ...@@ -98,10 +93,19 @@ void DownloadItemRequest::OnDownloadDestroyed(
item_ = nullptr; item_ = nullptr;
} }
void DownloadItemRequest::OnGotFileContents( void DownloadItemRequest::ReadFile() {
base::OnceCallback<void(const std::string&)> callback, base::PostTaskAndReplyWithResult(
const std::string& contents) { FROM_HERE,
std::move(callback).Run(contents); {base::ThreadPool(), base::TaskPriority::USER_VISIBLE, base::MayBlock()},
base::BindOnce(&GetFileContentsBlocking, item_->GetFullPath()),
base::BindOnce(&DownloadItemRequest::OnGotFileContents,
weakptr_factory_.GetWeakPtr()));
}
void DownloadItemRequest::OnGotFileContents(std::string contents) {
data_.contents = std::move(contents);
is_data_valid_ = true;
RunPendingGetFileContentsCallbacks();
} }
} // namespace safe_browsing } // namespace safe_browsing
...@@ -27,17 +27,16 @@ class DownloadItemRequest : public BinaryUploadService::Request, ...@@ -27,17 +27,16 @@ class DownloadItemRequest : public BinaryUploadService::Request,
DownloadItemRequest& operator=(DownloadItemRequest&&) = delete; DownloadItemRequest& operator=(DownloadItemRequest&&) = delete;
// BinaryUploadService::Request implementation. // BinaryUploadService::Request implementation.
void GetFileContents( void GetRequestData(DataCallback callback) override;
base::OnceCallback<void(const std::string&)> callback) override;
size_t GetFileSize() override;
// download::DownloadItem::Observer implementation. // download::DownloadItem::Observer implementation.
void OnDownloadDestroyed(download::DownloadItem* download) override; void OnDownloadDestroyed(download::DownloadItem* download) override;
void OnDownloadUpdated(download::DownloadItem* download) override; void OnDownloadUpdated(download::DownloadItem* download) override;
private: private:
void OnGotFileContents(base::OnceCallback<void(const std::string&)> callback, void ReadFile();
const std::string& contents);
void OnGotFileContents(std::string contents);
// Calls to GetFileContents can be deferred if the download item is not yet // Calls to GetFileContents can be deferred if the download item is not yet
// renamed to its final location. When ready, this method runs those // renamed to its final location. When ready, this method runs those
...@@ -48,11 +47,15 @@ class DownloadItemRequest : public BinaryUploadService::Request, ...@@ -48,11 +47,15 @@ class DownloadItemRequest : public BinaryUploadService::Request,
// thread. Unowned. // thread. Unowned.
download::DownloadItem* item_; download::DownloadItem* item_;
// Whether the download item has been renamed to its final destination yet. // The file's data.
bool download_item_renamed_; Data data_;
// Is the |data_| member valid? It becomes valid once the file has been
// read successfully.
bool is_data_valid_ = false;
// All pending callbacks to GetFileContents before the download item is ready. // All pending callbacks to GetFileContents before the download item is ready.
std::vector<base::OnceCallback<void(const std::string&)>> pending_callbacks_; std::vector<DataCallback> pending_callbacks_;
base::WeakPtrFactory<DownloadItemRequest> weakptr_factory_; base::WeakPtrFactory<DownloadItemRequest> weakptr_factory_;
}; };
......
...@@ -51,18 +51,15 @@ class DownloadItemRequestTest : public ::testing::Test { ...@@ -51,18 +51,15 @@ class DownloadItemRequestTest : public ::testing::Test {
std::string download_contents_; std::string download_contents_;
}; };
TEST_F(DownloadItemRequestTest, GetsSize) {
EXPECT_EQ(request_.GetFileSize(), download_contents_.size());
}
TEST_F(DownloadItemRequestTest, GetsContentsWaitsUntilRename) { TEST_F(DownloadItemRequestTest, GetsContentsWaitsUntilRename) {
ON_CALL(item_, GetFullPath()) ON_CALL(item_, GetFullPath())
.WillByDefault(ReturnRef(download_temporary_path_)); .WillByDefault(ReturnRef(download_temporary_path_));
std::string download_contents = ""; std::string download_contents = "";
request_.GetFileContents(base::BindOnce( request_.GetRequestData(base::BindOnce(
[](std::string* target_contents, const std::string& contents) { [](std::string* target_contents, BinaryUploadService::Result result,
*target_contents = contents; const BinaryUploadService::Request::Data& data) {
*target_contents = data.contents;
}, },
&download_contents)); &download_contents));
content::RunAllTasksUntilIdle(); content::RunAllTasksUntilIdle();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment