Commit edfd43fe authored by Michael Giuffrida's avatar Michael Giuffrida Committed by Chromium LUCI CQ

Fix URLChecker DCHECK/segfaults

Fixes use-after-free issues in URLChecker and
KidsManagementURLCheckerClient.

Asynchronous callbacks passed to unowned objects should be bound with
weak pointers, so if `this` is destroyed, the callback won't fire.

Test:
* As a parent, set a child's Safe Site setting to "Allow all sites"
* Add the child to a Chromebook
This CL fixes DCHECKs and (in prod) potential use-after-free segfaults
that could otherwise occur when running the steps above.

Bug: 1159898
Change-Id: I67ba8f924ded3b6316813529e9b77a3d3735b483
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2597547Reviewed-by: default avatarAga Wronska <agawronska@chromium.org>
Reviewed-by: default avatarToby Huang <tobyhuang@chromium.org>
Commit-Queue: Aga Wronska <agawronska@chromium.org>
Cr-Commit-Position: refs/heads/master@{#843171}
parent af4d8a1c
...@@ -51,7 +51,7 @@ void KidsManagementURLCheckerClient::CheckURL(const GURL& url, ...@@ -51,7 +51,7 @@ void KidsManagementURLCheckerClient::CheckURL(const GURL& url,
kids_chrome_management_client->ClassifyURL( kids_chrome_management_client->ClassifyURL(
std::move(classify_url_request), std::move(classify_url_request),
base::BindOnce(&KidsManagementURLCheckerClient::ConvertResponseCallback, base::BindOnce(&KidsManagementURLCheckerClient::ConvertResponseCallback,
base::Unretained(this), url, std::move(callback))); weak_factory_.GetWeakPtr(), url, std::move(callback)));
} }
void KidsManagementURLCheckerClient::ConvertResponseCallback( void KidsManagementURLCheckerClient::ConvertResponseCallback(
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <string> #include <string>
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "chrome/browser/supervised_user/kids_chrome_management/kids_chrome_management_client.h" #include "chrome/browser/supervised_user/kids_chrome_management/kids_chrome_management_client.h"
#include "components/safe_search_api/url_checker_client.h" #include "components/safe_search_api/url_checker_client.h"
...@@ -44,6 +45,8 @@ class KidsManagementURLCheckerClient ...@@ -44,6 +45,8 @@ class KidsManagementURLCheckerClient
const std::string country_; const std::string country_;
base::WeakPtrFactory<KidsManagementURLCheckerClient> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(KidsManagementURLCheckerClient); DISALLOW_COPY_AND_ASSIGN(KidsManagementURLCheckerClient);
}; };
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include "components/user_manager/scoped_user_manager.h" #include "components/user_manager/scoped_user_manager.h"
#endif #endif
using testing::_;
namespace { namespace {
using kids_chrome_management::ClassifyUrlResponse; using kids_chrome_management::ClassifyUrlResponse;
...@@ -68,7 +70,9 @@ class KidsChromeManagementClientForTesting : public KidsChromeManagementClient { ...@@ -68,7 +70,9 @@ class KidsChromeManagementClientForTesting : public KidsChromeManagementClient {
std::unique_ptr<kids_chrome_management::ClassifyUrlRequest> request_proto, std::unique_ptr<kids_chrome_management::ClassifyUrlRequest> request_proto,
KidsChromeManagementClient::KidsChromeManagementCallback callback) KidsChromeManagementClient::KidsChromeManagementCallback callback)
override { override {
std::move(callback).Run(std::move(response_proto_), error_code_); base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback),
std::move(response_proto_), error_code_));
} }
void SetupResponse(std::unique_ptr<ClassifyUrlResponse> response_proto, void SetupResponse(std::unique_ptr<ClassifyUrlResponse> response_proto,
...@@ -134,12 +138,15 @@ class KidsManagementURLCheckerClientTest : public testing::Test { ...@@ -134,12 +138,15 @@ class KidsManagementURLCheckerClientTest : public testing::Test {
->SetupResponse(std::move(response_proto), error_code); ->SetupResponse(std::move(response_proto), error_code);
} }
// Asynchronously checks the URL and waits until finished.
void CheckURL(const GURL& url) { void CheckURL(const GURL& url) {
url_classifier_->CheckURL( StartCheckURL(url);
url, base::BindOnce(&KidsManagementURLCheckerClientTest::OnCheckDone, task_environment_.RunUntilIdle();
base::Unretained(this)));
} }
// Starts a URL check, but doesn't wait for ClassifyURL() to finish.
void CheckURLWithoutResponse(const GURL& url) { StartCheckURL(url); }
MOCK_METHOD2(OnCheckDone, MOCK_METHOD2(OnCheckDone,
void(const GURL& url, void(const GURL& url,
safe_search_api::ClientClassification classification)); safe_search_api::ClientClassification classification));
...@@ -154,6 +161,12 @@ class KidsManagementURLCheckerClientTest : public testing::Test { ...@@ -154,6 +161,12 @@ class KidsManagementURLCheckerClientTest : public testing::Test {
#endif #endif
private: private:
void StartCheckURL(const GURL& url) {
url_classifier_->CheckURL(
url, base::BindOnce(&KidsManagementURLCheckerClientTest::OnCheckDone,
base::Unretained(this)));
}
DISALLOW_COPY_AND_ASSIGN(KidsManagementURLCheckerClientTest); DISALLOW_COPY_AND_ASSIGN(KidsManagementURLCheckerClientTest);
}; };
...@@ -240,3 +253,16 @@ TEST_F(KidsManagementURLCheckerClientTest, ServiceError) { ...@@ -240,3 +253,16 @@ TEST_F(KidsManagementURLCheckerClientTest, ServiceError) {
EXPECT_CALL(*this, OnCheckDone(url, classification)); EXPECT_CALL(*this, OnCheckDone(url, classification));
CheckURL(url); CheckURL(url);
} }
TEST_F(KidsManagementURLCheckerClientTest, DestroyClientBeforeCallback) {
GURL url("http://randomurl7.com");
EXPECT_CALL(*this, OnCheckDone(_, _)).Times(0);
CheckURLWithoutResponse(url);
// Destroy the URLCheckerClient.
url_classifier_.reset();
// Now run the callback.
task_environment_.RunUntilIdle();
}
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <utility> #include <utility>
#include "base/callback.h" #include "base/callback.h"
#include "base/threading/sequenced_task_runner_handle.h"
namespace safe_search_api { namespace safe_search_api {
...@@ -25,4 +26,10 @@ void FakeURLCheckerClient::RunCallback(ClientClassification classification) { ...@@ -25,4 +26,10 @@ void FakeURLCheckerClient::RunCallback(ClientClassification classification) {
std::move(callback_).Run(url_, classification); std::move(callback_).Run(url_, classification);
} }
void FakeURLCheckerClient::RunCallbackAsync(
ClientClassification classification) {
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback_), url_, classification));
}
} // namespace safe_search_api } // namespace safe_search_api
...@@ -23,10 +23,13 @@ class FakeURLCheckerClient : public URLCheckerClient { ...@@ -23,10 +23,13 @@ class FakeURLCheckerClient : public URLCheckerClient {
// See RunCallback() method documentation below on how to run the callback. // See RunCallback() method documentation below on how to run the callback.
void CheckURL(const GURL& url, ClientCheckCallback callback) override; void CheckURL(const GURL& url, ClientCheckCallback callback) override;
// Run the callback function input by the last call of CheckURL() with the // Runs the callback function input by the last call of CheckURL() with the
// result input with the last call of SetResult(). // result input with the last call of SetResult().
void RunCallback(ClientClassification classification); void RunCallback(ClientClassification classification);
// Asynchronous version of RunCallback().
void RunCallbackAsync(ClientClassification classification);
private: private:
ClientCheckCallback callback_; ClientCheckCallback callback_;
GURL url_; GURL url_;
......
...@@ -47,11 +47,7 @@ URLChecker::Check::Check(const GURL& url, CheckCallback callback) : url(url) { ...@@ -47,11 +47,7 @@ URLChecker::Check::Check(const GURL& url, CheckCallback callback) : url(url) {
callbacks.push_back(std::move(callback)); callbacks.push_back(std::move(callback));
} }
URLChecker::Check::~Check() { URLChecker::Check::~Check() = default;
for (const CheckCallback& callback : callbacks) {
DCHECK(!callback);
}
}
URLChecker::CheckResult::CheckResult(Classification classification, URLChecker::CheckResult::CheckResult(Classification classification,
bool uncertain) bool uncertain)
...@@ -118,7 +114,7 @@ bool URLChecker::CheckURL(const GURL& url, CheckCallback callback) { ...@@ -118,7 +114,7 @@ bool URLChecker::CheckURL(const GURL& url, CheckCallback callback) {
std::make_unique<Check>(url, std::move(callback))); std::make_unique<Check>(url, std::move(callback)));
async_checker_->CheckURL(url, async_checker_->CheckURL(url,
base::BindOnce(&URLChecker::OnAsyncCheckComplete, base::BindOnce(&URLChecker::OnAsyncCheckComplete,
base::Unretained(this), it)); weak_factory_.GetWeakPtr(), it));
return false; return false;
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/callback_forward.h" #include "base/callback_forward.h"
#include "base/containers/mru_cache.h" #include "base/containers/mru_cache.h"
#include "base/memory/weak_ptr.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "components/safe_search_api/url_checker_client.h" #include "components/safe_search_api/url_checker_client.h"
#include "url/gurl.h" #include "url/gurl.h"
...@@ -71,6 +72,8 @@ class URLChecker { ...@@ -71,6 +72,8 @@ class URLChecker {
base::MRUCache<GURL, CheckResult> cache_; base::MRUCache<GURL, CheckResult> cache_;
base::TimeDelta cache_timeout_; base::TimeDelta cache_timeout_;
base::WeakPtrFactory<URLChecker> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(URLChecker); DISALLOW_COPY_AND_ASSIGN(URLChecker);
}; };
......
...@@ -91,6 +91,7 @@ class SafeSearchURLCheckerTest : public testing::Test { ...@@ -91,6 +91,7 @@ class SafeSearchURLCheckerTest : public testing::Test {
size_t next_url_; size_t next_url_;
FakeURLCheckerClient* fake_client_; FakeURLCheckerClient* fake_client_;
std::unique_ptr<URLChecker> checker_; std::unique_ptr<URLChecker> checker_;
base::test::SingleThreadTaskEnvironment task_environment_;
}; };
TEST_F(SafeSearchURLCheckerTest, Simple) { TEST_F(SafeSearchURLCheckerTest, Simple) {
...@@ -196,4 +197,20 @@ TEST_F(SafeSearchURLCheckerTest, NoAllowAllGoogleURLs) { ...@@ -196,4 +197,20 @@ TEST_F(SafeSearchURLCheckerTest, NoAllowAllGoogleURLs) {
} }
} }
TEST_F(SafeSearchURLCheckerTest, DestroyURLCheckerBeforeCallback) {
GURL url(GetNewURL());
EXPECT_CALL(*this, OnCheckDone(_, _, _)).Times(0);
// Start a URL check.
ASSERT_FALSE(CheckURL(url));
fake_client_->RunCallbackAsync(
ToAPIClassification(Classification::SAFE, false));
// Reset the URLChecker before the callback occurs.
checker_.reset();
// The callback should now be invalid.
task_environment_.RunUntilIdle();
}
} // namespace safe_search_api } // namespace safe_search_api
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment