Commit a4418459 authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Move visual feature scoring to helper thread

This CL calls the GetMatchingVisualTargets in a worker thread, since
the process of scoring visual features takes longer than expected
and was blocking interaction.

Fixed: 1121375
Change-Id: I8f3e33d20f812645a72adfc16fe33241b49a6ff4
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2372869
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Reviewed-by: default avatarBettina Dea <bdea@chromium.org>
Reviewed-by: default avatarVarun Khaneja <vakh@chromium.org>
Cr-Commit-Position: refs/heads/master@{#801907}
parent 1fad9e7f
......@@ -4,6 +4,7 @@
#include "chrome/renderer/safe_browsing/phishing_classifier.h"
#include <memory>
#include <string>
#include <utility>
......@@ -15,6 +16,8 @@
#include "base/metrics/histogram_macros.h"
#include "base/single_thread_task_runner.h"
#include "base/strings/string_util.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/threading/thread_task_runner_handle.h"
#include "cc/paint/skia_paint_canvas.h"
#include "chrome/common/url_constants.h"
......@@ -26,6 +29,7 @@
#include "components/paint_preview/common/paint_preview_tracker.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "content/public/renderer/render_frame.h"
#include "content/public/renderer/render_thread.h"
#include "crypto/sha2.h"
#include "third_party/blink/public/platform/web_url.h"
#include "third_party/blink/public/platform/web_url_request.h"
......@@ -165,6 +169,7 @@ void PhishingClassifier::TermExtractionFinished(bool success) {
}
void PhishingClassifier::ExtractVisualFeatures() {
DCHECK(content::RenderThread::IsMainThread());
base::TimeTicks start_time = base::TimeTicks::Now();
blink::WebLocalFrame* frame = render_frame_->GetWebFrame();
......@@ -193,6 +198,7 @@ void PhishingClassifier::ExtractVisualFeatures() {
}
void PhishingClassifier::VisualExtractionFinished(bool success) {
DCHECK(content::RenderThread::IsMainThread());
if (!success) {
RunFailureCallback();
return;
......@@ -203,32 +209,48 @@ void PhishingClassifier::VisualExtractionFinished(bool success) {
// Hash all of the features so that they match the model, then compute
// the score.
FeatureMap hashed_features;
ClientPhishingRequest verdict;
verdict.set_model_version(scorer_->model_version());
verdict.set_url(main_frame->GetDocument().Url().GetString().Utf8());
std::unique_ptr<ClientPhishingRequest> verdict =
std::make_unique<ClientPhishingRequest>();
verdict->set_model_version(scorer_->model_version());
verdict->set_url(main_frame->GetDocument().Url().GetString().Utf8());
for (const auto& it : features_->features()) {
bool result = hashed_features.AddRealFeature(
crypto::SHA256HashString(it.first), it.second);
DCHECK(result);
ClientPhishingRequest::Feature* feature = verdict.add_feature_map();
ClientPhishingRequest::Feature* feature = verdict->add_feature_map();
feature->set_name(it.first);
feature->set_value(it.second);
}
for (const auto& it : *shingle_hashes_) {
verdict.add_shingle_hashes(it);
verdict->add_shingle_hashes(it);
}
float score = static_cast<float>(scorer_->ComputeScore(hashed_features));
verdict.set_client_score(score);
verdict.set_is_phishing(score >= scorer_->threshold_probability());
verdict->set_client_score(score);
verdict->set_is_phishing(score >= scorer_->threshold_probability());
visual_matching_start_ = base::TimeTicks::Now();
// Perform scoring off the UI thread to avoid blocking.
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::WithBaseSyncPrimitives()},
base::BindOnce(&Scorer::GetMatchingVisualTargets,
// base::Unretained is safe because the classification will
// be cancelled before the scorer is removed.
base::Unretained(scorer_), *bitmap_, std::move(verdict)),
base::BindOnce(&PhishingClassifier::OnVisualTargetsMatched,
weak_factory_.GetWeakPtr()));
}
base::TimeTicks visual_matching_start = base::TimeTicks::Now();
if (scorer_->GetMatchingVisualTargets(*bitmap_, &verdict)) {
verdict.set_is_phishing(true);
void PhishingClassifier::OnVisualTargetsMatched(
std::unique_ptr<ClientPhishingRequest> verdict) {
DCHECK(content::RenderThread::IsMainThread());
if (!verdict->vision_match().empty()) {
verdict->set_is_phishing(true);
}
base::UmaHistogramTimes("SBClientPhishing.VisualComparisonTime",
base::TimeTicks::Now() - visual_matching_start);
base::TimeTicks::Now() - visual_matching_start_);
RunCallback(verdict);
RunCallback(*verdict);
}
void PhishingClassifier::RunCallback(const ClientPhishingRequest& verdict) {
......
......@@ -121,18 +121,23 @@ class PhishingClassifier {
// non-phishy verdict.
void VisualExtractionFinished(bool success);
// Callback when visual features have been scored and compared against the
// model.
void OnVisualTargetsMatched(std::unique_ptr<ClientPhishingRequest> verdict);
// Helper method to run the DoneCallback and clear the state.
void RunCallback(const ClientPhishingRequest& verdict);
// Helper to run the DoneCallback when feature extraction has failed.
// This always signals a non-phishy verdict for the page, with kInvalidScore.
// This always signals a non-phishy verdict for the page, with
// |kInvalidScore|.
void RunFailureCallback();
// Clears the current state of the PhishingClassifier.
void Clear();
content::RenderFrame* render_frame_; // owns us
const Scorer* scorer_; // owned by the caller
const Scorer* scorer_; // owned by the caller
std::unique_ptr<PhishingUrlFeatureExtractor> url_extractor_;
std::unique_ptr<PhishingDOMFeatureExtractor> dom_extractor_;
std::unique_ptr<PhishingTermFeatureExtractor> term_extractor_;
......@@ -144,6 +149,9 @@ class PhishingClassifier {
std::unique_ptr<SkBitmap> bitmap_;
DoneCallback done_callback_;
// Used to record the duration of visual feature scoring.
base::TimeTicks visual_matching_start_;
// Used in scheduling BeginFeatureExtraction tasks.
// These pointers are invalidated if classification is cancelled.
base::WeakPtrFactory<PhishingClassifier> weak_factory_{this};
......
......@@ -127,7 +127,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
page_text,
base::BindOnce(&PhishingClassifierTest::ClassificationFinished,
base::Unretained(this)));
base::RunLoop().RunUntilIdle();
run_loop_.Run();
}
// Completion callback for classification.
......@@ -141,6 +141,8 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
screenshot_digest_ = verdict.screenshot_digest();
screenshot_phash_ = verdict.screenshot_phash();
phash_dimension_size_ = verdict.phash_dimension_size();
run_loop_.Quit();
}
void LoadHtml(const GURL& url, const std::string& content) {
......@@ -156,6 +158,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
std::string response_content_;
std::unique_ptr<Scorer> scorer_;
std::unique_ptr<PhishingClassifier> classifier_;
base::RunLoop run_loop_;
// Features that are in the model.
const std::string url_tld_token_net_;
......
......@@ -86,15 +86,14 @@ double Scorer::ComputeScore(const FeatureMap& features) const {
return LogOdds2Prob(logodds);
}
bool Scorer::GetMatchingVisualTargets(const SkBitmap& bitmap,
ClientPhishingRequest* request) const {
bool has_match = false;
std::unique_ptr<ClientPhishingRequest> Scorer::GetMatchingVisualTargets(
const SkBitmap& bitmap,
std::unique_ptr<ClientPhishingRequest> request) const {
for (const VisualTarget& target : model_.vision_model().targets()) {
base::Optional<VisionMatchResult> result =
visual_utils::IsVisualMatch(bitmap, target);
if (result.has_value()) {
*request->add_vision_match() = result.value();
has_match = true;
}
}
......@@ -112,7 +111,7 @@ bool Scorer::GetMatchingVisualTargets(const SkBitmap& bitmap,
}
}
return has_match;
return request;
}
int Scorer::model_version() const {
......
......@@ -42,10 +42,11 @@ class Scorer {
// (range is inclusive on both ends).
virtual double ComputeScore(const FeatureMap& features) const;
// This method matches the given |bitmap| against the visual model. It returns
// true if any visual target matches, and populates |request| appropriately.
virtual bool GetMatchingVisualTargets(const SkBitmap& bitmap,
ClientPhishingRequest* request) const;
// This method matches the given |bitmap| against the visual model. It
// modifies |request| appropriately, and returns the new request.
virtual std::unique_ptr<ClientPhishingRequest> GetMatchingVisualTargets(
const SkBitmap& bitmap,
std::unique_ptr<ClientPhishingRequest> request) const;
// Returns the version number of the loaded client model.
int model_version() const;
......
......@@ -189,10 +189,11 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchOne) {
for (int x = 0; x < 164; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000;
ClientPhishingRequest request;
scorer->GetMatchingVisualTargets(bitmap_, &request);
ASSERT_EQ(request.vision_match_size(), 1);
EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1");
std::unique_ptr<ClientPhishingRequest> request =
std::make_unique<ClientPhishingRequest>();
request = scorer->GetMatchingVisualTargets(bitmap_, std::move(request));
ASSERT_EQ(request->vision_match_size(), 1);
EXPECT_EQ(request->vision_match(0).matched_target_digest(), "target1");
}
TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
......@@ -212,11 +213,12 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
for (int x = 168; x < 248; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000;
ClientPhishingRequest request;
scorer->GetMatchingVisualTargets(bitmap_, &request);
ASSERT_EQ(request.vision_match_size(), 2);
EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1");
EXPECT_EQ(request.vision_match(1).matched_target_digest(), "target2");
std::unique_ptr<ClientPhishingRequest> request =
std::make_unique<ClientPhishingRequest>();
request = scorer->GetMatchingVisualTargets(bitmap_, std::move(request));
ASSERT_EQ(request->vision_match_size(), 2);
EXPECT_EQ(request->vision_match(0).matched_target_digest(), "target1");
EXPECT_EQ(request->vision_match(1).matched_target_digest(), "target2");
}
} // namespace safe_browsing
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment