Commit 97276a0e authored by shess@chromium.org's avatar shess@chromium.org

[Safe browsing] Clean up code to scan hash results for threats.

Rewrite safe_browsing_util::GetHashIndex() and GetUrlHashIndex() to use
sensible return types rather than transitioning to/from int.  Pull them
into database_manager.cc, which was the only place they were used.

BUG=none

Review URL: https://codereview.chromium.org/337723004

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@277997 0039d316-1c4b-4281-b951-d872f2087c98
parent c583751d
...@@ -68,30 +68,71 @@ bool IsExpectedThreat( ...@@ -68,30 +68,71 @@ bool IsExpectedThreat(
threat_type); threat_type);
} }
// |list_id| is from |safe_browsing_util::ListType|. // Return the list id from the first result in |full_hashes| which matches
SBThreatType GetThreatTypeFromListId(int list_id) { // |hash|, or INVALID if none match.
if (list_id == safe_browsing_util::PHISH) { safe_browsing_util::ListType GetHashThreatListType(
return SB_THREAT_TYPE_URL_PHISHING; const SBFullHash& hash,
const std::vector<SBFullHashResult>& full_hashes) {
for (size_t i = 0; i < full_hashes.size(); ++i) {
if (SBFullHashEqual(hash, full_hashes[i].hash))
return static_cast<safe_browsing_util::ListType>(full_hashes[i].list_id);
} }
return safe_browsing_util::INVALID;
}
if (list_id == safe_browsing_util::MALWARE) { // Given a URL, compare all the possible host + path full hashes to the set of
return SB_THREAT_TYPE_URL_MALWARE; // provided full hashes. Returns the list id of the a matching result from
} // |full_hashes|, or INVALID if none match.
safe_browsing_util::ListType GetUrlThreatListType(
const GURL& url,
const std::vector<SBFullHashResult>& full_hashes) {
if (full_hashes.empty())
return safe_browsing_util::INVALID;
if (list_id == safe_browsing_util::BINURL) { std::vector<std::string> patterns;
return SB_THREAT_TYPE_BINARY_MALWARE_URL; safe_browsing_util::GeneratePatternsToCheck(url, &patterns);
}
if (list_id == safe_browsing_util::EXTENSIONBLACKLIST) { for (size_t i = 0; i < patterns.size(); ++i) {
return SB_THREAT_TYPE_EXTENSION; safe_browsing_util::ListType threat =
GetHashThreatListType(SBFullHashForString(patterns[i]), full_hashes);
if (threat != safe_browsing_util::INVALID)
return threat;
} }
return safe_browsing_util::INVALID;
}
DVLOG(1) << "Unknown safe browsing list id " << list_id; SBThreatType GetThreatTypeFromListType(safe_browsing_util::ListType list_type) {
return SB_THREAT_TYPE_SAFE; switch (list_type) {
case safe_browsing_util::PHISH:
return SB_THREAT_TYPE_URL_PHISHING;
case safe_browsing_util::MALWARE:
return SB_THREAT_TYPE_URL_MALWARE;
case safe_browsing_util::BINURL:
return SB_THREAT_TYPE_BINARY_MALWARE_URL;
case safe_browsing_util::EXTENSIONBLACKLIST:
return SB_THREAT_TYPE_EXTENSION;
default:
DVLOG(1) << "Unknown safe browsing list id " << list_type;
return SB_THREAT_TYPE_SAFE;
}
} }
} // namespace } // namespace
// static
SBThreatType SafeBrowsingDatabaseManager::GetHashThreatType(
const SBFullHash& hash,
const std::vector<SBFullHashResult>& full_hashes) {
return GetThreatTypeFromListType(GetHashThreatListType(hash, full_hashes));
}
// static
SBThreatType SafeBrowsingDatabaseManager::GetUrlThreatType(
const GURL& url,
const std::vector<SBFullHashResult>& full_hashes) {
return GetThreatTypeFromListType(GetUrlThreatListType(url, full_hashes));
}
SafeBrowsingDatabaseManager::SafeBrowsingCheck::SafeBrowsingCheck( SafeBrowsingDatabaseManager::SafeBrowsingCheck::SafeBrowsingCheck(
const std::vector<GURL>& urls, const std::vector<GURL>& urls,
const std::vector<SBFullHash>& full_hashes, const std::vector<SBFullHash>& full_hashes,
...@@ -881,13 +922,19 @@ bool SafeBrowsingDatabaseManager::HandleOneCheck( ...@@ -881,13 +922,19 @@ bool SafeBrowsingDatabaseManager::HandleOneCheck(
bool is_threat = false; bool is_threat = false;
// TODO(shess): GetHashThreadListType() contains a loop,
// GetUrlThreatListType() a loop around that loop. Having another loop out
// here concerns me. It is likely that SAFE is an expected outcome, which
// means all of those loops run to completion. Refactoring this to generate a
// set of sorted items to compare in sequence would probably improve things.
//
// Additionally, the set of patterns generated from the urls is very similar
// to the patterns generated in ContainsBrowseUrl() and other database checks,
// which are called from this code. Refactoring that across the checks could
// interact well with batching the checks here.
for (size_t i = 0; i < check->urls.size(); ++i) { for (size_t i = 0; i < check->urls.size(); ++i) {
int index = SBThreatType threat = GetUrlThreatType(check->urls[i], full_hashes);
safe_browsing_util::GetUrlHashIndex(check->urls[i], full_hashes);
if (index == -1)
continue;
SBThreatType threat =
GetThreatTypeFromListId(full_hashes[index].list_id);
if (threat != SB_THREAT_TYPE_SAFE && if (threat != SB_THREAT_TYPE_SAFE &&
IsExpectedThreat(threat, check->expected_threats)) { IsExpectedThreat(threat, check->expected_threats)) {
check->url_results[i] = threat; check->url_results[i] = threat;
...@@ -896,12 +943,7 @@ bool SafeBrowsingDatabaseManager::HandleOneCheck( ...@@ -896,12 +943,7 @@ bool SafeBrowsingDatabaseManager::HandleOneCheck(
} }
for (size_t i = 0; i < check->full_hashes.size(); ++i) { for (size_t i = 0; i < check->full_hashes.size(); ++i) {
int index = SBThreatType threat = GetHashThreatType(check->full_hashes[i], full_hashes);
safe_browsing_util::GetHashIndex(check->full_hashes[i], full_hashes);
if (index == -1)
continue;
SBThreatType threat =
GetThreatTypeFromListId(full_hashes[index].list_id);
if (threat != SB_THREAT_TYPE_SAFE && if (threat != SB_THREAT_TYPE_SAFE &&
IsExpectedThreat(threat, check->expected_threats)) { IsExpectedThreat(threat, check->expected_threats)) {
check->full_hash_results[i] = threat; check->full_hash_results[i] = threat;
......
...@@ -214,6 +214,7 @@ class SafeBrowsingDatabaseManager ...@@ -214,6 +214,7 @@ class SafeBrowsingDatabaseManager
friend class SafeBrowsingServiceTest; friend class SafeBrowsingServiceTest;
friend class SafeBrowsingServiceTestHelper; friend class SafeBrowsingServiceTestHelper;
friend class SafeBrowsingDatabaseManagerTest; friend class SafeBrowsingDatabaseManagerTest;
FRIEND_TEST_ALL_PREFIXES(SafeBrowsingDatabaseManagerTest, GetUrlThreatType);
typedef std::set<SafeBrowsingCheck*> CurrentChecks; typedef std::set<SafeBrowsingCheck*> CurrentChecks;
typedef std::vector<SafeBrowsingCheck*> GetHashRequestors; typedef std::vector<SafeBrowsingCheck*> GetHashRequestors;
...@@ -234,6 +235,19 @@ class SafeBrowsingDatabaseManager ...@@ -234,6 +235,19 @@ class SafeBrowsingDatabaseManager
base::TimeTicks start; // When check was queued. base::TimeTicks start; // When check was queued.
}; };
// Return the threat type from the first result in |full_hashes| which matches
// |hash|, or SAFE if none match.
static SBThreatType GetHashThreatType(
const SBFullHash& hash,
const std::vector<SBFullHashResult>& full_hashes);
// Given a URL, compare all the possible host + path full hashes to the set of
// provided full hashes. Returns the threat type of the matching result from
// |full_hashes|, or SAFE if none match.
static SBThreatType GetUrlThreatType(
const GURL& url,
const std::vector<SBFullHashResult>& full_hashes);
// Called to stop operations on the io_thread. This may be called multiple // Called to stop operations on the io_thread. This may be called multiple
// times during the life of the DatabaseManager. Should be called on IO // times during the life of the DatabaseManager. Should be called on IO
// thread. // thread.
......
...@@ -78,3 +78,53 @@ TEST_F(SafeBrowsingDatabaseManagerTest, CheckCorrespondsListType) { ...@@ -78,3 +78,53 @@ TEST_F(SafeBrowsingDatabaseManagerTest, CheckCorrespondsListType) {
multiple_threats, multiple_threats,
safe_browsing_util::kMalwareList)); safe_browsing_util::kMalwareList));
} }
TEST_F(SafeBrowsingDatabaseManagerTest, GetUrlThreatType) {
std::vector<SBFullHashResult> full_hashes;
const GURL kMalwareUrl("http://www.malware.com/page.html");
const GURL kPhishingUrl("http://www.phishing.com/page.html");
const GURL kSafeUrl("http://www.safe.com/page.html");
const SBFullHash kMalwareHostHash = SBFullHashForString("malware.com/");
const SBFullHash kPhishingHostHash = SBFullHashForString("phishing.com/");
const SBFullHash kSafeHostHash = SBFullHashForString("www.safe.com/");
{
SBFullHashResult full_hash;
full_hash.hash = kMalwareHostHash;
full_hash.list_id = static_cast<int>(safe_browsing_util::MALWARE);
full_hashes.push_back(full_hash);
}
{
SBFullHashResult full_hash;
full_hash.hash = kPhishingHostHash;
full_hash.list_id = static_cast<int>(safe_browsing_util::PHISH);
full_hashes.push_back(full_hash);
}
EXPECT_EQ(SB_THREAT_TYPE_URL_MALWARE,
SafeBrowsingDatabaseManager::GetHashThreatType(
kMalwareHostHash, full_hashes));
EXPECT_EQ(SB_THREAT_TYPE_URL_PHISHING,
SafeBrowsingDatabaseManager::GetHashThreatType(
kPhishingHostHash, full_hashes));
EXPECT_EQ(SB_THREAT_TYPE_SAFE,
SafeBrowsingDatabaseManager::GetHashThreatType(
kSafeHostHash, full_hashes));
EXPECT_EQ(SB_THREAT_TYPE_URL_MALWARE,
SafeBrowsingDatabaseManager::GetUrlThreatType(
kMalwareUrl, full_hashes));
EXPECT_EQ(SB_THREAT_TYPE_URL_PHISHING,
SafeBrowsingDatabaseManager::GetUrlThreatType(
kPhishingUrl, full_hashes));
EXPECT_EQ(SB_THREAT_TYPE_SAFE,
SafeBrowsingDatabaseManager::GetUrlThreatType(
kSafeUrl, full_hashes));
}
...@@ -465,32 +465,6 @@ void GeneratePatternsToCheck(const GURL& url, std::vector<std::string>* urls) { ...@@ -465,32 +465,6 @@ void GeneratePatternsToCheck(const GURL& url, std::vector<std::string>* urls) {
} }
} }
int GetHashIndex(const SBFullHash& hash,
const std::vector<SBFullHashResult>& full_hashes) {
for (size_t i = 0; i < full_hashes.size(); ++i) {
if (SBFullHashEqual(hash, full_hashes[i].hash))
return static_cast<int>(i);
}
return -1;
}
int GetUrlHashIndex(const GURL& url,
const std::vector<SBFullHashResult>& full_hashes) {
if (full_hashes.empty())
return -1;
std::vector<std::string> patterns;
GeneratePatternsToCheck(url, &patterns);
for (size_t i = 0; i < patterns.size(); ++i) {
SBFullHash key = SBFullHashForString(patterns[i]);
int index = GetHashIndex(key, full_hashes);
if (index != -1)
return index;
}
return -1;
}
GURL GeneratePhishingReportUrl(const std::string& report_page, GURL GeneratePhishingReportUrl(const std::string& report_page,
const std::string& url_to_report, const std::string& url_to_report,
bool is_client_side_detection) { bool is_client_side_detection) {
......
...@@ -202,15 +202,6 @@ void GeneratePathsToCheck(const GURL& url, std::vector<std::string>* paths); ...@@ -202,15 +202,6 @@ void GeneratePathsToCheck(const GURL& url, std::vector<std::string>* paths);
// Given a URL, returns all the patterns we need to check. // Given a URL, returns all the patterns we need to check.
void GeneratePatternsToCheck(const GURL& url, std::vector<std::string>* urls); void GeneratePatternsToCheck(const GURL& url, std::vector<std::string>* urls);
int GetHashIndex(const SBFullHash& hash,
const std::vector<SBFullHashResult>& full_hashes);
// Given a URL, compare all the possible host + path full hashes to the set of
// provided full hashes. Returns the index of the match if one is found, or -1
// otherwise.
int GetUrlHashIndex(const GURL& url,
const std::vector<SBFullHashResult>& full_hashes);
GURL GeneratePhishingReportUrl(const std::string& report_page, GURL GeneratePhishingReportUrl(const std::string& report_page,
const std::string& url_to_report, const std::string& url_to_report,
bool is_client_side_detection); bool is_client_side_detection);
......
...@@ -277,19 +277,6 @@ TEST(SafeBrowsingUtilTest, CanonicalizeUrl) { ...@@ -277,19 +277,6 @@ TEST(SafeBrowsingUtilTest, CanonicalizeUrl) {
} }
} }
TEST(SafeBrowsingUtilTest, GetUrlHashIndex) {
GURL url("http://www.evil.com/phish.html");
SBFullHashResult full_hash;
full_hash.hash = SBFullHashForString(url.host() + url.path());
std::vector<SBFullHashResult> full_hashes;
full_hashes.push_back(full_hash);
EXPECT_EQ(safe_browsing_util::GetUrlHashIndex(url, full_hashes), 0);
url = GURL("http://www.evil.com/okay_path.html");
EXPECT_EQ(safe_browsing_util::GetUrlHashIndex(url, full_hashes), -1);
}
TEST(SafeBrowsingUtilTest, ListIdListNameConversion) { TEST(SafeBrowsingUtilTest, ListIdListNameConversion) {
std::string list_name; std::string list_name;
EXPECT_FALSE(safe_browsing_util::GetListName(safe_browsing_util::INVALID, EXPECT_FALSE(safe_browsing_util::GetListName(safe_browsing_util::INVALID,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment