Commit 39458f45 authored by Bettina's avatar Bettina Committed by Commit Bot

Merge PhishingModelSetter into PhishingClassifierDelegate

It seems the only purpose of the PhishingClassifierFilter
is to implement the PhishingModelSetter interface, so we
may as well merge it with the PhishingClassifierDelegate
to have one less interface.

Bug: 1078964
Change-Id: I0a4563df32681b59a203f5aab4ee851b61346ff1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2208101Reviewed-by: default avatarMustafa Emre Acer <meacer@chromium.org>
Reviewed-by: default avatarDaniel Rubery <drubery@chromium.org>
Commit-Queue: Bettina Dea <bdea@chromium.org>
Cr-Commit-Position: refs/heads/master@{#773948}
parent e8471f23
...@@ -290,6 +290,7 @@ std::unique_ptr<ClientSideDetectionHost> ClientSideDetectionHost::Create( ...@@ -290,6 +290,7 @@ std::unique_ptr<ClientSideDetectionHost> ClientSideDetectionHost::Create(
ClientSideDetectionHost::ClientSideDetectionHost(WebContents* tab) ClientSideDetectionHost::ClientSideDetectionHost(WebContents* tab)
: content::WebContentsObserver(tab), : content::WebContentsObserver(tab),
csd_service_(nullptr), csd_service_(nullptr),
tab_(tab),
classification_request_(nullptr), classification_request_(nullptr),
pageload_complete_(false), pageload_complete_(false),
unsafe_unique_page_id_(-1), unsafe_unique_page_id_(-1),
...@@ -312,6 +313,9 @@ ClientSideDetectionHost::ClientSideDetectionHost(WebContents* tab) ...@@ -312,6 +313,9 @@ ClientSideDetectionHost::ClientSideDetectionHost(WebContents* tab)
ClientSideDetectionHost::~ClientSideDetectionHost() { ClientSideDetectionHost::~ClientSideDetectionHost() {
if (ui_manager_.get()) if (ui_manager_.get())
ui_manager_->RemoveObserver(this); ui_manager_->RemoveObserver(this);
if (csd_service_)
csd_service_->RemoveClientSideDetectionHost(this);
} }
void ClientSideDetectionHost::DidFinishNavigation( void ClientSideDetectionHost::DidFinishNavigation(
...@@ -369,6 +373,44 @@ void ClientSideDetectionHost::DidFinishNavigation( ...@@ -369,6 +373,44 @@ void ClientSideDetectionHost::DidFinishNavigation(
classification_request_->Start(); classification_request_->Start();
} }
void ClientSideDetectionHost::SendModelToRenderFrame(
content::RenderProcessHost* process,
Profile* profile,
ModelLoader* model_loader_standard,
ModelLoader* model_loader_extended) {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
if (!web_contents() || web_contents() != tab_)
return;
for (content::RenderFrameHost* frame : web_contents()->GetAllFrames()) {
if (frame->GetProcess() != process)
continue;
std::string model;
if (IsSafeBrowsingEnabled(*profile->GetPrefs())) {
if (IsExtendedReportingEnabled(*profile->GetPrefs()) ||
IsEnhancedProtectionEnabled(*profile->GetPrefs())) {
DVLOG(2) << "Sending phishing model " << model_loader_extended->name()
<< " to RenderFrameHost @" << frame;
model = model_loader_extended->model_str();
} else {
DVLOG(2) << "Sending phishing model " << model_loader_standard->name()
<< " to RenderFrameHost @" << frame;
model = model_loader_standard->model_str();
}
} else {
DVLOG(2) << "Disabling client-side phishing detection for "
<< "RenderFrameHost @" << frame;
}
if (phishing_detector_)
phishing_detector_.reset();
frame->GetRemoteInterfaces()->GetInterface(
phishing_detector_.BindNewPipeAndPassReceiver());
phishing_detector_->SetPhishingModel(model);
}
}
void ClientSideDetectionHost::OnSafeBrowsingHit( void ClientSideDetectionHost::OnSafeBrowsingHit(
const security_interstitials::UnsafeResource& resource) { const security_interstitials::UnsafeResource& resource) {
if (!web_contents()) if (!web_contents())
...@@ -408,6 +450,8 @@ void ClientSideDetectionHost::WebContentsDestroyed() { ...@@ -408,6 +450,8 @@ void ClientSideDetectionHost::WebContentsDestroyed() {
} }
// Cancel all pending feature extractions. // Cancel all pending feature extractions.
feature_extractor_.reset(); feature_extractor_.reset();
csd_service_->RemoveClientSideDetectionHost(this);
} }
void ClientSideDetectionHost::OnPhishingPreClassificationDone( void ClientSideDetectionHost::OnPhishingPreClassificationDone(
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "chrome/browser/safe_browsing/browser_feature_extractor.h" #include "chrome/browser/safe_browsing/browser_feature_extractor.h"
#include "chrome/browser/safe_browsing/client_side_model_loader.h"
#include "chrome/browser/safe_browsing/ui_manager.h" #include "chrome/browser/safe_browsing/ui_manager.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h" #include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h"
#include "components/safe_browsing/core/db/database_manager.h" #include "components/safe_browsing/core/db/database_manager.h"
#include "content/public/browser/render_process_host.h"
#include "content/public/browser/web_contents_observer.h" #include "content/public/browser/web_contents_observer.h"
#include "mojo/public/cpp/bindings/remote.h" #include "mojo/public/cpp/bindings/remote.h"
#include "services/service_manager/public/cpp/binder_registry.h" #include "services/service_manager/public/cpp/binder_registry.h"
...@@ -50,6 +52,12 @@ class ClientSideDetectionHost : public content::WebContentsObserver, ...@@ -50,6 +52,12 @@ class ClientSideDetectionHost : public content::WebContentsObserver,
void DidFinishNavigation( void DidFinishNavigation(
content::NavigationHandle* navigation_handle) override; content::NavigationHandle* navigation_handle) override;
// Send the model to the given render frame host.
void SendModelToRenderFrame(content::RenderProcessHost* process,
Profile* profile,
ModelLoader* model_loader_standard,
ModelLoader* model_loader_extended);
// Called when the SafeBrowsingService found a hit with one of the // Called when the SafeBrowsingService found a hit with one of the
// SafeBrowsing lists. This method is called on the UI thread. // SafeBrowsing lists. This method is called on the UI thread.
void OnSafeBrowsingHit( void OnSafeBrowsingHit(
...@@ -112,6 +120,8 @@ class ClientSideDetectionHost : public content::WebContentsObserver, ...@@ -112,6 +120,8 @@ class ClientSideDetectionHost : public content::WebContentsObserver,
// This pointer may be nullptr if client-side phishing detection is disabled. // This pointer may be nullptr if client-side phishing detection is disabled.
ClientSideDetectionService* csd_service_; ClientSideDetectionService* csd_service_;
// The WebContents that the class is observing.
content::WebContents* tab_;
// These pointers may be nullptr if SafeBrowsing is disabled. // These pointers may be nullptr if SafeBrowsing is disabled.
scoped_refptr<SafeBrowsingDatabaseManager> database_manager_; scoped_refptr<SafeBrowsingDatabaseManager> database_manager_;
scoped_refptr<SafeBrowsingUIManager> ui_manager_; scoped_refptr<SafeBrowsingUIManager> ui_manager_;
......
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
#include "base/strings/stringprintf.h" #include "base/strings/stringprintf.h"
#include "base/synchronization/waitable_event.h" #include "base/synchronization/waitable_event.h"
#include "base/test/metrics/histogram_tester.h" #include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/simple_test_tick_clock.h" #include "base/test/simple_test_tick_clock.h"
#include "chrome/browser/safe_browsing/browser_feature_extractor.h" #include "chrome/browser/safe_browsing/browser_feature_extractor.h"
#include "chrome/browser/safe_browsing/client_side_detection_service.h" #include "chrome/browser/safe_browsing/client_side_detection_service.h"
#include "chrome/browser/safe_browsing/client_side_model_loader.h"
#include "chrome/browser/safe_browsing/safe_browsing_service.h" #include "chrome/browser/safe_browsing/safe_browsing_service.h"
#include "chrome/browser/safe_browsing/ui_manager.h" #include "chrome/browser/safe_browsing/ui_manager.h"
#include "chrome/common/chrome_switches.h" #include "chrome/common/chrome_switches.h"
...@@ -27,8 +29,10 @@ ...@@ -27,8 +29,10 @@
#include "chrome/test/base/testing_profile.h" #include "chrome/test/base/testing_profile.h"
#include "components/prefs/scoped_user_pref_update.h" #include "components/prefs/scoped_user_pref_update.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h" #include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h"
#include "components/safe_browsing/core/common/safe_browsing_prefs.h"
#include "components/safe_browsing/core/db/database_manager.h" #include "components/safe_browsing/core/db/database_manager.h"
#include "components/safe_browsing/core/db/test_database_manager.h" #include "components/safe_browsing/core/db/test_database_manager.h"
#include "components/safe_browsing/core/features.h"
#include "components/safe_browsing/core/proto/csd.pb.h" #include "components/safe_browsing/core/proto/csd.pb.h"
#include "components/security_interstitials/content/unsafe_resource_util.h" #include "components/security_interstitials/content/unsafe_resource_util.h"
#include "components/security_interstitials/core/unsafe_resource.h" #include "components/security_interstitials/core/unsafe_resource.h"
...@@ -107,6 +111,18 @@ ACTION(QuitUIMessageLoop) { ...@@ -107,6 +111,18 @@ ACTION(QuitUIMessageLoop) {
base::RunLoop::QuitCurrentWhenIdleDeprecated(); base::RunLoop::QuitCurrentWhenIdleDeprecated();
} }
class MockModelLoader : public ModelLoader {
public:
explicit MockModelLoader() : ModelLoader(base::Closure(), nullptr, false) {}
~MockModelLoader() override = default;
MOCK_METHOD1(ScheduleFetch, void(int64_t));
MOCK_METHOD0(CancelFetcher, void());
private:
DISALLOW_COPY_AND_ASSIGN(MockModelLoader);
};
class MockClientSideDetectionService : public ClientSideDetectionService { class MockClientSideDetectionService : public ClientSideDetectionService {
public: public:
MockClientSideDetectionService() : ClientSideDetectionService(nullptr) {} MockClientSideDetectionService() : ClientSideDetectionService(nullptr) {}
...@@ -190,6 +206,9 @@ class FakePhishingDetector : public mojom::PhishingDetector { ...@@ -190,6 +206,9 @@ class FakePhishingDetector : public mojom::PhishingDetector {
std::move(handle))); std::move(handle)));
} }
// mojom::PhishingDetector
void SetPhishingModel(const std::string& model) override { model_ = model; }
// mojom::PhishingDetector // mojom::PhishingDetector
void StartPhishingDetection( void StartPhishingDetection(
const GURL& url, const GURL& url,
...@@ -216,15 +235,19 @@ class FakePhishingDetector : public mojom::PhishingDetector { ...@@ -216,15 +235,19 @@ class FakePhishingDetector : public mojom::PhishingDetector {
} }
} }
void CheckModel(const std::string& model) { EXPECT_EQ(model, model_); }
void Reset() { void Reset() {
phishing_detection_started_ = false; phishing_detection_started_ = false;
url_ = GURL(); url_ = GURL();
model_ = "";
} }
private: private:
mojo::ReceiverSet<mojom::PhishingDetector> receivers_; mojo::ReceiverSet<mojom::PhishingDetector> receivers_;
bool phishing_detection_started_ = false; bool phishing_detection_started_ = false;
GURL url_; GURL url_;
std::string model_ = "";
DISALLOW_COPY_AND_ASSIGN(FakePhishingDetector); DISALLOW_COPY_AND_ASSIGN(FakePhishingDetector);
}; };
...@@ -1148,4 +1171,43 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) { ...@@ -1148,4 +1171,43 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) {
.min); .min);
} }
TEST_F(ClientSideDetectionHostTest, TestSendModelToRenderFrame) {
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, false);
// Safe Browsing is not enabled.
StrictMock<MockModelLoader> standard;
standard.SetModelStrForTesting("standard");
StrictMock<MockModelLoader> extended;
extended.SetModelStrForTesting("extended");
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("");
fake_phishing_detector_.Reset();
// Safe Browsing is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("standard");
fake_phishing_detector_.Reset();
{
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndEnableFeature(
safe_browsing::kEnhancedProtection);
// Safe Browsing enhanced protection is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("extended");
fake_phishing_detector_.Reset();
}
}
} // namespace safe_browsing } // namespace safe_browsing
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "chrome/browser/browser_process.h" #include "chrome/browser/browser_process.h"
#include "chrome/browser/policy/chrome_browser_policy_connector.h" #include "chrome/browser/policy/chrome_browser_policy_connector.h"
#include "chrome/browser/profiles/profile.h" #include "chrome/browser/profiles/profile.h"
#include "chrome/browser/safe_browsing/client_side_detection_host.h"
#include "chrome/common/pref_names.h" #include "chrome/common/pref_names.h"
#include "components/prefs/pref_service.h" #include "components/prefs/pref_service.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom.h" #include "components/safe_browsing/content/common/safe_browsing.mojom.h"
...@@ -166,6 +167,19 @@ bool ClientSideDetectionService::IsPrivateIPAddress( ...@@ -166,6 +167,19 @@ bool ClientSideDetectionService::IsPrivateIPAddress(
return !address.IsPubliclyRoutable(); return !address.IsPubliclyRoutable();
} }
void ClientSideDetectionService::AddClientSideDetectionHost(
ClientSideDetectionHost* host) {
csd_hosts_.push_back(host);
}
void ClientSideDetectionService::RemoveClientSideDetectionHost(
ClientSideDetectionHost* host) {
std::vector<ClientSideDetectionHost*>::iterator position =
std::find(csd_hosts_.begin(), csd_hosts_.end(), host);
if (position != csd_hosts_.end())
csd_hosts_.erase(position);
}
void ClientSideDetectionService::OnURLLoaderComplete( void ClientSideDetectionService::OnURLLoaderComplete(
network::SimpleURLLoader* url_loader, network::SimpleURLLoader* url_loader,
std::unique_ptr<std::string> response_body) { std::unique_ptr<std::string> response_body) {
...@@ -189,35 +203,13 @@ void ClientSideDetectionService::Observe( ...@@ -189,35 +203,13 @@ void ClientSideDetectionService::Observe(
DCHECK_EQ(content::NOTIFICATION_RENDERER_PROCESS_CREATED, type); DCHECK_EQ(content::NOTIFICATION_RENDERER_PROCESS_CREATED, type);
content::RenderProcessHost* process = content::RenderProcessHost* process =
content::Source<content::RenderProcessHost>(source).ptr(); content::Source<content::RenderProcessHost>(source).ptr();
if (process->GetBrowserContext() == profile_) if (process->GetBrowserContext() == profile_) {
SendModelToProcess(process); for (ClientSideDetectionHost* host : csd_hosts_) {
} host->SendModelToRenderFrame(process, profile_,
model_loader_standard_.get(),
void ClientSideDetectionService::SendModelToProcess( model_loader_extended_.get());
content::RenderProcessHost* process) {
DCHECK(process->IsInitializedAndNotDead());
DCHECK_EQ(process->GetBrowserContext(), profile_);
std::string model;
if (IsSafeBrowsingEnabled(*profile_->GetPrefs())) {
if (IsExtendedReportingEnabled(*profile_->GetPrefs()) ||
IsEnhancedProtectionEnabled(*profile_->GetPrefs())) {
DVLOG(2) << "Sending phishing model " << model_loader_extended_->name()
<< " to RenderProcessHost @" << process;
model = model_loader_extended_->model_str();
} else {
DVLOG(2) << "Sending phishing model " << model_loader_standard_->name()
<< " to RenderProcessHost @" << process;
model = model_loader_standard_->model_str();
} }
} else {
DVLOG(2) << "Disabling client-side phishing detection for "
<< "RenderProcessHost @" << process;
} }
mojo::Remote<safe_browsing::mojom::PhishingModelSetter> phishing;
process->BindReceiver(phishing.BindNewPipeAndPassReceiver());
phishing->SetPhishingModel(model);
} }
void ClientSideDetectionService::SendModelToRenderers() { void ClientSideDetectionService::SendModelToRenderers() {
...@@ -226,8 +218,13 @@ void ClientSideDetectionService::SendModelToRenderers() { ...@@ -226,8 +218,13 @@ void ClientSideDetectionService::SendModelToRenderers() {
!i.IsAtEnd(); i.Advance()) { !i.IsAtEnd(); i.Advance()) {
content::RenderProcessHost* process = i.GetCurrentValue(); content::RenderProcessHost* process = i.GetCurrentValue();
if (process->IsInitializedAndNotDead() && if (process->IsInitializedAndNotDead() &&
process->GetBrowserContext() == profile_) process->GetBrowserContext() == profile_) {
SendModelToProcess(process); for (ClientSideDetectionHost* host : csd_hosts_) {
host->SendModelToRenderFrame(process, profile_,
model_loader_standard_.get(),
model_loader_extended_.get());
}
}
} }
} }
......
...@@ -38,10 +38,6 @@ ...@@ -38,10 +38,6 @@
class Profile; class Profile;
namespace content {
class RenderProcessHost;
}
namespace network { namespace network {
class SimpleURLLoader; class SimpleURLLoader;
class SharedURLLoaderFactory; class SharedURLLoaderFactory;
...@@ -49,6 +45,7 @@ class SharedURLLoaderFactory; ...@@ -49,6 +45,7 @@ class SharedURLLoaderFactory;
namespace safe_browsing { namespace safe_browsing {
class ClientPhishingRequest; class ClientPhishingRequest;
class ClientSideDetectionHost;
// Main service which pushes models to the renderers, responds to classification // Main service which pushes models to the renderers, responds to classification
// requests. This owns two ModelLoader objects. // requests. This owns two ModelLoader objects.
...@@ -72,6 +69,9 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -72,6 +69,9 @@ class ClientSideDetectionService : public content::NotificationObserver,
return enabled_; return enabled_;
} }
void AddClientSideDetectionHost(ClientSideDetectionHost* host);
void RemoveClientSideDetectionHost(ClientSideDetectionHost* host);
void OnURLLoaderComplete(network::SimpleURLLoader* url_loader, void OnURLLoaderComplete(network::SimpleURLLoader* url_loader,
std::unique_ptr<std::string> response_body); std::unique_ptr<std::string> response_body);
...@@ -187,9 +187,6 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -187,9 +187,6 @@ class ClientSideDetectionService : public content::NotificationObserver,
// trims off the old elements. // trims off the old elements.
int GetNumReports(base::queue<base::Time>* report_times); int GetNumReports(base::queue<base::Time>* report_times);
// Send the model to the given renderer.
void SendModelToProcess(content::RenderProcessHost* process);
// Returns the URL that will be used for phishing requests. // Returns the URL that will be used for phishing requests.
static GURL GetClientReportUrl(const std::string& report_url); static GURL GetClientReportUrl(const std::string& report_url);
...@@ -233,6 +230,8 @@ class ClientSideDetectionService : public content::NotificationObserver, ...@@ -233,6 +230,8 @@ class ClientSideDetectionService : public content::NotificationObserver,
// PrefChangeRegistrar used to track when the Safe Browsing pref changes. // PrefChangeRegistrar used to track when the Safe Browsing pref changes.
PrefChangeRegistrar pref_change_registrar_; PrefChangeRegistrar pref_change_registrar_;
std::vector<ClientSideDetectionHost*> csd_hosts_;
// Used to asynchronously call the callbacks for // Used to asynchronously call the callbacks for
// SendClientReportPhishingRequest. // SendClientReportPhishingRequest.
base::WeakPtrFactory<ClientSideDetectionService> weak_factory_{this}; base::WeakPtrFactory<ClientSideDetectionService> weak_factory_{this};
......
...@@ -78,6 +78,9 @@ class ModelLoader { ...@@ -78,6 +78,9 @@ class ModelLoader {
// sequence as ScheduleFetch. // sequence as ScheduleFetch.
virtual void CancelFetcher(); virtual void CancelFetcher();
// Only used in tests.
void SetModelStrForTesting(const std::string& model_str) { model_str_ = model_str; }
const std::string& model_str() const { return model_str_; } const std::string& model_str() const { return model_str_; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#if BUILDFLAG(SAFE_BROWSING_CSD) #if BUILDFLAG(SAFE_BROWSING_CSD)
#include "chrome/browser/safe_browsing/client_side_detection_host.h" #include "chrome/browser/safe_browsing/client_side_detection_host.h"
#include "chrome/browser/safe_browsing/client_side_detection_service.h"
#include "chrome/browser/safe_browsing/client_side_detection_service_factory.h" #include "chrome/browser/safe_browsing/client_side_detection_service_factory.h"
#include "chrome/common/chrome_render_frame.mojom.h" #include "chrome/common/chrome_render_frame.mojom.h"
#include "third_party/blink/public/common/associated_interfaces/associated_interface_provider.h" #include "third_party/blink/public/common/associated_interfaces/associated_interface_provider.h"
...@@ -55,6 +56,8 @@ SafeBrowsingTabObserver::SafeBrowsingTabObserver( ...@@ -55,6 +56,8 @@ SafeBrowsingTabObserver::SafeBrowsingTabObserver(
g_browser_process->safe_browsing_service() && csd_service) { g_browser_process->safe_browsing_service() && csd_service) {
safebrowsing_detection_host_ = safebrowsing_detection_host_ =
ClientSideDetectionHost::Create(web_contents); ClientSideDetectionHost::Create(web_contents);
csd_service->AddClientSideDetectionHost(
safebrowsing_detection_host_.get());
} }
} }
#endif #endif
...@@ -78,6 +81,8 @@ void SafeBrowsingTabObserver::UpdateSafebrowsingDetectionHost() { ...@@ -78,6 +81,8 @@ void SafeBrowsingTabObserver::UpdateSafebrowsingDetectionHost() {
if (!safebrowsing_detection_host_.get()) { if (!safebrowsing_detection_host_.get()) {
safebrowsing_detection_host_ = safebrowsing_detection_host_ =
ClientSideDetectionHost::Create(web_contents_); ClientSideDetectionHost::Create(web_contents_);
csd_service->AddClientSideDetectionHost(
safebrowsing_detection_host_.get());
} }
} else { } else {
safebrowsing_detection_host_.reset(); safebrowsing_detection_host_.reset();
......
...@@ -12,16 +12,11 @@ ...@@ -12,16 +12,11 @@
#include "chrome/renderer/chrome_content_renderer_client.h" #include "chrome/renderer/chrome_content_renderer_client.h"
#include "chrome/renderer/chrome_render_thread_observer.h" #include "chrome/renderer/chrome_render_thread_observer.h"
#include "chrome/renderer/media/webrtc_logging_agent_impl.h" #include "chrome/renderer/media/webrtc_logging_agent_impl.h"
#include "components/safe_browsing/buildflags.h"
#include "components/spellcheck/spellcheck_buildflags.h" #include "components/spellcheck/spellcheck_buildflags.h"
#include "components/visitedlink/renderer/visitedlink_reader.h" #include "components/visitedlink/renderer/visitedlink_reader.h"
#include "components/web_cache/renderer/web_cache_impl.h" #include "components/web_cache/renderer/web_cache_impl.h"
#include "mojo/public/cpp/bindings/binder_map.h" #include "mojo/public/cpp/bindings/binder_map.h"
#if BUILDFLAG(FULL_SAFE_BROWSING)
#include "chrome/renderer/safe_browsing/phishing_classifier_delegate.h"
#endif
#if BUILDFLAG(ENABLE_SPELLCHECK) #if BUILDFLAG(ENABLE_SPELLCHECK)
#include "components/spellcheck/renderer/spellcheck.h" #include "components/spellcheck/renderer/spellcheck.h"
#endif #endif
...@@ -71,12 +66,6 @@ void ExposeChromeRendererInterfacesToBrowser( ...@@ -71,12 +66,6 @@ void ExposeChromeRendererInterfacesToBrowser(
binders->Add(base::BindRepeating(&BindWebRTCLoggingAgent, client), binders->Add(base::BindRepeating(&BindWebRTCLoggingAgent, client),
base::SequencedTaskRunnerHandle::Get()); base::SequencedTaskRunnerHandle::Get());
#if BUILDFLAG(FULL_SAFE_BROWSING)
binders->Add(
base::BindRepeating(&safe_browsing::PhishingClassifierFilter::Create),
base::SequencedTaskRunnerHandle::Get());
#endif
#if !defined(OS_ANDROID) #if !defined(OS_ANDROID)
binders->Add(base::BindRepeating( binders->Add(base::BindRepeating(
&performance_manager::V8PerFrameMemoryReporterImpl::Create), &performance_manager::V8PerFrameMemoryReporterImpl::Create),
......
...@@ -54,42 +54,6 @@ base::LazyInstance<std::unique_ptr<const safe_browsing::Scorer>>:: ...@@ -54,42 +54,6 @@ base::LazyInstance<std::unique_ptr<const safe_browsing::Scorer>>::
} // namespace } // namespace
// static
void PhishingClassifierFilter::Create(
mojo::PendingReceiver<mojom::PhishingModelSetter> receiver) {
mojo::MakeSelfOwnedReceiver(std::make_unique<PhishingClassifierFilter>(),
std::move(receiver));
}
PhishingClassifierFilter::PhishingClassifierFilter() {}
PhishingClassifierFilter::~PhishingClassifierFilter() {}
void PhishingClassifierFilter::SetPhishingModel(const std::string& model) {
safe_browsing::Scorer* scorer = NULL;
// An empty model string means we should disable client-side phishing
// detection.
if (!model.empty()) {
scorer = safe_browsing::Scorer::Create(model);
if (!scorer) {
DLOG(ERROR) << "Unable to create a PhishingScorer - corrupt model?";
return;
}
}
for (auto* delegate : PhishingClassifierDelegates())
delegate->SetPhishingScorer(scorer);
g_phishing_scorer.Get().reset(scorer);
}
// static
PhishingClassifierDelegate* PhishingClassifierDelegate::Create(
content::RenderFrame* render_frame,
PhishingClassifier* classifier) {
// Private constructor and public static Create() method to facilitate
// stubbing out this class for binary-size reduction purposes.
return new PhishingClassifierDelegate(render_frame, classifier);
}
PhishingClassifierDelegate::PhishingClassifierDelegate( PhishingClassifierDelegate::PhishingClassifierDelegate(
content::RenderFrame* render_frame, content::RenderFrame* render_frame,
PhishingClassifier* classifier) PhishingClassifier* classifier)
...@@ -118,6 +82,31 @@ PhishingClassifierDelegate::~PhishingClassifierDelegate() { ...@@ -118,6 +82,31 @@ PhishingClassifierDelegate::~PhishingClassifierDelegate() {
PhishingClassifierDelegates().erase(this); PhishingClassifierDelegates().erase(this);
} }
void PhishingClassifierDelegate::SetPhishingModel(const std::string& model) {
safe_browsing::Scorer* scorer = nullptr;
// An empty model string means we should disable client-side phishing
// detection.
if (!model.empty()) {
scorer = safe_browsing::Scorer::Create(model);
if (!scorer) {
DLOG(ERROR) << "Unable to create a PhishingScorer - corrupt model?";
return;
}
}
for (auto* delegate : PhishingClassifierDelegates())
delegate->SetPhishingScorer(scorer);
g_phishing_scorer.Get().reset(scorer);
}
// static
PhishingClassifierDelegate* PhishingClassifierDelegate::Create(
content::RenderFrame* render_frame,
PhishingClassifier* classifier) {
// Private constructor and public static Create() method to facilitate
// stubbing out this class for binary-size reduction purposes.
return new PhishingClassifierDelegate(render_frame, classifier);
}
void PhishingClassifierDelegate::SetPhishingScorer( void PhishingClassifierDelegate::SetPhishingScorer(
const safe_browsing::Scorer* scorer) { const safe_browsing::Scorer* scorer) {
if (is_classifying_) { if (is_classifying_) {
......
...@@ -38,21 +38,6 @@ enum class SBPhishingClassifierEvent { ...@@ -38,21 +38,6 @@ enum class SBPhishingClassifierEvent {
kMaxValue = kDestructedBeforeClassificationDone, kMaxValue = kDestructedBeforeClassificationDone,
}; };
class PhishingClassifierFilter : public mojom::PhishingModelSetter {
public:
PhishingClassifierFilter();
~PhishingClassifierFilter() override;
static void Create(
mojo::PendingReceiver<mojom::PhishingModelSetter> receiver);
private:
// mojom::PhishingModelSetter
void SetPhishingModel(const std::string& model) override;
DISALLOW_COPY_AND_ASSIGN(PhishingClassifierFilter);
};
class PhishingClassifierDelegate : public content::RenderFrameObserver, class PhishingClassifierDelegate : public content::RenderFrameObserver,
public mojom::PhishingDetector { public mojom::PhishingDetector {
public: public:
...@@ -63,6 +48,9 @@ class PhishingClassifierDelegate : public content::RenderFrameObserver, ...@@ -63,6 +48,9 @@ class PhishingClassifierDelegate : public content::RenderFrameObserver,
PhishingClassifier* classifier); PhishingClassifier* classifier);
~PhishingClassifierDelegate() override; ~PhishingClassifierDelegate() override;
// mojom::PhishingDetector
void SetPhishingModel(const std::string& model) override;
// Called by the RenderFrame once there is a phishing scorer available. // Called by the RenderFrame once there is a phishing scorer available.
// The scorer is passed on to the classifier. // The scorer is passed on to the classifier.
void SetPhishingScorer(const safe_browsing::Scorer* scorer); void SetPhishingScorer(const safe_browsing::Scorer* scorer);
......
...@@ -229,6 +229,25 @@ TEST_F(PhishingClassifierDelegateTest, Navigation) { ...@@ -229,6 +229,25 @@ TEST_F(PhishingClassifierDelegateTest, Navigation) {
EXPECT_CALL(*classifier_, CancelPendingClassification()); EXPECT_CALL(*classifier_, CancelPendingClassification());
} }
TEST_F(PhishingClassifierDelegateTest, NoPhishingModel) {
ASSERT_FALSE(classifier_->is_ready());
delegate_->SetPhishingModel("");
// The scorer is nullptr so the classifier should still not be ready.
ASSERT_FALSE(classifier_->is_ready());
}
TEST_F(PhishingClassifierDelegateTest, HasPhishingModel) {
ASSERT_FALSE(classifier_->is_ready());
ClientSideModel model;
model.set_max_words_per_term(1);
delegate_->SetPhishingModel(model.SerializeAsString());
ASSERT_TRUE(classifier_->is_ready());
// The delegate will cancel pending classification on destruction.
EXPECT_CALL(*classifier_, CancelPendingClassification());
}
TEST_F(PhishingClassifierDelegateTest, NoScorer) { TEST_F(PhishingClassifierDelegateTest, NoScorer) {
// For this test, we'll create the delegate with no scorer available yet. // For this test, we'll create the delegate with no scorer available yet.
ASSERT_FALSE(classifier_->is_ready()); ASSERT_FALSE(classifier_->is_ready());
......
...@@ -118,7 +118,13 @@ enum PhishingDetectorResult { ...@@ -118,7 +118,13 @@ enum PhishingDetectorResult {
}; };
[EnableIf=full_safe_browsing] [EnableIf=full_safe_browsing]
// Interface for setting the CSD model and to start phishing classification.
interface PhishingDetector { interface PhishingDetector {
// A classification model for client-side phishing detection.
// The string is an encoded safe_browsing::ClientSideModel protocol buffer, or
// empty to disable client-side phishing detection for this renderer.
SetPhishingModel(string model);
// Tells the renderer to begin phishing detection for the given toplevel URL // Tells the renderer to begin phishing detection for the given toplevel URL
// which it has started loading. Returns the serialized request proto and a // which it has started loading. Returns the serialized request proto and a
// |result| enum to indicate failure. If the URL is phishing the request proto // |result| enum to indicate failure. If the URL is phishing the request proto
...@@ -126,10 +132,3 @@ interface PhishingDetector { ...@@ -126,10 +132,3 @@ interface PhishingDetector {
StartPhishingDetection(url.mojom.Url url) StartPhishingDetection(url.mojom.Url url)
=> (PhishingDetectorResult result, string request_proto); => (PhishingDetectorResult result, string request_proto);
}; };
interface PhishingModelSetter {
// A classification model for client-side phishing detection.
// The string is an encoded safe_browsing::ClientSideModel protocol buffer, or
// empty to disable client-side phishing detection for this renderer.
SetPhishingModel(string model);
};
...@@ -96,6 +96,8 @@ class TestPhishingDetector : public mojom::PhishingDetector { ...@@ -96,6 +96,8 @@ class TestPhishingDetector : public mojom::PhishingDetector {
mojo::PendingReceiver<mojom::PhishingDetector>(std::move(handle))); mojo::PendingReceiver<mojom::PhishingDetector>(std::move(handle)));
} }
void SetPhishingModel(const std::string& model) override {}
void StartPhishingDetection( void StartPhishingDetection(
const GURL& url, const GURL& url,
StartPhishingDetectionCallback callback) override { StartPhishingDetectionCallback callback) override {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment