Commit 3125314c authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Reland "Move visual feature scoring to helper thread"

This is a reland of a4418459. The
original CL was causing crashes due to UaF on the Scorer object.
This CL fixes that issue by performing the thread hop within Scorer.
Since the returning callback on the PhishingClassifier is guarded by
a WeakPtr, this should be more safe.

Original change's description:
> Move visual feature scoring to helper thread
>
> This CL calls the GetMatchingVisualTargets in a worker thread, since
> the process of scoring visual features takes longer than expected
> and was blocking interaction.
>
> Fixed: 1121375
> Change-Id: I8f3e33d20f812645a72adfc16fe33241b49a6ff4
> Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2372869
> Commit-Queue: Daniel Rubery <drubery@chromium.org>
> Reviewed-by: Bettina Dea <bdea@chromium.org>
> Reviewed-by: Varun Khaneja <vakh@chromium.org>
> Cr-Commit-Position: refs/heads/master@{#801907}

BUG=1121375,1122534

Change-Id: Ifedb94ad8b613a2ba7377bb5bc7795e657a94512
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2386380Reviewed-by: default avatarVarun Khaneja <vakh@chromium.org>
Reviewed-by: default avatarBettina Dea <bdea@chromium.org>
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Cr-Commit-Position: refs/heads/master@{#805408}
parent 091c102c
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "chrome/renderer/safe_browsing/phishing_classifier.h" #include "chrome/renderer/safe_browsing/phishing_classifier.h"
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -15,6 +16,8 @@ ...@@ -15,6 +16,8 @@
#include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/strings/string_util.h" #include "base/strings/string_util.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/threading/thread_task_runner_handle.h" #include "base/threading/thread_task_runner_handle.h"
#include "cc/paint/skia_paint_canvas.h" #include "cc/paint/skia_paint_canvas.h"
#include "chrome/common/url_constants.h" #include "chrome/common/url_constants.h"
...@@ -26,6 +29,7 @@ ...@@ -26,6 +29,7 @@
#include "components/paint_preview/common/paint_preview_tracker.h" #include "components/paint_preview/common/paint_preview_tracker.h"
#include "components/safe_browsing/core/proto/csd.pb.h" #include "components/safe_browsing/core/proto/csd.pb.h"
#include "content/public/renderer/render_frame.h" #include "content/public/renderer/render_frame.h"
#include "content/public/renderer/render_thread.h"
#include "crypto/sha2.h" #include "crypto/sha2.h"
#include "third_party/blink/public/platform/web_url.h" #include "third_party/blink/public/platform/web_url.h"
#include "third_party/blink/public/platform/web_url_request.h" #include "third_party/blink/public/platform/web_url_request.h"
...@@ -165,6 +169,7 @@ void PhishingClassifier::TermExtractionFinished(bool success) { ...@@ -165,6 +169,7 @@ void PhishingClassifier::TermExtractionFinished(bool success) {
} }
void PhishingClassifier::ExtractVisualFeatures() { void PhishingClassifier::ExtractVisualFeatures() {
DCHECK(content::RenderThread::IsMainThread());
base::TimeTicks start_time = base::TimeTicks::Now(); base::TimeTicks start_time = base::TimeTicks::Now();
blink::WebLocalFrame* frame = render_frame_->GetWebFrame(); blink::WebLocalFrame* frame = render_frame_->GetWebFrame();
...@@ -193,6 +198,7 @@ void PhishingClassifier::ExtractVisualFeatures() { ...@@ -193,6 +198,7 @@ void PhishingClassifier::ExtractVisualFeatures() {
} }
void PhishingClassifier::VisualExtractionFinished(bool success) { void PhishingClassifier::VisualExtractionFinished(bool success) {
DCHECK(content::RenderThread::IsMainThread());
if (!success) { if (!success) {
RunFailureCallback(); RunFailureCallback();
return; return;
...@@ -203,32 +209,43 @@ void PhishingClassifier::VisualExtractionFinished(bool success) { ...@@ -203,32 +209,43 @@ void PhishingClassifier::VisualExtractionFinished(bool success) {
// Hash all of the features so that they match the model, then compute // Hash all of the features so that they match the model, then compute
// the score. // the score.
FeatureMap hashed_features; FeatureMap hashed_features;
ClientPhishingRequest verdict; std::unique_ptr<ClientPhishingRequest> verdict =
verdict.set_model_version(scorer_->model_version()); std::make_unique<ClientPhishingRequest>();
verdict.set_url(main_frame->GetDocument().Url().GetString().Utf8()); verdict->set_model_version(scorer_->model_version());
verdict->set_url(main_frame->GetDocument().Url().GetString().Utf8());
for (const auto& it : features_->features()) { for (const auto& it : features_->features()) {
bool result = hashed_features.AddRealFeature( bool result = hashed_features.AddRealFeature(
crypto::SHA256HashString(it.first), it.second); crypto::SHA256HashString(it.first), it.second);
DCHECK(result); DCHECK(result);
ClientPhishingRequest::Feature* feature = verdict.add_feature_map(); ClientPhishingRequest::Feature* feature = verdict->add_feature_map();
feature->set_name(it.first); feature->set_name(it.first);
feature->set_value(it.second); feature->set_value(it.second);
} }
for (const auto& it : *shingle_hashes_) { for (const auto& it : *shingle_hashes_) {
verdict.add_shingle_hashes(it); verdict->add_shingle_hashes(it);
} }
float score = static_cast<float>(scorer_->ComputeScore(hashed_features)); float score = static_cast<float>(scorer_->ComputeScore(hashed_features));
verdict.set_client_score(score); verdict->set_client_score(score);
verdict.set_is_phishing(score >= scorer_->threshold_probability()); verdict->set_is_phishing(score >= scorer_->threshold_probability());
visual_matching_start_ = base::TimeTicks::Now();
scorer_->GetMatchingVisualTargets(
*bitmap_, std::move(verdict),
base::BindOnce(&PhishingClassifier::OnVisualTargetsMatched,
weak_factory_.GetWeakPtr()));
}
base::TimeTicks visual_matching_start = base::TimeTicks::Now(); void PhishingClassifier::OnVisualTargetsMatched(
if (scorer_->GetMatchingVisualTargets(*bitmap_, &verdict)) { std::unique_ptr<ClientPhishingRequest> verdict) {
verdict.set_is_phishing(true); DCHECK(content::RenderThread::IsMainThread());
if (!verdict->vision_match().empty()) {
verdict->set_is_phishing(true);
} }
base::UmaHistogramTimes("SBClientPhishing.VisualComparisonTime", base::UmaHistogramTimes("SBClientPhishing.VisualComparisonTime",
base::TimeTicks::Now() - visual_matching_start); base::TimeTicks::Now() - visual_matching_start_);
RunCallback(verdict); RunCallback(*verdict);
} }
void PhishingClassifier::RunCallback(const ClientPhishingRequest& verdict) { void PhishingClassifier::RunCallback(const ClientPhishingRequest& verdict) {
......
...@@ -121,11 +121,16 @@ class PhishingClassifier { ...@@ -121,11 +121,16 @@ class PhishingClassifier {
// non-phishy verdict. // non-phishy verdict.
void VisualExtractionFinished(bool success); void VisualExtractionFinished(bool success);
// Callback when visual features have been scored and compared against the
// model.
void OnVisualTargetsMatched(std::unique_ptr<ClientPhishingRequest> verdict);
// Helper method to run the DoneCallback and clear the state. // Helper method to run the DoneCallback and clear the state.
void RunCallback(const ClientPhishingRequest& verdict); void RunCallback(const ClientPhishingRequest& verdict);
// Helper to run the DoneCallback when feature extraction has failed. // Helper to run the DoneCallback when feature extraction has failed.
// This always signals a non-phishy verdict for the page, with kInvalidScore. // This always signals a non-phishy verdict for the page, with
// |kInvalidScore|.
void RunFailureCallback(); void RunFailureCallback();
// Clears the current state of the PhishingClassifier. // Clears the current state of the PhishingClassifier.
...@@ -144,6 +149,9 @@ class PhishingClassifier { ...@@ -144,6 +149,9 @@ class PhishingClassifier {
std::unique_ptr<SkBitmap> bitmap_; std::unique_ptr<SkBitmap> bitmap_;
DoneCallback done_callback_; DoneCallback done_callback_;
// Used to record the duration of visual feature scoring.
base::TimeTicks visual_matching_start_;
// Used in scheduling BeginFeatureExtraction tasks. // Used in scheduling BeginFeatureExtraction tasks.
// These pointers are invalidated if classification is cancelled. // These pointers are invalidated if classification is cancelled.
base::WeakPtrFactory<PhishingClassifier> weak_factory_{this}; base::WeakPtrFactory<PhishingClassifier> weak_factory_{this};
......
...@@ -127,7 +127,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -127,7 +127,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
page_text, page_text,
base::BindOnce(&PhishingClassifierTest::ClassificationFinished, base::BindOnce(&PhishingClassifierTest::ClassificationFinished,
base::Unretained(this))); base::Unretained(this)));
base::RunLoop().RunUntilIdle(); run_loop_.Run();
} }
// Completion callback for classification. // Completion callback for classification.
...@@ -141,6 +141,8 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -141,6 +141,8 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
screenshot_digest_ = verdict.screenshot_digest(); screenshot_digest_ = verdict.screenshot_digest();
screenshot_phash_ = verdict.screenshot_phash(); screenshot_phash_ = verdict.screenshot_phash();
phash_dimension_size_ = verdict.phash_dimension_size(); phash_dimension_size_ = verdict.phash_dimension_size();
run_loop_.Quit();
} }
void LoadHtml(const GURL& url, const std::string& content) { void LoadHtml(const GURL& url, const std::string& content) {
...@@ -156,6 +158,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest { ...@@ -156,6 +158,7 @@ class PhishingClassifierTest : public ChromeRenderViewTest {
std::string response_content_; std::string response_content_;
std::unique_ptr<Scorer> scorer_; std::unique_ptr<Scorer> scorer_;
std::unique_ptr<PhishingClassifier> classifier_; std::unique_ptr<PhishingClassifier> classifier_;
base::RunLoop run_loop_;
// Features that are in the model. // Features that are in the model.
const std::string url_tld_token_net_; const std::string url_tld_token_net_;
......
...@@ -13,10 +13,17 @@ ...@@ -13,10 +13,17 @@
#include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros.h"
#include "base/strings/string_number_conversions.h" #include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "chrome/renderer/safe_browsing/features.h" #include "chrome/renderer/safe_browsing/features.h"
#include "components/safe_browsing/content/password_protection/visual_utils.h" #include "components/safe_browsing/content/password_protection/visual_utils.h"
#include "components/safe_browsing/core/proto/client_model.pb.h" #include "components/safe_browsing/core/proto/client_model.pb.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "content/public/renderer/render_thread.h"
#include "crypto/sha2.h" #include "crypto/sha2.h"
#include "third_party/skia/include/core/SkBitmap.h"
namespace safe_browsing {
namespace { namespace {
// Enum used to keep stats about the status of the Scorer creation. // Enum used to keep stats about the status of the Scorer creation.
...@@ -35,9 +42,38 @@ void RecordScorerCreationStatus(ScorerCreationStatus status) { ...@@ -35,9 +42,38 @@ void RecordScorerCreationStatus(ScorerCreationStatus status) {
status, status,
SCORER_STATUS_MAX); SCORER_STATUS_MAX);
} }
} // namespace
namespace safe_browsing { std::unique_ptr<ClientPhishingRequest> GetMatchingVisualTargetsHelper(
const SkBitmap& bitmap,
const ClientSideModel& model,
std::unique_ptr<ClientPhishingRequest> request) {
DCHECK(!content::RenderThread::IsMainThread());
for (const VisualTarget& target : model.vision_model().targets()) {
base::Optional<VisionMatchResult> result =
visual_utils::IsVisualMatch(bitmap, target);
if (result.has_value()) {
*request->add_vision_match() = result.value();
}
}
if (model.has_vision_model()) {
// Populate these fields for telementry purposes. They will be filtered in
// the browser process if they are not needed.
VisualFeatures::BlurredImage blurred_image;
if (visual_utils::GetBlurredImage(bitmap, &blurred_image)) {
std::string raw_digest = crypto::SHA256HashString(blurred_image.data());
request->set_screenshot_digest(
base::HexEncode(raw_digest.data(), raw_digest.size()));
request->set_screenshot_phash(
visual_utils::GetHashFromBlurredImage(blurred_image));
request->set_phash_dimension_size(48);
}
}
return request;
}
} // namespace
// Helper function which converts log odds to a probability in the range // Helper function which converts log odds to a probability in the range
// [0.0,1.0]. // [0.0,1.0].
...@@ -86,33 +122,19 @@ double Scorer::ComputeScore(const FeatureMap& features) const { ...@@ -86,33 +122,19 @@ double Scorer::ComputeScore(const FeatureMap& features) const {
return LogOdds2Prob(logodds); return LogOdds2Prob(logodds);
} }
bool Scorer::GetMatchingVisualTargets(const SkBitmap& bitmap, void Scorer::GetMatchingVisualTargets(
ClientPhishingRequest* request) const { const SkBitmap& bitmap,
bool has_match = false; std::unique_ptr<ClientPhishingRequest> request,
for (const VisualTarget& target : model_.vision_model().targets()) { base::OnceCallback<void(std::unique_ptr<ClientPhishingRequest>)> callback)
base::Optional<VisionMatchResult> result = const {
visual_utils::IsVisualMatch(bitmap, target); DCHECK(content::RenderThread::IsMainThread());
if (result.has_value()) {
*request->add_vision_match() = result.value(); // Perform scoring off the main thread to avoid blocking.
has_match = true; base::ThreadPool::PostTaskAndReplyWithResult(
} FROM_HERE, {base::WithBaseSyncPrimitives()},
} base::BindOnce(&GetMatchingVisualTargetsHelper, bitmap, model_,
std::move(request)),
if (model_.has_vision_model()) { std::move(callback));
// Populate these fields for telementry purposes. They will be filtered in
// the browser process if they are not needed.
VisualFeatures::BlurredImage blurred_image;
if (visual_utils::GetBlurredImage(bitmap, &blurred_image)) {
std::string raw_digest = crypto::SHA256HashString(blurred_image.data());
request->set_screenshot_digest(
base::HexEncode(raw_digest.data(), raw_digest.size()));
request->set_screenshot_phash(
visual_utils::GetHashFromBlurredImage(blurred_image));
request->set_phash_dimension_size(48);
}
}
return has_match;
} }
int Scorer::model_version() const { int Scorer::model_version() const {
......
...@@ -20,9 +20,12 @@ ...@@ -20,9 +20,12 @@
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "base/callback.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
#include "components/safe_browsing/core/proto/client_model.pb.h" #include "components/safe_browsing/core/proto/client_model.pb.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "third_party/skia/include/core/SkBitmap.h" #include "third_party/skia/include/core/SkBitmap.h"
namespace safe_browsing { namespace safe_browsing {
...@@ -42,10 +45,15 @@ class Scorer { ...@@ -42,10 +45,15 @@ class Scorer {
// (range is inclusive on both ends). // (range is inclusive on both ends).
virtual double ComputeScore(const FeatureMap& features) const; virtual double ComputeScore(const FeatureMap& features) const;
// This method matches the given |bitmap| against the visual model. It returns // This method matches the given |bitmap| against the visual model. It
// true if any visual target matches, and populates |request| appropriately. // modifies |request| appropriately, and returns the new request. This expects
virtual bool GetMatchingVisualTargets(const SkBitmap& bitmap, // to be called on the renderer main thread, but will perform scoring
ClientPhishingRequest* request) const; // asynchronously on a worker thread.
virtual void GetMatchingVisualTargets(
const SkBitmap& bitmap,
std::unique_ptr<ClientPhishingRequest> request,
base::OnceCallback<void(std::unique_ptr<ClientPhishingRequest>)> callback)
const;
// Returns the version number of the loaded client model. // Returns the version number of the loaded client model.
int model_version() const; int model_version() const;
...@@ -95,6 +103,8 @@ class Scorer { ...@@ -95,6 +103,8 @@ class Scorer {
std::unordered_set<std::string> page_terms_; std::unordered_set<std::string> page_terms_;
std::unordered_set<uint32_t> page_words_; std::unordered_set<uint32_t> page_words_;
base::WeakPtrFactory<Scorer> weak_ptr_factory_{this};
DISALLOW_COPY_AND_ASSIGN(Scorer); DISALLOW_COPY_AND_ASSIGN(Scorer);
}; };
} // namespace safe_browsing } // namespace safe_browsing
......
...@@ -12,9 +12,13 @@ ...@@ -12,9 +12,13 @@
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/files/scoped_temp_dir.h" #include "base/files/scoped_temp_dir.h"
#include "base/format_macros.h" #include "base/format_macros.h"
#include "base/run_loop.h"
#include "base/test/bind_test_util.h"
#include "base/test/task_environment.h"
#include "base/threading/thread.h" #include "base/threading/thread.h"
#include "chrome/renderer/safe_browsing/features.h" #include "chrome/renderer/safe_browsing/features.h"
#include "components/safe_browsing/core/proto/client_model.pb.h" #include "components/safe_browsing/core/proto/client_model.pb.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -189,10 +193,20 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchOne) { ...@@ -189,10 +193,20 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchOne) {
for (int x = 0; x < 164; x++) for (int x = 0; x < 164; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000; *bitmap_.getAddr32(x, 0) = 0xff000000;
ClientPhishingRequest request; base::test::TaskEnvironment task_environment;
scorer->GetMatchingVisualTargets(bitmap_, &request); base::RunLoop run_loop;
ASSERT_EQ(request.vision_match_size(), 1); std::unique_ptr<ClientPhishingRequest> request =
EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1"); std::make_unique<ClientPhishingRequest>();
scorer->GetMatchingVisualTargets(
bitmap_, std::move(request),
base::BindLambdaForTesting(
[&](std::unique_ptr<ClientPhishingRequest> request) {
ASSERT_EQ(request->vision_match_size(), 1);
EXPECT_EQ(request->vision_match(0).matched_target_digest(),
"target1");
run_loop.Quit();
}));
run_loop.Run();
} }
TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) { TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
...@@ -212,11 +226,22 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) { ...@@ -212,11 +226,22 @@ TEST_F(PhishingScorerTest, GetMatchingVisualTargetsMatchBoth) {
for (int x = 168; x < 248; x++) for (int x = 168; x < 248; x++)
*bitmap_.getAddr32(x, 0) = 0xff000000; *bitmap_.getAddr32(x, 0) = 0xff000000;
ClientPhishingRequest request; base::test::TaskEnvironment task_environment;
scorer->GetMatchingVisualTargets(bitmap_, &request); base::RunLoop run_loop;
ASSERT_EQ(request.vision_match_size(), 2); std::unique_ptr<ClientPhishingRequest> request =
EXPECT_EQ(request.vision_match(0).matched_target_digest(), "target1"); std::make_unique<ClientPhishingRequest>();
EXPECT_EQ(request.vision_match(1).matched_target_digest(), "target2"); scorer->GetMatchingVisualTargets(
bitmap_, std::move(request),
base::BindLambdaForTesting(
[&](std::unique_ptr<ClientPhishingRequest> request) {
ASSERT_EQ(request->vision_match_size(), 2);
EXPECT_EQ(request->vision_match(0).matched_target_digest(),
"target1");
EXPECT_EQ(request->vision_match(1).matched_target_digest(),
"target2");
run_loop.Quit();
}));
run_loop.Run();
} }
} // namespace safe_browsing } // namespace safe_browsing
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment