Commit 40c999fb authored by John Abd-El-Malek's avatar John Abd-El-Malek Committed by Commit Bot

Convert Download Protection to use SimpleURLLoader instead of URLFetcher.

Bug: 825242
Cq-Include-Trybots: master.tryserver.chromium.linux:linux_mojo
Change-Id: I21ab64e7570f7600031dbf566fc46214051bbc67
Reviewed-on: https://chromium-review.googlesource.com/989034Reviewed-by: default avatarVarun Khaneja <vakh@chromium.org>
Commit-Queue: John Abd-El-Malek <jam@chromium.org>
Cr-Commit-Position: refs/heads/master@{#547787}
parent 98db6a59
specific_include_rules = {
".*test\.cc": [
"+services/network/network_context.h",
]
}
......@@ -28,6 +28,8 @@
#include "content/public/common/service_manager_connection.h"
#include "net/http/http_cache.h"
#include "net/http/http_status_code.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/simple_url_loader.h"
namespace safe_browsing {
......@@ -163,13 +165,13 @@ void CheckClientDownloadRequest::StartTimeout() {
void CheckClientDownloadRequest::Cancel() {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
cancelable_task_tracker_.TryCancelAll();
if (fetcher_.get()) {
if (loader_.get()) {
// The DownloadProtectionService is going to release its reference, so we
// might be destroyed before the URLFetcher completes. Cancel the
// fetcher so it does not try to invoke OnURLFetchComplete.
fetcher_.reset();
// might be destroyed before the URLLoader completes. Cancel the
// loader so it does not try to invoke OnURLFetchComplete.
loader_.reset();
}
// Note: If there is no fetcher, then some callback is still holding a
// Note: If there is no loader, then some callback is still holding a
// reference to this object. We'll eventually wind up in some method on
// the UI thread that will call FinishRequest() again. If FinishRequest()
// is called a second time, it will be a no-op.
......@@ -186,31 +188,28 @@ void CheckClientDownloadRequest::OnDownloadDestroyed(
}
// TODO: this method puts "DownloadProtectionService::" in front of a lot of
// stuff to avoid referencing the enums i copied to this .h file. From the
// net::URLFetcherDelegate interface.
void CheckClientDownloadRequest::OnURLFetchComplete(
const net::URLFetcher* source) {
// stuff to avoid referencing the enums i copied to this .h file.
void CheckClientDownloadRequest::OnURLLoaderComplete(
std::unique_ptr<std::string> response_body) {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
DCHECK_EQ(source, fetcher_.get());
bool success = loader_->NetError() == net::OK;
int response_code = 0;
if (loader_->ResponseInfo() && loader_->ResponseInfo()->headers)
response_code = loader_->ResponseInfo()->headers->response_code();
DVLOG(2) << "Received a response for URL: " << item_->GetUrlChain().back()
<< ": success=" << source->GetStatus().is_success()
<< " response_code=" << source->GetResponseCode();
if (source->GetStatus().is_success()) {
<< ": success=" << success << " response_code=" << response_code;
if (success) {
base::UmaHistogramSparse("SBClientDownload.DownloadRequestResponseCode",
source->GetResponseCode());
response_code);
}
base::UmaHistogramSparse("SBClientDownload.DownloadRequestNetError",
-source->GetStatus().error());
-loader_->NetError());
DownloadCheckResultReason reason = REASON_SERVER_PING_FAILED;
DownloadCheckResult result = DownloadCheckResult::UNKNOWN;
std::string token;
if (source->GetStatus().is_success() &&
net::HTTP_OK == source->GetResponseCode()) {
if (success && net::HTTP_OK == response_code) {
ClientDownloadResponse response;
std::string data;
bool got_data = source->GetResponseAsString(&data);
DCHECK(got_data);
if (!response.ParseFromString(data)) {
if (!response.ParseFromString(*response_body.get())) {
reason = REASON_INVALID_RESPONSE_PROTO;
result = DownloadCheckResult::UNKNOWN;
} else if (type_ == ClientDownloadRequest::SAMPLED_UNSUPPORTED_FILE) {
......@@ -260,10 +259,11 @@ void CheckClientDownloadRequest::OnURLFetchComplete(
bool upload_requested = response.upload();
DownloadFeedbackService::MaybeStorePingsForDownload(
result, upload_requested, item_, client_download_request_data_, data);
result, upload_requested, item_, client_download_request_data_,
*response_body.get());
}
// We don't need the fetcher anymore.
fetcher_.reset();
// We don't need the loader anymore.
loader_.reset();
UMA_HISTOGRAM_TIMES("SBClientDownload.DownloadRequestDuration",
base::TimeTicks::Now() - start_time_);
UMA_HISTOGRAM_TIMES("SBClientDownload.DownloadRequestNetworkDuration",
......@@ -772,7 +772,7 @@ void CheckClientDownloadRequest::CheckCertificateChainAgainstWhitelist() {
return;
}
// The URLFetcher is owned by the UI thread, so post a message to
// The URLLoader is owned by the UI thread, so post a message to
// start the pingback.
BrowserThread::PostTask(
BrowserThread::UI, FROM_HERE,
......@@ -1004,20 +1004,21 @@ void CheckClientDownloadRequest::SendRequest() {
}
}
})");
fetcher_ =
net::URLFetcher::Create(0, PPAPIDownloadRequest::GetDownloadRequestUrl(),
net::URLFetcher::POST, this, traffic_annotation);
data_use_measurement::DataUseUserData::AttachToFetcher(
fetcher_.get(), data_use_measurement::DataUseUserData::SAFE_BROWSING);
fetcher_->SetLoadFlags(net::LOAD_DISABLE_CACHE);
fetcher_->SetAutomaticallyRetryOn5xx(false); // Don't retry on error.
fetcher_->SetRequestContext(service_->request_context_getter_.get());
fetcher_->SetUploadData("application/octet-stream",
client_download_request_data_);
auto resource_request = std::make_unique<network::ResourceRequest>();
resource_request->url = PPAPIDownloadRequest::GetDownloadRequestUrl();
resource_request->method = "POST";
resource_request->load_flags = net::LOAD_DISABLE_CACHE;
loader_ = network::SimpleURLLoader::Create(std::move(resource_request),
traffic_annotation);
loader_->AttachStringForUpload(client_download_request_data_,
"application/octet-stream");
loader_->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
service_->url_loader_factory_.get(),
base::BindOnce(&CheckClientDownloadRequest::OnURLLoaderComplete,
base::Unretained(this)));
request_start_time_ = base::TimeTicks::Now();
UMA_HISTOGRAM_COUNTS("SBClientDownload.DownloadRequestPayloadSize",
client_download_request_data_.size());
fetcher_->Start();
}
void CheckClientDownloadRequest::PostFinishTask(
......
......@@ -28,9 +28,6 @@
#include "components/history/core/browser/history_service.h"
#include "components/safe_browsing/db/database_manager.h"
#include "content/public/browser/browser_thread.h"
#include "net/url_request/url_fetcher.h"
#include "net/url_request/url_fetcher_delegate.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
#if defined(OS_MACOSX)
......@@ -40,12 +37,15 @@
using content::BrowserThread;
namespace network {
class SimpleURLLoader;
}
namespace safe_browsing {
class CheckClientDownloadRequest
: public base::RefCountedThreadSafe<CheckClientDownloadRequest,
BrowserThread::DeleteOnUIThread>,
public net::URLFetcherDelegate,
public download::DownloadItem::Observer {
public:
CheckClientDownloadRequest(
......@@ -59,7 +59,7 @@ class CheckClientDownloadRequest
void StartTimeout();
void Cancel();
void OnDownloadDestroyed(download::DownloadItem* download) override;
void OnURLFetchComplete(const net::URLFetcher* source) override;
void OnURLLoaderComplete(std::unique_ptr<std::string> response_body);
static bool IsSupportedDownload(const download::DownloadItem& item,
const base::FilePath& target_path,
DownloadCheckResultReason* reason,
......@@ -141,7 +141,7 @@ class CheckClientDownloadRequest
scoped_refptr<BinaryFeatureExtractor> binary_feature_extractor_;
scoped_refptr<SafeBrowsingDatabaseManager> database_manager_;
const bool pingback_enabled_;
std::unique_ptr<net::URLFetcher> fetcher_;
std::unique_ptr<network::SimpleURLLoader> loader_;
scoped_refptr<SandboxedRarAnalyzer> rar_analyzer_;
scoped_refptr<SandboxedZipAnalyzer> zip_analyzer_;
base::TimeTicks rar_analysis_start_time_;
......
......@@ -13,6 +13,7 @@
#include "components/safe_browsing/proto/csd.pb.h"
#include "content/public/browser/browser_thread.h"
#include "net/base/net_errors.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace safe_browsing {
......@@ -36,11 +37,12 @@ enum UploadResultType {
// download feedback service.
class DownloadFeedbackImpl : public DownloadFeedback {
public:
DownloadFeedbackImpl(net::URLRequestContextGetter* request_context_getter,
base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
const std::string& ping_response);
DownloadFeedbackImpl(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
const std::string& ping_response);
~DownloadFeedbackImpl() override;
void Start(const base::Closure& finish_callback) override;
......@@ -64,7 +66,7 @@ class DownloadFeedbackImpl : public DownloadFeedback {
void RecordUploadResult(UploadResultType result);
scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
scoped_refptr<base::TaskRunner> file_task_runner_;
const base::FilePath file_path_;
int64_t file_size_;
......@@ -80,12 +82,12 @@ class DownloadFeedbackImpl : public DownloadFeedback {
};
DownloadFeedbackImpl::DownloadFeedbackImpl(
net::URLRequestContextGetter* request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
const std::string& ping_response)
: request_context_getter_(request_context_getter),
: url_loader_factory_(url_loader_factory),
file_task_runner_(file_task_runner),
file_path_(file_path),
file_size_(-1),
......@@ -164,9 +166,8 @@ void DownloadFeedbackImpl::Start(const base::Closure& finish_callback) {
})");
uploader_ = TwoPhaseUploader::Create(
request_context_getter_.get(), file_task_runner_.get(),
GURL(kSbFeedbackURL), metadata_string, file_path_,
TwoPhaseUploader::ProgressCallback(),
url_loader_factory_, file_task_runner_.get(), GURL(kSbFeedbackURL),
metadata_string, file_path_,
base::Bind(&DownloadFeedbackImpl::FinishedUpload, base::Unretained(this),
finish_callback),
traffic_annotation);
......@@ -237,18 +238,18 @@ DownloadFeedbackFactory* DownloadFeedback::factory_ = nullptr;
// static
std::unique_ptr<DownloadFeedback> DownloadFeedback::Create(
net::URLRequestContextGetter* request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
const std::string& ping_response) {
if (!factory_) {
return base::WrapUnique(
new DownloadFeedbackImpl(request_context_getter, file_task_runner,
new DownloadFeedbackImpl(url_loader_factory, file_task_runner,
file_path, ping_request, ping_response));
}
return DownloadFeedback::factory_->CreateDownloadFeedback(
request_context_getter, file_task_runner, file_path, ping_request,
url_loader_factory, file_task_runner, file_path, ping_request,
ping_response);
}
......
......@@ -24,7 +24,7 @@ class DownloadFeedback {
// Takes ownership of the file pointed to be |file_path|, it will be deleted
// when the DownloadFeedback is destructed.
static std::unique_ptr<DownloadFeedback> Create(
net::URLRequestContextGetter* request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
......@@ -66,7 +66,7 @@ class DownloadFeedbackFactory {
virtual ~DownloadFeedbackFactory() {}
virtual std::unique_ptr<DownloadFeedback> CreateDownloadFeedback(
net::URLRequestContextGetter* request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
......
......@@ -15,6 +15,7 @@
#include "chrome/browser/safe_browsing/download_protection/download_feedback.h"
#include "components/download/public/common/download_item.h"
#include "content/public/browser/browser_thread.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace safe_browsing {
......@@ -68,9 +69,9 @@ DownloadFeedbackPings* DownloadFeedbackPings::FromDownload(
} // namespace
DownloadFeedbackService::DownloadFeedbackService(
net::URLRequestContextGetter* request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner)
: request_context_getter_(request_context_getter),
: url_loader_factory_(url_loader_factory),
file_task_runner_(file_task_runner),
weak_ptr_factory_(this) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
......@@ -182,9 +183,9 @@ void DownloadFeedbackService::BeginFeedback(const std::string& ping_request,
const std::string& ping_response,
const base::FilePath& path) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
std::unique_ptr<DownloadFeedback> feedback(DownloadFeedback::Create(
request_context_getter_.get(), file_task_runner_.get(), path,
ping_request, ping_response));
std::unique_ptr<DownloadFeedback> feedback(
DownloadFeedback::Create(url_loader_factory_, file_task_runner_.get(),
path, ping_request, ping_response));
active_feedback_.push(std::move(feedback));
UMA_HISTOGRAM_COUNTS_100("SBDownloadFeedback.ActiveFeedbacks",
active_feedback_.size());
......
......@@ -24,8 +24,8 @@ namespace download {
class DownloadItem;
}
namespace net {
class URLRequestContextGetter;
namespace network {
class SharedURLLoaderFactory;
}
namespace safe_browsing {
......@@ -37,8 +37,9 @@ class DownloadFeedback;
// Lives on the UI thread.
class DownloadFeedbackService {
public:
DownloadFeedbackService(net::URLRequestContextGetter* request_context_getter,
base::TaskRunner* file_task_runner);
DownloadFeedbackService(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner);
~DownloadFeedbackService();
// Stores the request and response ping data from the download check, if the
......@@ -86,7 +87,7 @@ class DownloadFeedbackService {
const base::FilePath& path);
void FeedbackComplete();
scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
scoped_refptr<base::TaskRunner> file_task_runner_;
// Currently active & pending uploads. The first item is active, remaining
......
......@@ -22,6 +22,7 @@
#include "content/public/test/test_browser_thread_bundle.h"
#include "content/public/test/test_utils.h"
#include "net/url_request/url_request_test_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
......@@ -36,8 +37,7 @@ namespace {
class FakeDownloadFeedback : public DownloadFeedback {
public:
FakeDownloadFeedback(net::URLRequestContextGetter* request_context_getter,
base::TaskRunner* file_task_runner,
FakeDownloadFeedback(base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
const std::string& ping_response,
......@@ -67,7 +67,6 @@ class FakeDownloadFeedback : public DownloadFeedback {
bool start_called() const { return start_called_; }
private:
scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
scoped_refptr<base::TaskRunner> file_task_runner_;
base::FilePath file_path_;
std::string ping_request_;
......@@ -83,14 +82,13 @@ class FakeDownloadFeedbackFactory : public DownloadFeedbackFactory {
~FakeDownloadFeedbackFactory() override {}
std::unique_ptr<DownloadFeedback> CreateDownloadFeedback(
net::URLRequestContextGetter* request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const base::FilePath& file_path,
const std::string& ping_request,
const std::string& ping_response) override {
FakeDownloadFeedback* feedback = new FakeDownloadFeedback(
request_context_getter, file_task_runner, file_path, ping_request,
ping_response,
file_task_runner, file_path, ping_request, ping_response,
base::Bind(&FakeDownloadFeedbackFactory::DownloadFeedbackSent,
base::Unretained(this), feedbacks_.size()));
feedbacks_.push_back(feedback);
......@@ -125,11 +123,7 @@ class DownloadFeedbackServiceTest : public testing::Test {
public:
DownloadFeedbackServiceTest()
: file_task_runner_(base::CreateSequencedTaskRunnerWithTraits(
{base::MayBlock(), base::TaskPriority::BACKGROUND})),
io_task_runner_(content::BrowserThread::GetTaskRunnerForThread(
content::BrowserThread::IO)),
request_context_getter_(
new net::TestURLRequestContextGetter(io_task_runner_)) {}
{base::MayBlock(), base::TaskPriority::BACKGROUND})) {}
void SetUp() override {
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
......@@ -160,8 +154,6 @@ class DownloadFeedbackServiceTest : public testing::Test {
base::ScopedTempDir temp_dir_;
content::TestBrowserThreadBundle thread_bundle_;
scoped_refptr<base::SequencedTaskRunner> file_task_runner_;
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner_;
scoped_refptr<net::TestURLRequestContextGetter> request_context_getter_;
FakeDownloadFeedbackFactory download_feedback_factory_;
};
......@@ -219,8 +211,7 @@ TEST_F(DownloadFeedbackServiceTest, SingleFeedbackCompleteAndDiscardDownload) {
StealDangerousDownload(true /*delete_file_after_feedback*/, _))
.WillOnce(SaveArg<1>(&download_discarded_callback));
DownloadFeedbackService service(request_context_getter_.get(),
file_task_runner_.get());
DownloadFeedbackService service(nullptr, file_task_runner_.get());
service.MaybeStorePingsForDownload(DownloadCheckResult::UNCOMMON,
true /* upload_requested */, &item,
ping_request, ping_response);
......@@ -261,8 +252,7 @@ TEST_F(DownloadFeedbackServiceTest, SingleFeedbackCompleteAndKeepDownload) {
GURL empty_url;
EXPECT_CALL(item, GetURL()).WillOnce(ReturnRef(empty_url));
DownloadFeedbackService service(request_context_getter_.get(),
file_task_runner_.get());
DownloadFeedbackService service(nullptr, file_task_runner_.get());
service.MaybeStorePingsForDownload(DownloadCheckResult::UNCOMMON,
true /* upload_requested */, &item,
ping_request, ping_response);
......@@ -310,8 +300,7 @@ TEST_F(DownloadFeedbackServiceTest, MultiplePendingFeedbackComplete) {
}
{
DownloadFeedbackService service(request_context_getter_.get(),
file_task_runner_.get());
DownloadFeedbackService service(nullptr, file_task_runner_.get());
for (size_t i = 0; i < kNumDownloads; ++i) {
SCOPED_TRACE(i);
service.BeginFeedbackForDownload(&item[i], DownloadCommands::DISCARD);
......@@ -380,8 +369,7 @@ TEST_F(DownloadFeedbackServiceTest, MultiFeedbackWithIncomplete) {
}
{
DownloadFeedbackService service(request_context_getter_.get(),
file_task_runner_.get());
DownloadFeedbackService service(nullptr, file_task_runner_.get());
for (size_t i = 0; i < kNumDownloads; ++i) {
SCOPED_TRACE(i);
service.BeginFeedbackForDownload(&item[i], DownloadCommands::DISCARD);
......
......@@ -18,6 +18,8 @@
#include "content/public/test/test_utils.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "net/url_request/url_request_test_util.h"
#include "services/network/network_context.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace safe_browsing {
......@@ -26,42 +28,33 @@ namespace {
class FakeUploader : public TwoPhaseUploader {
public:
FakeUploader(net::URLRequestContextGetter* url_request_context_getter,
base::TaskRunner* file_task_runner,
FakeUploader(base::TaskRunner* file_task_runner,
const GURL& base_url,
const std::string& metadata,
const base::FilePath& file_path,
const ProgressCallback& progress_callback,
const FinishCallback& finish_callback);
~FakeUploader() override {}
void Start() override { start_called_ = true; }
scoped_refptr<net::URLRequestContextGetter> url_request_context_getter_;
scoped_refptr<base::TaskRunner> file_task_runner_;
GURL base_url_;
std::string metadata_;
base::FilePath file_path_;
ProgressCallback progress_callback_;
FinishCallback finish_callback_;
bool start_called_;
};
FakeUploader::FakeUploader(
net::URLRequestContextGetter* url_request_context_getter,
base::TaskRunner* file_task_runner,
const GURL& base_url,
const std::string& metadata,
const base::FilePath& file_path,
const ProgressCallback& progress_callback,
const FinishCallback& finish_callback)
: url_request_context_getter_(url_request_context_getter),
file_task_runner_(file_task_runner),
FakeUploader::FakeUploader(base::TaskRunner* file_task_runner,
const GURL& base_url,
const std::string& metadata,
const base::FilePath& file_path,
const FinishCallback& finish_callback)
: file_task_runner_(file_task_runner),
base_url_(base_url),
metadata_(metadata),
file_path_(file_path),
progress_callback_(progress_callback),
finish_callback_(finish_callback),
start_called_(false) {}
......@@ -71,12 +64,11 @@ class FakeUploaderFactory : public TwoPhaseUploaderFactory {
~FakeUploaderFactory() override {}
std::unique_ptr<TwoPhaseUploader> CreateTwoPhaseUploader(
net::URLRequestContextGetter* url_request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const GURL& base_url,
const std::string& metadata,
const base::FilePath& file_path,
const TwoPhaseUploader::ProgressCallback& progress_callback,
const TwoPhaseUploader::FinishCallback& finish_callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation) override;
......@@ -84,22 +76,52 @@ class FakeUploaderFactory : public TwoPhaseUploaderFactory {
};
std::unique_ptr<TwoPhaseUploader> FakeUploaderFactory::CreateTwoPhaseUploader(
net::URLRequestContextGetter* url_request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const GURL& base_url,
const std::string& metadata,
const base::FilePath& file_path,
const TwoPhaseUploader::ProgressCallback& progress_callback,
const TwoPhaseUploader::FinishCallback& finish_callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
EXPECT_FALSE(uploader_);
uploader_ =
new FakeUploader(url_request_context_getter, file_task_runner, base_url,
metadata, file_path, progress_callback, finish_callback);
uploader_ = new FakeUploader(file_task_runner, base_url, metadata, file_path,
finish_callback);
return base::WrapUnique(uploader_);
}
class SharedURLLoaderFactory : public network::SharedURLLoaderFactory {
public:
explicit SharedURLLoaderFactory(
network::mojom::URLLoaderFactory* url_loader_factory)
: url_loader_factory_(url_loader_factory) {}
std::unique_ptr<network::SharedURLLoaderFactoryInfo> Clone() override {
NOTREACHED();
return nullptr;
}
// network::URLLoaderFactory implementation:
void CreateLoaderAndStart(network::mojom::URLLoaderRequest loader,
int32_t routing_id,
int32_t request_id,
uint32_t options,
const network::ResourceRequest& request,
network::mojom::URLLoaderClientPtr client,
const net::MutableNetworkTrafficAnnotationTag&
traffic_annotation) override {
url_loader_factory_->CreateLoaderAndStart(
std::move(loader), routing_id, request_id, options, std::move(request),
std::move(client), traffic_annotation);
}
private:
friend class base::RefCounted<SharedURLLoaderFactory>;
~SharedURLLoaderFactory() override = default;
network::mojom::URLLoaderFactory* url_loader_factory_;
};
} // namespace
class DownloadFeedbackTest : public testing::Test {
......@@ -113,6 +135,14 @@ class DownloadFeedbackTest : public testing::Test {
new net::TestURLRequestContextGetter(io_task_runner_)),
feedback_finish_called_(false) {
EXPECT_NE(io_task_runner_, file_task_runner_);
network::mojom::NetworkContextPtr network_context;
network_context_ = std::make_unique<network::NetworkContext>(
nullptr, mojo::MakeRequest(&network_context),
url_request_context_getter_);
network_context_->CreateURLLoaderFactory(
mojo::MakeRequest(&url_loader_factory_), 0);
shared_url_loader_factory_ =
base::MakeRefCounted<SharedURLLoaderFactory>(url_loader_factory_.get());
}
void SetUp() override {
......@@ -145,6 +175,9 @@ class DownloadFeedbackTest : public testing::Test {
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner_;
FakeUploaderFactory two_phase_uploader_factory_;
scoped_refptr<net::TestURLRequestContextGetter> url_request_context_getter_;
std::unique_ptr<network::NetworkContext> network_context_;
network::mojom::URLLoaderFactoryPtr url_loader_factory_;
scoped_refptr<SharedURLLoaderFactory> shared_url_loader_factory_;
bool feedback_finish_called_;
};
......@@ -163,8 +196,8 @@ TEST_F(DownloadFeedbackTest, CompleteUpload) {
expected_report_metadata.download_response().SerializeAsString());
std::unique_ptr<DownloadFeedback> feedback = DownloadFeedback::Create(
url_request_context_getter_.get(), file_task_runner_.get(),
upload_file_path_, ping_request, ping_response);
shared_url_loader_factory_, file_task_runner_.get(), upload_file_path_,
ping_request, ping_response);
EXPECT_FALSE(uploader());
feedback->Start(base::Bind(&DownloadFeedbackTest::FinishCallback,
......@@ -173,8 +206,6 @@ TEST_F(DownloadFeedbackTest, CompleteUpload) {
EXPECT_FALSE(feedback_finish_called_);
EXPECT_TRUE(uploader()->start_called_);
EXPECT_EQ(url_request_context_getter_,
uploader()->url_request_context_getter_);
EXPECT_EQ(file_task_runner_, uploader()->file_task_runner_);
EXPECT_EQ(upload_file_path_, uploader()->file_path_);
EXPECT_EQ(expected_report_metadata.SerializeAsString(),
......@@ -206,8 +237,8 @@ TEST_F(DownloadFeedbackTest, CancelUpload) {
expected_report_metadata.download_response().SerializeAsString());
std::unique_ptr<DownloadFeedback> feedback = DownloadFeedback::Create(
url_request_context_getter_.get(), file_task_runner_.get(),
upload_file_path_, ping_request, ping_response);
shared_url_loader_factory_, file_task_runner_.get(), upload_file_path_,
ping_request, ping_response);
EXPECT_FALSE(uploader());
feedback->Start(base::Bind(&DownloadFeedbackTest::FinishCallback,
......
......@@ -26,6 +26,7 @@
#include "content/public/browser/web_contents.h"
#include "net/base/url_util.h"
#include "net/cert/x509_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
using content::BrowserThread;
namespace safe_browsing {
......@@ -64,13 +65,13 @@ DownloadProtectionService::DownloadProtectionService(
SafeBrowsingService* sb_service)
: sb_service_(sb_service),
navigation_observer_manager_(nullptr),
request_context_getter_(sb_service ? sb_service->url_request_context()
: nullptr),
url_loader_factory_(sb_service ? sb_service->GetURLLoaderFactory()
: nullptr),
enabled_(false),
binary_feature_extractor_(new BinaryFeatureExtractor()),
download_request_timeout_ms_(kDownloadRequestTimeoutMs),
feedback_service_(new DownloadFeedbackService(
request_context_getter_.get(),
url_loader_factory_,
base::CreateSequencedTaskRunnerWithTraits(
{base::MayBlock(), base::TaskPriority::BACKGROUND})
.get())),
......
......@@ -27,7 +27,6 @@
#include "chrome/browser/safe_browsing/safe_browsing_navigation_observer_manager.h"
#include "chrome/browser/safe_browsing/ui_manager.h"
#include "components/safe_browsing/db/database_manager.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
namespace content {
......@@ -42,6 +41,10 @@ namespace net {
class X509Certificate;
} // namespace net
namespace network {
class SharedURLLoaderFactory;
}
class Profile;
namespace safe_browsing {
......@@ -250,8 +253,8 @@ class DownloadProtectionService {
scoped_refptr<SafeBrowsingNavigationObserverManager>
navigation_observer_manager_;
// The context we use to issue network requests.
scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
// The loader factory we use to issue network requests.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// Set of pending server requests for DownloadManager mediated downloads.
std::set<scoped_refptr<CheckClientDownloadRequest>> download_requests_;
......
......@@ -20,6 +20,8 @@
#include "net/http/http_cache.h"
#include "net/http/http_status_code.h"
#include "net/url_request/url_fetcher.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/simple_url_loader.h"
using content::BrowserThread;
......@@ -66,7 +68,7 @@ PPAPIDownloadRequest::PPAPIDownloadRequest(
}
PPAPIDownloadRequest::~PPAPIDownloadRequest() {
if (fetcher_ && !callback_.is_null())
if (loader_ && !callback_.is_null())
Finish(RequestOutcome::REQUEST_DESTROYED, DownloadCheckResult::UNKNOWN);
}
......@@ -236,35 +238,35 @@ void PPAPIDownloadRequest::SendRequest() {
}
}
})");
fetcher_ =
net::URLFetcher::Create(0, GetDownloadRequestUrl(), net::URLFetcher::POST,
this, traffic_annotation);
data_use_measurement::DataUseUserData::AttachToFetcher(
fetcher_.get(), data_use_measurement::DataUseUserData::SAFE_BROWSING);
fetcher_->SetLoadFlags(net::LOAD_DISABLE_CACHE);
fetcher_->SetAutomaticallyRetryOn5xx(false);
fetcher_->SetRequestContext(service_->request_context_getter_.get());
fetcher_->SetUploadData("application/octet-stream",
client_download_request_data_);
fetcher_->Start();
auto resource_request = std::make_unique<network::ResourceRequest>();
resource_request->url = GetDownloadRequestUrl();
resource_request->method = "POST";
resource_request->load_flags = net::LOAD_DISABLE_CACHE;
loader_ = network::SimpleURLLoader::Create(std::move(resource_request),
traffic_annotation);
loader_->AttachStringForUpload(client_download_request_data_,
"application/octet-stream");
loader_->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
service_->url_loader_factory_.get(),
base::BindOnce(&PPAPIDownloadRequest::OnURLLoaderComplete,
base::Unretained(this)));
}
// net::URLFetcherDelegate
void PPAPIDownloadRequest::OnURLFetchComplete(const net::URLFetcher* source) {
void PPAPIDownloadRequest::OnURLLoaderComplete(
std::unique_ptr<std::string> response_body) {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
if (!source->GetStatus().is_success() ||
net::HTTP_OK != source->GetResponseCode()) {
int response_code = 0;
if (loader_->ResponseInfo() && loader_->ResponseInfo()->headers)
response_code = loader_->ResponseInfo()->headers->response_code();
if (loader_->NetError() != net::OK || net::HTTP_OK != response_code) {
Finish(RequestOutcome::FETCH_FAILED, DownloadCheckResult::UNKNOWN);
return;
}
ClientDownloadResponse response;
std::string response_body;
bool got_data = source->GetResponseAsString(&response_body);
DCHECK(got_data);
if (response.ParseFromString(response_body)) {
if (response.ParseFromString(*response_body.get())) {
Finish(RequestOutcome::SUCCEEDED,
DownloadCheckResultFromClientDownloadResponse(response.verdict()));
} else {
......@@ -291,7 +293,7 @@ void PPAPIDownloadRequest::Finish(RequestOutcome reason,
base::TimeTicks::Now() - start_time_);
if (!callback_.is_null())
base::ResetAndReturn(&callback_).Run(response);
fetcher_.reset();
loader_.reset();
weakptr_factory_.InvalidateWeakPtrs();
// If the request is being destroyed, don't notify the service_. It already
......
......@@ -11,13 +11,16 @@
#include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "chrome/browser/safe_browsing/download_protection/download_protection_util.h"
#include "net/url_request/url_fetcher_delegate.h"
#include "url/gurl.h"
namespace content {
class WebContents;
} // namespace content
namespace network {
class SimpleURLLoader;
}
class Profile;
namespace safe_browsing {
......@@ -39,7 +42,7 @@ class PPAPIDownloadRequest;
//
// PPAPIDownloadRequest objects are owned by the DownloadProtectionService
// indicated by |service|.
class PPAPIDownloadRequest : public net::URLFetcherDelegate {
class PPAPIDownloadRequest {
public:
// The outcome of the request. These values are used for UMA. New values
// should only be added at the end.
......@@ -66,7 +69,7 @@ class PPAPIDownloadRequest : public net::URLFetcherDelegate {
DownloadProtectionService* service,
scoped_refptr<SafeBrowsingDatabaseManager> database_manager);
~PPAPIDownloadRequest() override;
~PPAPIDownloadRequest();
// Start the process of checking the download request. The callback passed as
// the |callback| parameter to the constructor will be invoked with the result
......@@ -100,8 +103,7 @@ class PPAPIDownloadRequest : public net::URLFetcherDelegate {
void SendRequest();
// net::URLFetcherDelegate
void OnURLFetchComplete(const net::URLFetcher* source) override;
void OnURLLoaderComplete(std::unique_ptr<std::string> response_body);
void OnRequestTimedOut();
......@@ -118,7 +120,7 @@ class PPAPIDownloadRequest : public net::URLFetcherDelegate {
const base::FilePath& default_file_path,
const std::vector<base::FilePath::StringType>& alternate_extensions);
std::unique_ptr<net::URLFetcher> fetcher_;
std::unique_ptr<network::SimpleURLLoader> loader_;
std::string client_download_request_data_;
// URL of document that requested the PPAPI download.
......
......@@ -13,14 +13,13 @@
#include "base/callback.h"
#include "base/files/file_path.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
namespace base {
class TaskRunner;
}
namespace net {
class URLRequestContextGetter;
namespace network {
class SharedURLLoaderFactory;
}
class TwoPhaseUploaderFactory;
......@@ -42,7 +41,6 @@ class TwoPhaseUploader {
UPLOAD_FILE,
STATE_SUCCESS,
};
using ProgressCallback = base::Callback<void(int64_t sent, int64_t total)>;
using FinishCallback = base::Callback<void(State state,
int net_error,
int response_code,
......@@ -55,8 +53,6 @@ class TwoPhaseUploader {
// The uploaded |file_path| will be read on |file_task_runner|.
// The first phase request will be sent to |base_url|, with |metadata|
// included.
// |progress_callback| will be called periodically as the second phase
// progresses, if it is non-null.
// On success |finish_callback| will be called with state = STATE_SUCCESS and
// the server response in response_data. On failure, state will specify
// which step the failure occurred in, and net_error, response_code, and
......@@ -64,12 +60,11 @@ class TwoPhaseUploader {
// will not be called if the upload is cancelled by destructing the
// TwoPhaseUploader object before completion.
static std::unique_ptr<TwoPhaseUploader> Create(
net::URLRequestContextGetter* url_request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const GURL& base_url,
const std::string& metadata,
const base::FilePath& file_path,
const ProgressCallback& progress_callback,
const FinishCallback& finish_callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation);
......@@ -93,12 +88,11 @@ class TwoPhaseUploaderFactory {
virtual ~TwoPhaseUploaderFactory() {}
virtual std::unique_ptr<TwoPhaseUploader> CreateTwoPhaseUploader(
net::URLRequestContextGetter* url_request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::TaskRunner* file_task_runner,
const GURL& base_url,
const std::string& metadata,
const base::FilePath& file_path,
const TwoPhaseUploader::ProgressCallback& progress_callback,
const TwoPhaseUploader::FinishCallback& finish_callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation) = 0;
};
......
......@@ -17,6 +17,8 @@
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "net/url_request/url_fetcher.h"
#include "net/url_request/url_request_test_util.h"
#include "services/network/network_context.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
using content::BrowserThread;
......@@ -30,8 +32,6 @@ class Delegate {
public:
Delegate() : state_(TwoPhaseUploader::STATE_NONE) {}
void ProgressCallback(int64_t current, int64_t total) {}
void FinishCallback(scoped_refptr<MessageLoopRunner> runner,
TwoPhaseUploader::State state,
int net_error,
......@@ -60,6 +60,38 @@ base::FilePath GetTestFilePath() {
return file_path;
}
class SharedURLLoaderFactory : public network::SharedURLLoaderFactory {
public:
explicit SharedURLLoaderFactory(
network::mojom::URLLoaderFactory* url_loader_factory)
: url_loader_factory_(url_loader_factory) {}
std::unique_ptr<network::SharedURLLoaderFactoryInfo> Clone() override {
NOTREACHED();
return nullptr;
}
// network::URLLoaderFactory implementation:
void CreateLoaderAndStart(network::mojom::URLLoaderRequest loader,
int32_t routing_id,
int32_t request_id,
uint32_t options,
const network::ResourceRequest& request,
network::mojom::URLLoaderClientPtr client,
const net::MutableNetworkTrafficAnnotationTag&
traffic_annotation) override {
url_loader_factory_->CreateLoaderAndStart(
std::move(loader), routing_id, request_id, options, std::move(request),
std::move(client), traffic_annotation);
}
private:
friend class base::RefCounted<SharedURLLoaderFactory>;
~SharedURLLoaderFactory() override = default;
network::mojom::URLLoaderFactory* url_loader_factory_;
};
} // namespace
class TwoPhaseUploaderTest : public testing::Test {
......@@ -67,7 +99,16 @@ class TwoPhaseUploaderTest : public testing::Test {
TwoPhaseUploaderTest()
: thread_bundle_(content::TestBrowserThreadBundle::IO_MAINLOOP),
url_request_context_getter_(new net::TestURLRequestContextGetter(
BrowserThread::GetTaskRunnerForThread(BrowserThread::IO))) {}
BrowserThread::GetTaskRunnerForThread(BrowserThread::IO))) {
network::mojom::NetworkContextPtr network_context;
network_context_ = std::make_unique<network::NetworkContext>(
nullptr, mojo::MakeRequest(&network_context),
url_request_context_getter_);
network_context_->CreateURLLoaderFactory(
mojo::MakeRequest(&url_loader_factory_), 0);
shared_url_loader_factory_ =
base::MakeRefCounted<SharedURLLoaderFactory>(url_loader_factory_.get());
}
protected:
content::TestBrowserThreadBundle thread_bundle_;
......@@ -76,6 +117,9 @@ class TwoPhaseUploaderTest : public testing::Test {
const scoped_refptr<base::SequencedTaskRunner> task_runner_ =
base::CreateSequencedTaskRunnerWithTraits(
{base::MayBlock(), base::TaskPriority::BACKGROUND});
std::unique_ptr<network::NetworkContext> network_context_;
network::mojom::URLLoaderFactoryPtr url_loader_factory_;
scoped_refptr<SharedURLLoaderFactory> shared_url_loader_factory_;
};
TEST_F(TwoPhaseUploaderTest, UploadFile) {
......@@ -84,9 +128,8 @@ TEST_F(TwoPhaseUploaderTest, UploadFile) {
ASSERT_TRUE(test_server.Start());
Delegate delegate;
std::unique_ptr<TwoPhaseUploader> uploader(TwoPhaseUploader::Create(
url_request_context_getter_.get(), task_runner_.get(),
shared_url_loader_factory_, task_runner_.get(),
test_server.GetURL("start"), "metadata", GetTestFilePath(),
base::Bind(&Delegate::ProgressCallback, base::Unretained(&delegate)),
base::Bind(&Delegate::FinishCallback, base::Unretained(&delegate),
runner),
TRAFFIC_ANNOTATION_FOR_TESTS));
......@@ -108,9 +151,8 @@ TEST_F(TwoPhaseUploaderTest, BadPhaseOneResponse) {
ASSERT_TRUE(test_server.Start());
Delegate delegate;
std::unique_ptr<TwoPhaseUploader> uploader(TwoPhaseUploader::Create(
url_request_context_getter_.get(), task_runner_.get(),
shared_url_loader_factory_, task_runner_.get(),
test_server.GetURL("start?p1code=500"), "metadata", GetTestFilePath(),
base::Bind(&Delegate::ProgressCallback, base::Unretained(&delegate)),
base::Bind(&Delegate::FinishCallback, base::Unretained(&delegate),
runner),
TRAFFIC_ANNOTATION_FOR_TESTS));
......@@ -128,9 +170,8 @@ TEST_F(TwoPhaseUploaderTest, BadPhaseTwoResponse) {
ASSERT_TRUE(test_server.Start());
Delegate delegate;
std::unique_ptr<TwoPhaseUploader> uploader(TwoPhaseUploader::Create(
url_request_context_getter_.get(), task_runner_.get(),
shared_url_loader_factory_, task_runner_.get(),
test_server.GetURL("start?p2code=500"), "metadata", GetTestFilePath(),
base::Bind(&Delegate::ProgressCallback, base::Unretained(&delegate)),
base::Bind(&Delegate::FinishCallback, base::Unretained(&delegate),
runner),
TRAFFIC_ANNOTATION_FOR_TESTS));
......@@ -152,9 +193,8 @@ TEST_F(TwoPhaseUploaderTest, PhaseOneConnectionClosed) {
ASSERT_TRUE(test_server.Start());
Delegate delegate;
std::unique_ptr<TwoPhaseUploader> uploader(TwoPhaseUploader::Create(
url_request_context_getter_.get(), task_runner_.get(),
shared_url_loader_factory_, task_runner_.get(),
test_server.GetURL("start?p1close=1"), "metadata", GetTestFilePath(),
base::Bind(&Delegate::ProgressCallback, base::Unretained(&delegate)),
base::Bind(&Delegate::FinishCallback, base::Unretained(&delegate),
runner),
TRAFFIC_ANNOTATION_FOR_TESTS));
......@@ -162,7 +202,6 @@ TEST_F(TwoPhaseUploaderTest, PhaseOneConnectionClosed) {
runner->Run();
EXPECT_EQ(TwoPhaseUploader::UPLOAD_METADATA, delegate.state_);
EXPECT_EQ(net::ERR_EMPTY_RESPONSE, delegate.net_error_);
EXPECT_EQ(net::URLFetcher::RESPONSE_CODE_INVALID, delegate.response_code_);
EXPECT_EQ("", delegate.response_);
}
......@@ -172,9 +211,8 @@ TEST_F(TwoPhaseUploaderTest, PhaseTwoConnectionClosed) {
ASSERT_TRUE(test_server.Start());
Delegate delegate;
std::unique_ptr<TwoPhaseUploader> uploader(TwoPhaseUploader::Create(
url_request_context_getter_.get(), task_runner_.get(),
shared_url_loader_factory_, task_runner_.get(),
test_server.GetURL("start?p2close=1"), "metadata", GetTestFilePath(),
base::Bind(&Delegate::ProgressCallback, base::Unretained(&delegate)),
base::Bind(&Delegate::FinishCallback, base::Unretained(&delegate),
runner),
TRAFFIC_ANNOTATION_FOR_TESTS));
......@@ -182,7 +220,6 @@ TEST_F(TwoPhaseUploaderTest, PhaseTwoConnectionClosed) {
runner->Run();
EXPECT_EQ(TwoPhaseUploader::UPLOAD_FILE, delegate.state_);
EXPECT_EQ(net::ERR_EMPTY_RESPONSE, delegate.net_error_);
EXPECT_EQ(net::URLFetcher::RESPONSE_CODE_INVALID, delegate.response_code_);
EXPECT_EQ("", delegate.response_);
}
......
......@@ -152,7 +152,7 @@ class SafeBrowsingService : public base::RefCountedThreadSafe<
// NetworkContext and URLLoaderFactory used for safe browsing requests.
network::mojom::NetworkContext* GetNetworkContext();
scoped_refptr<network::SharedURLLoaderFactory> GetURLLoaderFactory();
virtual scoped_refptr<network::SharedURLLoaderFactory> GetURLLoaderFactory();
// Called on IO thread thread when QUIC should be disabled (e.g. because of
// policy). This should not be necessary anymore when http://crbug.com/678653
......
......@@ -60,6 +60,10 @@ void TestURLLoaderFactory::ClearResponses() {
responses_.clear();
}
void TestURLLoaderFactory::SetInterceptor(const Interceptor& interceptor) {
interceptor_ = interceptor;
}
void TestURLLoaderFactory::CreateLoaderAndStart(
mojom::URLLoaderRequest request,
int32_t routing_id,
......@@ -68,6 +72,9 @@ void TestURLLoaderFactory::CreateLoaderAndStart(
const ResourceRequest& url_request,
mojom::URLLoaderClientPtr client,
const net::MutableNetworkTrafficAnnotationTag& traffic_annotation) {
if (interceptor_)
interceptor_.Run(url_request);
if (CreateLoaderAndStartInternal(url_request.url, client.get()))
return;
......
......@@ -41,6 +41,9 @@ class TestURLLoaderFactory : public mojom::URLLoaderFactory {
// Clear all the responses that were previously set.
void ClearResponses();
using Interceptor = base::RepeatingCallback<void(const ResourceRequest&)>;
void SetInterceptor(const Interceptor& interceptor);
// mojom::URLLoaderFactory implementation.
void CreateLoaderAndStart(mojom::URLLoaderRequest request,
int32_t routing_id,
......@@ -78,6 +81,8 @@ class TestURLLoaderFactory : public mojom::URLLoaderFactory {
};
std::vector<Pending> pending_;
Interceptor interceptor_;
DISALLOW_COPY_AND_ASSIGN(TestURLLoaderFactory);
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment