Commit 80258c10 authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Attempt to fetch CSD model from cache at startup

This CL adds code to fetch the model from the cache (without making a
network request) immediately at startup. Since the models have a very
long max-age, this should solidly address the high proportion of
MODEL_NEVER_FETCHED in SBClientPhishing.ClassifierNotReadyReason.

See https://uma.googleplex.com/p/chrome/histograms?sid=20464e66b314d32caf0ef36facc35186

Bug: 1068623
Change-Id: I46c56daa4abb5ac32abf54c1970cb4146e79c22a
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2250956
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Reviewed-by: default avatarBettina Dea <bdea@chromium.org>
Cr-Commit-Position: refs/heads/master@{#780063}
parent 2b828ca8
......@@ -120,6 +120,7 @@ ModelLoader::ModelLoader(
url_loader_factory_(url_loader_factory),
last_client_model_status_(ClientModelStatus::MODEL_NEVER_FETCHED) {
DCHECK(url_.is_valid());
StartFetch(/*only_from_cache=*/true);
}
// For testing only
......@@ -141,7 +142,7 @@ ModelLoader::~ModelLoader() {
DCHECK(fetch_sequence_checker_.CalledOnValidSequence());
}
void ModelLoader::StartFetch() {
void ModelLoader::StartFetch(bool only_from_cache) {
if (base::CommandLine::ForCurrentProcess()->HasSwitch(
kOverrideCsdModelFlag)) {
OverrideModelWithLocalFile();
......@@ -190,6 +191,8 @@ void ModelLoader::StartFetch() {
})");
auto resource_request = std::make_unique<network::ResourceRequest>();
resource_request->url = url_;
if (only_from_cache)
resource_request->load_flags = net::LOAD_ONLY_FROM_CACHE;
resource_request->credentials_mode = network::mojom::CredentialsMode::kOmit;
url_loader_ = network::SimpleURLLoader::Create(std::move(resource_request),
traffic_annotation);
......@@ -278,7 +281,8 @@ void ModelLoader::ScheduleFetch(int64_t delay_ms) {
DCHECK(fetch_sequence_checker_.CalledOnValidSequence());
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&ModelLoader::StartFetch, weak_factory_.GetWeakPtr()),
base::BindOnce(&ModelLoader::StartFetch, weak_factory_.GetWeakPtr(),
/*only_from_cache=*/false),
base::TimeDelta::FromMilliseconds(delay_ms));
}
......
......@@ -95,8 +95,10 @@ class ModelLoader {
const std::string& model_name);
// This is called periodically to check whether a new client model is
// available for download.
virtual void StartFetch();
// available for download. If |only_from_cache| is true, we will not make a
// network request, but will instead try to fetch the model from the local
// cache.
virtual void StartFetch(bool only_from_cache);
// This method is called when we're done fetching the model either because
// we hit an error somewhere or because we're actually done fetch and
......
......@@ -24,6 +24,7 @@
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "components/variations/variations_associated_data.h"
#include "content/public/test/browser_task_environment.h"
#include "net/base/load_flags.h"
#include "net/url_request/url_request_status.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
......@@ -61,7 +62,7 @@ class MockModelLoader : public ModelLoader {
class ModelLoaderTest : public testing::Test {
protected:
ModelLoaderTest()
: test_shared_loader_factory_(
: shared_loader_factory_(
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
&test_url_loader_factory_)) {
scoped_feature_list_.Init();
......@@ -107,15 +108,19 @@ class ModelLoaderTest : public testing::Test {
test_url_loader_factory_.AddResponse(model_url_.spec(), response_data);
}
scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory() {
return test_shared_loader_factory_;
scoped_refptr<network::SharedURLLoaderFactory> shared_loader_factory() {
return shared_loader_factory_;
}
network::TestURLLoaderFactory* test_url_loader_factory() {
return &test_url_loader_factory_;
}
private:
content::BrowserTaskEnvironment task_environment_;
network::TestURLLoaderFactory test_url_loader_factory_;
base::test::ScopedFeatureList scoped_feature_list_;
scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_;
scoped_refptr<network::SharedURLLoaderFactory> shared_loader_factory_;
GURL model_url_;
};
......@@ -124,15 +129,15 @@ ACTION_P(InvokeClosure, closure) {
}
TEST_F(ModelLoaderTest, FetchModelFromLocalFileTest) {
StrictMock<MockModelLoader> loader(
base::Closure(), test_shared_loader_factory(), "top_model.pb");
StrictMock<MockModelLoader> loader(base::Closure(), shared_loader_factory(),
"top_model.pb");
SetModelUrl(loader);
// The model fetch tries to read from local file but is empty.
{
base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
"csd-model-override-path", "");
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -140,7 +145,7 @@ TEST_F(ModelLoaderTest, FetchModelFromLocalFileTest) {
{
base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
"csd-model-override-path", "invalid-file");
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -155,7 +160,7 @@ TEST_F(ModelLoaderTest, FetchModelFromLocalFileTest) {
base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
"csd-model-override-path",
test_path.GetPath().AppendASCII("model.txt").MaybeAsASCII());
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -177,7 +182,7 @@ TEST_F(ModelLoaderTest, FetchModelFromLocalFileTest) {
test_path.GetPath().AppendASCII("model.txt").MaybeAsASCII());
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_SUCCESS, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -185,8 +190,8 @@ TEST_F(ModelLoaderTest, FetchModelFromLocalFileTest) {
// Test the response to many variations of model responses.
TEST_F(ModelLoaderTest, FetchModelTest) {
StrictMock<MockModelLoader> loader(
base::Closure(), test_shared_loader_factory(), "top_model.pb");
StrictMock<MockModelLoader> loader(base::Closure(), shared_loader_factory(),
"top_model.pb");
SetModelUrl(loader);
// The model fetch failed.
......@@ -195,7 +200,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse("blamodel", net::ERR_FAILED);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_FETCH_FAILED, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -206,7 +211,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse(std::string(), net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_EMPTY, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -218,7 +223,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_TOO_LARGE, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -229,7 +234,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse("Invalid model file", net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_PARSE_ERROR, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -242,7 +247,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse(model.SerializePartialAsString(), net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_MISSING_FIELDS, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -256,7 +261,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse(model.SerializePartialAsString(), net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_BAD_HASH_IDS, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -269,7 +274,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse(model.SerializeAsString(), net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -281,7 +286,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse(model.SerializeAsString(), net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_SUCCESS, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -295,7 +300,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse(model.SerializeAsString(), net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -307,7 +312,7 @@ TEST_F(ModelLoaderTest, FetchModelTest) {
SetModelFetchResponse(model.SerializeAsString(), net::OK);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_NOT_CHANGED, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loader.StartFetch(/*only_from_cache=*/false);
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
......@@ -414,4 +419,15 @@ TEST_F(ModelLoaderTest, ModelHasValidHashIds) {
EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
}
TEST_F(ModelLoaderTest, FetchesFromCacheAtStartup) {
ModelLoader model_loader(base::DoNothing(), shared_loader_factory(),
/*is_extended_reporting=*/false);
ASSERT_NE(test_url_loader_factory()->GetPendingRequest(0), nullptr);
// Check the request does not use the network
int load_flags =
test_url_loader_factory()->GetPendingRequest(0)->request.load_flags;
EXPECT_NE((load_flags & net::LOAD_ONLY_FROM_CACHE), 0);
}
} // namespace safe_browsing
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment