Commit 389b08fa authored by Mark Pilgrim's avatar Mark Pilgrim Committed by Commit Bot

Reland Migrate AssistRanker to SimpleURLLoader

As part of the effort to bring the new Network Service online, we are
deprecating URLFetcher in favor of SimpleURLLoader.

Background: https://docs.google.com/document/d/1YZDPeg3bf46QPU_vUotFwOEPHquww36b-UdRlO3ZpMU/edit
Previous CL: https://chromium-review.googlesource.com/c/chromium/src/+/1070110
Reverted: https://chromium-review.googlesource.com/c/chromium/src/+/1073656

Refactoring exposed an underlying bug in the existing code (max_retries_on_5xx_ could
be used before being initialized). This CL initializes it to 0 in the constructor.

Bug: 844937
Change-Id: I419565dc899b2b48118b351206ade4f15090836a
TBR: groby@chromium.org, charleszhao@chromium.org
Reviewed-on: https://chromium-review.googlesource.com/1076407
Commit-Queue: Mark Pilgrim <pilgrim@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Cr-Commit-Position: refs/heads/master@{#562806}
parent cf709734
......@@ -5,11 +5,13 @@
#include "chrome/browser/assist_ranker/assist_ranker_service_factory.h"
#include "chrome/browser/browser_process.h"
#include "chrome/browser/net/system_network_context_manager.h"
#include "chrome/browser/profiles/incognito_helpers.h"
#include "components/assist_ranker/assist_ranker_service_impl.h"
#include "components/keyed_service/content/browser_context_dependency_manager.h"
#include "components/keyed_service/core/keyed_service.h"
#include "content/public/browser/browser_context.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace assist_ranker {
......@@ -35,7 +37,9 @@ AssistRankerServiceFactory::~AssistRankerServiceFactory() {}
KeyedService* AssistRankerServiceFactory::BuildServiceInstanceFor(
content::BrowserContext* browser_context) const {
return new AssistRankerServiceImpl(
browser_context->GetPath(), g_browser_process->system_request_context());
browser_context->GetPath(),
g_browser_process->system_network_context_manager()
->GetSharedURLLoaderFactory());
}
content::BrowserContext* AssistRankerServiceFactory::GetBrowserContextToUse(
......
......@@ -11,6 +11,7 @@
#include "base/strings/string_util.h"
#include "build/build_config.h"
#include "chrome/browser/browser_process.h"
#include "chrome/browser/net/system_network_context_manager.h"
#include "chrome/browser/translate/chrome_translate_client.h"
#include "chrome/common/chrome_switches.h"
#include "chrome/common/pref_names.h"
......@@ -53,6 +54,13 @@ void TranslateService::Initialize() {
translate::TranslateDownloadManager::GetInstance();
download_manager->set_request_context(
g_browser_process->system_request_context());
SystemNetworkContextManager* system_network_context_manager =
g_browser_process->system_network_context_manager();
// Manager will be null if called from InitializeForTesting.
if (system_network_context_manager) {
download_manager->set_url_loader_factory(
system_network_context_manager->GetSharedURLLoaderFactory());
}
download_manager->set_application_locale(
g_browser_process->GetApplicationLocale());
}
......
......@@ -39,6 +39,9 @@ static_library("assist_ranker") {
"//components/keyed_service/core",
"//net",
"//services/metrics/public/cpp:metrics_cpp",
"//services/network:network_service",
"//services/network/public/cpp",
"//services/network/public/mojom",
"//url",
]
}
......@@ -62,6 +65,8 @@ source_set("unit_tests") {
"//components/assist_ranker/proto",
"//components/ukm:test_support",
"//net:test_support",
"//services/network:test_support",
"//services/network/public/cpp",
"//testing/gtest",
]
}
......@@ -5,5 +5,8 @@ include_rules = [
"+components/ukm",
"+net",
"+services/metrics/public",
"+services/network/public/cpp",
"+services/network/public/mojom",
"+services/network/test",
"+third_party/protobuf",
]
......@@ -6,15 +6,15 @@
#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "url/gurl.h"
namespace assist_ranker {
AssistRankerServiceImpl::AssistRankerServiceImpl(
base::FilePath base_path,
net::URLRequestContextGetter* url_request_context_getter)
: url_request_context_getter_(url_request_context_getter),
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: url_loader_factory_(std::move(url_loader_factory)),
base_path_(std::move(base_path)) {}
AssistRankerServiceImpl::~AssistRankerServiceImpl() {}
......@@ -35,7 +35,7 @@ AssistRankerServiceImpl::FetchBinaryClassifierPredictor(
DVLOG(1) << "Initializing predictor: " << model_name;
std::unique_ptr<BinaryClassifierPredictor> predictor =
BinaryClassifierPredictor::Create(config, GetModelPath(model_name),
url_request_context_getter_.get());
url_loader_factory_);
base::WeakPtr<BinaryClassifierPredictor> weak_ptr =
base::AsWeakPtr(predictor.get());
predictor_map_[model_name] = std::move(predictor);
......
......@@ -15,8 +15,8 @@
#include "components/assist_ranker/assist_ranker_service.h"
#include "components/assist_ranker/predictor_config.h"
namespace net {
class URLRequestContextGetter;
namespace network {
class SharedURLLoaderFactory;
}
namespace assist_ranker {
......@@ -28,7 +28,7 @@ class AssistRankerServiceImpl : public AssistRankerService {
public:
AssistRankerServiceImpl(
base::FilePath base_path,
net::URLRequestContextGetter* url_request_context_getter);
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
~AssistRankerServiceImpl() override;
// AssistRankerService...
......@@ -39,8 +39,8 @@ class AssistRankerServiceImpl : public AssistRankerService {
// Returns the full path to the model cache.
base::FilePath GetModelPath(const std::string& model_filename);
// Request Context Getter used for RankerURLFetcher.
scoped_refptr<net::URLRequestContextGetter> url_request_context_getter_;
// URL loader factory used for RankerURLFetcher.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// Base path where models are stored.
const base::FilePath base_path_;
......
......@@ -13,7 +13,7 @@
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_model.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace assist_ranker {
......@@ -26,7 +26,7 @@ BinaryClassifierPredictor::~BinaryClassifierPredictor(){};
std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
const PredictorConfig& config,
const base::FilePath& model_path,
net::URLRequestContextGetter* request_context_getter) {
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
std::unique_ptr<BinaryClassifierPredictor> predictor(
new BinaryClassifierPredictor(config));
if (!predictor->is_query_enabled()) {
......@@ -40,7 +40,7 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())),
request_context_getter, model_path, model_url, config.uma_prefix);
url_loader_factory, model_path, model_url, config.uma_prefix);
predictor->LoadModel(std::move(model_loader));
return predictor;
}
......
......@@ -13,8 +13,8 @@ namespace base {
class FilePath;
}
namespace net {
class URLRequestContextGetter;
namespace network {
class SharedURLLoaderFactory;
}
namespace assist_ranker {
......@@ -32,7 +32,8 @@ class BinaryClassifierPredictor : public BasePredictor {
static std::unique_ptr<BinaryClassifierPredictor> Create(
const PredictorConfig& config,
const base::FilePath& model_path,
net::URLRequestContextGetter* request_context_getter) WARN_UNUSED_RESULT;
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
WARN_UNUSED_RESULT;
// Fills in a boolean decision given a RankerExample. Returns false if a
// prediction could not be made (e.g. the model is not loaded yet).
......
......@@ -22,6 +22,7 @@
#include "base/threading/sequenced_task_runner_handle.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_url_fetcher.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace assist_ranker {
namespace {
......@@ -90,7 +91,7 @@ void SaveToFile(const GURL& model_url,
RankerModelLoaderImpl::RankerModelLoaderImpl(
ValidateModelCallback validate_model_cb,
OnModelAvailableCallback on_model_available_cb,
net::URLRequestContextGetter* request_context_getter,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::FilePath model_path,
GURL model_url,
std::string uma_prefix)
......@@ -99,7 +100,7 @@ RankerModelLoaderImpl::RankerModelLoaderImpl(
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})),
validate_model_cb_(std::move(validate_model_cb)),
on_model_available_cb_(std::move(on_model_available_cb)),
request_context_getter_(request_context_getter),
url_loader_factory_(std::move(url_loader_factory)),
model_path_(std::move(model_path)),
model_url_(std::move(model_url)),
uma_prefix_(std::move(uma_prefix)),
......@@ -222,7 +223,7 @@ void RankerModelLoaderImpl::StartLoadFromURL() {
url_fetcher_->Request(model_url_,
base::Bind(&RankerModelLoaderImpl::OnURLFetched,
weak_ptr_factory_.GetWeakPtr()),
request_context_getter_.get());
url_loader_factory_.get());
// |url_fetcher_| maintains a request retry counter. If all allowed attempts
// have already been exhausted, then the loader is finished and has abandoned
......
......@@ -16,13 +16,16 @@
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/time/time.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
namespace base {
class SequencedTaskRunner;
} // namespace base
namespace network {
class SharedURLLoaderFactory;
}
namespace assist_ranker {
class RankerURLFetcher;
......@@ -48,12 +51,13 @@ class RankerModelLoaderImpl : public RankerModelLoader {
//
// |uma_prefix| will be used as a prefix for the names of all UMA metrics
// generated by this loader.
RankerModelLoaderImpl(ValidateModelCallback validate_model_callback,
OnModelAvailableCallback on_model_available_callback,
net::URLRequestContextGetter* request_context_getter,
base::FilePath model_path,
GURL model_url,
std::string uma_prefix);
RankerModelLoaderImpl(
ValidateModelCallback validate_model_callback,
OnModelAvailableCallback on_model_available_callback,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::FilePath model_path,
GURL model_url,
std::string uma_prefix);
~RankerModelLoaderImpl() override;
......@@ -130,8 +134,8 @@ class RankerModelLoaderImpl : public RankerModelLoader {
// constructed.
const OnModelAvailableCallback on_model_available_cb_;
// Request Context Getter used for RankerURLFetcher.
scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
// URL loader factory used for RankerURLFetcher.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// The path at which the model is (or should be) cached.
const base::FilePath model_path_;
......
......@@ -20,8 +20,9 @@
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/proto/translate_ranker_model.pb.h"
#include "components/assist_ranker/ranker_model.h"
#include "net/url_request/test_url_fetcher_factory.h"
#include "net/url_request/url_request_test_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
......@@ -76,15 +77,13 @@ class RankerModelLoaderImplTest : public ::testing::Test {
// Sets up the task scheduling/task-runner environment for each test.
base::test::ScopedTaskEnvironment scoped_task_environment_;
// Override the default URL fetcher to return custom responses for tests.
net::FakeURLFetcherFactory url_fetcher_factory_;
// Override the default URL loader to return custom responses for tests.
network::TestURLLoaderFactory test_loader_factory_;
scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_;
// Temporary directory for model files.
base::ScopedTempDir scoped_temp_dir_;
// Used for URLFetcher.
scoped_refptr<net::TestURLRequestContextGetter> request_context_getter_;
// A queue of responses to return from Validate(). If empty, validate will
// return 'OK'.
base::circular_deque<RankerModelStatus> validate_model_response_;
......@@ -114,13 +113,13 @@ class RankerModelLoaderImplTest : public ::testing::Test {
DISALLOW_COPY_AND_ASSIGN(RankerModelLoaderImplTest);
};
RankerModelLoaderImplTest::RankerModelLoaderImplTest()
: url_fetcher_factory_(nullptr) {}
RankerModelLoaderImplTest::RankerModelLoaderImplTest() {
test_shared_loader_factory_ =
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
&test_loader_factory_);
}
void RankerModelLoaderImplTest::SetUp() {
request_context_getter_ =
new net::TestURLRequestContextGetter(base::ThreadTaskRunnerHandle::Get());
ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir());
const auto& temp_dir_path = scoped_temp_dir_.GetPath();
......@@ -172,7 +171,7 @@ bool RankerModelLoaderImplTest::DoLoaderTest(const base::FilePath& model_path,
base::Unretained(this)),
base::Bind(&RankerModelLoaderImplTest::OnModelAvailable,
base::Unretained(this)),
request_context_getter_.get(), model_path, model_url,
test_shared_loader_factory_, model_path, model_url,
"RankerModelLoaderImplTest");
loader->NotifyOfRankerActivity();
scoped_task_environment_.RunUntilIdle();
......@@ -182,15 +181,13 @@ bool RankerModelLoaderImplTest::DoLoaderTest(const base::FilePath& model_path,
void RankerModelLoaderImplTest::InitRemoteModels() {
InitModel(remote_model_url_, base::Time(), base::TimeDelta(), &remote_model_);
url_fetcher_factory_.SetFakeResponse(
remote_model_url_, remote_model_.SerializeAsString(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
url_fetcher_factory_.SetFakeResponse(invalid_model_url_, kInvalidModelData,
net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
url_fetcher_factory_.SetFakeResponse(failed_model_url_, "",
net::HTTP_INTERNAL_SERVER_ERROR,
net::URLRequestStatus::FAILED);
test_loader_factory_.AddResponse(remote_model_url_.spec(),
remote_model_.SerializeAsString());
test_loader_factory_.AddResponse(invalid_model_url_.spec(),
kInvalidModelData);
test_loader_factory_.AddResponse(
failed_model_url_, network::ResourceResponseHead(), "",
network::URLLoaderCompletionStatus(net::HTTP_INTERNAL_SERVER_ERROR));
}
void RankerModelLoaderImplTest::InitLocalModels() {
......
......@@ -9,8 +9,9 @@
#include "net/base/load_flags.h"
#include "net/http/http_status_code.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/url_request/url_fetcher.h"
#include "net/url_request/url_request_status.h"
#include "services/network/public/cpp/resource_request.h"
#include "services/network/public/cpp/simple_url_loader.h"
#include "services/network/public/mojom/url_loader_factory.mojom.h"
namespace assist_ranker {
......@@ -21,14 +22,15 @@ const int kMaxRetry = 16;
} // namespace
RankerURLFetcher::RankerURLFetcher() : state_(IDLE), retry_count_(0) {}
RankerURLFetcher::RankerURLFetcher()
: state_(IDLE), retry_count_(0), max_retry_on_5xx_(0) {}
RankerURLFetcher::~RankerURLFetcher() {}
bool RankerURLFetcher::Request(
const GURL& url,
const RankerURLFetcher::Callback& callback,
net::URLRequestContextGetter* request_context_getter) {
network::mojom::URLLoaderFactory* url_loader_factory) {
// This function is not supposed to be called if the previous operation is not
// finished.
if (state_ == REQUESTING) {
......@@ -44,7 +46,7 @@ bool RankerURLFetcher::Request(
url_ = url;
callback_ = callback;
if (request_context_getter == nullptr)
if (url_loader_factory == nullptr)
return false;
net::NetworkTrafficAnnotationTag traffic_annotation =
......@@ -73,38 +75,37 @@ bool RankerURLFetcher::Request(
policy_exception_justification:
"Not implemented, considered not necessary as no user data is sent."
})");
// Create and initialize the URL fetcher.
fetcher_ = net::URLFetcher::Create(url_, net::URLFetcher::GET, this,
traffic_annotation);
data_use_measurement::DataUseUserData::AttachToFetcher(
fetcher_.get(),
data_use_measurement::DataUseUserData::MACHINE_INTELLIGENCE);
fetcher_->SetLoadFlags(net::LOAD_DO_NOT_SEND_COOKIES |
net::LOAD_DO_NOT_SAVE_COOKIES);
fetcher_->SetRequestContext(request_context_getter);
// Set retry parameter for HTTP status code 5xx. This doesn't work against
// 106 (net::ERR_INTERNET_DISCONNECTED) and so on.
fetcher_->SetMaxRetriesOn5xx(max_retry_on_5xx_);
fetcher_->Start();
auto resource_request = std::make_unique<network::ResourceRequest>();
resource_request->url = url_;
resource_request->load_flags =
net::LOAD_DO_NOT_SEND_COOKIES | net::LOAD_DO_NOT_SAVE_COOKIES;
// TODO(https://crbug.com/808498): Re-add data use measurement once
// SimpleURLLoader supports it.
// ID=data_use_measurement::DataUseUserData::MACHINE_INTELLIGENCE
simple_url_loader_ = network::SimpleURLLoader::Create(
std::move(resource_request), traffic_annotation);
if (max_retry_on_5xx_ > 0) {
simple_url_loader_->SetRetryOptions(max_retry_on_5xx_,
network::SimpleURLLoader::RETRY_ON_5XX);
}
simple_url_loader_->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
url_loader_factory,
base::BindOnce(&RankerURLFetcher::OnSimpleLoaderComplete,
base::Unretained(this)));
return true;
}
void RankerURLFetcher::OnURLFetchComplete(const net::URLFetcher* source) {
DCHECK(fetcher_.get() == source);
void RankerURLFetcher::OnSimpleLoaderComplete(
std::unique_ptr<std::string> response_body) {
std::string data;
if (source->GetStatus().status() == net::URLRequestStatus::SUCCESS &&
source->GetResponseCode() == net::HTTP_OK) {
if (response_body) {
state_ = COMPLETED;
source->GetResponseAsString(&data);
data = std::move(*response_body);
} else {
state_ = FAILED;
}
// Transfer URLFetcher's ownership before invoking a callback.
std::unique_ptr<const net::URLFetcher> delete_ptr(fetcher_.release());
simple_url_loader_.reset();
callback_.Run(state_ == COMPLETED, data);
}
......
......@@ -9,14 +9,19 @@
#include "base/callback.h"
#include "base/macros.h"
#include "net/url_request/url_fetcher_delegate.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
namespace network {
class SimpleURLLoader;
namespace mojom {
class URLLoaderFactory;
} // namespace mojom
} // namespace network
namespace assist_ranker {
// Downloads Ranker models.
class RankerURLFetcher : public net::URLFetcherDelegate {
class RankerURLFetcher {
public:
// Callback type for Request().
typedef base::Callback<void(bool, const std::string&)> Callback;
......@@ -30,7 +35,7 @@ class RankerURLFetcher : public net::URLFetcherDelegate {
};
RankerURLFetcher();
~RankerURLFetcher() override;
~RankerURLFetcher();
int max_retry_on_5xx() { return max_retry_on_5xx_; }
void set_max_retry_on_5xx(int count) { max_retry_on_5xx_ = count; }
......@@ -41,23 +46,22 @@ class RankerURLFetcher : public net::URLFetcherDelegate {
// is omitted due to retry limitation.
bool Request(const GURL& url,
const Callback& callback,
net::URLRequestContextGetter* request_context);
network::mojom::URLLoaderFactory* url_loader_factory);
// Gets internal state.
State state() { return state_; }
// net::URLFetcherDelegate implementation:
void OnURLFetchComplete(const net::URLFetcher* source) override;
private:
void OnSimpleLoaderComplete(std::unique_ptr<std::string> response_body);
// URL to send the request.
GURL url_;
// Internal state.
enum State state_;
// URLFetcher instance.
std::unique_ptr<net::URLFetcher> fetcher_;
// SimpleURLLoader instance.
std::unique_ptr<network::SimpleURLLoader> simple_url_loader_;
// Callback passed at Request(). It will be invoked when an asynchronous
// fetch operation is finished.
......
......@@ -12,6 +12,7 @@ include_rules = [
"+components/variations",
"+google_apis",
"+net",
"+services/network/public/cpp",
"+ui",
"+third_party/metrics_proto",
......
......@@ -63,6 +63,7 @@ static_library("browser") {
"//net",
"//services/metrics/public/cpp:metrics_cpp",
"//services/metrics/public/cpp:ukm_builders",
"//services/network/public/cpp:cpp",
"//third_party/icu",
"//third_party/metrics_proto",
"//ui/base",
......
......@@ -13,6 +13,7 @@
#include "components/translate/core/browser/translate_language_list.h"
#include "components/translate/core/browser/translate_script.h"
#include "net/url_request/url_request_context_getter.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace base {
template <typename T> struct DefaultSingletonTraits;
......@@ -38,6 +39,16 @@ class TranslateDownloadManager {
request_context_ = context;
}
// The URL loader factory used to download the resources.
// Should be set before this class can be used.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory() {
return url_loader_factory_;
}
void set_url_loader_factory(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
url_loader_factory_ = std::move(url_loader_factory);
}
// The application locale.
// Should be set before this class can be used.
const std::string& application_locale() {
......@@ -110,6 +121,7 @@ class TranslateDownloadManager {
std::string application_locale_;
scoped_refptr<net::URLRequestContextGetter> request_context_;
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
};
} // namespace translate
......
......@@ -168,8 +168,8 @@ TranslateRankerImpl::TranslateRankerImpl(const base::FilePath& model_path,
base::Bind(&ValidateModel),
base::Bind(&TranslateRankerImpl::OnModelAvailable,
weak_ptr_factory_.GetWeakPtr()),
TranslateDownloadManager::GetInstance()->request_context(), model_path,
model_url, kUmaPrefix);
TranslateDownloadManager::GetInstance()->url_loader_factory(),
model_path, model_url, kUmaPrefix);
// Kick off the initial load from cache.
model_loader_->NotifyOfRankerActivity();
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment