Commit dda7c815 authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Make ClientSideDetectionService track SBER prefs and only load one model

Previously, the ClientSideDetectionService was running two ModelLoaders,
one for the regular model, and one for the SBER model. This CL makes the
ClientSideDetectionService track the SB prefs, so it can only load the
model needed for that profile.

Bug: 1068623
Change-Id: I544c059caa3a2545dc41771fc0bee04a808b5992
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2249162Reviewed-by: default avatarBettina Dea <bdea@chromium.org>
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Cr-Commit-Position: refs/heads/master@{#780046}
parent a9be5907
......@@ -376,8 +376,7 @@ void ClientSideDetectionHost::DidFinishNavigation(
void ClientSideDetectionHost::SendModelToRenderFrame(
content::RenderProcessHost* process,
Profile* profile,
ModelLoader* model_loader_standard,
ModelLoader* model_loader_extended) {
ModelLoader* model_loader) {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
if (!web_contents() || web_contents() != tab_)
return;
......@@ -386,21 +385,11 @@ void ClientSideDetectionHost::SendModelToRenderFrame(
if (frame->GetProcess() != process)
continue;
std::string model;
if (IsSafeBrowsingEnabled(*profile->GetPrefs())) {
if (IsExtendedReportingEnabled(*profile->GetPrefs()) ||
IsEnhancedProtectionEnabled(*profile->GetPrefs())) {
model = model_loader_extended->model_str();
} else {
model = model_loader_standard->model_str();
}
}
if (phishing_detector_)
phishing_detector_.reset();
frame->GetRemoteInterfaces()->GetInterface(
phishing_detector_.BindNewPipeAndPassReceiver());
phishing_detector_->SetPhishingModel(model);
phishing_detector_->SetPhishingModel(model_loader->model_str());
}
}
......@@ -473,12 +462,8 @@ void ClientSideDetectionHost::PhishingDetectionDone(
base::TimeTicks::Now() - phishing_detection_start_time_);
UMA_HISTOGRAM_ENUMERATION("SBClientPhishing.PhishingDetectorResult", result);
if (result == mojom::PhishingDetectorResult::CLASSIFIER_NOT_READY) {
Profile* profile =
Profile::FromBrowserContext(web_contents()->GetBrowserContext());
UMA_HISTOGRAM_ENUMERATION(
"SBClientPhishing.ClassifierNotReadyReason",
csd_service_->GetLastModelStatus(
IsExtendedReportingEnabled(*profile->GetPrefs())));
UMA_HISTOGRAM_ENUMERATION("SBClientPhishing.ClassifierNotReadyReason",
csd_service_->GetLastModelStatus());
}
if (result != mojom::PhishingDetectorResult::SUCCESS)
return;
......
......@@ -55,8 +55,7 @@ class ClientSideDetectionHost : public content::WebContentsObserver,
// Send the model to the given render frame host.
void SendModelToRenderFrame(content::RenderProcessHost* process,
Profile* profile,
ModelLoader* model_loader_standard,
ModelLoader* model_loader_extended);
ModelLoader* model_loader);
// Called when the SafeBrowsingService found a hit with one of the
// SafeBrowsing lists. This method is called on the UI thread.
......
......@@ -1160,42 +1160,13 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) {
}
TEST_F(ClientSideDetectionHostTest, TestSendModelToRenderFrame) {
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, false);
// Safe Browsing is not enabled.
StrictMock<MockModelLoader> standard;
standard.SetModelStrForTesting("standard");
StrictMock<MockModelLoader> extended;
extended.SetModelStrForTesting("extended");
StrictMock<MockModelLoader> loader;
loader.SetModelStrForTesting("standard");
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("");
fake_phishing_detector_.Reset();
// Safe Browsing is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
web_contents()->GetMainFrame()->GetProcess(), profile(), &loader);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("standard");
fake_phishing_detector_.Reset();
{
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndEnableFeature(
safe_browsing::kEnhancedProtection);
// Safe Browsing enhanced protection is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("extended");
fake_phishing_detector_.Reset();
}
}
} // namespace safe_browsing
......@@ -73,8 +73,8 @@ ClientSideDetectionService::ClientSideDetectionService(Profile* profile)
: nullptr) {
profile_ = profile;
// |profile_| and |url_loader_factory_| can be null in unit tests
if (!profile_ || !url_loader_factory_)
// |profile_| can be null in unit tests
if (!profile_)
return;
pref_change_registrar_.Init(profile_->GetPrefs());
......@@ -82,6 +82,14 @@ ClientSideDetectionService::ClientSideDetectionService(Profile* profile)
prefs::kSafeBrowsingEnabled,
base::Bind(&ClientSideDetectionService::OnPrefsUpdated,
base::Unretained(this)));
pref_change_registrar_.Add(
prefs::kSafeBrowsingEnhanced,
base::Bind(&ClientSideDetectionService::OnPrefsUpdated,
base::Unretained(this)));
pref_change_registrar_.Add(
prefs::kSafeBrowsingScoutReportingEnabled,
base::Bind(&ClientSideDetectionService::OnPrefsUpdated,
base::Unretained(this)));
// Do an initial check of the prefs.
OnPrefsUpdated();
......@@ -93,10 +101,6 @@ ClientSideDetectionService::ClientSideDetectionService(
base::Closure update_renderers =
base::Bind(&ClientSideDetectionService::SendModelToRenderers,
base::Unretained(this));
model_loader_standard_.reset(
new ModelLoader(update_renderers, url_loader_factory_, false));
model_loader_extended_.reset(
new ModelLoader(update_renderers, url_loader_factory_, true));
registrar_.Add(this, content::NOTIFICATION_RENDERER_PROCESS_CREATED,
content::NotificationService::AllBrowserContextsAndSources());
......@@ -111,27 +115,38 @@ void ClientSideDetectionService::Shutdown() {
}
void ClientSideDetectionService::OnPrefsUpdated() {
SetEnabledAndRefreshState(IsSafeBrowsingEnabled(*profile_->GetPrefs()));
}
void ClientSideDetectionService::SetEnabledAndRefreshState(bool enabled) {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
SendModelToRenderers(); // always refresh the renderer state
if (enabled == enabled_)
bool enabled = IsSafeBrowsingEnabled(*profile_->GetPrefs());
bool extended_reporting =
IsEnhancedProtectionEnabled(*profile_->GetPrefs()) ||
IsExtendedReportingEnabled(*profile_->GetPrefs());
if (enabled == enabled_ && extended_reporting_ == extended_reporting)
return;
enabled_ = enabled;
extended_reporting_ = extended_reporting;
if (enabled_) {
if (!model_factory_.is_null()) {
model_loader_ = model_factory_.Run();
} else {
model_loader_ = std::make_unique<ModelLoader>(
base::BindRepeating(&ClientSideDetectionService::SendModelToRenderers,
base::Unretained(this)),
url_loader_factory_, extended_reporting_);
}
// Refresh the models when the service is enabled. This can happen when
// either of the preferences are toggled, or early during startup if
// safe browsing is already enabled. In a lot of cases the model will be
// in the cache so it won't actually be fetched from the network.
// We delay the first model fetches to avoid slowing down browser startup.
model_loader_standard_->ScheduleFetch(kInitialClientModelFetchDelayMs);
model_loader_extended_->ScheduleFetch(kInitialClientModelFetchDelayMs);
model_loader_->ScheduleFetch(kInitialClientModelFetchDelayMs);
} else {
// Cancel model loads in progress.
model_loader_standard_->CancelFetcher();
model_loader_extended_->CancelFetcher();
if (model_loader_) {
// Cancel model loads in progress.
model_loader_->CancelFetcher();
}
// Invoke pending callbacks with a false verdict.
for (auto it = client_phishing_reports_.begin();
it != client_phishing_reports_.end(); ++it) {
......@@ -207,9 +222,7 @@ void ClientSideDetectionService::Observe(
content::Source<content::RenderProcessHost>(source).ptr();
if (process->GetBrowserContext() == profile_) {
for (ClientSideDetectionHost* host : csd_hosts_) {
host->SendModelToRenderFrame(process, profile_,
model_loader_standard_.get(),
model_loader_extended_.get());
host->SendModelToRenderFrame(process, profile_, model_loader_.get());
}
}
}
......@@ -222,9 +235,7 @@ void ClientSideDetectionService::SendModelToRenderers() {
if (process->IsInitializedAndNotDead() &&
process->GetBrowserContext() == profile_) {
for (ClientSideDetectionHost* host : csd_hosts_) {
host->SendModelToRenderFrame(process, profile_,
model_loader_standard_.get(),
model_loader_extended_.get());
host->SendModelToRenderFrame(process, profile_, model_loader_.get());
}
}
}
......@@ -245,8 +256,8 @@ void ClientSideDetectionService::StartClientReportPhishingRequest(
}
// Fill in metadata about which model we used.
request->set_model_filename(model_loader_->name());
if (is_extended_reporting || is_enhanced_reporting) {
request->set_model_filename(model_loader_extended_->name());
if (is_enhanced_reporting) {
request->mutable_population()->set_user_population(
ChromeUserPopulation::ENHANCED_PROTECTION);
......@@ -255,7 +266,6 @@ void ClientSideDetectionService::StartClientReportPhishingRequest(
ChromeUserPopulation::EXTENDED_REPORTING);
}
} else {
request->set_model_filename(model_loader_standard_->name());
request->mutable_population()->set_user_population(
ChromeUserPopulation::SAFE_BROWSING);
}
......@@ -438,11 +448,21 @@ GURL ClientSideDetectionService::GetClientReportUrl(
return url;
}
ModelLoader::ClientModelStatus ClientSideDetectionService::GetLastModelStatus(
bool use_extended_model) {
ModelLoader* model_loader = use_extended_model ? model_loader_extended_.get()
: model_loader_standard_.get();
return model_loader->last_client_model_status();
ModelLoader::ClientModelStatus
ClientSideDetectionService::GetLastModelStatus() {
// |model_loader_| can be null in tests
return model_loader_ ? model_loader_->last_client_model_status()
: ModelLoader::MODEL_NEVER_FETCHED;
}
void ClientSideDetectionService::SetModelLoaderFactoryForTesting(
base::RepeatingCallback<std::unique_ptr<ModelLoader>()> factory) {
model_factory_ = factory;
}
void ClientSideDetectionService::SetURLLoaderFactoryForTesting(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
url_loader_factory_ = url_loader_factory;
}
} // namespace safe_browsing
......@@ -122,9 +122,17 @@ class ClientSideDetectionService : public content::NotificationObserver,
base::WeakPtr<ClientSideDetectionService> GetWeakPtr();
// Get the model status for the given client-side model (extended reporting or
// regular).
ModelLoader::ClientModelStatus GetLastModelStatus(bool use_extended_model);
// Get the model status for the given client-side model.
ModelLoader::ClientModelStatus GetLastModelStatus();
// Makes ModelLoaders be constructed by calling |factory| rather than the
// default constructor.
void SetModelLoaderFactoryForTesting(
base::RepeatingCallback<std::unique_ptr<ModelLoader>()> factory);
// Overrides the SharedURLLoaderFactory
void SetURLLoaderFactoryForTesting(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
private:
friend class ClientSideDetectionServiceTest;
......@@ -135,6 +143,8 @@ class ClientSideDetectionService : public content::NotificationObserver,
FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest,
SendClientReportPhishingRequest);
FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest, GetNumReportTest);
FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest,
TestModelFollowsPrefs);
// CacheState holds all information necessary to respond to a caller without
// actually making a HTTP request.
......@@ -152,15 +162,13 @@ class ClientSideDetectionService : public content::NotificationObserver,
static const int kNegativeCacheIntervalDays;
static const int kPositiveCacheIntervalMinutes;
// Called when the prefs have changed in a way we may need to respond to.
void OnPrefsUpdated();
// Enables or disables the service, and refreshes the state of all renderers.
// Called when the prefs have changed in a way we may need to respond to. May
// enable or disable the service and refresh the state of all renderers.
// Disabling cancels any pending requests; existing ClientSideDetectionHosts
// will have their callbacks called with "false" verdicts. Enabling starts
// downloading the model after a delay. In all cases, each render process is
// updated to match the state
void SetEnabledAndRefreshState(bool enabled);
void OnPrefsUpdated();
// Starts sending the request to the client-side detection frontends.
// This method takes ownership of both pointers.
......@@ -198,10 +206,11 @@ class ClientSideDetectionService : public content::NotificationObserver,
// it won't download the model nor report detected phishing URLs.
bool enabled_;
// We load two models: One for stadard Safe Browsing profiles,
// and one for those opted into extended reporting.
std::unique_ptr<ModelLoader> model_loader_standard_;
std::unique_ptr<ModelLoader> model_loader_extended_;
// Whether the service is in extended reporting mode or not. This affects the
// choice of model.
bool extended_reporting_;
std::unique_ptr<ModelLoader> model_loader_;
// Map of client report phishing request to the corresponding callback that
// has to be invoked when the request is done.
......@@ -233,6 +242,9 @@ class ClientSideDetectionService : public content::NotificationObserver,
std::vector<ClientSideDetectionHost*> csd_hosts_;
// Factory used for constructing ModelLoaders
base::RepeatingCallback<std::unique_ptr<ModelLoader>()> model_factory_;
// Used to asynchronously call the callbacks for
// SendClientReportPhishingRequest.
base::WeakPtrFactory<ClientSideDetectionService> weak_factory_{this};
......
......@@ -148,6 +148,10 @@ void ModelLoader::StartFetch() {
return;
}
// |url_loader_factory_| can be null in tests.
if (!url_loader_factory_)
return;
// Start fetching the model either from the cache or possibly from the
// network if the model isn't in the cache.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment