Commit 0bc461bf authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Fix lifetime of CheckClientDownloadRequest

CheckClientDownloadRequest was using 3 methods to prevent use-after-free
(weak pointers, ref counting, and cancelable tasks). This CL replaces
all of those with weak pointers, so that CheckClientDownloadRequest can
be singly owned by the DownloadProtectionService.

Bug: 889986
Change-Id: I098d207df9fccd7842c7fdc65dba94c75ec8b062
Reviewed-on: https://chromium-review.googlesource.com/c/1289956
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Reviewed-by: default avatarNathan Parker <nparker@chromium.org>
Cr-Commit-Position: refs/heads/master@{#604652}
parent 2b7d5477
......@@ -16,7 +16,6 @@
#include "base/gtest_prod_util.h"
#include "base/memory/ref_counted.h"
#include "base/supports_user_data.h"
#include "base/task/cancelable_task_tracker.h"
#include "build/build_config.h"
#include "chrome/browser/safe_browsing/download_protection/download_protection_util.h"
#include "chrome/browser/safe_browsing/download_protection/file_analyzer.h"
......@@ -44,10 +43,7 @@ class SimpleURLLoader;
namespace safe_browsing {
class CheckClientDownloadRequest
: public base::RefCountedThreadSafe<CheckClientDownloadRequest,
BrowserThread::DeleteOnUIThread>,
public download::DownloadItem::Observer {
class CheckClientDownloadRequest : public download::DownloadItem::Observer {
public:
CheckClientDownloadRequest(
download::DownloadItem* item,
......@@ -55,13 +51,12 @@ class CheckClientDownloadRequest
DownloadProtectionService* service,
const scoped_refptr<SafeBrowsingDatabaseManager>& database_manager,
BinaryFeatureExtractor* binary_feature_extractor);
~CheckClientDownloadRequest() override;
bool ShouldSampleUnsupportedFile(const base::FilePath& filename);
void Start();
void StartTimeout();
// |download_destroyed| indicates if cancellation is due to the destruction of
// the download item.
void Cancel(bool download_destroyed);
void OnDownloadDestroyed(download::DownloadItem* download) override;
void OnURLLoaderComplete(std::unique_ptr<std::string> response_body);
static bool IsSupportedDownload(const download::DownloadItem& item,
......@@ -76,40 +71,20 @@ class CheckClientDownloadRequest
using ArchivedBinaries =
google::protobuf::RepeatedPtrField<ClientDownloadRequest_ArchivedBinary>;
~CheckClientDownloadRequest() override;
// Performs file feature extraction and SafeBrowsing ping for downloads that
// don't match the URL whitelist.
void AnalyzeFile();
void OnFileFeatureExtractionDone(FileAnalyzer::Results results);
void StartExtractFileFeatures();
void ExtractFileFeatures(const base::FilePath& file_path);
void StartExtractRarFeatures();
void OnRarAnalysisFinished(const ArchiveAnalyzerResults& results);
void StartExtractZipFeatures();
void OnZipAnalysisFinished(const ArchiveAnalyzerResults& results);
static void CopyArchivedBinaries(const ArchivedBinaries& src_binaries,
ArchivedBinaries* dest_binaries);
#if defined(OS_MACOSX)
void StartExtractDmgFeatures();
void ExtractFileOrDmgFeatures(bool download_file_has_koly_signature);
void OnDmgAnalysisFinished(const ArchiveAnalyzerResults& results);
#endif // defined(OS_MACOSX)
bool ShouldSampleWhitelistedDownload();
// Checks the download URL against SafeBrowsing whitelist. If download URL is
// on whitelist, file feature extraction and download ping are skipped.
void CheckUrlAgainstWhitelist();
void CheckCertificateChainAgainstWhitelist();
void OnUrlWhitelistCheckDone(bool is_whitelisted);
void OnCertificateWhitelistCheckDone(bool is_whitelisted);
void GetTabRedirects();
void OnGotTabRedirects(const GURL& url,
const history::RedirectList* redirect_list);
bool IsDownloadManuallyBlacklisted(const ClientDownloadRequest& request);
std::string SanitizeUrl(const GURL& url) const;
void SendRequest();
void PostFinishTask(DownloadCheckResult result,
DownloadCheckResultReason reason);
void FinishRequest(DownloadCheckResult result,
DownloadCheckResultReason reason);
bool CertificateChainIsWhitelisted(
......@@ -148,7 +123,6 @@ class CheckClientDownloadRequest
const bool pingback_enabled_;
std::unique_ptr<network::SimpleURLLoader> loader_;
std::unique_ptr<FileAnalyzer> file_analyzer_;
bool finished_;
ClientDownloadRequest::DownloadType type_;
std::string client_download_request_data_;
base::CancelableTaskTracker request_tracker_; // For HistoryService lookup.
......@@ -160,9 +134,7 @@ class CheckClientDownloadRequest
bool is_extended_reporting_;
bool is_incognito_;
bool is_under_advanced_protection_;
// This task tracker is used for posting the URL whitelist check to the IO
// thread. The posted task will be cancelled if DownloadItem gets destroyed.
base::CancelableTaskTracker cancelable_task_tracker_;
base::WeakPtrFactory<CheckClientDownloadRequest> weakptr_factory_;
FRIEND_TEST_ALL_PREFIXES(CheckClientDownloadRequestTest,
......
......@@ -172,11 +172,11 @@ void DownloadProtectionService::CheckClientDownload(
callback.Run(DownloadCheckResult::WHITELISTED_BY_POLICY);
return;
}
scoped_refptr<CheckClientDownloadRequest> request(
new CheckClientDownloadRequest(item, callback, this, database_manager_,
binary_feature_extractor_.get()));
download_requests_.insert(request);
request->Start();
auto request = std::make_unique<CheckClientDownloadRequest>(
item, callback, this, database_manager_, binary_feature_extractor_.get());
CheckClientDownloadRequest* request_copy = request.get();
download_requests_[request_copy] = std::move(request);
request_copy->Start();
}
void DownloadProtectionService::CheckDownloadUrl(
......@@ -257,13 +257,8 @@ DownloadProtectionService::RegisterPPAPIDownloadRequestCallback(
void DownloadProtectionService::CancelPendingRequests() {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
for (auto it = download_requests_.begin(); it != download_requests_.end();) {
// We need to advance the iterator before we cancel because canceling
// the request will invalidate it when RequestFinished is called below.
scoped_refptr<CheckClientDownloadRequest> tmp = *it++;
tmp->Cancel(/*download_destropyed=*/false);
}
DCHECK(download_requests_.empty());
// It is sufficient to delete the list of CheckClientDownloadRequests.
download_requests_.clear();
// It is sufficient to delete the list of PPAPI download requests.
ppapi_download_requests_.clear();
......@@ -274,7 +269,7 @@ void DownloadProtectionService::RequestFinished(
DCHECK_CURRENTLY_ON(BrowserThread::UI);
auto it = download_requests_.find(request);
DCHECK(it != download_requests_.end());
download_requests_.erase(*it);
download_requests_.erase(it);
}
void DownloadProtectionService::PPAPIDownloadCheckRequestFinished(
......@@ -352,81 +347,6 @@ void DownloadProtectionService::MaybeSendDangerousDownloadOpenedReport(
}
}
namespace {
// Escapes a certificate attribute so that it can be used in a whitelist
// entry. Currently, we only escape slashes, since they are used as a
// separator between attributes.
std::string EscapeCertAttribute(const std::string& attribute) {
std::string escaped;
for (size_t i = 0; i < attribute.size(); ++i) {
if (attribute[i] == '%') {
escaped.append("%25");
} else if (attribute[i] == '/') {
escaped.append("%2F");
} else {
escaped.push_back(attribute[i]);
}
}
return escaped;
}
} // namespace
// static
void DownloadProtectionService::GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings) {
// The whitelist paths are in the format:
// cert/<ascii issuer fingerprint>[/CN=common_name][/O=org][/OU=unit]
//
// Any of CN, O, or OU may be omitted from the whitelist entry, in which
// case they match anything. However, the attributes that do appear will
// always be in the order shown above. At least one attribute will always
// be present.
const net::CertPrincipal& subject = certificate.subject();
std::vector<std::string> ou_tokens;
for (size_t i = 0; i < subject.organization_unit_names.size(); ++i) {
ou_tokens.push_back(
"/OU=" + EscapeCertAttribute(subject.organization_unit_names[i]));
}
std::vector<std::string> o_tokens;
for (size_t i = 0; i < subject.organization_names.size(); ++i) {
o_tokens.push_back("/O=" +
EscapeCertAttribute(subject.organization_names[i]));
}
std::string cn_token;
if (!subject.common_name.empty()) {
cn_token = "/CN=" + EscapeCertAttribute(subject.common_name);
}
std::set<std::string> paths_to_check;
if (!cn_token.empty()) {
paths_to_check.insert(cn_token);
}
for (size_t i = 0; i < o_tokens.size(); ++i) {
paths_to_check.insert(cn_token + o_tokens[i]);
paths_to_check.insert(o_tokens[i]);
for (size_t j = 0; j < ou_tokens.size(); ++j) {
paths_to_check.insert(cn_token + o_tokens[i] + ou_tokens[j]);
paths_to_check.insert(o_tokens[i] + ou_tokens[j]);
}
}
for (size_t i = 0; i < ou_tokens.size(); ++i) {
paths_to_check.insert(cn_token + ou_tokens[i]);
paths_to_check.insert(ou_tokens[i]);
}
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer.cert_buffer())));
std::string issuer_fp = base::HexEncode(hashed.data(), hashed.size());
for (auto it = paths_to_check.begin(); it != paths_to_check.end(); ++it) {
whitelist_strings->push_back("cert/" + issuer_fp + *it);
}
}
std::unique_ptr<ReferrerChainData>
DownloadProtectionService::IdentifyReferrerChain(
const download::DownloadItem& item) {
......
......@@ -39,10 +39,6 @@ namespace download {
class DownloadItem;
}
namespace net {
class X509Certificate;
} // namespace net
namespace network {
class SharedURLLoaderFactory;
}
......@@ -204,14 +200,6 @@ class DownloadProtectionService {
void PPAPIDownloadCheckRequestFinished(PPAPIDownloadRequest* request);
// Given a certificate and its immediate issuer certificate, generates the
// list of strings that need to be checked against the download whitelist to
// determine whether the certificate is whitelisted.
static void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings);
// Identify referrer chain info of a download. This function also records UMA
// stats of download attribution result.
std::unique_ptr<ReferrerChainData> IdentifyReferrerChain(
......@@ -242,7 +230,9 @@ class DownloadProtectionService {
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// Set of pending server requests for DownloadManager mediated downloads.
std::set<scoped_refptr<CheckClientDownloadRequest>> download_requests_;
std::unordered_map<CheckClientDownloadRequest*,
std::unique_ptr<CheckClientDownloadRequest>>
download_requests_;
// Set of pending server requests for PPAPI mediated downloads. Using a map
// because heterogeneous lookups aren't available yet in std::unordered_map.
......
......@@ -344,15 +344,6 @@ class DownloadProtectionServiceTest : public ChromeRenderViewHostTestHarness {
RunLoop().RunUntilIdle();
}
// Proxy for private method.
static void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings) {
DownloadProtectionService::GetCertificateWhitelistStrings(
certificate, issuer, whitelist_strings);
}
// Reads a single PEM-encoded certificate from the testdata directory.
// Returns NULL on failure.
scoped_refptr<net::X509Certificate> ReadTestCertificate(
......@@ -2235,87 +2226,6 @@ TEST_F(DownloadProtectionServiceTest,
EXPECT_FALSE(HasClientDownloadRequest());
}
TEST_F(DownloadProtectionServiceTest, GetCertificateWhitelistStrings) {
// We'll pass this cert in as the "issuer", even though it isn't really
// used to sign the certs below. GetCertificateWhitelistStirngs doesn't care
// about this.
scoped_refptr<net::X509Certificate> issuer_cert(
ReadTestCertificate("issuer.pem"));
ASSERT_TRUE(issuer_cert.get());
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer_cert->cert_buffer())));
std::string cert_base =
"cert/" + base::HexEncode(hashed.data(), hashed.size());
scoped_refptr<net::X509Certificate> cert(ReadTestCertificate("test_cn.pem"));
ASSERT_TRUE(cert.get());
std::vector<std::string> whitelist_strings;
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
// This also tests escaping of characters in the certificate attributes.
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject%2F%251"));
cert = ReadTestCertificate("test_cn_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/O=org",
cert_base + "/O=org"));
cert = ReadTestCertificate("test_cn_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(
whitelist_strings,
ElementsAre(cert_base + "/CN=subject", cert_base + "/CN=subject/O=org",
cert_base + "/CN=subject/O=org/OU=unit",
cert_base + "/CN=subject/OU=unit", cert_base + "/O=org",
cert_base + "/O=org/OU=unit", cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_cn_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/O=org"));
cert = ReadTestCertificate("test_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings,
ElementsAre(cert_base + "/O=org", cert_base + "/O=org/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_c.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre());
}
namespace {
class MockPageNavigator : public content::PageNavigator {
......
......@@ -5,12 +5,91 @@
#include "chrome/browser/safe_browsing/download_protection/download_protection_util.h"
#include "base/metrics/histogram_macros.h"
#include "base/sha1.h"
#include "base/strings/string_number_conversions.h"
#include "net/cert/x509_util.h"
namespace safe_browsing {
namespace {
// Escapes a certificate attribute so that it can be used in a whitelist
// entry. Currently, we only escape slashes, since they are used as a
// separator between attributes.
std::string EscapeCertAttribute(const std::string& attribute) {
std::string escaped;
for (size_t i = 0; i < attribute.size(); ++i) {
if (attribute[i] == '%') {
escaped.append("%25");
} else if (attribute[i] == '/') {
escaped.append("%2F");
} else {
escaped.push_back(attribute[i]);
}
}
return escaped;
}
} // namespace
void RecordCountOfWhitelistedDownload(WhitelistType type) {
UMA_HISTOGRAM_ENUMERATION("SBClientDownload.CheckWhitelistResult", type,
WHITELIST_TYPE_MAX);
}
void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings) {
// The whitelist paths are in the format:
// cert/<ascii issuer fingerprint>[/CN=common_name][/O=org][/OU=unit]
//
// Any of CN, O, or OU may be omitted from the whitelist entry, in which
// case they match anything. However, the attributes that do appear will
// always be in the order shown above. At least one attribute will always
// be present.
const net::CertPrincipal& subject = certificate.subject();
std::vector<std::string> ou_tokens;
for (size_t i = 0; i < subject.organization_unit_names.size(); ++i) {
ou_tokens.push_back(
"/OU=" + EscapeCertAttribute(subject.organization_unit_names[i]));
}
std::vector<std::string> o_tokens;
for (size_t i = 0; i < subject.organization_names.size(); ++i) {
o_tokens.push_back("/O=" +
EscapeCertAttribute(subject.organization_names[i]));
}
std::string cn_token;
if (!subject.common_name.empty()) {
cn_token = "/CN=" + EscapeCertAttribute(subject.common_name);
}
std::set<std::string> paths_to_check;
if (!cn_token.empty()) {
paths_to_check.insert(cn_token);
}
for (size_t i = 0; i < o_tokens.size(); ++i) {
paths_to_check.insert(cn_token + o_tokens[i]);
paths_to_check.insert(o_tokens[i]);
for (size_t j = 0; j < ou_tokens.size(); ++j) {
paths_to_check.insert(cn_token + o_tokens[i] + ou_tokens[j]);
paths_to_check.insert(o_tokens[i] + ou_tokens[j]);
}
}
for (size_t i = 0; i < ou_tokens.size(); ++i) {
paths_to_check.insert(cn_token + ou_tokens[i]);
paths_to_check.insert(ou_tokens[i]);
}
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer.cert_buffer())));
std::string issuer_fp = base::HexEncode(hashed.data(), hashed.size());
for (auto it = paths_to_check.begin(); it != paths_to_check.end(); ++it) {
whitelist_strings->push_back("cert/" + issuer_fp + *it);
}
}
} // namespace safe_browsing
......@@ -10,6 +10,7 @@
#include "base/callback_list.h"
#include "components/download/public/common/download_item.h"
#include "components/safe_browsing/proto/csd.pb.h"
#include "net/cert/x509_certificate.h"
namespace safe_browsing {
......@@ -115,6 +116,14 @@ typedef std::unique_ptr<PPAPIDownloadRequestCallbackList::Subscription>
void RecordCountOfWhitelistedDownload(WhitelistType type);
// Given a certificate and its immediate issuer certificate, generates the
// list of strings that need to be checked against the download whitelist to
// determine whether the certificate is whitelisted.
void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings);
} // namespace safe_browsing
#endif // CHROME_BROWSER_SAFE_BROWSING_DOWNLOAD_PROTECTION_DOWNLOAD_PROTECTION_UTIL_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/safe_browsing/download_protection/download_protection_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace safe_browsing {
TEST(DownloadProtectionUtilTest, GetCertificateWhitelistStrings) {
// We'll pass this cert in as the "issuer", even though it isn't really
// used to sign the certs below. GetCertificateWhitelistStirngs doesn't care
// about this.
scoped_refptr<net::X509Certificate> issuer_cert(
ReadTestCertificate("issuer.pem"));
ASSERT_TRUE(issuer_cert.get());
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer_cert->cert_buffer())));
std::string cert_base =
"cert/" + base::HexEncode(hashed.data(), hashed.size());
scoped_refptr<net::X509Certificate> cert(ReadTestCertificate("test_cn.pem"));
ASSERT_TRUE(cert.get());
std::vector<std::string> whitelist_strings;
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
// This also tests escaping of characters in the certificate attributes.
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject%2F%251"));
cert = ReadTestCertificate("test_cn_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/O=org",
cert_base + "/O=org"));
cert = ReadTestCertificate("test_cn_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(
whitelist_strings,
ElementsAre(cert_base + "/CN=subject", cert_base + "/CN=subject/O=org",
cert_base + "/CN=subject/O=org/OU=unit",
cert_base + "/CN=subject/OU=unit", cert_base + "/O=org",
cert_base + "/O=org/OU=unit", cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_cn_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/O=org"));
cert = ReadTestCertificate("test_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings,
ElementsAre(cert_base + "/O=org", cert_base + "/O=org/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_c.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre());
}
} // namespace safe_browsing
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment