Commit a1559706 authored by Mark Pilgrim's avatar Mark Pilgrim Committed by Commit Bot

Migrate AssistRanker to SimpleURLLoader

As part of the effort to bring the new Network Service online, we are
deprecating URLFetcher in favor of SimpleURLLoader.

Background: https://docs.google.com/document/d/1YZDPeg3bf46QPU_vUotFwOEPHquww36b-UdRlO3ZpMU/edit

Bug: 844937
Change-Id: I4a00c42b481f10c03657aa52b9a3d197febdc6ed
Reviewed-on: https://chromium-review.googlesource.com/1070110Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Reviewed-by: default avatarRachel Blum <groby@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Commit-Queue: Mark Pilgrim <pilgrim@chromium.org>
Cr-Commit-Position: refs/heads/master@{#561845}
parent d7e6a562
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
#include "chrome/browser/assist_ranker/assist_ranker_service_factory.h" #include "chrome/browser/assist_ranker/assist_ranker_service_factory.h"
#include "chrome/browser/browser_process.h" #include "chrome/browser/browser_process.h"
#include "chrome/browser/net/system_network_context_manager.h"
#include "chrome/browser/profiles/incognito_helpers.h" #include "chrome/browser/profiles/incognito_helpers.h"
#include "components/assist_ranker/assist_ranker_service_impl.h" #include "components/assist_ranker/assist_ranker_service_impl.h"
#include "components/keyed_service/content/browser_context_dependency_manager.h" #include "components/keyed_service/content/browser_context_dependency_manager.h"
#include "components/keyed_service/core/keyed_service.h" #include "components/keyed_service/core/keyed_service.h"
#include "content/public/browser/browser_context.h" #include "content/public/browser/browser_context.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace assist_ranker { namespace assist_ranker {
...@@ -35,7 +37,9 @@ AssistRankerServiceFactory::~AssistRankerServiceFactory() {} ...@@ -35,7 +37,9 @@ AssistRankerServiceFactory::~AssistRankerServiceFactory() {}
KeyedService* AssistRankerServiceFactory::BuildServiceInstanceFor( KeyedService* AssistRankerServiceFactory::BuildServiceInstanceFor(
content::BrowserContext* browser_context) const { content::BrowserContext* browser_context) const {
return new AssistRankerServiceImpl( return new AssistRankerServiceImpl(
browser_context->GetPath(), g_browser_process->system_request_context()); browser_context->GetPath(),
g_browser_process->system_network_context_manager()
->GetSharedURLLoaderFactory());
} }
content::BrowserContext* AssistRankerServiceFactory::GetBrowserContextToUse( content::BrowserContext* AssistRankerServiceFactory::GetBrowserContextToUse(
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "base/strings/string_util.h" #include "base/strings/string_util.h"
#include "build/build_config.h" #include "build/build_config.h"
#include "chrome/browser/browser_process.h" #include "chrome/browser/browser_process.h"
#include "chrome/browser/net/system_network_context_manager.h"
#include "chrome/browser/translate/chrome_translate_client.h" #include "chrome/browser/translate/chrome_translate_client.h"
#include "chrome/common/chrome_switches.h" #include "chrome/common/chrome_switches.h"
#include "chrome/common/pref_names.h" #include "chrome/common/pref_names.h"
...@@ -53,6 +54,13 @@ void TranslateService::Initialize() { ...@@ -53,6 +54,13 @@ void TranslateService::Initialize() {
translate::TranslateDownloadManager::GetInstance(); translate::TranslateDownloadManager::GetInstance();
download_manager->set_request_context( download_manager->set_request_context(
g_browser_process->system_request_context()); g_browser_process->system_request_context());
SystemNetworkContextManager* system_network_context_manager =
g_browser_process->system_network_context_manager();
// Manager will be null if called from InitializeForTesting.
if (system_network_context_manager) {
download_manager->set_url_loader_factory(
system_network_context_manager->GetSharedURLLoaderFactory());
}
download_manager->set_application_locale( download_manager->set_application_locale(
g_browser_process->GetApplicationLocale()); g_browser_process->GetApplicationLocale());
} }
......
...@@ -39,6 +39,9 @@ static_library("assist_ranker") { ...@@ -39,6 +39,9 @@ static_library("assist_ranker") {
"//components/keyed_service/core", "//components/keyed_service/core",
"//net", "//net",
"//services/metrics/public/cpp:metrics_cpp", "//services/metrics/public/cpp:metrics_cpp",
"//services/network:network_service",
"//services/network/public/cpp",
"//services/network/public/mojom",
"//url", "//url",
] ]
} }
...@@ -62,6 +65,8 @@ source_set("unit_tests") { ...@@ -62,6 +65,8 @@ source_set("unit_tests") {
"//components/assist_ranker/proto", "//components/assist_ranker/proto",
"//components/ukm:test_support", "//components/ukm:test_support",
"//net:test_support", "//net:test_support",
"//services/network:test_support",
"//services/network/public/cpp",
"//testing/gtest", "//testing/gtest",
] ]
} }
...@@ -5,5 +5,8 @@ include_rules = [ ...@@ -5,5 +5,8 @@ include_rules = [
"+components/ukm", "+components/ukm",
"+net", "+net",
"+services/metrics/public", "+services/metrics/public",
"+services/network/public/cpp",
"+services/network/public/mojom",
"+services/network/test",
"+third_party/protobuf", "+third_party/protobuf",
] ]
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "components/assist_ranker/binary_classifier_predictor.h" #include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/ranker_model_loader_impl.h" #include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h" #include "services/network/public/cpp/shared_url_loader_factory.h"
#include "url/gurl.h" #include "url/gurl.h"
namespace assist_ranker { namespace assist_ranker {
AssistRankerServiceImpl::AssistRankerServiceImpl( AssistRankerServiceImpl::AssistRankerServiceImpl(
base::FilePath base_path, base::FilePath base_path,
net::URLRequestContextGetter* url_request_context_getter) scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: url_request_context_getter_(url_request_context_getter), : url_loader_factory_(std::move(url_loader_factory)),
base_path_(std::move(base_path)) {} base_path_(std::move(base_path)) {}
AssistRankerServiceImpl::~AssistRankerServiceImpl() {} AssistRankerServiceImpl::~AssistRankerServiceImpl() {}
...@@ -35,7 +35,7 @@ AssistRankerServiceImpl::FetchBinaryClassifierPredictor( ...@@ -35,7 +35,7 @@ AssistRankerServiceImpl::FetchBinaryClassifierPredictor(
DVLOG(1) << "Initializing predictor: " << model_name; DVLOG(1) << "Initializing predictor: " << model_name;
std::unique_ptr<BinaryClassifierPredictor> predictor = std::unique_ptr<BinaryClassifierPredictor> predictor =
BinaryClassifierPredictor::Create(config, GetModelPath(model_name), BinaryClassifierPredictor::Create(config, GetModelPath(model_name),
url_request_context_getter_.get()); url_loader_factory_);
base::WeakPtr<BinaryClassifierPredictor> weak_ptr = base::WeakPtr<BinaryClassifierPredictor> weak_ptr =
base::AsWeakPtr(predictor.get()); base::AsWeakPtr(predictor.get());
predictor_map_[model_name] = std::move(predictor); predictor_map_[model_name] = std::move(predictor);
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include "components/assist_ranker/assist_ranker_service.h" #include "components/assist_ranker/assist_ranker_service.h"
#include "components/assist_ranker/predictor_config.h" #include "components/assist_ranker/predictor_config.h"
namespace net { namespace network {
class URLRequestContextGetter; class SharedURLLoaderFactory;
} }
namespace assist_ranker { namespace assist_ranker {
...@@ -28,7 +28,7 @@ class AssistRankerServiceImpl : public AssistRankerService { ...@@ -28,7 +28,7 @@ class AssistRankerServiceImpl : public AssistRankerService {
public: public:
AssistRankerServiceImpl( AssistRankerServiceImpl(
base::FilePath base_path, base::FilePath base_path,
net::URLRequestContextGetter* url_request_context_getter); scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
~AssistRankerServiceImpl() override; ~AssistRankerServiceImpl() override;
// AssistRankerService... // AssistRankerService...
...@@ -39,8 +39,8 @@ class AssistRankerServiceImpl : public AssistRankerService { ...@@ -39,8 +39,8 @@ class AssistRankerServiceImpl : public AssistRankerService {
// Returns the full path to the model cache. // Returns the full path to the model cache.
base::FilePath GetModelPath(const std::string& model_filename); base::FilePath GetModelPath(const std::string& model_filename);
// Request Context Getter used for RankerURLFetcher. // URL loader factory used for RankerURLFetcher.
scoped_refptr<net::URLRequestContextGetter> url_request_context_getter_; scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// Base path where models are stored. // Base path where models are stored.
const base::FilePath base_path_; const base::FilePath base_path_;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "components/assist_ranker/proto/ranker_model.pb.h" #include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_model.h" #include "components/assist_ranker/ranker_model.h"
#include "components/assist_ranker/ranker_model_loader_impl.h" #include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h" #include "services/network/public/cpp/shared_url_loader_factory.h"
namespace assist_ranker { namespace assist_ranker {
...@@ -26,7 +26,7 @@ BinaryClassifierPredictor::~BinaryClassifierPredictor(){}; ...@@ -26,7 +26,7 @@ BinaryClassifierPredictor::~BinaryClassifierPredictor(){};
std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
const PredictorConfig& config, const PredictorConfig& config,
const base::FilePath& model_path, const base::FilePath& model_path,
net::URLRequestContextGetter* request_context_getter) { scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
std::unique_ptr<BinaryClassifierPredictor> predictor( std::unique_ptr<BinaryClassifierPredictor> predictor(
new BinaryClassifierPredictor(config)); new BinaryClassifierPredictor(config));
if (!predictor->is_query_enabled()) { if (!predictor->is_query_enabled()) {
...@@ -40,7 +40,7 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( ...@@ -40,7 +40,7 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
base::BindRepeating(&BinaryClassifierPredictor::ValidateModel), base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable, base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())), base::Unretained(predictor.get())),
request_context_getter, model_path, model_url, config.uma_prefix); url_loader_factory, model_path, model_url, config.uma_prefix);
predictor->LoadModel(std::move(model_loader)); predictor->LoadModel(std::move(model_loader));
return predictor; return predictor;
} }
......
...@@ -13,8 +13,8 @@ namespace base { ...@@ -13,8 +13,8 @@ namespace base {
class FilePath; class FilePath;
} }
namespace net { namespace network {
class URLRequestContextGetter; class SharedURLLoaderFactory;
} }
namespace assist_ranker { namespace assist_ranker {
...@@ -32,7 +32,8 @@ class BinaryClassifierPredictor : public BasePredictor { ...@@ -32,7 +32,8 @@ class BinaryClassifierPredictor : public BasePredictor {
static std::unique_ptr<BinaryClassifierPredictor> Create( static std::unique_ptr<BinaryClassifierPredictor> Create(
const PredictorConfig& config, const PredictorConfig& config,
const base::FilePath& model_path, const base::FilePath& model_path,
net::URLRequestContextGetter* request_context_getter) WARN_UNUSED_RESULT; scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
WARN_UNUSED_RESULT;
// Fills in a boolean decision given a RankerExample. Returns false if a // Fills in a boolean decision given a RankerExample. Returns false if a
// prediction could not be made (e.g. the model is not loaded yet). // prediction could not be made (e.g. the model is not loaded yet).
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "base/threading/sequenced_task_runner_handle.h" #include "base/threading/sequenced_task_runner_handle.h"
#include "components/assist_ranker/proto/ranker_model.pb.h" #include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_url_fetcher.h" #include "components/assist_ranker/ranker_url_fetcher.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace assist_ranker { namespace assist_ranker {
namespace { namespace {
...@@ -90,7 +91,7 @@ void SaveToFile(const GURL& model_url, ...@@ -90,7 +91,7 @@ void SaveToFile(const GURL& model_url,
RankerModelLoaderImpl::RankerModelLoaderImpl( RankerModelLoaderImpl::RankerModelLoaderImpl(
ValidateModelCallback validate_model_cb, ValidateModelCallback validate_model_cb,
OnModelAvailableCallback on_model_available_cb, OnModelAvailableCallback on_model_available_cb,
net::URLRequestContextGetter* request_context_getter, scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::FilePath model_path, base::FilePath model_path,
GURL model_url, GURL model_url,
std::string uma_prefix) std::string uma_prefix)
...@@ -99,7 +100,7 @@ RankerModelLoaderImpl::RankerModelLoaderImpl( ...@@ -99,7 +100,7 @@ RankerModelLoaderImpl::RankerModelLoaderImpl(
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})), base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})),
validate_model_cb_(std::move(validate_model_cb)), validate_model_cb_(std::move(validate_model_cb)),
on_model_available_cb_(std::move(on_model_available_cb)), on_model_available_cb_(std::move(on_model_available_cb)),
request_context_getter_(request_context_getter), url_loader_factory_(std::move(url_loader_factory)),
model_path_(std::move(model_path)), model_path_(std::move(model_path)),
model_url_(std::move(model_url)), model_url_(std::move(model_url)),
uma_prefix_(std::move(uma_prefix)), uma_prefix_(std::move(uma_prefix)),
...@@ -222,7 +223,7 @@ void RankerModelLoaderImpl::StartLoadFromURL() { ...@@ -222,7 +223,7 @@ void RankerModelLoaderImpl::StartLoadFromURL() {
url_fetcher_->Request(model_url_, url_fetcher_->Request(model_url_,
base::Bind(&RankerModelLoaderImpl::OnURLFetched, base::Bind(&RankerModelLoaderImpl::OnURLFetched,
weak_ptr_factory_.GetWeakPtr()), weak_ptr_factory_.GetWeakPtr()),
request_context_getter_.get()); url_loader_factory_.get());
// |url_fetcher_| maintains a request retry counter. If all allowed attempts // |url_fetcher_| maintains a request retry counter. If all allowed attempts
// have already been exhausted, then the loader is finished and has abandoned // have already been exhausted, then the loader is finished and has abandoned
......
...@@ -16,13 +16,16 @@ ...@@ -16,13 +16,16 @@
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h" #include "base/sequence_checker.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h" #include "url/gurl.h"
namespace base { namespace base {
class SequencedTaskRunner; class SequencedTaskRunner;
} // namespace base } // namespace base
namespace network {
class SharedURLLoaderFactory;
}
namespace assist_ranker { namespace assist_ranker {
class RankerURLFetcher; class RankerURLFetcher;
...@@ -48,9 +51,10 @@ class RankerModelLoaderImpl : public RankerModelLoader { ...@@ -48,9 +51,10 @@ class RankerModelLoaderImpl : public RankerModelLoader {
// //
// |uma_prefix| will be used as a prefix for the names of all UMA metrics // |uma_prefix| will be used as a prefix for the names of all UMA metrics
// generated by this loader. // generated by this loader.
RankerModelLoaderImpl(ValidateModelCallback validate_model_callback, RankerModelLoaderImpl(
ValidateModelCallback validate_model_callback,
OnModelAvailableCallback on_model_available_callback, OnModelAvailableCallback on_model_available_callback,
net::URLRequestContextGetter* request_context_getter, scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::FilePath model_path, base::FilePath model_path,
GURL model_url, GURL model_url,
std::string uma_prefix); std::string uma_prefix);
...@@ -130,8 +134,8 @@ class RankerModelLoaderImpl : public RankerModelLoader { ...@@ -130,8 +134,8 @@ class RankerModelLoaderImpl : public RankerModelLoader {
// constructed. // constructed.
const OnModelAvailableCallback on_model_available_cb_; const OnModelAvailableCallback on_model_available_cb_;
// Request Context Getter used for RankerURLFetcher. // URL loader factory used for RankerURLFetcher.
scoped_refptr<net::URLRequestContextGetter> request_context_getter_; scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// The path at which the model is (or should be) cached. // The path at which the model is (or should be) cached.
const base::FilePath model_path_; const base::FilePath model_path_;
......
...@@ -20,8 +20,9 @@ ...@@ -20,8 +20,9 @@
#include "components/assist_ranker/proto/ranker_model.pb.h" #include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/proto/translate_ranker_model.pb.h" #include "components/assist_ranker/proto/translate_ranker_model.pb.h"
#include "components/assist_ranker/ranker_model.h" #include "components/assist_ranker/ranker_model.h"
#include "net/url_request/test_url_fetcher_factory.h" #include "services/network/public/cpp/shared_url_loader_factory.h"
#include "net/url_request/url_request_test_util.h" #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace { namespace {
...@@ -76,15 +77,13 @@ class RankerModelLoaderImplTest : public ::testing::Test { ...@@ -76,15 +77,13 @@ class RankerModelLoaderImplTest : public ::testing::Test {
// Sets up the task scheduling/task-runner environment for each test. // Sets up the task scheduling/task-runner environment for each test.
base::test::ScopedTaskEnvironment scoped_task_environment_; base::test::ScopedTaskEnvironment scoped_task_environment_;
// Override the default URL fetcher to return custom responses for tests. // Override the default URL loader to return custom responses for tests.
net::FakeURLFetcherFactory url_fetcher_factory_; network::TestURLLoaderFactory test_loader_factory_;
scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_;
// Temporary directory for model files. // Temporary directory for model files.
base::ScopedTempDir scoped_temp_dir_; base::ScopedTempDir scoped_temp_dir_;
// Used for URLFetcher.
scoped_refptr<net::TestURLRequestContextGetter> request_context_getter_;
// A queue of responses to return from Validate(). If empty, validate will // A queue of responses to return from Validate(). If empty, validate will
// return 'OK'. // return 'OK'.
base::circular_deque<RankerModelStatus> validate_model_response_; base::circular_deque<RankerModelStatus> validate_model_response_;
...@@ -114,13 +113,13 @@ class RankerModelLoaderImplTest : public ::testing::Test { ...@@ -114,13 +113,13 @@ class RankerModelLoaderImplTest : public ::testing::Test {
DISALLOW_COPY_AND_ASSIGN(RankerModelLoaderImplTest); DISALLOW_COPY_AND_ASSIGN(RankerModelLoaderImplTest);
}; };
RankerModelLoaderImplTest::RankerModelLoaderImplTest() RankerModelLoaderImplTest::RankerModelLoaderImplTest() {
: url_fetcher_factory_(nullptr) {} test_shared_loader_factory_ =
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
&test_loader_factory_);
}
void RankerModelLoaderImplTest::SetUp() { void RankerModelLoaderImplTest::SetUp() {
request_context_getter_ =
new net::TestURLRequestContextGetter(base::ThreadTaskRunnerHandle::Get());
ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir()); ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir());
const auto& temp_dir_path = scoped_temp_dir_.GetPath(); const auto& temp_dir_path = scoped_temp_dir_.GetPath();
...@@ -172,7 +171,7 @@ bool RankerModelLoaderImplTest::DoLoaderTest(const base::FilePath& model_path, ...@@ -172,7 +171,7 @@ bool RankerModelLoaderImplTest::DoLoaderTest(const base::FilePath& model_path,
base::Unretained(this)), base::Unretained(this)),
base::Bind(&RankerModelLoaderImplTest::OnModelAvailable, base::Bind(&RankerModelLoaderImplTest::OnModelAvailable,
base::Unretained(this)), base::Unretained(this)),
request_context_getter_.get(), model_path, model_url, test_shared_loader_factory_, model_path, model_url,
"RankerModelLoaderImplTest"); "RankerModelLoaderImplTest");
loader->NotifyOfRankerActivity(); loader->NotifyOfRankerActivity();
scoped_task_environment_.RunUntilIdle(); scoped_task_environment_.RunUntilIdle();
...@@ -182,15 +181,13 @@ bool RankerModelLoaderImplTest::DoLoaderTest(const base::FilePath& model_path, ...@@ -182,15 +181,13 @@ bool RankerModelLoaderImplTest::DoLoaderTest(const base::FilePath& model_path,
void RankerModelLoaderImplTest::InitRemoteModels() { void RankerModelLoaderImplTest::InitRemoteModels() {
InitModel(remote_model_url_, base::Time(), base::TimeDelta(), &remote_model_); InitModel(remote_model_url_, base::Time(), base::TimeDelta(), &remote_model_);
url_fetcher_factory_.SetFakeResponse( test_loader_factory_.AddResponse(remote_model_url_.spec(),
remote_model_url_, remote_model_.SerializeAsString(), net::HTTP_OK, remote_model_.SerializeAsString());
net::URLRequestStatus::SUCCESS); test_loader_factory_.AddResponse(invalid_model_url_.spec(),
url_fetcher_factory_.SetFakeResponse(invalid_model_url_, kInvalidModelData, kInvalidModelData);
net::HTTP_OK, test_loader_factory_.AddResponse(
net::URLRequestStatus::SUCCESS); failed_model_url_, network::ResourceResponseHead(), "",
url_fetcher_factory_.SetFakeResponse(failed_model_url_, "", network::URLLoaderCompletionStatus(net::HTTP_INTERNAL_SERVER_ERROR));
net::HTTP_INTERNAL_SERVER_ERROR,
net::URLRequestStatus::FAILED);
} }
void RankerModelLoaderImplTest::InitLocalModels() { void RankerModelLoaderImplTest::InitLocalModels() {
......
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
#include "net/base/load_flags.h" #include "net/base/load_flags.h"
#include "net/http/http_status_code.h" #include "net/http/http_status_code.h"
#include "net/traffic_annotation/network_traffic_annotation.h" #include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/url_request/url_fetcher.h" #include "services/network/public/cpp/resource_request.h"
#include "net/url_request/url_request_status.h" #include "services/network/public/cpp/simple_url_loader.h"
#include "services/network/public/mojom/url_loader_factory.mojom.h"
namespace assist_ranker { namespace assist_ranker {
...@@ -28,7 +29,7 @@ RankerURLFetcher::~RankerURLFetcher() {} ...@@ -28,7 +29,7 @@ RankerURLFetcher::~RankerURLFetcher() {}
bool RankerURLFetcher::Request( bool RankerURLFetcher::Request(
const GURL& url, const GURL& url,
const RankerURLFetcher::Callback& callback, const RankerURLFetcher::Callback& callback,
net::URLRequestContextGetter* request_context_getter) { network::mojom::URLLoaderFactory* url_loader_factory) {
// This function is not supposed to be called if the previous operation is not // This function is not supposed to be called if the previous operation is not
// finished. // finished.
if (state_ == REQUESTING) { if (state_ == REQUESTING) {
...@@ -44,7 +45,7 @@ bool RankerURLFetcher::Request( ...@@ -44,7 +45,7 @@ bool RankerURLFetcher::Request(
url_ = url; url_ = url;
callback_ = callback; callback_ = callback;
if (request_context_getter == nullptr) if (url_loader_factory == nullptr)
return false; return false;
net::NetworkTrafficAnnotationTag traffic_annotation = net::NetworkTrafficAnnotationTag traffic_annotation =
...@@ -73,38 +74,37 @@ bool RankerURLFetcher::Request( ...@@ -73,38 +74,37 @@ bool RankerURLFetcher::Request(
policy_exception_justification: policy_exception_justification:
"Not implemented, considered not necessary as no user data is sent." "Not implemented, considered not necessary as no user data is sent."
})"); })");
// Create and initialize the URL fetcher. auto resource_request = std::make_unique<network::ResourceRequest>();
fetcher_ = net::URLFetcher::Create(url_, net::URLFetcher::GET, this, resource_request->url = url_;
traffic_annotation); resource_request->load_flags =
data_use_measurement::DataUseUserData::AttachToFetcher( net::LOAD_DO_NOT_SEND_COOKIES | net::LOAD_DO_NOT_SAVE_COOKIES;
fetcher_.get(), // TODO(https://crbug.com/808498): Re-add data use measurement once
data_use_measurement::DataUseUserData::MACHINE_INTELLIGENCE); // SimpleURLLoader supports it.
fetcher_->SetLoadFlags(net::LOAD_DO_NOT_SEND_COOKIES | // ID=data_use_measurement::DataUseUserData::MACHINE_INTELLIGENCE
net::LOAD_DO_NOT_SAVE_COOKIES); simple_url_loader_ = network::SimpleURLLoader::Create(
fetcher_->SetRequestContext(request_context_getter); std::move(resource_request), traffic_annotation);
if (max_retry_on_5xx_ > 0) {
// Set retry parameter for HTTP status code 5xx. This doesn't work against simple_url_loader_->SetRetryOptions(max_retry_on_5xx_,
// 106 (net::ERR_INTERNET_DISCONNECTED) and so on. network::SimpleURLLoader::RETRY_ON_5XX);
fetcher_->SetMaxRetriesOn5xx(max_retry_on_5xx_); }
fetcher_->Start(); simple_url_loader_->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
url_loader_factory,
base::BindOnce(&RankerURLFetcher::OnSimpleLoaderComplete,
base::Unretained(this)));
return true; return true;
} }
void RankerURLFetcher::OnURLFetchComplete(const net::URLFetcher* source) { void RankerURLFetcher::OnSimpleLoaderComplete(
DCHECK(fetcher_.get() == source); std::unique_ptr<std::string> response_body) {
std::string data; std::string data;
if (source->GetStatus().status() == net::URLRequestStatus::SUCCESS && if (response_body) {
source->GetResponseCode() == net::HTTP_OK) {
state_ = COMPLETED; state_ = COMPLETED;
source->GetResponseAsString(&data); data = std::move(*response_body);
} else { } else {
state_ = FAILED; state_ = FAILED;
} }
simple_url_loader_.reset();
// Transfer URLFetcher's ownership before invoking a callback.
std::unique_ptr<const net::URLFetcher> delete_ptr(fetcher_.release());
callback_.Run(state_ == COMPLETED, data); callback_.Run(state_ == COMPLETED, data);
} }
......
...@@ -9,14 +9,19 @@ ...@@ -9,14 +9,19 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/macros.h" #include "base/macros.h"
#include "net/url_request/url_fetcher_delegate.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h" #include "url/gurl.h"
namespace network {
class SimpleURLLoader;
namespace mojom {
class URLLoaderFactory;
} // namespace mojom
} // namespace network
namespace assist_ranker { namespace assist_ranker {
// Downloads Ranker models. // Downloads Ranker models.
class RankerURLFetcher : public net::URLFetcherDelegate { class RankerURLFetcher {
public: public:
// Callback type for Request(). // Callback type for Request().
typedef base::Callback<void(bool, const std::string&)> Callback; typedef base::Callback<void(bool, const std::string&)> Callback;
...@@ -30,7 +35,7 @@ class RankerURLFetcher : public net::URLFetcherDelegate { ...@@ -30,7 +35,7 @@ class RankerURLFetcher : public net::URLFetcherDelegate {
}; };
RankerURLFetcher(); RankerURLFetcher();
~RankerURLFetcher() override; ~RankerURLFetcher();
int max_retry_on_5xx() { return max_retry_on_5xx_; } int max_retry_on_5xx() { return max_retry_on_5xx_; }
void set_max_retry_on_5xx(int count) { max_retry_on_5xx_ = count; } void set_max_retry_on_5xx(int count) { max_retry_on_5xx_ = count; }
...@@ -41,23 +46,22 @@ class RankerURLFetcher : public net::URLFetcherDelegate { ...@@ -41,23 +46,22 @@ class RankerURLFetcher : public net::URLFetcherDelegate {
// is omitted due to retry limitation. // is omitted due to retry limitation.
bool Request(const GURL& url, bool Request(const GURL& url,
const Callback& callback, const Callback& callback,
net::URLRequestContextGetter* request_context); network::mojom::URLLoaderFactory* url_loader_factory);
// Gets internal state. // Gets internal state.
State state() { return state_; } State state() { return state_; }
// net::URLFetcherDelegate implementation:
void OnURLFetchComplete(const net::URLFetcher* source) override;
private: private:
void OnSimpleLoaderComplete(std::unique_ptr<std::string> response_body);
// URL to send the request. // URL to send the request.
GURL url_; GURL url_;
// Internal state. // Internal state.
enum State state_; enum State state_;
// URLFetcher instance. // SimpleURLLoader instance.
std::unique_ptr<net::URLFetcher> fetcher_; std::unique_ptr<network::SimpleURLLoader> simple_url_loader_;
// Callback passed at Request(). It will be invoked when an asynchronous // Callback passed at Request(). It will be invoked when an asynchronous
// fetch operation is finished. // fetch operation is finished.
......
...@@ -12,6 +12,7 @@ include_rules = [ ...@@ -12,6 +12,7 @@ include_rules = [
"+components/variations", "+components/variations",
"+google_apis", "+google_apis",
"+net", "+net",
"+services/network/public/cpp",
"+ui", "+ui",
"+third_party/metrics_proto", "+third_party/metrics_proto",
......
...@@ -63,6 +63,7 @@ static_library("browser") { ...@@ -63,6 +63,7 @@ static_library("browser") {
"//net", "//net",
"//services/metrics/public/cpp:metrics_cpp", "//services/metrics/public/cpp:metrics_cpp",
"//services/metrics/public/cpp:ukm_builders", "//services/metrics/public/cpp:ukm_builders",
"//services/network/public/cpp:cpp",
"//third_party/icu", "//third_party/icu",
"//third_party/metrics_proto", "//third_party/metrics_proto",
"//ui/base", "//ui/base",
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "components/translate/core/browser/translate_language_list.h" #include "components/translate/core/browser/translate_language_list.h"
#include "components/translate/core/browser/translate_script.h" #include "components/translate/core/browser/translate_script.h"
#include "net/url_request/url_request_context_getter.h" #include "net/url_request/url_request_context_getter.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace base { namespace base {
template <typename T> struct DefaultSingletonTraits; template <typename T> struct DefaultSingletonTraits;
...@@ -38,6 +39,16 @@ class TranslateDownloadManager { ...@@ -38,6 +39,16 @@ class TranslateDownloadManager {
request_context_ = context; request_context_ = context;
} }
// The URL loader factory used to download the resources.
// Should be set before this class can be used.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory() {
return url_loader_factory_;
}
void set_url_loader_factory(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
url_loader_factory_ = std::move(url_loader_factory);
}
// The application locale. // The application locale.
// Should be set before this class can be used. // Should be set before this class can be used.
const std::string& application_locale() { const std::string& application_locale() {
...@@ -110,6 +121,7 @@ class TranslateDownloadManager { ...@@ -110,6 +121,7 @@ class TranslateDownloadManager {
std::string application_locale_; std::string application_locale_;
scoped_refptr<net::URLRequestContextGetter> request_context_; scoped_refptr<net::URLRequestContextGetter> request_context_;
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
}; };
} // namespace translate } // namespace translate
......
...@@ -168,8 +168,8 @@ TranslateRankerImpl::TranslateRankerImpl(const base::FilePath& model_path, ...@@ -168,8 +168,8 @@ TranslateRankerImpl::TranslateRankerImpl(const base::FilePath& model_path,
base::Bind(&ValidateModel), base::Bind(&ValidateModel),
base::Bind(&TranslateRankerImpl::OnModelAvailable, base::Bind(&TranslateRankerImpl::OnModelAvailable,
weak_ptr_factory_.GetWeakPtr()), weak_ptr_factory_.GetWeakPtr()),
TranslateDownloadManager::GetInstance()->request_context(), model_path, TranslateDownloadManager::GetInstance()->url_loader_factory(),
model_url, kUmaPrefix); model_path, model_url, kUmaPrefix);
// Kick off the initial load from cache. // Kick off the initial load from cache.
model_loader_->NotifyOfRankerActivity(); model_loader_->NotifyOfRankerActivity();
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment