Commit 0c2e6819 authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Revert "Move visual feature scoring to helper thread"

This reverts commit a4418459.

Reason for revert: Causing crashes (https://bugs.chromium.org/p/chromium/issues/detail?id=1122534 and https://bugs.chromium.org/p/chromium/issues/detail?id=1122732)

Original change's description:
> Move visual feature scoring to helper thread
> 
> This CL calls the GetMatchingVisualTargets in a worker thread, since
> the process of scoring visual features takes longer than expected
> and was blocking interaction.
> 
> Fixed: 1121375
> Change-Id: I8f3e33d20f812645a72adfc16fe33241b49a6ff4
> Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2372869
> Commit-Queue: Daniel Rubery <drubery@chromium.org>
> Reviewed-by: Bettina Dea <bdea@chromium.org>
> Reviewed-by: Varun Khaneja <vakh@chromium.org>
> Cr-Commit-Position: refs/heads/master@{#801907}

TBR=vakh@chromium.org,bdea@chromium.org,drubery@chromium.org

# Not skipping CQ checks because original CL landed > 1 day ago.

Change-Id: I02048f19ab8aee3a39a413b6f3cea1e1c4f696de
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2385782Reviewed-by: default avatarDaniel Rubery <drubery@chromium.org>
Reviewed-by: default avatarVarun Khaneja <vakh@chromium.org>
Reviewed-by: default avatarBettina Dea <bdea@chromium.org>
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Cr-Commit-Position: refs/heads/master@{#803219}
parent 7b4ab7f5
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include "chrome/renderer/safe_browsing/phishing_classifier.h" #include "chrome/renderer/safe_browsing/phishing_classifier.h"
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -16,8 +15,6 @@ ...@@ -16,8 +15,6 @@
#include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/strings/string_util.h" #include "base/strings/string_util.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/threading/thread_task_runner_handle.h" #include "base/threading/thread_task_runner_handle.h"
#include "cc/paint/skia_paint_canvas.h" #include "cc/paint/skia_paint_canvas.h"
#include "chrome/common/url_constants.h" #include "chrome/common/url_constants.h"
...@@ -29,7 +26,6 @@ ...@@ -29,7 +26,6 @@
#include "components/paint_preview/common/paint_preview_tracker.h" #include "components/paint_preview/common/paint_preview_tracker.h"
#include "components/safe_browsing/core/proto/csd.pb.h" #include "components/safe_browsing/core/proto/csd.pb.h"
#include "content/public/renderer/render_frame.h" #include "content/public/renderer/render_frame.h"
#include "content/public/renderer/render_thread.h"
#include "crypto/sha2.h" #include "crypto/sha2.h"
#include "third_party/blink/public/platform/web_url.h" #include "third_party/blink/public/platform/web_url.h"
#include "third_party/blink/public/platform/web_url_request.h" #include "third_party/blink/public/platform/web_url_request.h"
...@@ -169,7 +165,6 @@ void PhishingClassifier::TermExtractionFinished(bool success) { ...@@ -169,7 +165,6 @@ void PhishingClassifier::TermExtractionFinished(bool success) {
} }
void PhishingClassifier::ExtractVisualFeatures() { void PhishingClassifier::ExtractVisualFeatures() {
DCHECK(content::RenderThread::IsMainThread());
base::TimeTicks start_time = base::TimeTicks::Now(); base::TimeTicks start_time = base::TimeTicks::Now();
blink::WebLocalFrame* frame = render_frame_->GetWebFrame(); blink::WebLocalFrame* frame = render_frame_->GetWebFrame();
...@@ -198,7 +193,6 @@ void PhishingClassifier::ExtractVisualFeatures() { ...@@ -198,7 +193,6 @@ void PhishingClassifier::ExtractVisualFeatures() {
} }
void PhishingClassifier::VisualExtractionFinished(bool success) { void PhishingClassifier::VisualExtractionFinished(bool success) {
DCHECK(content::RenderThread::IsMainThread());
if (!success) { if (!success) {
RunFailureCallback(); RunFailureCallback();
return; return;
...@@ -209,48 +203,32 @@ void PhishingClassifier::VisualExtractionFinished(bool success) { ...@@ -209,48 +203,32 @@ void PhishingClassifier::VisualExtractionFinished(bool success) {
// Hash all of the features so that they match the model, then compute // Hash all of the features so that they match the model, then compute
// the score. // the score.
FeatureMap hashed_features; FeatureMap hashed_features;
std::unique_ptr<ClientPhishingRequest> verdict = ClientPhishingRequest verdict;
std::make_unique<ClientPhishingRequest>(); verdict.set_model_version(scorer_->model_version());
verdict->set_model_version(scorer_->model_version()); verdict.set_url(main_frame->GetDocument().Url().GetString().Utf8());
verdict->set_url(main_frame->GetDocument().Url().GetString().Utf8());
for (const auto& it : features_->features()) { for (const auto& it : features_->features()) {
bool result = hashed_features.AddRealFeature( bool result = hashed_features.AddRealFeature(
crypto::SHA256HashString(it.first), it.second); crypto::SHA256HashString(it.first), it.second);
DCHECK(result); DCHECK(result);
ClientPhishingRequest::Feature* feature = verdict->add_feature_map(); ClientPhishingRequest::Feature* feature = verdict.add_feature_map();
feature->set_name(it.first); feature->set_name(it.first);
feature->set_value(it.second); feature->set_value(it.second);
} }
for (const auto& it : *shingle_hashes_) { for (const auto& it : *shingle_hashes_) {
verdict->add_shingle_hashes(it); verdict.add_shingle_hashes(it);
} }
float score = static_cast<float>(scorer_->ComputeScore(hashed_features)); float score = static_cast<float>(scorer_->ComputeScore(hashed_features));
verdict->set_client_score(score); verdict.set_client_score(score);
verdict->set_is_phishing(score >= scorer_->threshold_probability()); verdict.set_is_phishing(score >= scorer_->threshold_probability());
visual_matching_start_ = base::TimeTicks::Now();
// Perform scoring off the UI thread to avoid blocking.
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::WithBaseSyncPrimitives()},
base::BindOnce(&Scorer::GetMatchingVisualTargets,
// base::Unretained is safe because the classification will
// be cancelled before the scorer is removed.
base::Unretained(scorer_), *bitmap_, std::move(verdict)),
base::BindOnce(&PhishingClassifier::OnVisualTargetsMatched,
weak_factory_.GetWeakPtr()));
}
void PhishingClassifier::OnVisualTargetsMatched( base::TimeTicks visual_matching_start = base::TimeTicks::Now();
std::unique_ptr<ClientPhishingRequest> verdict) { if (scorer_->GetMatchingVisualTargets(*bitmap_, &verdict)) {
DCHECK(content::RenderThread::IsMainThread()); verdict.set_is_phishing(true);
if (!verdict->vision_match().empty()) {
verdict->set_is_phishing(true);
} }
base::UmaHistogramTimes("SBClientPhishing.VisualComparisonTime", base::UmaHistogramTimes("SBClientPhishing.VisualComparisonTime",
base::TimeTicks::Now() - visual_matching_start_); base::TimeTicks::Now() - visual_matching_start);
RunCallback(*verdict); RunCallback(verdict);
} }
void PhishingClassifier::RunCallback(const ClientPhishingRequest& verdict) { void PhishingClassifier::RunCallback(const ClientPhishingRequest& verdict) {
......
...@@ -121,23 +121,18 @@ class PhishingClassifier { ...@@ -121,23 +121,18 @@ class PhishingClassifier {
// non-phishy verdict. // non-phishy verdict.
void VisualExtractionFinished(bool success); void VisualExtractionFinished(bool success);
// Callback when visual features have been scored and compared against the
// model.
void OnVisualTargetsMatched(std::unique_ptr<ClientPhishingRequest> verdict);
// Helper method to run the DoneCallback and clear the state. // Helper method to run the DoneCallback and clear the state.
void RunCallback(const ClientPhishingRequest& verdict); void RunCallback(const ClientPhishingRequest& verdict);
// Helper to run the DoneCallback when feature extraction has failed. // Helper to run the DoneCallback when feature extraction has failed.
// This always signals a non-phishy verdict for the page, with // This always signals a non-phishy verdict for the page, with kInvalidScore.
// |kInvalidScore|.
void RunFailureCallback(); void RunFailureCallback();
// Clears the current state of the PhishingClassifier. // Clears the current state of the PhishingClassifier.
void Clear(); void Clear();
content::RenderFrame* render_frame_; // owns us content::RenderFrame* render_frame_; // owns us
const Scorer* scorer_; // owned by the caller const Scorer* scorer_; // owned by the caller
std::unique_ptr<PhishingUrlFeatureExtractor> url_extractor_; std::unique_ptr<PhishingUrlFeatureExtractor> url_extractor_;
std::unique_ptr<PhishingDOMFeatureExtractor> dom_extractor_; std::unique_ptr<PhishingDOMFeatureExtractor> dom_extractor_;
std::unique_ptr<PhishingTermFeatureExtractor> term_extractor_; std::unique_ptr<PhishingTermFeatureExtractor> term_extractor_;
...@@ -149,9 +144,6 @@ class PhishingClassifier { ...@@ -149,9 +144,6 @@ class PhishingClassifier {
std::unique_ptr<SkBitmap> bitmap_; std::unique_ptr<SkBitmap> bitmap_;
DoneCallback done_callback_; DoneCallback done_callback_;
// Used to record the duration of visual feature scoring.
base::TimeTicks visual_matching_start_;
// Used in scheduling BeginFeatureExtraction tasks. // Used in scheduling BeginFeatureExtraction tasks.
// These pointers are invalidated if classification is cancelled. // These pointers are invalidated if classification is cancelled.
base::WeakPtrFactory<PhishingClassifier> weak_factory_{this}; base::WeakPtrFactory<PhishingClassifier> weak_factory_{this};
......
...@@ -127,7 +127,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -127,7 +127,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
page_text, page_text,
base::BindOnce(&PhishingClassifierTest::ClassificationFinished, base::BindOnce(&PhishingClassifierTest::ClassificationFinished,
base::Unretained(this))); base::Unretained(this)));
run_loop_.Run(); base::RunLoop().RunUntilIdle();
} }
// Completion callback for classification. // Completion callback for classification.
...@@ -141,8 +141,6 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -141,8 +141,6 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
screenshot_digest_ = verdict.screenshot_digest(); screenshot_digest_ = verdict.screenshot_digest();
screenshot_phash_ = verdict.screenshot_phash(); screenshot_phash_ = verdict.screenshot_phash();
phash_dimension_size_ = verdict.phash_dimension_size(); phash_dimension_size_ = verdict.phash_dimension_size();
run_loop_.Quit();
} }
void LoadHtml(const GURL& url, const std::string& content) { void LoadHtml(const GURL& url, const std::string& content) {
...@@ -158,7 +156,6 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -158,7 +156,6 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
std::string response_content_; std::string response_content_;
std::unique_ptr<Scorer> scorer_; std::unique_ptr<Scorer> scorer_;
std::unique_ptr<PhishingClassifier> classifier_; std::unique_ptr<PhishingClassifier> classifier_;
base::RunLoop run_loop_;
// Features that are in the model. // Features that are in the model.
const std::string url_tld_token_net_; const std::string url_tld_token_net_;
......
...@@ -86,14 +86,15 @@ double Scorer::ComputeScore(const FeatureMap& features) const { ...@@ -86,14 +86,15 @@ double Scorer::ComputeScore(const FeatureMap& features) const {
return LogOdds2Prob(logodds); return LogOdds2Prob(logodds);
} }
std::unique_ptr<ClientPhishingRequest> Scorer::GetMatchingVisualTargets( bool Scorer::GetMatchingVisualTargets(const SkBitmap& bitmap,
const SkBitmap& bitmap, ClientPhishingRequest* request) const {
std::unique_ptr<ClientPhishingRequest> request) const { bool has_match = false;
for (const VisualTarget& target : model_.vision_model().targets()) { for (const VisualTarget& target : model_.vision_model().targets()) {
base::Optional<VisionMatchResult> result = base::Optional<VisionMatchResult> result =
visual_utils::IsVisualMatch(bitmap, target); visual_utils::IsVisualMatch(bitmap, target);
if (result.has_value()) { if (result.has_value()) {
*request->add_vision_match() = result.value(); *request->add_vision_match() = result.value();
has_match = true;
} }
} }
...@@ -111,7 +112,7 @@ std::unique_ptr<ClientPhishingRequest> Scorer::GetMatchingVisualTargets( ...@@ -111,7 +112,7 @@ std::unique_ptr<ClientPhishingRequest> Scorer::GetMatchingVisualTargets(
} }
} }
return request; return has_match;
} }
int Scorer::model_version() const { int Scorer::model_version() const {
......
...@@ -42,11 +42,10 @@ class Scorer { ...@@ -42,11 +42,10 @@ class Scorer {
// (range is inclusive on both ends). // (range is inclusive on both ends).
virtual double ComputeScore(const FeatureMap& features) const; virtual double ComputeScore(const FeatureMap& features) const;
// This method matches the given |bitmap| against the visual model. It // This method matches the given |bitmap| against the visual model. It returns
// modifies |request| appropriately, and returns the new request. // true if any visual target matches, and populates |request| appropriately.
virtual std::unique_ptr<ClientPhishingRequest> GetMatchingVisualTargets( virtual bool GetMatchingVisualTargets(const SkBitmap& bitmap,
const SkBitmap& bitmap, ClientPhishingRequest* request) const;
std::unique_ptr<ClientPhishingRequest> request) const;
// Returns the version number of the loaded client model. // Returns the version number of the loaded client model.
int model_version() const; int model_version() const;
......
...@@ -189,11 +189,10 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchOne) { ...@@ -189,11 +189,10 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchOne) {
for (int x = 0; x < 164; x++) for (int x = 0; x < 164; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000; *bitmap_.getAddr32(x, 0) = 0xff000000;
std::unique_ptr<ClientPhishingRequest> request = ClientPhishingRequest request;
std::make_unique<ClientPhishingRequest>(); scorer->GetMatchingVisualTargets(bitmap_, &request);
request = scorer->GetMatchingVisualTargets(bitmap_, std::move(request)); ASSERT_EQ(request.vision_match_size(), 1);
ASSERT_EQ(request->vision_match_size(), 1); EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1");
EXPECT_EQ(request->vision_match(0).matched_target_digest(), "target1");
} }
TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) { TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
...@@ -213,12 +212,11 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) { ...@@ -213,12 +212,11 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
for (int x = 168; x < 248; x++) for (int x = 168; x < 248; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000; *bitmap_.getAddr32(x, 0) = 0xff000000;
std::unique_ptr<ClientPhishingRequest> request = ClientPhishingRequest request;
std::make_unique<ClientPhishingRequest>(); scorer->GetMatchingVisualTargets(bitmap_, &request);
request = scorer->GetMatchingVisualTargets(bitmap_, std::move(request)); ASSERT_EQ(request.vision_match_size(), 2);
ASSERT_EQ(request->vision_match_size(), 2); EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1");
EXPECT_EQ(request->vision_match(0).matched_target_digest(), "target1"); EXPECT_EQ(request.vision_match(1).matched_target_digest(), "target2");
EXPECT_EQ(request->vision_match(1).matched_target_digest(), "target2");
} }
} // namespace safe_browsing } // namespace safe_browsing
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment