Commit 39458f45 authored by Bettina's avatar Bettina Committed by Commit Bot

Merge PhishingModelSetter into PhishingClassifierDelegate

It seems the only purpose of the PhishingClassifierFilter
is to implement the PhishingModelSetter interface, so we
may as well merge it with the PhishingClassifierDelegate
to have one less interface.

Bug: 1078964
Change-Id: I0a4563df32681b59a203f5aab4ee851b61346ff1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2208101Reviewed-by: default avatarMustafa Emre Acer <meacer@chromium.org>
Reviewed-by: default avatarDaniel Rubery <drubery@chromium.org>
Commit-Queue: Bettina Dea <bdea@chromium.org>
Cr-Commit-Position: refs/heads/master@{#773948}
parent e8471f23
......@@ -290,6 +290,7 @@ std::unique_ptr<ClientSideDetectionHost> ClientSideDetectionHost::Create(
ClientSideDetectionHost::ClientSideDetectionHost(WebContents* tab)
: content::WebContentsObserver(tab),
csd_service_(nullptr),
tab_(tab),
classification_request_(nullptr),
pageload_complete_(false),
unsafe_unique_page_id_(-1),
......@@ -312,6 +313,9 @@ ClientSideDetectionHost::ClientSideDetectionHost(WebContents* tab)
ClientSideDetectionHost::~ClientSideDetectionHost() {
if (ui_manager_.get())
ui_manager_->RemoveObserver(this);
if (csd_service_)
csd_service_->RemoveClientSideDetectionHost(this);
}
void ClientSideDetectionHost::DidFinishNavigation(
......@@ -369,6 +373,44 @@ void ClientSideDetectionHost::DidFinishNavigation(
classification_request_->Start();
}
void ClientSideDetectionHost::SendModelToRenderFrame(
content::RenderProcessHost* process,
Profile* profile,
ModelLoader* model_loader_standard,
ModelLoader* model_loader_extended) {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
if (!web_contents() || web_contents() != tab_)
return;
for (content::RenderFrameHost* frame : web_contents()->GetAllFrames()) {
if (frame->GetProcess() != process)
continue;
std::string model;
if (IsSafeBrowsingEnabled(*profile->GetPrefs())) {
if (IsExtendedReportingEnabled(*profile->GetPrefs()) ||
IsEnhancedProtectionEnabled(*profile->GetPrefs())) {
DVLOG(2) << "Sending phishing model " << model_loader_extended->name()
<< " to RenderFrameHost @" << frame;
model = model_loader_extended->model_str();
} else {
DVLOG(2) << "Sending phishing model " << model_loader_standard->name()
<< " to RenderFrameHost @" << frame;
model = model_loader_standard->model_str();
}
} else {
DVLOG(2) << "Disabling client-side phishing detection for "
<< "RenderFrameHost @" << frame;
}
if (phishing_detector_)
phishing_detector_.reset();
frame->GetRemoteInterfaces()->GetInterface(
phishing_detector_.BindNewPipeAndPassReceiver());
phishing_detector_->SetPhishingModel(model);
}
}
void ClientSideDetectionHost::OnSafeBrowsingHit(
const security_interstitials::UnsafeResource& resource) {
if (!web_contents())
......@@ -408,6 +450,8 @@ void ClientSideDetectionHost::WebContentsDestroyed() {
}
// Cancel all pending feature extractions.
feature_extractor_.reset();
csd_service_->RemoveClientSideDetectionHost(this);
}
void ClientSideDetectionHost::OnPhishingPreClassificationDone(
......
......@@ -14,9 +14,11 @@
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "chrome/browser/safe_browsing/browser_feature_extractor.h"
#include "chrome/browser/safe_browsing/client_side_model_loader.h"
#include "chrome/browser/safe_browsing/ui_manager.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h"
#include "components/safe_browsing/core/db/database_manager.h"
#include "content/public/browser/render_process_host.h"
#include "content/public/browser/web_contents_observer.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/service_manager/public/cpp/binder_registry.h"
......@@ -50,6 +52,12 @@ class ClientSideDetectionHost : public content::WebContentsObserver,
void DidFinishNavigation(
content::NavigationHandle* navigation_handle) override;
// Send the model to the given render frame host.
void SendModelToRenderFrame(content::RenderProcessHost* process,
Profile* profile,
ModelLoader* model_loader_standard,
ModelLoader* model_loader_extended);
// Called when the SafeBrowsingService found a hit with one of the
// SafeBrowsing lists. This method is called on the UI thread.
void OnSafeBrowsingHit(
......@@ -112,6 +120,8 @@ class ClientSideDetectionHost : public content::WebContentsObserver,
// This pointer may be nullptr if client-side phishing detection is disabled.
ClientSideDetectionService* csd_service_;
// The WebContents that the class is observing.
content::WebContents* tab_;
// These pointers may be nullptr if SafeBrowsing is disabled.
scoped_refptr<SafeBrowsingDatabaseManager> database_manager_;
scoped_refptr<SafeBrowsingUIManager> ui_manager_;
......
......@@ -17,9 +17,11 @@
#include "base/strings/stringprintf.h"
#include "base/synchronization/waitable_event.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/simple_test_tick_clock.h"
#include "chrome/browser/safe_browsing/browser_feature_extractor.h"
#include "chrome/browser/safe_browsing/client_side_detection_service.h"
#include "chrome/browser/safe_browsing/client_side_model_loader.h"
#include "chrome/browser/safe_browsing/safe_browsing_service.h"
#include "chrome/browser/safe_browsing/ui_manager.h"
#include "chrome/common/chrome_switches.h"
......@@ -27,8 +29,10 @@
#include "chrome/test/base/testing_profile.h"
#include "components/prefs/scoped_user_pref_update.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h"
#include "components/safe_browsing/core/common/safe_browsing_prefs.h"
#include "components/safe_browsing/core/db/database_manager.h"
#include "components/safe_browsing/core/db/test_database_manager.h"
#include "components/safe_browsing/core/features.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "components/security_interstitials/content/unsafe_resource_util.h"
#include "components/security_interstitials/core/unsafe_resource.h"
......@@ -107,6 +111,18 @@ ACTION(QuitUIMessageLoop) {
base::RunLoop::QuitCurrentWhenIdleDeprecated();
}
class MockModelLoader : public ModelLoader {
public:
explicit MockModelLoader() : ModelLoader(base::Closure(), nullptr, false) {}
~MockModelLoader() override = default;
MOCK_METHOD1(ScheduleFetch, void(int64_t));
MOCK_METHOD0(CancelFetcher, void());
private:
DISALLOW_COPY_AND_ASSIGN(MockModelLoader);
};
class MockClientSideDetectionService : public ClientSideDetectionService {
public:
MockClientSideDetectionService() : ClientSideDetectionService(nullptr) {}
......@@ -190,6 +206,9 @@ class FakePhishingDetector : public mojom::PhishingDetector {
std::move(handle)));
}
// mojom::PhishingDetector
void SetPhishingModel(const std::string& model) override { model_ = model; }
// mojom::PhishingDetector
void StartPhishingDetection(
const GURL& url,
......@@ -216,15 +235,19 @@ class FakePhishingDetector : public mojom::PhishingDetector {
}
}
void CheckModel(const std::string& model) { EXPECT_EQ(model, model_); }
void Reset() {
phishing_detection_started_ = false;
url_ = GURL();
model_ = "";
}
private:
mojo::ReceiverSet<mojom::PhishingDetector> receivers_;
bool phishing_detection_started_ = false;
GURL url_;
std::string model_ = "";
DISALLOW_COPY_AND_ASSIGN(FakePhishingDetector);
};
......@@ -1148,4 +1171,43 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) {
.min);
}
TEST_F(ClientSideDetectionHostTest, TestSendModelToRenderFrame) {
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, false);
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, false);
// Safe Browsing is not enabled.
StrictMock<MockModelLoader> standard;
standard.SetModelStrForTesting("standard");
StrictMock<MockModelLoader> extended;
extended.SetModelStrForTesting("extended");
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("");
fake_phishing_detector_.Reset();
// Safe Browsing is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnabled, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("standard");
fake_phishing_detector_.Reset();
{
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndEnableFeature(
safe_browsing::kEnhancedProtection);
// Safe Browsing enhanced protection is enabled.
profile()->GetPrefs()->SetBoolean(prefs::kSafeBrowsingEnhanced, true);
csd_host_->SendModelToRenderFrame(
web_contents()->GetMainFrame()->GetProcess(), profile(), &standard,
&extended);
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel("extended");
fake_phishing_detector_.Reset();
}
}
} // namespace safe_browsing
......@@ -22,6 +22,7 @@
#include "chrome/browser/browser_process.h"
#include "chrome/browser/policy/chrome_browser_policy_connector.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/safe_browsing/client_side_detection_host.h"
#include "chrome/common/pref_names.h"
#include "components/prefs/pref_service.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom.h"
......@@ -166,6 +167,19 @@ bool ClientSideDetectionService::IsPrivateIPAddress(
return !address.IsPubliclyRoutable();
}
void ClientSideDetectionService::AddClientSideDetectionHost(
ClientSideDetectionHost* host) {
csd_hosts_.push_back(host);
}
void ClientSideDetectionService::RemoveClientSideDetectionHost(
ClientSideDetectionHost* host) {
std::vector<ClientSideDetectionHost*>::iterator position =
std::find(csd_hosts_.begin(), csd_hosts_.end(), host);
if (position != csd_hosts_.end())
csd_hosts_.erase(position);
}
void ClientSideDetectionService::OnURLLoaderComplete(
network::SimpleURLLoader* url_loader,
std::unique_ptr<std::string> response_body) {
......@@ -189,35 +203,13 @@ void ClientSideDetectionService::Observe(
DCHECK_EQ(content::NOTIFICATION_RENDERER_PROCESS_CREATED, type);
content::RenderProcessHost* process =
content::Source<content::RenderProcessHost>(source).ptr();
if (process->GetBrowserContext() == profile_)
SendModelToProcess(process);
}
void ClientSideDetectionService::SendModelToProcess(
content::RenderProcessHost* process) {
DCHECK(process->IsInitializedAndNotDead());
DCHECK_EQ(process->GetBrowserContext(), profile_);
std::string model;
if (IsSafeBrowsingEnabled(*profile_->GetPrefs())) {
if (IsExtendedReportingEnabled(*profile_->GetPrefs()) ||
IsEnhancedProtectionEnabled(*profile_->GetPrefs())) {
DVLOG(2) << "Sending phishing model " << model_loader_extended_->name()
<< " to RenderProcessHost @" << process;
model = model_loader_extended_->model_str();
} else {
DVLOG(2) << "Sending phishing model " << model_loader_standard_->name()
<< " to RenderProcessHost @" << process;
model = model_loader_standard_->model_str();
if (process->GetBrowserContext() == profile_) {
for (ClientSideDetectionHost* host : csd_hosts_) {
host->SendModelToRenderFrame(process, profile_,
model_loader_standard_.get(),
model_loader_extended_.get());
}
} else {
DVLOG(2) << "Disabling client-side phishing detection for "
<< "RenderProcessHost @" << process;
}
mojo::Remote<safe_browsing::mojom::PhishingModelSetter> phishing;
process->BindReceiver(phishing.BindNewPipeAndPassReceiver());
phishing->SetPhishingModel(model);
}
void ClientSideDetectionService::SendModelToRenderers() {
......@@ -226,8 +218,13 @@ void ClientSideDetectionService::SendModelToRenderers() {
!i.IsAtEnd(); i.Advance()) {
content::RenderProcessHost* process = i.GetCurrentValue();
if (process->IsInitializedAndNotDead() &&
process->GetBrowserContext() == profile_)
SendModelToProcess(process);
process->GetBrowserContext() == profile_) {
for (ClientSideDetectionHost* host : csd_hosts_) {
host->SendModelToRenderFrame(process, profile_,
model_loader_standard_.get(),
model_loader_extended_.get());
}
}
}
}
......
......@@ -38,10 +38,6 @@
class Profile;
namespace content {
class RenderProcessHost;
}
namespace network {
class SimpleURLLoader;
class SharedURLLoaderFactory;
......@@ -49,6 +45,7 @@ class SharedURLLoaderFactory;
namespace safe_browsing {
class ClientPhishingRequest;
class ClientSideDetectionHost;
// Main service which pushes models to the renderers, responds to classification
// requests. This owns two ModelLoader objects.
......@@ -72,6 +69,9 @@ class ClientSideDetectionService : public content::NotificationObserver,
return enabled_;
}
void AddClientSideDetectionHost(ClientSideDetectionHost* host);
void RemoveClientSideDetectionHost(ClientSideDetectionHost* host);
void OnURLLoaderComplete(network::SimpleURLLoader* url_loader,
std::unique_ptr<std::string> response_body);
......@@ -187,9 +187,6 @@ class ClientSideDetectionService : public content::NotificationObserver,
// trims off the old elements.
int GetNumReports(base::queue<base::Time>* report_times);
// Send the model to the given renderer.
void SendModelToProcess(content::RenderProcessHost* process);
// Returns the URL that will be used for phishing requests.
static GURL GetClientReportUrl(const std::string& report_url);
......@@ -233,6 +230,8 @@ class ClientSideDetectionService : public content::NotificationObserver,
// PrefChangeRegistrar used to track when the Safe Browsing pref changes.
PrefChangeRegistrar pref_change_registrar_;
std::vector<ClientSideDetectionHost*> csd_hosts_;
// Used to asynchronously call the callbacks for
// SendClientReportPhishingRequest.
base::WeakPtrFactory<ClientSideDetectionService> weak_factory_{this};
......
......@@ -78,6 +78,9 @@ class ModelLoader {
// sequence as ScheduleFetch.
virtual void CancelFetcher();
// Only used in tests.
void SetModelStrForTesting(const std::string& model_str) { model_str_ = model_str; }
const std::string& model_str() const { return model_str_; }
const std::string& name() const { return name_; }
......
......@@ -20,6 +20,7 @@
#if BUILDFLAG(SAFE_BROWSING_CSD)
#include "chrome/browser/safe_browsing/client_side_detection_host.h"
#include "chrome/browser/safe_browsing/client_side_detection_service.h"
#include "chrome/browser/safe_browsing/client_side_detection_service_factory.h"
#include "chrome/common/chrome_render_frame.mojom.h"
#include "third_party/blink/public/common/associated_interfaces/associated_interface_provider.h"
......@@ -55,6 +56,8 @@ SafeBrowsingTabObserver::SafeBrowsingTabObserver(
g_browser_process->safe_browsing_service() && csd_service) {
safebrowsing_detection_host_ =
ClientSideDetectionHost::Create(web_contents);
csd_service->AddClientSideDetectionHost(
safebrowsing_detection_host_.get());
}
}
#endif
......@@ -78,6 +81,8 @@ void SafeBrowsingTabObserver::UpdateSafebrowsingDetectionHost() {
if (!safebrowsing_detection_host_.get()) {
safebrowsing_detection_host_ =
ClientSideDetectionHost::Create(web_contents_);
csd_service->AddClientSideDetectionHost(
safebrowsing_detection_host_.get());
}
} else {
safebrowsing_detection_host_.reset();
......
......@@ -12,16 +12,11 @@
#include "chrome/renderer/chrome_content_renderer_client.h"
#include "chrome/renderer/chrome_render_thread_observer.h"
#include "chrome/renderer/media/webrtc_logging_agent_impl.h"
#include "components/safe_browsing/buildflags.h"
#include "components/spellcheck/spellcheck_buildflags.h"
#include "components/visitedlink/renderer/visitedlink_reader.h"
#include "components/web_cache/renderer/web_cache_impl.h"
#include "mojo/public/cpp/bindings/binder_map.h"
#if BUILDFLAG(FULL_SAFE_BROWSING)
#include "chrome/renderer/safe_browsing/phishing_classifier_delegate.h"
#endif
#if BUILDFLAG(ENABLE_SPELLCHECK)
#include "components/spellcheck/renderer/spellcheck.h"
#endif
......@@ -71,12 +66,6 @@ void ExposeChromeRendererInterfacesToBrowser(
binders->Add(base::BindRepeating(&BindWebRTCLoggingAgent, client),
base::SequencedTaskRunnerHandle::Get());
#if BUILDFLAG(FULL_SAFE_BROWSING)
binders->Add(
base::BindRepeating(&safe_browsing::PhishingClassifierFilter::Create),
base::SequencedTaskRunnerHandle::Get());
#endif
#if !defined(OS_ANDROID)
binders->Add(base::BindRepeating(
&performance_manager::V8PerFrameMemoryReporterImpl::Create),
......
......@@ -54,42 +54,6 @@ base::LazyInstance<std::unique_ptr<const safe_browsing::Scorer>>::
} // namespace
// static
void PhishingClassifierFilter::Create(
mojo::PendingReceiver<mojom::PhishingModelSetter> receiver) {
mojo::MakeSelfOwnedReceiver(std::make_unique<PhishingClassifierFilter>(),
std::move(receiver));
}
PhishingClassifierFilter::PhishingClassifierFilter() {}
PhishingClassifierFilter::~PhishingClassifierFilter() {}
void PhishingClassifierFilter::SetPhishingModel(const std::string& model) {
safe_browsing::Scorer* scorer = NULL;
// An empty model string means we should disable client-side phishing
// detection.
if (!model.empty()) {
scorer = safe_browsing::Scorer::Create(model);
if (!scorer) {
DLOG(ERROR) << "Unable to create a PhishingScorer - corrupt model?";
return;
}
}
for (auto* delegate : PhishingClassifierDelegates())
delegate->SetPhishingScorer(scorer);
g_phishing_scorer.Get().reset(scorer);
}
// static
PhishingClassifierDelegate* PhishingClassifierDelegate::Create(
content::RenderFrame* render_frame,
PhishingClassifier* classifier) {
// Private constructor and public static Create() method to facilitate
// stubbing out this class for binary-size reduction purposes.
return new PhishingClassifierDelegate(render_frame, classifier);
}
PhishingClassifierDelegate::PhishingClassifierDelegate(
content::RenderFrame* render_frame,
PhishingClassifier* classifier)
......@@ -118,6 +82,31 @@ PhishingClassifierDelegate::~PhishingClassifierDelegate() {
PhishingClassifierDelegates().erase(this);
}
void PhishingClassifierDelegate::SetPhishingModel(const std::string& model) {
safe_browsing::Scorer* scorer = nullptr;
// An empty model string means we should disable client-side phishing
// detection.
if (!model.empty()) {
scorer = safe_browsing::Scorer::Create(model);
if (!scorer) {
DLOG(ERROR) << "Unable to create a PhishingScorer - corrupt model?";
return;
}
}
for (auto* delegate : PhishingClassifierDelegates())
delegate->SetPhishingScorer(scorer);
g_phishing_scorer.Get().reset(scorer);
}
// static
PhishingClassifierDelegate* PhishingClassifierDelegate::Create(
content::RenderFrame* render_frame,
PhishingClassifier* classifier) {
// Private constructor and public static Create() method to facilitate
// stubbing out this class for binary-size reduction purposes.
return new PhishingClassifierDelegate(render_frame, classifier);
}
void PhishingClassifierDelegate::SetPhishingScorer(
const safe_browsing::Scorer* scorer) {
if (is_classifying_) {
......
......@@ -38,21 +38,6 @@ enum class SBPhishingClassifierEvent {
kMaxValue = kDestructedBeforeClassificationDone,
};
class PhishingClassifierFilter : public mojom::PhishingModelSetter {
public:
PhishingClassifierFilter();
~PhishingClassifierFilter() override;
static void Create(
mojo::PendingReceiver<mojom::PhishingModelSetter> receiver);
private:
// mojom::PhishingModelSetter
void SetPhishingModel(const std::string& model) override;
DISALLOW_COPY_AND_ASSIGN(PhishingClassifierFilter);
};
class PhishingClassifierDelegate : public content::RenderFrameObserver,
public mojom::PhishingDetector {
public:
......@@ -63,6 +48,9 @@ class PhishingClassifierDelegate : public content::RenderFrameObserver,
PhishingClassifier* classifier);
~PhishingClassifierDelegate() override;
// mojom::PhishingDetector
void SetPhishingModel(const std::string& model) override;
// Called by the RenderFrame once there is a phishing scorer available.
// The scorer is passed on to the classifier.
void SetPhishingScorer(const safe_browsing::Scorer* scorer);
......
......@@ -229,6 +229,25 @@ TEST_F(PhishingClassifierDelegateTest, Navigation) {
EXPECT_CALL(*classifier_, CancelPendingClassification());
}
TEST_F(PhishingClassifierDelegateTest, NoPhishingModel) {
ASSERT_FALSE(classifier_->is_ready());
delegate_->SetPhishingModel("");
// The scorer is nullptr so the classifier should still not be ready.
ASSERT_FALSE(classifier_->is_ready());
}
TEST_F(PhishingClassifierDelegateTest, HasPhishingModel) {
ASSERT_FALSE(classifier_->is_ready());
ClientSideModel model;
model.set_max_words_per_term(1);
delegate_->SetPhishingModel(model.SerializeAsString());
ASSERT_TRUE(classifier_->is_ready());
// The delegate will cancel pending classification on destruction.
EXPECT_CALL(*classifier_, CancelPendingClassification());
}
TEST_F(PhishingClassifierDelegateTest, NoScorer) {
// For this test, we'll create the delegate with no scorer available yet.
ASSERT_FALSE(classifier_->is_ready());
......
......@@ -118,7 +118,13 @@ enum PhishingDetectorResult {
};
[EnableIf=full_safe_browsing]
// Interface for setting the CSD model and to start phishing classification.
interface PhishingDetector {
// A classification model for client-side phishing detection.
// The string is an encoded safe_browsing::ClientSideModel protocol buffer, or
// empty to disable client-side phishing detection for this renderer.
SetPhishingModel(string model);
// Tells the renderer to begin phishing detection for the given toplevel URL
// which it has started loading. Returns the serialized request proto and a
// |result| enum to indicate failure. If the URL is phishing the request proto
......@@ -126,10 +132,3 @@ interface PhishingDetector {
StartPhishingDetection(url.mojom.Url url)
=> (PhishingDetectorResult result, string request_proto);
};
interface PhishingModelSetter {
// A classification model for client-side phishing detection.
// The string is an encoded safe_browsing::ClientSideModel protocol buffer, or
// empty to disable client-side phishing detection for this renderer.
SetPhishingModel(string model);
};
......@@ -96,6 +96,8 @@ class TestPhishingDetector : public mojom::PhishingDetector {
mojo::PendingReceiver<mojom::PhishingDetector>(std::move(handle)));
}
void SetPhishingModel(const std::string& model) override {}
void StartPhishingDetection(
const GURL& url,
StartPhishingDetectionCallback callback) override {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment