Commit 0bc461bf authored by Daniel Rubery's avatar Daniel Rubery Committed by Commit Bot

Fix lifetime of CheckClientDownloadRequest

CheckClientDownloadRequest was using 3 methods to prevent use-after-free
(weak pointers, ref counting, and cancelable tasks). This CL replaces
all of those with weak pointers, so that CheckClientDownloadRequest can
be singly owned by the DownloadProtectionService.

Bug: 889986
Change-Id: I098d207df9fccd7842c7fdc65dba94c75ec8b062
Reviewed-on: https://chromium-review.googlesource.com/c/1289956
Commit-Queue: Daniel Rubery <drubery@chromium.org>
Reviewed-by: default avatarNathan Parker <nparker@chromium.org>
Cr-Commit-Position: refs/heads/master@{#604652}
parent 2b7d5477
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "base/gtest_prod_util.h" #include "base/gtest_prod_util.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/supports_user_data.h" #include "base/supports_user_data.h"
#include "base/task/cancelable_task_tracker.h"
#include "build/build_config.h" #include "build/build_config.h"
#include "chrome/browser/safe_browsing/download_protection/download_protection_util.h" #include "chrome/browser/safe_browsing/download_protection/download_protection_util.h"
#include "chrome/browser/safe_browsing/download_protection/file_analyzer.h" #include "chrome/browser/safe_browsing/download_protection/file_analyzer.h"
...@@ -44,10 +43,7 @@ class SimpleURLLoader; ...@@ -44,10 +43,7 @@ class SimpleURLLoader;
namespace safe_browsing { namespace safe_browsing {
class CheckClientDownloadRequest class CheckClientDownloadRequest : public download::DownloadItem::Observer {
: public base::RefCountedThreadSafe<CheckClientDownloadRequest,
BrowserThread::DeleteOnUIThread>,
public download::DownloadItem::Observer {
public: public:
CheckClientDownloadRequest( CheckClientDownloadRequest(
download::DownloadItem* item, download::DownloadItem* item,
...@@ -55,13 +51,12 @@ class CheckClientDownloadRequest ...@@ -55,13 +51,12 @@ class CheckClientDownloadRequest
DownloadProtectionService* service, DownloadProtectionService* service,
const scoped_refptr<SafeBrowsingDatabaseManager>& database_manager, const scoped_refptr<SafeBrowsingDatabaseManager>& database_manager,
BinaryFeatureExtractor* binary_feature_extractor); BinaryFeatureExtractor* binary_feature_extractor);
~CheckClientDownloadRequest() override;
bool ShouldSampleUnsupportedFile(const base::FilePath& filename); bool ShouldSampleUnsupportedFile(const base::FilePath& filename);
void Start(); void Start();
void StartTimeout(); void StartTimeout();
// |download_destroyed| indicates if cancellation is due to the destruction of
// the download item.
void Cancel(bool download_destroyed);
void OnDownloadDestroyed(download::DownloadItem* download) override; void OnDownloadDestroyed(download::DownloadItem* download) override;
void OnURLLoaderComplete(std::unique_ptr<std::string> response_body); void OnURLLoaderComplete(std::unique_ptr<std::string> response_body);
static bool IsSupportedDownload(const download::DownloadItem& item, static bool IsSupportedDownload(const download::DownloadItem& item,
...@@ -76,40 +71,20 @@ class CheckClientDownloadRequest ...@@ -76,40 +71,20 @@ class CheckClientDownloadRequest
using ArchivedBinaries = using ArchivedBinaries =
google::protobuf::RepeatedPtrField<ClientDownloadRequest_ArchivedBinary>; google::protobuf::RepeatedPtrField<ClientDownloadRequest_ArchivedBinary>;
~CheckClientDownloadRequest() override;
// Performs file feature extraction and SafeBrowsing ping for downloads that // Performs file feature extraction and SafeBrowsing ping for downloads that
// don't match the URL whitelist. // don't match the URL whitelist.
void AnalyzeFile(); void AnalyzeFile();
void OnFileFeatureExtractionDone(FileAnalyzer::Results results); void OnFileFeatureExtractionDone(FileAnalyzer::Results results);
void StartExtractFileFeatures();
void ExtractFileFeatures(const base::FilePath& file_path);
void StartExtractRarFeatures();
void OnRarAnalysisFinished(const ArchiveAnalyzerResults& results);
void StartExtractZipFeatures();
void OnZipAnalysisFinished(const ArchiveAnalyzerResults& results);
static void CopyArchivedBinaries(const ArchivedBinaries& src_binaries,
ArchivedBinaries* dest_binaries);
#if defined(OS_MACOSX)
void StartExtractDmgFeatures();
void ExtractFileOrDmgFeatures(bool download_file_has_koly_signature);
void OnDmgAnalysisFinished(const ArchiveAnalyzerResults& results);
#endif // defined(OS_MACOSX)
bool ShouldSampleWhitelistedDownload(); bool ShouldSampleWhitelistedDownload();
// Checks the download URL against SafeBrowsing whitelist. If download URL is void OnUrlWhitelistCheckDone(bool is_whitelisted);
// on whitelist, file feature extraction and download ping are skipped. void OnCertificateWhitelistCheckDone(bool is_whitelisted);
void CheckUrlAgainstWhitelist();
void CheckCertificateChainAgainstWhitelist();
void GetTabRedirects(); void GetTabRedirects();
void OnGotTabRedirects(const GURL& url, void OnGotTabRedirects(const GURL& url,
const history::RedirectList* redirect_list); const history::RedirectList* redirect_list);
bool IsDownloadManuallyBlacklisted(const ClientDownloadRequest& request); bool IsDownloadManuallyBlacklisted(const ClientDownloadRequest& request);
std::string SanitizeUrl(const GURL& url) const; std::string SanitizeUrl(const GURL& url) const;
void SendRequest(); void SendRequest();
void PostFinishTask(DownloadCheckResult result,
DownloadCheckResultReason reason);
void FinishRequest(DownloadCheckResult result, void FinishRequest(DownloadCheckResult result,
DownloadCheckResultReason reason); DownloadCheckResultReason reason);
bool CertificateChainIsWhitelisted( bool CertificateChainIsWhitelisted(
...@@ -148,7 +123,6 @@ class CheckClientDownloadRequest ...@@ -148,7 +123,6 @@ class CheckClientDownloadRequest
const bool pingback_enabled_; const bool pingback_enabled_;
std::unique_ptr<network::SimpleURLLoader> loader_; std::unique_ptr<network::SimpleURLLoader> loader_;
std::unique_ptr<FileAnalyzer> file_analyzer_; std::unique_ptr<FileAnalyzer> file_analyzer_;
bool finished_;
ClientDownloadRequest::DownloadType type_; ClientDownloadRequest::DownloadType type_;
std::string client_download_request_data_; std::string client_download_request_data_;
base::CancelableTaskTracker request_tracker_; // For HistoryService lookup. base::CancelableTaskTracker request_tracker_; // For HistoryService lookup.
...@@ -160,9 +134,7 @@ class CheckClientDownloadRequest ...@@ -160,9 +134,7 @@ class CheckClientDownloadRequest
bool is_extended_reporting_; bool is_extended_reporting_;
bool is_incognito_; bool is_incognito_;
bool is_under_advanced_protection_; bool is_under_advanced_protection_;
// This task tracker is used for posting the URL whitelist check to the IO
// thread. The posted task will be cancelled if DownloadItem gets destroyed.
base::CancelableTaskTracker cancelable_task_tracker_;
base::WeakPtrFactory<CheckClientDownloadRequest> weakptr_factory_; base::WeakPtrFactory<CheckClientDownloadRequest> weakptr_factory_;
FRIEND_TEST_ALL_PREFIXES(CheckClientDownloadRequestTest, FRIEND_TEST_ALL_PREFIXES(CheckClientDownloadRequestTest,
......
...@@ -172,11 +172,11 @@ void DownloadProtectionService::CheckClientDownload( ...@@ -172,11 +172,11 @@ void DownloadProtectionService::CheckClientDownload(
callback.Run(DownloadCheckResult::WHITELISTED_BY_POLICY); callback.Run(DownloadCheckResult::WHITELISTED_BY_POLICY);
return; return;
} }
scoped_refptr<CheckClientDownloadRequest> request( auto request = std::make_unique<CheckClientDownloadRequest>(
new CheckClientDownloadRequest(item, callback, this, database_manager_, item, callback, this, database_manager_, binary_feature_extractor_.get());
binary_feature_extractor_.get())); CheckClientDownloadRequest* request_copy = request.get();
download_requests_.insert(request); download_requests_[request_copy] = std::move(request);
request->Start(); request_copy->Start();
} }
void DownloadProtectionService::CheckDownloadUrl( void DownloadProtectionService::CheckDownloadUrl(
...@@ -257,13 +257,8 @@ DownloadProtectionService::RegisterPPAPIDownloadRequestCallback( ...@@ -257,13 +257,8 @@ DownloadProtectionService::RegisterPPAPIDownloadRequestCallback(
void DownloadProtectionService::CancelPendingRequests() { void DownloadProtectionService::CancelPendingRequests() {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
for (auto it = download_requests_.begin(); it != download_requests_.end();) { // It is sufficient to delete the list of CheckClientDownloadRequests.
// We need to advance the iterator before we cancel because canceling download_requests_.clear();
// the request will invalidate it when RequestFinished is called below.
scoped_refptr<CheckClientDownloadRequest> tmp = *it++;
tmp->Cancel(/*download_destropyed=*/false);
}
DCHECK(download_requests_.empty());
// It is sufficient to delete the list of PPAPI download requests. // It is sufficient to delete the list of PPAPI download requests.
ppapi_download_requests_.clear(); ppapi_download_requests_.clear();
...@@ -274,7 +269,7 @@ void DownloadProtectionService::RequestFinished( ...@@ -274,7 +269,7 @@ void DownloadProtectionService::RequestFinished(
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
auto it = download_requests_.find(request); auto it = download_requests_.find(request);
DCHECK(it != download_requests_.end()); DCHECK(it != download_requests_.end());
download_requests_.erase(*it); download_requests_.erase(it);
} }
void DownloadProtectionService::PPAPIDownloadCheckRequestFinished( void DownloadProtectionService::PPAPIDownloadCheckRequestFinished(
...@@ -352,81 +347,6 @@ void DownloadProtectionService::MaybeSendDangerousDownloadOpenedReport( ...@@ -352,81 +347,6 @@ void DownloadProtectionService::MaybeSendDangerousDownloadOpenedReport(
} }
} }
namespace {
// Escapes a certificate attribute so that it can be used in a whitelist
// entry. Currently, we only escape slashes, since they are used as a
// separator between attributes.
std::string EscapeCertAttribute(const std::string& attribute) {
std::string escaped;
for (size_t i = 0; i < attribute.size(); ++i) {
if (attribute[i] == '%') {
escaped.append("%25");
} else if (attribute[i] == '/') {
escaped.append("%2F");
} else {
escaped.push_back(attribute[i]);
}
}
return escaped;
}
} // namespace
// static
void DownloadProtectionService::GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings) {
// The whitelist paths are in the format:
// cert/<ascii issuer fingerprint>[/CN=common_name][/O=org][/OU=unit]
//
// Any of CN, O, or OU may be omitted from the whitelist entry, in which
// case they match anything. However, the attributes that do appear will
// always be in the order shown above. At least one attribute will always
// be present.
const net::CertPrincipal& subject = certificate.subject();
std::vector<std::string> ou_tokens;
for (size_t i = 0; i < subject.organization_unit_names.size(); ++i) {
ou_tokens.push_back(
"/OU=" + EscapeCertAttribute(subject.organization_unit_names[i]));
}
std::vector<std::string> o_tokens;
for (size_t i = 0; i < subject.organization_names.size(); ++i) {
o_tokens.push_back("/O=" +
EscapeCertAttribute(subject.organization_names[i]));
}
std::string cn_token;
if (!subject.common_name.empty()) {
cn_token = "/CN=" + EscapeCertAttribute(subject.common_name);
}
std::set<std::string> paths_to_check;
if (!cn_token.empty()) {
paths_to_check.insert(cn_token);
}
for (size_t i = 0; i < o_tokens.size(); ++i) {
paths_to_check.insert(cn_token + o_tokens[i]);
paths_to_check.insert(o_tokens[i]);
for (size_t j = 0; j < ou_tokens.size(); ++j) {
paths_to_check.insert(cn_token + o_tokens[i] + ou_tokens[j]);
paths_to_check.insert(o_tokens[i] + ou_tokens[j]);
}
}
for (size_t i = 0; i < ou_tokens.size(); ++i) {
paths_to_check.insert(cn_token + ou_tokens[i]);
paths_to_check.insert(ou_tokens[i]);
}
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer.cert_buffer())));
std::string issuer_fp = base::HexEncode(hashed.data(), hashed.size());
for (auto it = paths_to_check.begin(); it != paths_to_check.end(); ++it) {
whitelist_strings->push_back("cert/" + issuer_fp + *it);
}
}
std::unique_ptr<ReferrerChainData> std::unique_ptr<ReferrerChainData>
DownloadProtectionService::IdentifyReferrerChain( DownloadProtectionService::IdentifyReferrerChain(
const download::DownloadItem& item) { const download::DownloadItem& item) {
......
...@@ -39,10 +39,6 @@ namespace download { ...@@ -39,10 +39,6 @@ namespace download {
class DownloadItem; class DownloadItem;
} }
namespace net {
class X509Certificate;
} // namespace net
namespace network { namespace network {
class SharedURLLoaderFactory; class SharedURLLoaderFactory;
} }
...@@ -204,14 +200,6 @@ class DownloadProtectionService { ...@@ -204,14 +200,6 @@ class DownloadProtectionService {
void PPAPIDownloadCheckRequestFinished(PPAPIDownloadRequest* request); void PPAPIDownloadCheckRequestFinished(PPAPIDownloadRequest* request);
// Given a certificate and its immediate issuer certificate, generates the
// list of strings that need to be checked against the download whitelist to
// determine whether the certificate is whitelisted.
static void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings);
// Identify referrer chain info of a download. This function also records UMA // Identify referrer chain info of a download. This function also records UMA
// stats of download attribution result. // stats of download attribution result.
std::unique_ptr<ReferrerChainData> IdentifyReferrerChain( std::unique_ptr<ReferrerChainData> IdentifyReferrerChain(
...@@ -242,7 +230,9 @@ class DownloadProtectionService { ...@@ -242,7 +230,9 @@ class DownloadProtectionService {
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_; scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// Set of pending server requests for DownloadManager mediated downloads. // Set of pending server requests for DownloadManager mediated downloads.
std::set<scoped_refptr<CheckClientDownloadRequest>> download_requests_; std::unordered_map<CheckClientDownloadRequest*,
std::unique_ptr<CheckClientDownloadRequest>>
download_requests_;
// Set of pending server requests for PPAPI mediated downloads. Using a map // Set of pending server requests for PPAPI mediated downloads. Using a map
// because heterogeneous lookups aren't available yet in std::unordered_map. // because heterogeneous lookups aren't available yet in std::unordered_map.
......
...@@ -344,15 +344,6 @@ class DownloadProtectionServiceTest : public ChromeRenderViewHostTestHarness { ...@@ -344,15 +344,6 @@ class DownloadProtectionServiceTest : public ChromeRenderViewHostTestHarness {
RunLoop().RunUntilIdle(); RunLoop().RunUntilIdle();
} }
// Proxy for private method.
static void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings) {
DownloadProtectionService::GetCertificateWhitelistStrings(
certificate, issuer, whitelist_strings);
}
// Reads a single PEM-encoded certificate from the testdata directory. // Reads a single PEM-encoded certificate from the testdata directory.
// Returns NULL on failure. // Returns NULL on failure.
scoped_refptr<net::X509Certificate> ReadTestCertificate( scoped_refptr<net::X509Certificate> ReadTestCertificate(
...@@ -2235,87 +2226,6 @@ TEST_F(DownloadProtectionServiceTest, ...@@ -2235,87 +2226,6 @@ TEST_F(DownloadProtectionServiceTest,
EXPECT_FALSE(HasClientDownloadRequest()); EXPECT_FALSE(HasClientDownloadRequest());
} }
TEST_F(DownloadProtectionServiceTest, GetCertificateWhitelistStrings) {
// We'll pass this cert in as the "issuer", even though it isn't really
// used to sign the certs below. GetCertificateWhitelistStirngs doesn't care
// about this.
scoped_refptr<net::X509Certificate> issuer_cert(
ReadTestCertificate("issuer.pem"));
ASSERT_TRUE(issuer_cert.get());
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer_cert->cert_buffer())));
std::string cert_base =
"cert/" + base::HexEncode(hashed.data(), hashed.size());
scoped_refptr<net::X509Certificate> cert(ReadTestCertificate("test_cn.pem"));
ASSERT_TRUE(cert.get());
std::vector<std::string> whitelist_strings;
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
// This also tests escaping of characters in the certificate attributes.
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject%2F%251"));
cert = ReadTestCertificate("test_cn_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/O=org",
cert_base + "/O=org"));
cert = ReadTestCertificate("test_cn_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(
whitelist_strings,
ElementsAre(cert_base + "/CN=subject", cert_base + "/CN=subject/O=org",
cert_base + "/CN=subject/O=org/OU=unit",
cert_base + "/CN=subject/OU=unit", cert_base + "/O=org",
cert_base + "/O=org/OU=unit", cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_cn_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/O=org"));
cert = ReadTestCertificate("test_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings,
ElementsAre(cert_base + "/O=org", cert_base + "/O=org/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_c.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre());
}
namespace { namespace {
class MockPageNavigator : public content::PageNavigator { class MockPageNavigator : public content::PageNavigator {
......
...@@ -5,12 +5,91 @@ ...@@ -5,12 +5,91 @@
#include "chrome/browser/safe_browsing/download_protection/download_protection_util.h" #include "chrome/browser/safe_browsing/download_protection/download_protection_util.h"
#include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros.h"
#include "base/sha1.h"
#include "base/strings/string_number_conversions.h"
#include "net/cert/x509_util.h"
namespace safe_browsing { namespace safe_browsing {
namespace {
// Escapes a certificate attribute so that it can be used in a whitelist
// entry. Currently, we only escape slashes, since they are used as a
// separator between attributes.
std::string EscapeCertAttribute(const std::string& attribute) {
std::string escaped;
for (size_t i = 0; i < attribute.size(); ++i) {
if (attribute[i] == '%') {
escaped.append("%25");
} else if (attribute[i] == '/') {
escaped.append("%2F");
} else {
escaped.push_back(attribute[i]);
}
}
return escaped;
}
} // namespace
void RecordCountOfWhitelistedDownload(WhitelistType type) { void RecordCountOfWhitelistedDownload(WhitelistType type) {
UMA_HISTOGRAM_ENUMERATION("SBClientDownload.CheckWhitelistResult", type, UMA_HISTOGRAM_ENUMERATION("SBClientDownload.CheckWhitelistResult", type,
WHITELIST_TYPE_MAX); WHITELIST_TYPE_MAX);
} }
void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings) {
// The whitelist paths are in the format:
// cert/<ascii issuer fingerprint>[/CN=common_name][/O=org][/OU=unit]
//
// Any of CN, O, or OU may be omitted from the whitelist entry, in which
// case they match anything. However, the attributes that do appear will
// always be in the order shown above. At least one attribute will always
// be present.
const net::CertPrincipal& subject = certificate.subject();
std::vector<std::string> ou_tokens;
for (size_t i = 0; i < subject.organization_unit_names.size(); ++i) {
ou_tokens.push_back(
"/OU=" + EscapeCertAttribute(subject.organization_unit_names[i]));
}
std::vector<std::string> o_tokens;
for (size_t i = 0; i < subject.organization_names.size(); ++i) {
o_tokens.push_back("/O=" +
EscapeCertAttribute(subject.organization_names[i]));
}
std::string cn_token;
if (!subject.common_name.empty()) {
cn_token = "/CN=" + EscapeCertAttribute(subject.common_name);
}
std::set<std::string> paths_to_check;
if (!cn_token.empty()) {
paths_to_check.insert(cn_token);
}
for (size_t i = 0; i < o_tokens.size(); ++i) {
paths_to_check.insert(cn_token + o_tokens[i]);
paths_to_check.insert(o_tokens[i]);
for (size_t j = 0; j < ou_tokens.size(); ++j) {
paths_to_check.insert(cn_token + o_tokens[i] + ou_tokens[j]);
paths_to_check.insert(o_tokens[i] + ou_tokens[j]);
}
}
for (size_t i = 0; i < ou_tokens.size(); ++i) {
paths_to_check.insert(cn_token + ou_tokens[i]);
paths_to_check.insert(ou_tokens[i]);
}
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer.cert_buffer())));
std::string issuer_fp = base::HexEncode(hashed.data(), hashed.size());
for (auto it = paths_to_check.begin(); it != paths_to_check.end(); ++it) {
whitelist_strings->push_back("cert/" + issuer_fp + *it);
}
}
} // namespace safe_browsing } // namespace safe_browsing
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/callback_list.h" #include "base/callback_list.h"
#include "components/download/public/common/download_item.h" #include "components/download/public/common/download_item.h"
#include "components/safe_browsing/proto/csd.pb.h" #include "components/safe_browsing/proto/csd.pb.h"
#include "net/cert/x509_certificate.h"
namespace safe_browsing { namespace safe_browsing {
...@@ -115,6 +116,14 @@ typedef std::unique_ptr<PPAPIDownloadRequestCallbackList::Subscription> ...@@ -115,6 +116,14 @@ typedef std::unique_ptr<PPAPIDownloadRequestCallbackList::Subscription>
void RecordCountOfWhitelistedDownload(WhitelistType type); void RecordCountOfWhitelistedDownload(WhitelistType type);
// Given a certificate and its immediate issuer certificate, generates the
// list of strings that need to be checked against the download whitelist to
// determine whether the certificate is whitelisted.
void GetCertificateWhitelistStrings(
const net::X509Certificate& certificate,
const net::X509Certificate& issuer,
std::vector<std::string>* whitelist_strings);
} // namespace safe_browsing } // namespace safe_browsing
#endif // CHROME_BROWSER_SAFE_BROWSING_DOWNLOAD_PROTECTION_DOWNLOAD_PROTECTION_UTIL_H_ #endif // CHROME_BROWSER_SAFE_BROWSING_DOWNLOAD_PROTECTION_DOWNLOAD_PROTECTION_UTIL_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/safe_browsing/download_protection/download_protection_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace safe_browsing {
TEST(DownloadProtectionUtilTest, GetCertificateWhitelistStrings) {
// We'll pass this cert in as the "issuer", even though it isn't really
// used to sign the certs below. GetCertificateWhitelistStirngs doesn't care
// about this.
scoped_refptr<net::X509Certificate> issuer_cert(
ReadTestCertificate("issuer.pem"));
ASSERT_TRUE(issuer_cert.get());
std::string hashed = base::SHA1HashString(std::string(
net::x509_util::CryptoBufferAsStringPiece(issuer_cert->cert_buffer())));
std::string cert_base =
"cert/" + base::HexEncode(hashed.data(), hashed.size());
scoped_refptr<net::X509Certificate> cert(ReadTestCertificate("test_cn.pem"));
ASSERT_TRUE(cert.get());
std::vector<std::string> whitelist_strings;
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
// This also tests escaping of characters in the certificate attributes.
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject%2F%251"));
cert = ReadTestCertificate("test_cn_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/O=org",
cert_base + "/O=org"));
cert = ReadTestCertificate("test_cn_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(
whitelist_strings,
ElementsAre(cert_base + "/CN=subject", cert_base + "/CN=subject/O=org",
cert_base + "/CN=subject/O=org/OU=unit",
cert_base + "/CN=subject/OU=unit", cert_base + "/O=org",
cert_base + "/O=org/OU=unit", cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_cn_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/CN=subject",
cert_base + "/CN=subject/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_o.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/O=org"));
cert = ReadTestCertificate("test_o_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings,
ElementsAre(cert_base + "/O=org", cert_base + "/O=org/OU=unit",
cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_ou.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre(cert_base + "/OU=unit"));
cert = ReadTestCertificate("test_c.pem");
ASSERT_TRUE(cert.get());
whitelist_strings.clear();
GetCertificateWhitelistStrings(*cert.get(), *issuer_cert.get(),
&whitelist_strings);
EXPECT_THAT(whitelist_strings, ElementsAre());
}
} // namespace safe_browsing
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment