Commit cbd7db17 authored by cfredric's avatar cfredric Committed by Commit Bot

Use SchemefulSite instead of std::strings in PreloadedFirstPartySets.

This CL also adds a PrintTo(SchemefulSite, ostream) function overload,
to allow tests to produce better error messages.

Bug: 1143756
Change-Id: I90e594eb5c355deb28344f75c1aa5e91402b2044
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2519753
Commit-Queue: Chris Fredrickson <cfredric@chromium.org>
Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Reviewed-by: default avatarJohn Delaney <johnidel@chromium.org>
Reviewed-by: default avatarEric Orth <ericorth@chromium.org>
Reviewed-by: default avatarLily Chen <chlily@chromium.org>
Cr-Commit-Position: refs/heads/master@{#825063}
parent fa8a889b
...@@ -12,12 +12,16 @@ ...@@ -12,12 +12,16 @@
namespace net { namespace net {
namespace { // Return a tuple containing:
// * a new origin using the registerable domain of `origin` if possible and
// Return a new origin using the registerable domain of `origin` if possible and // a port of 0; otherwise, the passed-in origin.
// a port of 0. Otherwise, returns the passed in origin. Follows steps specified // * a bool indicating whether `origin` had a non-null registerable domain.
// in https://html.spec.whatwg.org/multipage/origin.html#obtain-a-site // (False if `origin` was opaque.)
url::Origin SwitchToRegistrableDomainAndRemovePort(const url::Origin& origin) { //
// Follows steps specified in
// https://html.spec.whatwg.org/multipage/origin.html#obtain-a-site
SchemefulSite::ObtainASiteResult SchemefulSite::ObtainASite(
const url::Origin& origin) {
// There is currently no reason for getting the schemeful site of a web // There is currently no reason for getting the schemeful site of a web
// socket, so disallow passing in websocket origins. // socket, so disallow passing in websocket origins.
DCHECK_NE(origin.scheme(), url::kWsScheme); DCHECK_NE(origin.scheme(), url::kWsScheme);
...@@ -25,7 +29,7 @@ url::Origin SwitchToRegistrableDomainAndRemovePort(const url::Origin& origin) { ...@@ -25,7 +29,7 @@ url::Origin SwitchToRegistrableDomainAndRemovePort(const url::Origin& origin) {
// 1. If origin is an opaque origin, then return origin. // 1. If origin is an opaque origin, then return origin.
if (origin.opaque()) if (origin.opaque())
return origin; return {origin, false /* used_registerable_domain */};
std::string registerable_domain; std::string registerable_domain;
...@@ -48,7 +52,8 @@ url::Origin SwitchToRegistrableDomainAndRemovePort(const url::Origin& origin) { ...@@ -48,7 +52,8 @@ url::Origin SwitchToRegistrableDomainAndRemovePort(const url::Origin& origin) {
// //
// Note that `registerable_domain` could still end up empty, since the // Note that `registerable_domain` could still end up empty, since the
// `origin` might have a scheme that permits empty hostnames, such as "file". // `origin` might have a scheme that permits empty hostnames, such as "file".
if (registerable_domain.empty()) bool used_registerable_domain = !registerable_domain.empty();
if (!used_registerable_domain)
registerable_domain = origin.host(); registerable_domain = origin.host();
int port = url::DefaultPortForScheme(origin.scheme().c_str(), int port = url::DefaultPortForScheme(origin.scheme().c_str(),
...@@ -58,18 +63,19 @@ url::Origin SwitchToRegistrableDomainAndRemovePort(const url::Origin& origin) { ...@@ -58,18 +63,19 @@ url::Origin SwitchToRegistrableDomainAndRemovePort(const url::Origin& origin) {
if (port == url::PORT_UNSPECIFIED) if (port == url::PORT_UNSPECIFIED)
port = 0; port = 0;
return url::Origin::CreateFromNormalizedTuple(origin.scheme(), return {url::Origin::CreateFromNormalizedTuple(origin.scheme(),
registerable_domain, port); registerable_domain, port),
used_registerable_domain};
} }
} // namespace SchemefulSite::SchemefulSite(ObtainASiteResult result)
: site_as_origin_(std::move(result.origin)) {}
SchemefulSite::SchemefulSite(const url::Origin& origin) SchemefulSite::SchemefulSite(const url::Origin& origin)
: site_as_origin_(SwitchToRegistrableDomainAndRemovePort(origin)) {} : SchemefulSite(ObtainASite(origin)) {}
SchemefulSite::SchemefulSite(const GURL& url) SchemefulSite::SchemefulSite(const GURL& url)
: site_as_origin_( : SchemefulSite(url::Origin::Create(url)) {}
SwitchToRegistrableDomainAndRemovePort(url::Origin::Create(url))) {}
SchemefulSite::SchemefulSite(const SchemefulSite& other) = default; SchemefulSite::SchemefulSite(const SchemefulSite& other) = default;
SchemefulSite::SchemefulSite(SchemefulSite&& other) = default; SchemefulSite::SchemefulSite(SchemefulSite&& other) = default;
...@@ -77,6 +83,14 @@ SchemefulSite::SchemefulSite(SchemefulSite&& other) = default; ...@@ -77,6 +83,14 @@ SchemefulSite::SchemefulSite(SchemefulSite&& other) = default;
SchemefulSite& SchemefulSite::operator=(const SchemefulSite& other) = default; SchemefulSite& SchemefulSite::operator=(const SchemefulSite& other) = default;
SchemefulSite& SchemefulSite::operator=(SchemefulSite&& other) = default; SchemefulSite& SchemefulSite::operator=(SchemefulSite&& other) = default;
base::Optional<SchemefulSite> SchemefulSite::CreateIfHasRegisterableDomain(
const url::Origin& origin) {
ObtainASiteResult result = ObtainASite(origin);
if (!result.used_registerable_domain)
return base::nullopt;
return SchemefulSite(std::move(result));
}
// static // static
SchemefulSite SchemefulSite::Deserialize(const std::string& value) { SchemefulSite SchemefulSite::Deserialize(const std::string& value) {
return SchemefulSite(GURL(value)); return SchemefulSite(GURL(value));
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifndef NET_BASE_SCHEMEFUL_SITE_H_ #ifndef NET_BASE_SCHEMEFUL_SITE_H_
#define NET_BASE_SCHEMEFUL_SITE_H_ #define NET_BASE_SCHEMEFUL_SITE_H_
#include <ostream>
#include <string> #include <string>
#include "base/gtest_prod_util.h" #include "base/gtest_prod_util.h"
...@@ -58,6 +59,10 @@ class NET_EXPORT SchemefulSite { ...@@ -58,6 +59,10 @@ class NET_EXPORT SchemefulSite {
SchemefulSite& operator=(const SchemefulSite& other); SchemefulSite& operator=(const SchemefulSite& other);
SchemefulSite& operator=(SchemefulSite&& other); SchemefulSite& operator=(SchemefulSite&& other);
// Creates a SchemefulSite iff the passed-in origin has a registerable domain.
static base::Optional<SchemefulSite> CreateIfHasRegisterableDomain(
const url::Origin&);
// Deserializes a string obtained from `Serialize()` to a `SchemefulSite`. // Deserializes a string obtained from `Serialize()` to a `SchemefulSite`.
// Returns an opaque `SchemefulSite` if the value was invalid in any way. // Returns an opaque `SchemefulSite` if the value was invalid in any way.
static SchemefulSite Deserialize(const std::string& value); static SchemefulSite Deserialize(const std::string& value);
...@@ -88,6 +93,15 @@ class NET_EXPORT SchemefulSite { ...@@ -88,6 +93,15 @@ class NET_EXPORT SchemefulSite {
FRIEND_TEST_ALL_PREFIXES(SchemefulSiteTest, OpaqueSerialization); FRIEND_TEST_ALL_PREFIXES(SchemefulSiteTest, OpaqueSerialization);
struct ObtainASiteResult {
url::Origin origin;
bool used_registerable_domain;
};
static ObtainASiteResult ObtainASite(const url::Origin&);
explicit SchemefulSite(ObtainASiteResult);
// Deserializes a string obtained from `SerializeWithNonce()` to a // Deserializes a string obtained from `SerializeWithNonce()` to a
// `SchemefulSite`. Returns nullopt if the value was invalid in any way. // `SchemefulSite`. Returns nullopt if the value was invalid in any way.
static base::Optional<SchemefulSite> DeserializeWithNonce( static base::Optional<SchemefulSite> DeserializeWithNonce(
...@@ -114,6 +128,12 @@ class NET_EXPORT SchemefulSite { ...@@ -114,6 +128,12 @@ class NET_EXPORT SchemefulSite {
url::Origin site_as_origin_; url::Origin site_as_origin_;
}; };
// Provided to allow gtest to create more helpful error messages, instead of
// printing hex.
inline void PrintTo(const SchemefulSite& ss, std::ostream* os) {
*os << ss.Serialize();
}
} // namespace net } // namespace net
#endif // NET_BASE_SCHEMEFUL_SITE_H_ #endif // NET_BASE_SCHEMEFUL_SITE_H_
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "net/base/schemeful_site.h" #include "net/base/schemeful_site.h"
#include "testing/gmock/include/gmock/gmock-matchers.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h" #include "url/gurl.h"
#include "url/origin.h" #include "url/origin.h"
...@@ -186,4 +187,33 @@ TEST(SchemefulSiteTest, OpaqueSerialization) { ...@@ -186,4 +187,33 @@ TEST(SchemefulSiteTest, OpaqueSerialization) {
} }
} }
TEST(SchemefulSiteTest, CreateIfHasRegisterableDomain) {
for (const auto& site : std::initializer_list<std::string>{
"http://a.bar.test",
"http://c.test",
"http://a.foo.test",
"https://a.bar.test",
"https://c.test",
"https://a.foo.test",
}) {
url::Origin origin = url::Origin::Create(GURL(site));
EXPECT_THAT(SchemefulSite::CreateIfHasRegisterableDomain(origin),
testing::Optional(SchemefulSite(origin)))
<< "site = \"" << site << "\"";
}
for (const auto& site : std::initializer_list<std::string>{
"data:text/html,<body>Hello World</body>",
"file:///",
"file://foo",
"http://127.0.0.1:1234",
"https://127.0.0.1:1234",
}) {
url::Origin origin = url::Origin::Create(GURL(site));
EXPECT_EQ(SchemefulSite::CreateIfHasRegisterableDomain(origin),
base::nullopt)
<< "site = \"" << site << "\"";
}
}
} // namespace net } // namespace net
...@@ -88,6 +88,7 @@ fuzzer_test("first_party_set_parser_fuzzer") { ...@@ -88,6 +88,7 @@ fuzzer_test("first_party_set_parser_fuzzer") {
":first_party_sets", ":first_party_sets",
"//base", "//base",
"//base:i18n", "//base:i18n",
"//net:net",
"//net:net_fuzzer_test_support", "//net:net_fuzzer_test_support",
] ]
dict = "test/first_party_set_parser_fuzzer.dict" dict = "test/first_party_set_parser_fuzzer.dict"
...@@ -100,6 +101,7 @@ fuzzer_test("first_party_set_parser_json_fuzzer") { ...@@ -100,6 +101,7 @@ fuzzer_test("first_party_set_parser_json_fuzzer") {
":first_party_sets", ":first_party_sets",
"//base", "//base",
"//base:i18n", "//base:i18n",
"//net:net",
"//testing/libfuzzer/proto:json_proto", "//testing/libfuzzer/proto:json_proto",
"//testing/libfuzzer/proto:json_proto_converter", "//testing/libfuzzer/proto:json_proto_converter",
"//third_party/libprotobuf-mutator", "//third_party/libprotobuf-mutator",
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "base/logging.h" #include "base/logging.h"
#include "base/optional.h" #include "base/optional.h"
#include "net/base/registry_controlled_domains/registry_controlled_domain.h" #include "net/base/registry_controlled_domains/registry_controlled_domain.h"
#include "net/base/schemeful_site.h"
#include "url/gurl.h" #include "url/gurl.h"
#include "url/origin.h" #include "url/origin.h"
...@@ -24,8 +25,8 @@ namespace { ...@@ -24,8 +25,8 @@ namespace {
// Ensures that the string represents an origin that is non-opaque and HTTPS. // Ensures that the string represents an origin that is non-opaque and HTTPS.
// Returns the registered domain. // Returns the registered domain.
base::Optional<std::string> Canonicalize(const base::StringPiece origin_string, base::Optional<net::SchemefulSite> Canonicalize(base::StringPiece origin_string,
bool emit_errors) { bool emit_errors) {
const url::Origin origin(url::Origin::Create(GURL(origin_string))); const url::Origin origin(url::Origin::Create(GURL(origin_string)));
if (origin.opaque()) { if (origin.opaque()) {
if (emit_errors) { if (emit_errors) {
...@@ -41,11 +42,9 @@ base::Optional<std::string> Canonicalize(const base::StringPiece origin_string, ...@@ -41,11 +42,9 @@ base::Optional<std::string> Canonicalize(const base::StringPiece origin_string,
} }
return base::nullopt; return base::nullopt;
} }
const std::string domain_and_registry = base::Optional<net::SchemefulSite> site =
net::registry_controlled_domains::GetDomainAndRegistry( net::SchemefulSite::CreateIfHasRegisterableDomain(origin);
origin, net::registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES); if (!site.has_value()) {
if (domain_and_registry.empty()) {
if (emit_errors) { if (emit_errors) {
LOG(ERROR) << "First-Party Set origin" << origin_string LOG(ERROR) << "First-Party Set origin" << origin_string
<< " does not have a valid registered domain; ignoring."; << " does not have a valid registered domain; ignoring.";
...@@ -53,7 +52,7 @@ base::Optional<std::string> Canonicalize(const base::StringPiece origin_string, ...@@ -53,7 +52,7 @@ base::Optional<std::string> Canonicalize(const base::StringPiece origin_string,
return base::nullopt; return base::nullopt;
} }
return domain_and_registry; return site;
} }
const char kFirstPartySetOwnerField[] = "owner"; const char kFirstPartySetOwnerField[] = "owner";
...@@ -70,9 +69,10 @@ const char kFirstPartySetMembersField[] = "members"; ...@@ -70,9 +69,10 @@ const char kFirstPartySetMembersField[] = "members";
// and augments `elements` to include the elements of the set that was parsed. // and augments `elements` to include the elements of the set that was parsed.
// //
// Returns true if parsing and validation were successful, false otherwise. // Returns true if parsing and validation were successful, false otherwise.
bool ParsePreloadedSet(const base::Value& value, bool ParsePreloadedSet(
base::flat_map<std::string, std::string>& map, const base::Value& value,
base::flat_set<std::string>& elements) { base::flat_map<net::SchemefulSite, net::SchemefulSite>& map,
base::flat_set<net::SchemefulSite>& elements) {
if (!value.is_dict()) if (!value.is_dict())
return false; return false;
...@@ -82,7 +82,7 @@ bool ParsePreloadedSet(const base::Value& value, ...@@ -82,7 +82,7 @@ bool ParsePreloadedSet(const base::Value& value,
if (!maybe_owner) if (!maybe_owner)
return false; return false;
base::Optional<std::string> canonical_owner = base::Optional<net::SchemefulSite> canonical_owner =
Canonicalize(std::move(*maybe_owner), false /* emit_errors */); Canonicalize(std::move(*maybe_owner), false /* emit_errors */);
if (!canonical_owner.has_value()) if (!canonical_owner.has_value())
return false; return false;
...@@ -105,7 +105,7 @@ bool ParsePreloadedSet(const base::Value& value, ...@@ -105,7 +105,7 @@ bool ParsePreloadedSet(const base::Value& value,
// another set. // another set.
if (!item.is_string()) if (!item.is_string())
return false; return false;
base::Optional<std::string> member = base::Optional<net::SchemefulSite> member =
Canonicalize(item.GetString(), false /* emit_errors */); Canonicalize(item.GetString(), false /* emit_errors */);
if (!member.has_value() || elements.contains(*member)) if (!member.has_value() || elements.contains(*member))
return false; return false;
...@@ -117,13 +117,14 @@ bool ParsePreloadedSet(const base::Value& value, ...@@ -117,13 +117,14 @@ bool ParsePreloadedSet(const base::Value& value,
} // namespace } // namespace
base::Optional<std::string> FirstPartySetParser::CanonicalizeRegisteredDomain( base::Optional<net::SchemefulSite>
FirstPartySetParser::CanonicalizeRegisteredDomain(
const base::StringPiece origin_string, const base::StringPiece origin_string,
bool emit_errors) { bool emit_errors) {
return Canonicalize(origin_string, emit_errors); return Canonicalize(origin_string, emit_errors);
} }
std::unique_ptr<base::flat_map<std::string, std::string>> std::unique_ptr<base::flat_map<net::SchemefulSite, net::SchemefulSite>>
FirstPartySetParser::ParsePreloadedSets(base::StringPiece raw_sets) { FirstPartySetParser::ParsePreloadedSets(base::StringPiece raw_sets) {
base::Optional<base::Value> maybe_value = base::JSONReader::Read( base::Optional<base::Value> maybe_value = base::JSONReader::Read(
raw_sets, base::JSONParserOptions::JSON_ALLOW_TRAILING_COMMAS); raw_sets, base::JSONParserOptions::JSON_ALLOW_TRAILING_COMMAS);
...@@ -132,8 +133,9 @@ FirstPartySetParser::ParsePreloadedSets(base::StringPiece raw_sets) { ...@@ -132,8 +133,9 @@ FirstPartySetParser::ParsePreloadedSets(base::StringPiece raw_sets) {
if (!maybe_value->is_list()) if (!maybe_value->is_list())
return nullptr; return nullptr;
auto map = std::make_unique<base::flat_map<std::string, std::string>>(); auto map = std::make_unique<
base::flat_set<std::string> elements; base::flat_map<net::SchemefulSite, net::SchemefulSite>>();
base::flat_set<net::SchemefulSite> elements;
for (const auto& value : maybe_value->GetList()) { for (const auto& value : maybe_value->GetList()) {
if (!ParsePreloadedSet(value, *map, elements)) if (!ParsePreloadedSet(value, *map, elements))
return nullptr; return nullptr;
......
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
#include "base/strings/string_piece_forward.h" #include "base/strings/string_piece_forward.h"
#include "base/values.h" #include "base/values.h"
namespace net {
class SchemefulSite;
}
namespace network { namespace network {
class FirstPartySetParser { class FirstPartySetParser {
...@@ -32,13 +36,13 @@ class FirstPartySetParser { ...@@ -32,13 +36,13 @@ class FirstPartySetParser {
// only for *preloaded* sets. // only for *preloaded* sets.
// //
// Returns nullptr if parsing or validation of any set failed. // Returns nullptr if parsing or validation of any set failed.
static std::unique_ptr<base::flat_map<std::string, std::string>> static std::unique_ptr<base::flat_map<net::SchemefulSite, net::SchemefulSite>>
ParsePreloadedSets(base::StringPiece raw_sets); ParsePreloadedSets(base::StringPiece raw_sets);
// Canonicalizes the passed in origin to a registered domain. In particular, // Canonicalizes the passed in origin to a registered domain. In particular,
// this ensures that the origin is non-opaque, is HTTPS, and has a registered // this ensures that the origin is non-opaque, is HTTPS, and has a registered
// domain. Returns base::nullopt in case of any error. // domain. Returns base::nullopt in case of any error.
static base::Optional<std::string> CanonicalizeRegisteredDomain( static base::Optional<net::SchemefulSite> CanonicalizeRegisteredDomain(
const base::StringPiece origin_string, const base::StringPiece origin_string,
bool emit_errors); bool emit_errors);
}; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "services/network/first_party_sets/first_party_set_parser.h" #include "services/network/first_party_sets/first_party_set_parser.h"
#include "base/json/json_reader.h" #include "base/json/json_reader.h"
#include "net/base/schemeful_site.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -16,6 +17,11 @@ using ::testing::UnorderedElementsAre; ...@@ -16,6 +17,11 @@ using ::testing::UnorderedElementsAre;
namespace network { namespace network {
MATCHER_P(SerializesTo, want, "") {
const std::string got = arg.Serialize();
return testing::ExplainMatchResult(testing::Eq(want), got, result_listener);
}
TEST(FirstPartySetParser_Preloaded, RejectsEmpty) { TEST(FirstPartySetParser_Preloaded, RejectsEmpty) {
// If the input isn't valid JSON, we should // If the input isn't valid JSON, we should
// reject it. In particular, we should reject // reject it. In particular, we should reject
...@@ -64,7 +70,9 @@ TEST(FirstPartySetParser, AcceptsMinimal) { ...@@ -64,7 +70,9 @@ TEST(FirstPartySetParser, AcceptsMinimal) {
ASSERT_TRUE(base::JSONReader::Read(input)); ASSERT_TRUE(base::JSONReader::Read(input));
EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input), EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input),
Pointee(UnorderedElementsAre(Pair("aaaa.test", "example.test")))); Pointee(UnorderedElementsAre(
Pair(SerializesTo("https://aaaa.test"),
SerializesTo("https://example.test")))));
} }
TEST(FirstPartySetParser, RejectsMissingOwner) { TEST(FirstPartySetParser, RejectsMissingOwner) {
...@@ -197,7 +205,9 @@ TEST(FirstPartySetParser, TruncatesSubdomain_Owner) { ...@@ -197,7 +205,9 @@ TEST(FirstPartySetParser, TruncatesSubdomain_Owner) {
ASSERT_TRUE(base::JSONReader::Read(input)); ASSERT_TRUE(base::JSONReader::Read(input));
EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input), EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input),
Pointee(UnorderedElementsAre(Pair("aaaa.test", "example.test")))); Pointee(UnorderedElementsAre(
Pair(SerializesTo("https://aaaa.test"),
SerializesTo("https://example.test")))));
} }
TEST(FirstPartySetParser, TruncatesSubdomain_Member) { TEST(FirstPartySetParser, TruncatesSubdomain_Member) {
...@@ -211,7 +221,9 @@ TEST(FirstPartySetParser, TruncatesSubdomain_Member) { ...@@ -211,7 +221,9 @@ TEST(FirstPartySetParser, TruncatesSubdomain_Member) {
ASSERT_TRUE(base::JSONReader::Read(input)); ASSERT_TRUE(base::JSONReader::Read(input));
EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input), EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input),
Pointee(UnorderedElementsAre(Pair("aaaa.test", "example.test")))); Pointee(UnorderedElementsAre(
Pair(SerializesTo("https://aaaa.test"),
SerializesTo("https://example.test")))));
} }
TEST(FirstPartySetParser, AcceptsMultipleSets) { TEST(FirstPartySetParser, AcceptsMultipleSets) {
...@@ -231,9 +243,12 @@ TEST(FirstPartySetParser, AcceptsMultipleSets) { ...@@ -231,9 +243,12 @@ TEST(FirstPartySetParser, AcceptsMultipleSets) {
// Sanity check that the input is actually valid JSON. // Sanity check that the input is actually valid JSON.
ASSERT_TRUE(base::JSONReader::Read(input)); ASSERT_TRUE(base::JSONReader::Read(input));
EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input), EXPECT_THAT(
Pointee(UnorderedElementsAre(Pair("member1.test", "example.test"), FirstPartySetParser::ParsePreloadedSets(input),
Pair("member2.test", "foo.test")))); Pointee(UnorderedElementsAre(Pair(SerializesTo("https://member1.test"),
SerializesTo("https://example.test")),
Pair(SerializesTo("https://member2.test"),
SerializesTo("https://foo.test")))));
} }
TEST(FirstPartySetParser, RejectsInvalidSets_InvalidOwner) { TEST(FirstPartySetParser, RejectsInvalidSets_InvalidOwner) {
...@@ -290,9 +305,10 @@ TEST(FirstPartySetParser, AllowsTrailingCommas) { ...@@ -290,9 +305,10 @@ TEST(FirstPartySetParser, AllowsTrailingCommas) {
ASSERT_TRUE(base::JSONReader::Read( ASSERT_TRUE(base::JSONReader::Read(
input, base::JSONParserOptions::JSON_ALLOW_TRAILING_COMMAS)); input, base::JSONParserOptions::JSON_ALLOW_TRAILING_COMMAS));
EXPECT_THAT( EXPECT_THAT(FirstPartySetParser::ParsePreloadedSets(input),
FirstPartySetParser::ParsePreloadedSets(input), Pointee(UnorderedElementsAre(
Pointee(UnorderedElementsAre(Pair("member1.test", "example.test")))); Pair(SerializesTo("https://member1.test"),
SerializesTo("https://example.test")))));
} }
TEST(FirstPartySetParser, Rejects_SameOwner) { TEST(FirstPartySetParser, Rejects_SameOwner) {
......
...@@ -10,18 +10,20 @@ ...@@ -10,18 +10,20 @@
#include "base/optional.h" #include "base/optional.h"
#include "base/ranges/algorithm.h" #include "base/ranges/algorithm.h"
#include "base/strings/string_split.h" #include "base/strings/string_split.h"
#include "net/base/schemeful_site.h"
#include "services/network/first_party_sets/first_party_set_parser.h" #include "services/network/first_party_sets/first_party_set_parser.h"
namespace network { namespace network {
namespace { namespace {
base::Optional<std::pair<std::string, base::flat_set<std::string>>> base::Optional<
std::pair<net::SchemefulSite, base::flat_set<net::SchemefulSite>>>
CanonicalizeSet(const std::vector<std::string>& origins) { CanonicalizeSet(const std::vector<std::string>& origins) {
if (origins.empty()) if (origins.empty())
return base::nullopt; return base::nullopt;
const base::Optional<std::string> maybe_owner = const base::Optional<net::SchemefulSite> maybe_owner =
FirstPartySetParser::CanonicalizeRegisteredDomain(origins[0], FirstPartySetParser::CanonicalizeRegisteredDomain(origins[0],
true /* emit_errors */); true /* emit_errors */);
if (!maybe_owner.has_value()) { if (!maybe_owner.has_value()) {
...@@ -29,10 +31,10 @@ CanonicalizeSet(const std::vector<std::string>& origins) { ...@@ -29,10 +31,10 @@ CanonicalizeSet(const std::vector<std::string>& origins) {
return base::nullopt; return base::nullopt;
} }
const std::string& owner = *maybe_owner; const net::SchemefulSite& owner = *maybe_owner;
base::flat_set<std::string> members; base::flat_set<net::SchemefulSite> members;
for (auto it = origins.begin() + 1; it != origins.end(); ++it) { for (auto it = origins.begin() + 1; it != origins.end(); ++it) {
const base::Optional<std::string> maybe_member = const base::Optional<net::SchemefulSite> maybe_member =
FirstPartySetParser::CanonicalizeRegisteredDomain( FirstPartySetParser::CanonicalizeRegisteredDomain(
*it, true /* emit_errors */); *it, true /* emit_errors */);
if (maybe_member.has_value() && maybe_member != owner) if (maybe_member.has_value() && maybe_member != owner)
...@@ -62,10 +64,10 @@ void PreloadedFirstPartySets::SetManuallySpecifiedSet( ...@@ -62,10 +64,10 @@ void PreloadedFirstPartySets::SetManuallySpecifiedSet(
ApplyManuallySpecifiedSet(); ApplyManuallySpecifiedSet();
} }
base::flat_map<std::string, std::string>* PreloadedFirstPartySets::ParseAndSet( base::flat_map<net::SchemefulSite, net::SchemefulSite>*
base::StringPiece raw_sets) { PreloadedFirstPartySets::ParseAndSet(base::StringPiece raw_sets) {
std::unique_ptr<base::flat_map<std::string, std::string>> parsed = std::unique_ptr<base::flat_map<net::SchemefulSite, net::SchemefulSite>>
FirstPartySetParser::ParsePreloadedSets(raw_sets); parsed = FirstPartySetParser::ParsePreloadedSets(raw_sets);
if (parsed) { if (parsed) {
sets_.swap(*parsed); sets_.swap(*parsed);
} else { } else {
...@@ -81,8 +83,8 @@ void PreloadedFirstPartySets::ApplyManuallySpecifiedSet() { ...@@ -81,8 +83,8 @@ void PreloadedFirstPartySets::ApplyManuallySpecifiedSet() {
if (!manually_specified_set_) if (!manually_specified_set_)
return; return;
const std::string& manual_owner = manually_specified_set_->first; const net::SchemefulSite& manual_owner = manually_specified_set_->first;
const base::flat_set<std::string>& manual_members = const base::flat_set<net::SchemefulSite>& manual_members =
manually_specified_set_->second; manually_specified_set_->second;
sets_.erase( sets_.erase(
...@@ -96,7 +98,7 @@ void PreloadedFirstPartySets::ApplyManuallySpecifiedSet() { ...@@ -96,7 +98,7 @@ void PreloadedFirstPartySets::ApplyManuallySpecifiedSet() {
sets_.end()); sets_.end());
// Next, we must add the manually-added set to the parsed value. // Next, we must add the manually-added set to the parsed value.
for (const std::string& member : manual_members) { for (const net::SchemefulSite& member : manual_members) {
sets_.emplace(member, manual_owner); sets_.emplace(member, manual_owner);
} }
} }
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/containers/flat_map.h" #include "base/containers/flat_map.h"
#include "base/containers/flat_set.h" #include "base/containers/flat_set.h"
#include "net/base/schemeful_site.h"
#include "services/network/first_party_sets/first_party_set_parser.h"
namespace network { namespace network {
...@@ -35,7 +37,7 @@ class PreloadedFirstPartySets { ...@@ -35,7 +37,7 @@ class PreloadedFirstPartySets {
// //
// In case of invalid input, clears the current members-to-owners map, but // In case of invalid input, clears the current members-to-owners map, but
// keeps any manually-specified set (i.e. a set provided on the command line). // keeps any manually-specified set (i.e. a set provided on the command line).
base::flat_map<std::string, std::string>* ParseAndSet( base::flat_map<net::SchemefulSite, net::SchemefulSite>* ParseAndSet(
base::StringPiece raw_sets); base::StringPiece raw_sets);
int64_t size() const { return sets_.size(); } int64_t size() const { return sets_.size(); }
...@@ -48,8 +50,9 @@ class PreloadedFirstPartySets { ...@@ -48,8 +50,9 @@ class PreloadedFirstPartySets {
// `manually_specified_set_`. // `manually_specified_set_`.
void ApplyManuallySpecifiedSet(); void ApplyManuallySpecifiedSet();
base::flat_map<std::string, std::string> sets_; base::flat_map<net::SchemefulSite, net::SchemefulSite> sets_;
base::Optional<std::pair<std::string, base::flat_set<std::string>>> base::Optional<
std::pair<net::SchemefulSite, base::flat_set<net::SchemefulSite>>>
manually_specified_set_; manually_specified_set_;
}; };
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "base/at_exit.h" #include "base/at_exit.h"
#include "base/i18n/icu_util.h" #include "base/i18n/icu_util.h"
#include "net/base/schemeful_site.h"
struct TestCase { struct TestCase {
TestCase() { CHECK(base::i18n::InitializeICU()); } TestCase() { CHECK(base::i18n::InitializeICU()); }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "base/at_exit.h" #include "base/at_exit.h"
#include "base/i18n/icu_util.h" #include "base/i18n/icu_util.h"
#include "net/base/schemeful_site.h"
#include "testing/libfuzzer/proto/json.pb.h" #include "testing/libfuzzer/proto/json.pb.h"
#include "testing/libfuzzer/proto/json_proto_converter.h" #include "testing/libfuzzer/proto/json_proto_converter.h"
#include "testing/libfuzzer/proto/lpm_interface.h" #include "testing/libfuzzer/proto/lpm_interface.h"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment