Commit dda7c815 authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Make ClientSideDetectionService track SBER prefs and only load one model

Previously, the ClientSideDetectionService was running two ModelLoaders,
one for the regular model, and one for the SBER model. This CL makes the
ClientSideDetectionService track the SB prefs, so it can only load the
model needed for that profile.

Bug: 1068623
Change-Id: I544c059caa3a2545dc41771fc0bee04a808b5992
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2249162Reviewed-by: default avatarBettina Dea <bdea@chromium.org>
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Cr-Commit-Position: refs/heads/master@{#780046}
parent a9be5907
...@@ -376,8 +376,7 @@ void ClientSideDetectionHost::DidFinishNavigation( ...@@ -376,8 +376,7 @@ void ClientSideDetectionHost::DidFinishNavigation(
void ClientSideDetectionHost::SendModelToRenderFrame( void ClientSideDetectionHost::SendModelToRenderFrame(
content::RenderProcessHost* process, content::RenderProcessHost* process,
Profile* profile, Profile* profile,
ModelLoader* model_loader_standard, ModelLoader* model_loader) {
ModelLoader* model_loader_extended) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
if (!web_contents() || web_contents() != tab_) if (!web_contents() || web_contents() != tab_)
return; return;
...@@ -386,21 +385,11 @@ void ClientSideDetectionHost::SendModelToRenderFrame( ...@@ -386,21 +385,11 @@ void ClientSideDetectionHost::SendModelToRenderFrame(
if (frame->GetProcess() != process) if (frame->GetProcess() != process)
continue; continue;
std::string model;
if (IsSafeBrowsingEnabled(*profile->GetPrefs())) {
if (IsExtendedReportingEnabled(*profile->GetPrefs()) ||
IsEnhancedProtectionEnabled(*profile->GetPrefs())) {
model = model_loader_extended->model_str();
} else {
model = model_loader_standard->model_str();
}
}
if (phishing_detector_) if (phishing_detector_)
phishing_detector_.reset(); phishing_detector_.reset();
frame->GetRemoteInterfaces()->GetInterface( frame->GetRemoteInterfaces()->GetInterface(
phishing_detector_.BindNewPipeAndPassReceiver()); phishing_detector_.BindNewPipeAndPassReceiver());
phishing_detector_->SetPhishingModel(model); phishing_detector_->SetPhishingModel(model_loader->model_str());
} }
} }
...@@ -473,12 +462,8 @@ void ClientSideDetectionHost::PhishingDetectionDone( ...@@ -473,12 +462,8 @@ void ClientSideDetectionHost::PhishingDetectionDone(
base::TimeTicks::Now() - phishing_detection_start_time_); base::TimeTicks::Now() - phishing_detection_start_time_);
UMA_HISTOGRAM_ENUMERATION("SBClientPhishing.PhishingDetectorResult", result); UMA_HISTOGRAM_ENUMERATION("SBClientPhishing.PhishingDetectorResult", result);
if (result == mojom::PhishingDetectorResult::CLASSIFIER_NOT_READY) { if (result == mojom::PhishingDetectorResult::CLASSIFIER_NOT_READY) {
Profile* profile = UMA_HISTOGRAM_ENUMERATION("SBClientPhishing.ClassifierNotReadyReason",
Profile::FromBrowserContext(web_contents()->GetBrowserContext()); csd_service_->GetLastModelStatus());
UMA_HISTOGRAM_ENUMERATION(
"SBClientPhishing.ClassifierNotReadyReason",
csd_service_->GetLastModelStatus(
IsExtendedReportingEnabled(*profile->GetPrefs())));
} }
if (result != mojom::PhishingDetectorResult::SUCCESS) if (result != mojom::PhishingDetectorResult::SUCCESS)
return; return;
......
...@@ -55,8 +55,7 @@ class ClientSideDetectionHost : public content::WebContentsObserver, ...@@ -55,8 +55,7 @@ class ClientSideDetectionHost : public content::WebContentsObserver,
// Send the model to the given render frame host. // Send the model to the given render frame host.
void SendModelToRenderFrame(content::RenderProcessHost* process, void SendModelToRenderFrame(content::RenderProcessHost* process,
Profile* profile, Profile* profile,
ModelLoader* model_loader_standard, ModelLoader* model_loader);
ModelLoader* model_loader_extended);
// Called when the SafeBrowsingService found a hit with one of the // Called when the SafeBrowsingService found a hit with one of the
// SafeBrowsing lists. This method is called on the UI thread. // SafeBrowsing lists. This method is called on the UI thread.
......
...@@ -1160,42 +1160,13 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) { ...@@ -1160,42 +1160,13 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) {
} }
TEST_F(ClientSideDetectionHostTest, TestSendModelToRenderFrame) { TEST_F(ClientSideDetectionHostTest, TestSendModelToRenderFrame) {
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false); StrictMock<MockModelLoader> loader;
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, false); loader.SetModelStrForTesting("standard");
// Safe Browsing is not enabled.
StrictMock<MockModelLoader> standard;
standard.SetModelStrForTesting("standard");
StrictMock<MockModelLoader> extended;
extended.SetModelStrForTesting("extended");
csd_host_->SendModelToRenderFrame( csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard, web_contents()->GetMainFrame()->GetProcess(), profile(), &loader);
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("");
fake_phishing_detector_.Reset();
// Safe Browsing is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("standard"); fake_phishing_detector_.CheckModel("standard");
fake_phishing_detector_.Reset(); fake_phishing_detector_.Reset();
{
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndEnableFeature(
safe_browsing::kEnhancedProtection);
// Safe Browsing enhanced protection is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("extended");
fake_phishing_detector_.Reset();
}
} }
} // namespace safe_browsing } // namespace safe_browsing
...@@ -73,8 +73,8 @@ ClientSideDetectionService::ClientSideDetectionService(Profile* profile) ...@@ -73,8 +73,8 @@ ClientSideDetectionService::ClientSideDetectionService(Profile* profile)
: nullptr) { : nullptr) {
profile_ = profile; profile_ = profile;
// |profile_| and |url_loader_factory_| can be null in unit tests // |profile_| can be null in unit tests
if (!profile_ || !url_loader_factory_) if (!profile_)
return; return;
pref_change_registrar_.Init(profile_->GetPrefs()); pref_change_registrar_.Init(profile_->GetPrefs());
...@@ -82,6 +82,14 @@ ClientSideDetectionService::ClientSideDetectionService(Profile* profile) ...@@ -82,6 +82,14 @@ ClientSideDetectionService::ClientSideDetectionService(Profile* profile)
prefs::kSafeBrowsingEnabled, prefs::kSafeBrowsingEnabled,
base::Bind(&ClientSideDetectionService::OnPrefsUpdated, base::Bind(&ClientSideDetectionService::OnPrefsUpdated,
base::Unretained(this))); base::Unretained(this)));
pref_change_registrar_.Add(
prefs::kSafeBrowsingEnhanced,
base::Bind(&ClientSideDetectionService::OnPrefsUpdated,
base::Unretained(this)));
pref_change_registrar_.Add(
prefs::kSafeBrowsingScoutReportingEnabled,
base::Bind(&ClientSideDetectionService::OnPrefsUpdated,
base::Unretained(this)));
// Do an initial check of the prefs. // Do an initial check of the prefs.
OnPrefsUpdated(); OnPrefsUpdated();
...@@ -93,10 +101,6 @@ ClientSideDetectionService::ClientSideDetectionService( ...@@ -93,10 +101,6 @@ ClientSideDetectionService::ClientSideDetectionService(
base::Closure update_renderers = base::Closure update_renderers =
base::Bind(&ClientSideDetectionService::SendModelToRenderers, base::Bind(&ClientSideDetectionService::SendModelToRenderers,
base::Unretained(this)); base::Unretained(this));
model_loader_standard_.reset(
new ModelLoader(update_renderers, url_loader_factory_, false));
model_loader_extended_.reset(
new ModelLoader(update_renderers, url_loader_factory_, true));
registrar_.Add(this, content::NOTIFICATION_RENDERER_PROCESS_CREATED, registrar_.Add(this, content::NOTIFICATION_RENDERER_PROCESS_CREATED,
content::NotificationService::AllBrowserContextsAndSources()); content::NotificationService::AllBrowserContextsAndSources());
...@@ -111,27 +115,38 @@ void ClientSideDetectionService::Shutdown() { ...@@ -111,27 +115,38 @@ void ClientSideDetectionService::Shutdown() {
} }
void ClientSideDetectionService::OnPrefsUpdated() { void ClientSideDetectionService::OnPrefsUpdated() {
SetEnabledAndRefreshState(IsSafeBrowsingEnabled(*profile_->GetPrefs()));
}
void ClientSideDetectionService::SetEnabledAndRefreshState(bool enabled) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
SendModelToRenderers(); // always refresh the renderer state SendModelToRenderers(); // always refresh the renderer state
if (enabled == enabled_) bool enabled = IsSafeBrowsingEnabled(*profile_->GetPrefs());
bool extended_reporting =
IsEnhancedProtectionEnabled(*profile_->GetPrefs()) ||
IsExtendedReportingEnabled(*profile_->GetPrefs());
if (enabled == enabled_ && extended_reporting_ == extended_reporting)
return; return;
enabled_ = enabled; enabled_ = enabled;
extended_reporting_ = extended_reporting;
if (enabled_) { if (enabled_) {
if (!model_factory_.is_null()) {
model_loader_ = model_factory_.Run();
} else {
model_loader_ = std::make_unique<ModelLoader>(
base::BindRepeating(&ClientSideDetectionService::SendModelToRenderers,
base::Unretained(this)),
url_loader_factory_, extended_reporting_);
}
// Refresh the models when the service is enabled. This can happen when // Refresh the models when the service is enabled. This can happen when
// either of the preferences are toggled, or early during startup if // either of the preferences are toggled, or early during startup if
// safe browsing is already enabled. In a lot of cases the model will be // safe browsing is already enabled. In a lot of cases the model will be
// in the cache so it won't actually be fetched from the network. // in the cache so it won't actually be fetched from the network.
// We delay the first model fetches to avoid slowing down browser startup. // We delay the first model fetches to avoid slowing down browser startup.
model_loader_standard_->ScheduleFetch(kInitialClientModelFetchDelayMs); model_loader_->ScheduleFetch(kInitialClientModelFetchDelayMs);
model_loader_extended_->ScheduleFetch(kInitialClientModelFetchDelayMs);
} else { } else {
// Cancel model loads in progress. if (model_loader_) {
model_loader_standard_->CancelFetcher(); // Cancel model loads in progress.
model_loader_extended_->CancelFetcher(); model_loader_->CancelFetcher();
}
// Invoke pending callbacks with a false verdict. // Invoke pending callbacks with a false verdict.
for (auto it = client_phishing_reports_.begin(); for (auto it = client_phishing_reports_.begin();
it != client_phishing_reports_.end(); ++it) { it != client_phishing_reports_.end(); ++it) {
...@@ -207,9 +222,7 @@ void ClientSideDetectionService::Observe( ...@@ -207,9 +222,7 @@ void ClientSideDetectionService::Observe(
content::Source<content::RenderProcessHost>(source).ptr(); content::Source<content::RenderProcessHost>(source).ptr();
if (process->GetBrowserContext() == profile_) { if (process->GetBrowserContext() == profile_) {
for (ClientSideDetectionHost* host : csd_hosts_) { for (ClientSideDetectionHost* host : csd_hosts_) {
host->SendModelToRenderFrame(process, profile_, host->SendModelToRenderFrame(process, profile_, model_loader_.get());
model_loader_standard_.get(),
model_loader_extended_.get());
} }
} }
} }
...@@ -222,9 +235,7 @@ void ClientSideDetectionService::SendModelToRenderers() { ...@@ -222,9 +235,7 @@ void ClientSideDetectionService::SendModelToRenderers() {
if (process->IsInitializedAndNotDead() && if (process->IsInitializedAndNotDead() &&
process->GetBrowserContext() == profile_) { process->GetBrowserContext() == profile_) {
for (ClientSideDetectionHost* host : csd_hosts_) { for (ClientSideDetectionHost* host : csd_hosts_) {
host->SendModelToRenderFrame(process, profile_, host->SendModelToRenderFrame(process, profile_, model_loader_.get());
model_loader_standard_.get(),
model_loader_extended_.get());
} }
} }
} }
...@@ -245,8 +256,8 @@ void ClientSideDetectionService::StartClientReportPhishingRequest( ...@@ -245,8 +256,8 @@ void ClientSideDetectionService::StartClientReportPhishingRequest(
} }
// Fill in metadata about which model we used. // Fill in metadata about which model we used.
request->set_model_filename(model_loader_->name());
if (is_extended_reporting || is_enhanced_reporting) { if (is_extended_reporting || is_enhanced_reporting) {
request->set_model_filename(model_loader_extended_->name());
if (is_enhanced_reporting) { if (is_enhanced_reporting) {
request->mutable_population()->set_user_population( request->mutable_population()->set_user_population(
ChromeUserPopulation::ENHANCED_PROTECTION); ChromeUserPopulation::ENHANCED_PROTECTION);
...@@ -255,7 +266,6 @@ void ClientSideDetectionService::StartClientReportPhishingRequest( ...@@ -255,7 +266,6 @@ void ClientSideDetectionService::StartClientReportPhishingRequest(
ChromeUserPopulation::EXTENDED_REPORTING); ChromeUserPopulation::EXTENDED_REPORTING);
} }
} else { } else {
request->set_model_filename(model_loader_standard_->name());
request->mutable_population()->set_user_population( request->mutable_population()->set_user_population(
ChromeUserPopulation::SAFE_BROWSING); ChromeUserPopulation::SAFE_BROWSING);
} }
...@@ -438,11 +448,21 @@ GURL ClientSideDetectionService::GetClientReportUrl( ...@@ -438,11 +448,21 @@ GURL ClientSideDetectionService::GetClientReportUrl(
return url; return url;
} }
ModelLoader::ClientModelStatus ClientSideDetectionService::GetLastModelStatus( ModelLoader::ClientModelStatus
bool use_extended_model) { ClientSideDetectionService::GetLastModelStatus() {
ModelLoader* model_loader = use_extended_model ? model_loader_extended_.get() // |model_loader_| can be null in tests
: model_loader_standard_.get(); return model_loader_ ? model_loader_->last_client_model_status()
return model_loader->last_client_model_status(); : ModelLoader::MODEL_NEVER_FETCHED;
}
void ClientSideDetectionService::SetModelLoaderFactoryForTesting(
base::RepeatingCallback<std::unique_ptr<ModelLoader>()> factory) {
model_factory_ = factory;
}
void ClientSideDetectionService::SetURLLoaderFactoryForTesting(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
url_loader_factory_ = url_loader_factory;
} }
} // namespace safe_browsing } // namespace safe_browsing
...@@ -122,9 +122,17 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -122,9 +122,17 @@ class ClientSideDetectionService : public content::NotificationObserver,
base::WeakPtr<ClientSideDetectionService> GetWeakPtr(); base::WeakPtr<ClientSideDetectionService> GetWeakPtr();
// Get the model status for the given client-side model (extended reporting or // Get the model status for the given client-side model.
// regular). ModelLoader::ClientModelStatus GetLastModelStatus();
ModelLoader::ClientModelStatus GetLastModelStatus(bool use_extended_model);
// Makes ModelLoaders be constructed by calling |factory| rather than the
// default constructor.
void SetModelLoaderFactoryForTesting(
base::RepeatingCallback<std::unique_ptr<ModelLoader>()> factory);
// Overrides the SharedURLLoaderFactory
void SetURLLoaderFactoryForTesting(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
private: private:
friend class ClientSideDetectionServiceTest; friend class ClientSideDetectionServiceTest;
...@@ -135,6 +143,8 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -135,6 +143,8 @@ class ClientSideDetectionService : public content::NotificationObserver,
FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest, FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest,
SendClientReportPhishingRequest); SendClientReportPhishingRequest);
FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest, GetNumReportTest); FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest, GetNumReportTest);
FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest,
TestModelFollowsPrefs);
// CacheState holds all information necessary to respond to a caller without // CacheState holds all information necessary to respond to a caller without
// actually making a HTTP request. // actually making a HTTP request.
...@@ -152,15 +162,13 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -152,15 +162,13 @@ class ClientSideDetectionService : public content::NotificationObserver,
static const int kNegativeCacheIntervalDays; static const int kNegativeCacheIntervalDays;
static const int kPositiveCacheIntervalMinutes; static const int kPositiveCacheIntervalMinutes;
// Called when the prefs have changed in a way we may need to respond to. // Called when the prefs have changed in a way we may need to respond to. May
void OnPrefsUpdated(); // enable or disable the service and refresh the state of all renderers.
// Enables or disables the service, and refreshes the state of all renderers.
// Disabling cancels any pending requests; existing ClientSideDetectionHosts // Disabling cancels any pending requests; existing ClientSideDetectionHosts
// will have their callbacks called with "false" verdicts. Enabling starts // will have their callbacks called with "false" verdicts. Enabling starts
// downloading the model after a delay. In all cases, each render process is // downloading the model after a delay. In all cases, each render process is
// updated to match the state // updated to match the state
void SetEnabledAndRefreshState(bool enabled); void OnPrefsUpdated();
// Starts sending the request to the client-side detection frontends. // Starts sending the request to the client-side detection frontends.
// This method takes ownership of both pointers. // This method takes ownership of both pointers.
...@@ -198,10 +206,11 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -198,10 +206,11 @@ class ClientSideDetectionService : public content::NotificationObserver,
// it won't download the model nor report detected phishing URLs. // it won't download the model nor report detected phishing URLs.
bool enabled_; bool enabled_;
// We load two models: One for stadard Safe Browsing profiles, // Whether the service is in extended reporting mode or not. This affects the
// and one for those opted into extended reporting. // choice of model.
std::unique_ptr<ModelLoader> model_loader_standard_; bool extended_reporting_;
std::unique_ptr<ModelLoader> model_loader_extended_;
std::unique_ptr<ModelLoader> model_loader_;
// Map of client report phishing request to the corresponding callback that // Map of client report phishing request to the corresponding callback that
// has to be invoked when the request is done. // has to be invoked when the request is done.
...@@ -233,6 +242,9 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -233,6 +242,9 @@ class ClientSideDetectionService : public content::NotificationObserver,
std::vector<ClientSideDetectionHost*> csd_hosts_; std::vector<ClientSideDetectionHost*> csd_hosts_;
// Factory used for constructing ModelLoaders
base::RepeatingCallback<std::unique_ptr<ModelLoader>()> model_factory_;
// Used to asynchronously call the callbacks for // Used to asynchronously call the callbacks for
// SendClientReportPhishingRequest. // SendClientReportPhishingRequest.
base::WeakPtrFactory<ClientSideDetectionService> weak_factory_{this}; base::WeakPtrFactory<ClientSideDetectionService> weak_factory_{this};
......
...@@ -17,7 +17,12 @@ ...@@ -17,7 +17,12 @@
#include "base/metrics/field_trial.h" #include "base/metrics/field_trial.h"
#include "base/run_loop.h" #include "base/run_loop.h"
#include "base/strings/string_number_conversions.h" #include "base/strings/string_number_conversions.h"
#include "base/test/bind_test_util.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "chrome/test/base/testing_browser_process.h"
#include "chrome/test/base/testing_profile.h"
#include "chrome/test/base/testing_profile_manager.h"
#include "components/safe_browsing/core/common/safe_browsing_prefs.h"
#include "components/safe_browsing/core/proto/client_model.pb.h" #include "components/safe_browsing/core/proto/client_model.pb.h"
#include "components/safe_browsing/core/proto/csd.pb.h" #include "components/safe_browsing/core/proto/csd.pb.h"
#include "components/variations/variations_associated_data.h" #include "components/variations/variations_associated_data.h"
...@@ -55,19 +60,16 @@ class MockModelLoader : public ModelLoader { ...@@ -55,19 +60,16 @@ class MockModelLoader : public ModelLoader {
DISALLOW_COPY_AND_ASSIGN(MockModelLoader); DISALLOW_COPY_AND_ASSIGN(MockModelLoader);
}; };
class MockClientSideDetectionService : public ClientSideDetectionService {
public:
MockClientSideDetectionService() : ClientSideDetectionService(nullptr) {}
~MockClientSideDetectionService() override {}
private:
DISALLOW_COPY_AND_ASSIGN(MockClientSideDetectionService);
};
} // namespace } // namespace
class ClientSideDetectionServiceTest : public testing::Test { class ClientSideDetectionServiceTest : public testing::Test {
public:
ClientSideDetectionServiceTest()
: profile_manager_(TestingBrowserProcess::GetGlobal()) {
EXPECT_TRUE(profile_manager_.SetUp());
profile_ = profile_manager_.CreateTestingProfile("test-user");
}
protected: protected:
void SetUp() override { void SetUp() override {
test_shared_loader_factory_ = test_shared_loader_factory_ =
...@@ -197,15 +199,13 @@ class ClientSideDetectionServiceTest : public testing::Test { ...@@ -197,15 +199,13 @@ class ClientSideDetectionServiceTest : public testing::Test {
protected: protected:
content::BrowserTaskEnvironment task_environment_; content::BrowserTaskEnvironment task_environment_;
TestingProfileManager profile_manager_;
TestingProfile* profile_;
std::unique_ptr<ClientSideDetectionService> csd_service_; std::unique_ptr<ClientSideDetectionService> csd_service_;
network::TestURLLoaderFactory test_url_loader_factory_; network::TestURLLoaderFactory test_url_loader_factory_;
scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_; scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_;
#if defined(OS_CHROMEOS)
chromeos::ScopedStubInstallAttributes test_install_attributes_;
#endif
private: private:
void SendRequestDone(base::OnceClosure continuation_callback, void SendRequestDone(base::OnceClosure continuation_callback,
GURL phishing_url, GURL phishing_url,
...@@ -224,10 +224,9 @@ class ClientSideDetectionServiceTest : public testing::Test { ...@@ -224,10 +224,9 @@ class ClientSideDetectionServiceTest : public testing::Test {
TEST_F(ClientSideDetectionServiceTest, ServiceObjectDeletedBeforeCallbackDone) { TEST_F(ClientSideDetectionServiceTest, ServiceObjectDeletedBeforeCallbackDone) {
SetModelFetchResponses(); SetModelFetchResponses();
csd_service_ = csd_service_ = std::make_unique<ClientSideDetectionService>(profile_);
std::make_unique<ClientSideDetectionService>(test_shared_loader_factory_); profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
csd_service_->SetEnabledAndRefreshState(true); EXPECT_NE(csd_service_.get(), nullptr);
EXPECT_TRUE(csd_service_.get() != NULL);
// We delete the client-side detection service class even though the callbacks // We delete the client-side detection service class even though the callbacks
// haven't run yet. // haven't run yet.
csd_service_.reset(); csd_service_.reset();
...@@ -238,17 +237,17 @@ TEST_F(ClientSideDetectionServiceTest, ServiceObjectDeletedBeforeCallbackDone) { ...@@ -238,17 +237,17 @@ TEST_F(ClientSideDetectionServiceTest, ServiceObjectDeletedBeforeCallbackDone) {
TEST_F(ClientSideDetectionServiceTest, SendClientReportPhishingRequest) { TEST_F(ClientSideDetectionServiceTest, SendClientReportPhishingRequest) {
SetModelFetchResponses(); SetModelFetchResponses();
csd_service_ = csd_service_ = std::make_unique<ClientSideDetectionService>(profile_);
std::make_unique<ClientSideDetectionService>(test_shared_loader_factory_); csd_service_->SetURLLoaderFactoryForTesting(test_shared_loader_factory_);
GURL url("http://a.com/"); GURL url("http://a.com/");
float score = 0.4f; // Some random client score. float score = 0.4f; // Some random client score.
// Safe browsing is not enabled. // Safe browsing is not enabled.
csd_service_->SetEnabledAndRefreshState(false); profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
EXPECT_FALSE(SendClientReportPhishingRequest(url, score, false, true)); EXPECT_FALSE(SendClientReportPhishingRequest(url, score, false, true));
csd_service_->SetEnabledAndRefreshState(true); profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
base::Time before = base::Time::Now(); base::Time before = base::Time::Now();
// Invalid response body from the server. // Invalid response body from the server.
...@@ -294,8 +293,7 @@ TEST_F(ClientSideDetectionServiceTest, SendClientReportPhishingRequest) { ...@@ -294,8 +293,7 @@ TEST_F(ClientSideDetectionServiceTest, SendClientReportPhishingRequest) {
TEST_F(ClientSideDetectionServiceTest, GetNumReportTest) { TEST_F(ClientSideDetectionServiceTest, GetNumReportTest) {
SetModelFetchResponses(); SetModelFetchResponses();
csd_service_ = csd_service_ = std::make_unique<ClientSideDetectionService>(profile_);
std::make_unique<ClientSideDetectionService>(test_shared_loader_factory_);
base::queue<base::Time>& report_times = GetPhishingReportTimes(); base::queue<base::Time>& report_times = GetPhishingReportTimes();
base::Time now = base::Time::Now(); base::Time now = base::Time::Now();
...@@ -311,16 +309,14 @@ TEST_F(ClientSideDetectionServiceTest, GetNumReportTest) { ...@@ -311,16 +309,14 @@ TEST_F(ClientSideDetectionServiceTest, GetNumReportTest) {
TEST_F(ClientSideDetectionServiceTest, CacheTest) { TEST_F(ClientSideDetectionServiceTest, CacheTest) {
SetModelFetchResponses(); SetModelFetchResponses();
csd_service_ = csd_service_ = std::make_unique<ClientSideDetectionService>(profile_);
std::make_unique<ClientSideDetectionService>(test_shared_loader_factory_);
TestCache(); TestCache();
} }
TEST_F(ClientSideDetectionServiceTest, IsPrivateIPAddress) { TEST_F(ClientSideDetectionServiceTest, IsPrivateIPAddress) {
SetModelFetchResponses(); SetModelFetchResponses();
csd_service_ = csd_service_ = std::make_unique<ClientSideDetectionService>(profile_);
std::make_unique<ClientSideDetectionService>(test_shared_loader_factory_);
EXPECT_TRUE(csd_service_->IsPrivateIPAddress("10.1.2.3")); EXPECT_TRUE(csd_service_->IsPrivateIPAddress("10.1.2.3"));
EXPECT_TRUE(csd_service_->IsPrivateIPAddress("127.0.0.1")); EXPECT_TRUE(csd_service_->IsPrivateIPAddress("127.0.0.1"));
...@@ -343,64 +339,78 @@ TEST_F(ClientSideDetectionServiceTest, IsPrivateIPAddress) { ...@@ -343,64 +339,78 @@ TEST_F(ClientSideDetectionServiceTest, IsPrivateIPAddress) {
TEST_F(ClientSideDetectionServiceTest, SetEnabledAndRefreshState) { TEST_F(ClientSideDetectionServiceTest, SetEnabledAndRefreshState) {
// Check that the model isn't downloaded until the service is enabled. // Check that the model isn't downloaded until the service is enabled.
csd_service_ = profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
std::make_unique<ClientSideDetectionService>(test_shared_loader_factory_); csd_service_ = std::make_unique<ClientSideDetectionService>(profile_);
EXPECT_FALSE(csd_service_->enabled()); EXPECT_FALSE(csd_service_->enabled());
EXPECT_TRUE(csd_service_->model_loader_standard_->url_loader_.get() == NULL); EXPECT_TRUE(csd_service_->model_loader_ == nullptr);
// Use a MockClientSideDetectionService for the rest of the test, to avoid // Inject mock loader.
// the scheduling delay. csd_service_->SetModelLoaderFactoryForTesting(base::BindLambdaForTesting([] {
MockClientSideDetectionService* service = auto loader = std::make_unique<StrictMock<MockModelLoader>>("model1");
new StrictMock<MockClientSideDetectionService>(); return std::unique_ptr<ModelLoader>(std::move(loader));
// Inject mock loaders. }));
MockModelLoader* loader_1 = new StrictMock<MockModelLoader>("model1");
MockModelLoader* loader_2 = new StrictMock<MockModelLoader>("model2");
service->model_loader_standard_.reset(loader_1);
service->model_loader_extended_.reset(loader_2);
csd_service_.reset(service);
EXPECT_FALSE(csd_service_->enabled()); EXPECT_FALSE(csd_service_->enabled());
// No calls expected yet.
Mock::VerifyAndClearExpectations(service);
Mock::VerifyAndClearExpectations(loader_1);
Mock::VerifyAndClearExpectations(loader_2);
// Check that initial ScheduleFetch() calls are made. // Check that initial ScheduleFetch() calls are made.
EXPECT_CALL(*loader_1, csd_service_->SetModelLoaderFactoryForTesting(base::BindLambdaForTesting([] {
ScheduleFetch( auto loader = std::make_unique<StrictMock<MockModelLoader>>("model1");
ClientSideDetectionService::kInitialClientModelFetchDelayMs)); EXPECT_CALL(
EXPECT_CALL(*loader_2, *loader,
ScheduleFetch( ScheduleFetch(
ClientSideDetectionService::kInitialClientModelFetchDelayMs)); ClientSideDetectionService::kInitialClientModelFetchDelayMs));
csd_service_->SetEnabledAndRefreshState(true);
// Whenever this model is torn down, CancelFetcher will be called.
EXPECT_CALL(*loader, CancelFetcher());
return std::unique_ptr<ModelLoader>(std::move(loader));
}));
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
Mock::VerifyAndClearExpectations(service);
Mock::VerifyAndClearExpectations(loader_1);
Mock::VerifyAndClearExpectations(loader_2);
// Check that enabling again doesn't request the model. // Check that enabling again doesn't request the model.
csd_service_->SetEnabledAndRefreshState(true); profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
// No calls expected. // No calls expected.
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
Mock::VerifyAndClearExpectations(service);
Mock::VerifyAndClearExpectations(loader_1); // Check that disabling the service cancels pending requests. CancelFetch will
Mock::VerifyAndClearExpectations(loader_2); // be called here.
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
// Check that disabling the service cancels pending requests.
EXPECT_CALL(*loader_1, CancelFetcher());
EXPECT_CALL(*loader_2, CancelFetcher());
csd_service_->SetEnabledAndRefreshState(false);
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
Mock::VerifyAndClearExpectations(service);
Mock::VerifyAndClearExpectations(loader_1);
Mock::VerifyAndClearExpectations(loader_2);
// Check that disabling again doesn't request the model. // Check that disabling again doesn't request the model.
csd_service_->SetEnabledAndRefreshState(false); profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
// No calls expected. // No calls expected.
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
Mock::VerifyAndClearExpectations(service);
Mock::VerifyAndClearExpectations(loader_1);
Mock::VerifyAndClearExpectations(loader_2);
} }
TEST_F(ClientSideDetectionServiceTest, TestModelFollowsPrefs) {
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingScoutReportingEnabled,
false);
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, false);
csd_service_ = std::make_unique<ClientSideDetectionService>(profile_);
// Safe Browsing is not enabled.
EXPECT_EQ(csd_service_->model_loader_, nullptr);
// Safe Browsing is enabled.
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
ASSERT_NE(csd_service_->model_loader_, nullptr);
EXPECT_EQ(csd_service_->model_loader_->name(),
"client_model_v5_variation_0.pb");
// Safe Browsing extended reporting is enabled
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingScoutReportingEnabled,
true);
ASSERT_NE(csd_service_->model_loader_, nullptr);
EXPECT_EQ(csd_service_->model_loader_->name(),
"client_model_v5_ext_variation_0.pb");
// Safe Browsing enhanced protection is enabled.
profile_->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, true);
ASSERT_NE(csd_service_->model_loader_, nullptr);
EXPECT_EQ(csd_service_->model_loader_->name(),
"client_model_v5_ext_variation_0.pb");
}
} // namespace safe_browsing } // namespace safe_browsing
...@@ -148,6 +148,10 @@ void ModelLoader::StartFetch() { ...@@ -148,6 +148,10 @@ void ModelLoader::StartFetch() {
return; return;
} }
// |url_loader_factory_| can be null in tests.
if (!url_loader_factory_)
return;
// Start fetching the model either from the cache or possibly from the // Start fetching the model either from the cache or possibly from the
// network if the model isn't in the cache. // network if the model isn't in the cache.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment