Commit a4418459 authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Move visual feature scoring to helper thread

This CL calls the GetMatchingVisualTargets in a worker thread, since
the process of scoring visual features takes longer than expected
and was blocking interaction.

Fixed: 1121375
Change-Id: I8f3e33d20f812645a72adfc16fe33241b49a6ff4
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2372869
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Reviewed-by: default avatarBettina Dea <bdea@chromium.org>
Reviewed-by: default avatarVarun Khaneja <vakh@chromium.org>
Cr-Commit-Position: refs/heads/master@{#801907}
parent 1fad9e7f
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "chrome/renderer/safe_browsing/phishing_classifier.h" #include "chrome/renderer/safe_browsing/phishing_classifier.h"
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -15,6 +16,8 @@ ...@@ -15,6 +16,8 @@
#include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/strings/string_util.h" #include "base/strings/string_util.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/threading/thread_task_runner_handle.h" #include "base/threading/thread_task_runner_handle.h"
#include "cc/paint/skia_paint_canvas.h" #include "cc/paint/skia_paint_canvas.h"
#include "chrome/common/url_constants.h" #include "chrome/common/url_constants.h"
...@@ -26,6 +29,7 @@ ...@@ -26,6 +29,7 @@
#include "components/paint_preview/common/paint_preview_tracker.h" #include "components/paint_preview/common/paint_preview_tracker.h"
#include "components/safe_browsing/core/proto/csd.pb.h" #include "components/safe_browsing/core/proto/csd.pb.h"
#include "content/public/renderer/render_frame.h" #include "content/public/renderer/render_frame.h"
#include "content/public/renderer/render_thread.h"
#include "crypto/sha2.h" #include "crypto/sha2.h"
#include "third_party/blink/public/platform/web_url.h" #include "third_party/blink/public/platform/web_url.h"
#include "third_party/blink/public/platform/web_url_request.h" #include "third_party/blink/public/platform/web_url_request.h"
...@@ -165,6 +169,7 @@ void PhishingClassifier::TermExtractionFinished(bool success) { ...@@ -165,6 +169,7 @@ void PhishingClassifier::TermExtractionFinished(bool success) {
} }
void PhishingClassifier::ExtractVisualFeatures() { void PhishingClassifier::ExtractVisualFeatures() {
DCHECK(content::RenderThread::IsMainThread());
base::TimeTicks start_time = base::TimeTicks::Now(); base::TimeTicks start_time = base::TimeTicks::Now();
blink::WebLocalFrame* frame = render_frame_->GetWebFrame(); blink::WebLocalFrame* frame = render_frame_->GetWebFrame();
...@@ -193,6 +198,7 @@ void PhishingClassifier::ExtractVisualFeatures() { ...@@ -193,6 +198,7 @@ void PhishingClassifier::ExtractVisualFeatures() {
} }
void PhishingClassifier::VisualExtractionFinished(bool success) { void PhishingClassifier::VisualExtractionFinished(bool success) {
DCHECK(content::RenderThread::IsMainThread());
if (!success) { if (!success) {
RunFailureCallback(); RunFailureCallback();
return; return;
...@@ -203,32 +209,48 @@ void PhishingClassifier::VisualExtractionFinished(bool success) { ...@@ -203,32 +209,48 @@ void PhishingClassifier::VisualExtractionFinished(bool success) {
// Hash all of the features so that they match the model, then compute // Hash all of the features so that they match the model, then compute
// the score. // the score.
FeatureMap hashed_features; FeatureMap hashed_features;
ClientPhishingRequest verdict; std::unique_ptr<ClientPhishingRequest> verdict =
verdict.set_model_version(scorer_->model_version()); std::make_unique<ClientPhishingRequest>();
verdict.set_url(main_frame->GetDocument().Url().GetString().Utf8()); verdict->set_model_version(scorer_->model_version());
verdict->set_url(main_frame->GetDocument().Url().GetString().Utf8());
for (const auto& it : features_->features()) { for (const auto& it : features_->features()) {
bool result = hashed_features.AddRealFeature( bool result = hashed_features.AddRealFeature(
crypto::SHA256HashString(it.first), it.second); crypto::SHA256HashString(it.first), it.second);
DCHECK(result); DCHECK(result);
ClientPhishingRequest::Feature* feature = verdict.add_feature_map(); ClientPhishingRequest::Feature* feature = verdict->add_feature_map();
feature->set_name(it.first); feature->set_name(it.first);
feature->set_value(it.second); feature->set_value(it.second);
} }
for (const auto& it : *shingle_hashes_) { for (const auto& it : *shingle_hashes_) {
verdict.add_shingle_hashes(it); verdict->add_shingle_hashes(it);
} }
float score = static_cast<float>(scorer_->ComputeScore(hashed_features)); float score = static_cast<float>(scorer_->ComputeScore(hashed_features));
verdict.set_client_score(score); verdict->set_client_score(score);
verdict.set_is_phishing(score >= scorer_->threshold_probability()); verdict->set_is_phishing(score >= scorer_->threshold_probability());
visual_matching_start_ = base::TimeTicks::Now();
// Perform scoring off the UI thread to avoid blocking.
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::WithBaseSyncPrimitives()},
base::BindOnce(&Scorer::GetMatchingVisualTargets,
// base::Unretained is safe because the classification will
// be cancelled before the scorer is removed.
base::Unretained(scorer_), *bitmap_, std::move(verdict)),
base::BindOnce(&PhishingClassifier::OnVisualTargetsMatched,
weak_factory_.GetWeakPtr()));
}
base::TimeTicks visual_matching_start = base::TimeTicks::Now(); void PhishingClassifier::OnVisualTargetsMatched(
if (scorer_->GetMatchingVisualTargets(*bitmap_, &verdict)) { std::unique_ptr<ClientPhishingRequest> verdict) {
verdict.set_is_phishing(true); DCHECK(content::RenderThread::IsMainThread());
if (!verdict->vision_match().empty()) {
verdict->set_is_phishing(true);
} }
base::UmaHistogramTimes("SBClientPhishing.VisualComparisonTime", base::UmaHistogramTimes("SBClientPhishing.VisualComparisonTime",
base::TimeTicks::Now() - visual_matching_start); base::TimeTicks::Now() - visual_matching_start_);
RunCallback(verdict); RunCallback(*verdict);
} }
void PhishingClassifier::RunCallback(const ClientPhishingRequest& verdict) { void PhishingClassifier::RunCallback(const ClientPhishingRequest& verdict) {
......
...@@ -121,18 +121,23 @@ class PhishingClassifier { ...@@ -121,18 +121,23 @@ class PhishingClassifier {
// non-phishy verdict. // non-phishy verdict.
void VisualExtractionFinished(bool success); void VisualExtractionFinished(bool success);
// Callback when visual features have been scored and compared against the
// model.
void OnVisualTargetsMatched(std::unique_ptr<ClientPhishingRequest> verdict);
// Helper method to run the DoneCallback and clear the state. // Helper method to run the DoneCallback and clear the state.
void RunCallback(const ClientPhishingRequest& verdict); void RunCallback(const ClientPhishingRequest& verdict);
// Helper to run the DoneCallback when feature extraction has failed. // Helper to run the DoneCallback when feature extraction has failed.
// This always signals a non-phishy verdict for the page, with kInvalidScore. // This always signals a non-phishy verdict for the page, with
// |kInvalidScore|.
void RunFailureCallback(); void RunFailureCallback();
// Clears the current state of the PhishingClassifier. // Clears the current state of the PhishingClassifier.
void Clear(); void Clear();
content::RenderFrame* render_frame_; // owns us content::RenderFrame* render_frame_; // owns us
const Scorer* scorer_; // owned by the caller const Scorer* scorer_; // owned by the caller
std::unique_ptr<PhishingUrlFeatureExtractor> url_extractor_; std::unique_ptr<PhishingUrlFeatureExtractor> url_extractor_;
std::unique_ptr<PhishingDOMFeatureExtractor> dom_extractor_; std::unique_ptr<PhishingDOMFeatureExtractor> dom_extractor_;
std::unique_ptr<PhishingTermFeatureExtractor> term_extractor_; std::unique_ptr<PhishingTermFeatureExtractor> term_extractor_;
...@@ -144,6 +149,9 @@ class PhishingClassifier { ...@@ -144,6 +149,9 @@ class PhishingClassifier {
std::unique_ptr<SkBitmap> bitmap_; std::unique_ptr<SkBitmap> bitmap_;
DoneCallback done_callback_; DoneCallback done_callback_;
// Used to record the duration of visual feature scoring.
base::TimeTicks visual_matching_start_;
// Used in scheduling BeginFeatureExtraction tasks. // Used in scheduling BeginFeatureExtraction tasks.
// These pointers are invalidated if classification is cancelled. // These pointers are invalidated if classification is cancelled.
base::WeakPtrFactory<PhishingClassifier> weak_factory_{this}; base::WeakPtrFactory<PhishingClassifier> weak_factory_{this};
......
...@@ -127,7 +127,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -127,7 +127,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
page_text, page_text,
base::BindOnce(&PhishingClassifierTest::ClassificationFinished, base::BindOnce(&PhishingClassifierTest::ClassificationFinished,
base::Unretained(this))); base::Unretained(this)));
base::RunLoop().RunUntilIdle(); run_loop_.Run();
} }
// Completion callback for classification. // Completion callback for classification.
...@@ -141,6 +141,8 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -141,6 +141,8 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
screenshot_digest_ = verdict.screenshot_digest(); screenshot_digest_ = verdict.screenshot_digest();
screenshot_phash_ = verdict.screenshot_phash(); screenshot_phash_ = verdict.screenshot_phash();
phash_dimension_size_ = verdict.phash_dimension_size(); phash_dimension_size_ = verdict.phash_dimension_size();
run_loop_.Quit();
} }
void LoadHtml(const GURL& url, const std::string& content) { void LoadHtml(const GURL& url, const std::string& content) {
...@@ -156,6 +158,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -156,6 +158,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
std::string response_content_; std::string response_content_;
std::unique_ptr<Scorer> scorer_; std::unique_ptr<Scorer> scorer_;
std::unique_ptr<PhishingClassifier> classifier_; std::unique_ptr<PhishingClassifier> classifier_;
base::RunLoop run_loop_;
// Features that are in the model. // Features that are in the model.
const std::string url_tld_token_net_; const std::string url_tld_token_net_;
......
...@@ -86,15 +86,14 @@ double Scorer::ComputeScore(const FeatureMap& features) const { ...@@ -86,15 +86,14 @@ double Scorer::ComputeScore(const FeatureMap& features) const {
return LogOdds2Prob(logodds); return LogOdds2Prob(logodds);
} }
bool Scorer::GetMatchingVisualTargets(const SkBitmap& bitmap, std::unique_ptr<ClientPhishingRequest> Scorer::GetMatchingVisualTargets(
ClientPhishingRequest* request) const { const SkBitmap& bitmap,
bool has_match = false; std::unique_ptr<ClientPhishingRequest> request) const {
for (const VisualTarget& target : model_.vision_model().targets()) { for (const VisualTarget& target : model_.vision_model().targets()) {
base::Optional<VisionMatchResult> result = base::Optional<VisionMatchResult> result =
visual_utils::IsVisualMatch(bitmap, target); visual_utils::IsVisualMatch(bitmap, target);
if (result.has_value()) { if (result.has_value()) {
*request->add_vision_match() = result.value(); *request->add_vision_match() = result.value();
has_match = true;
} }
} }
...@@ -112,7 +111,7 @@ bool Scorer::GetMatchingVisualTargets(const SkBitmap& bitmap, ...@@ -112,7 +111,7 @@ bool Scorer::GetMatchingVisualTargets(const SkBitmap& bitmap,
} }
} }
return has_match; return request;
} }
int Scorer::model_version() const { int Scorer::model_version() const {
......
...@@ -42,10 +42,11 @@ class Scorer { ...@@ -42,10 +42,11 @@ class Scorer {
// (range is inclusive on both ends). // (range is inclusive on both ends).
virtual double ComputeScore(const FeatureMap& features) const; virtual double ComputeScore(const FeatureMap& features) const;
// This method matches the given |bitmap| against the visual model. It returns // This method matches the given |bitmap| against the visual model. It
// true if any visual target matches, and populates |request| appropriately. // modifies |request| appropriately, and returns the new request.
virtual bool GetMatchingVisualTargets(const SkBitmap& bitmap, virtual std::unique_ptr<ClientPhishingRequest> GetMatchingVisualTargets(
ClientPhishingRequest* request) const; const SkBitmap& bitmap,
std::unique_ptr<ClientPhishingRequest> request) const;
// Returns the version number of the loaded client model. // Returns the version number of the loaded client model.
int model_version() const; int model_version() const;
......
...@@ -189,10 +189,11 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchOne) { ...@@ -189,10 +189,11 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchOne) {
for (int x = 0; x < 164; x++) for (int x = 0; x < 164; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000; *bitmap_.getAddr32(x, 0) = 0xff000000;
ClientPhishingRequest request; std::unique_ptr<ClientPhishingRequest> request =
scorer->GetMatchingVisualTargets(bitmap_, &request); std::make_unique<ClientPhishingRequest>();
ASSERT_EQ(request.vision_match_size(), 1); request = scorer->GetMatchingVisualTargets(bitmap_, std::move(request));
EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1"); ASSERT_EQ(request->vision_match_size(), 1);
EXPECT_EQ(request->vision_match(0).matched_target_digest(), "target1");
} }
TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) { TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
...@@ -212,11 +213,12 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) { ...@@ -212,11 +213,12 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
for (int x = 168; x < 248; x++) for (int x = 168; x < 248; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000; *bitmap_.getAddr32(x, 0) = 0xff000000;
ClientPhishingRequest request; std::unique_ptr<ClientPhishingRequest> request =
scorer->GetMatchingVisualTargets(bitmap_, &request); std::make_unique<ClientPhishingRequest>();
ASSERT_EQ(request.vision_match_size(), 2); request = scorer->GetMatchingVisualTargets(bitmap_, std::move(request));
EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1"); ASSERT_EQ(request->vision_match_size(), 2);
EXPECT_EQ(request.vision_match(1).matched_target_digest(), "target2"); EXPECT_EQ(request->vision_match(0).matched_target_digest(), "target1");
EXPECT_EQ(request->vision_match(1).matched_target_digest(), "target2");
} }
} // namespace safe_browsing } // namespace safe_browsing
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment