Commit 3318cc99 authored by Piotr Pawliczek's avatar Piotr Pawliczek Committed by Commit Bot

mDNS cache: Switch index to case-insensitive order

In mDNS cache, names are compared in standard, case-sensitive way.
However, according to RFC 6762 mDNS names are case-insensitive. As a
result, chrome cannot resolve mDNS name if it has different
capitalization than a corresponding field in mDNS record returned by a
device. The only solution seems to be switching mDNS cache to a
case-insensitive index.

BUG=chromium:1108807
TEST=tested on atlas

Change-Id: Ia641e82fd4f175c39d1a80f52b1111238bf50b4e
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2315355
Auto-Submit: Piotr Pawliczek <pawliczek@chromium.org>
Commit-Queue: Eric Orth <ericorth@chromium.org>
Reviewed-by: default avatarEric Orth <ericorth@chromium.org>
Cr-Commit-Position: refs/heads/master@{#791348}
parent 02fac647
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include "base/strings/string_number_conversions.h" #include "base/strings/string_number_conversions.h"
#include "base/strings/string_util.h"
#include "net/dns/public/dns_protocol.h" #include "net/dns/public/dns_protocol.h"
#include "net/dns/record_parsed.h" #include "net/dns/record_parsed.h"
#include "net/dns/record_rdata.h" #include "net/dns/record_rdata.h"
...@@ -26,10 +27,12 @@ constexpr size_t kDefaultEntryLimit = 100'000; ...@@ -26,10 +27,12 @@ constexpr size_t kDefaultEntryLimit = 100'000;
// Section 10.1. // Section 10.1.
static const unsigned kZeroTTLSeconds = 1; static const unsigned kZeroTTLSeconds = 1;
MDnsCache::Key::Key(unsigned type, const std::string& name, MDnsCache::Key::Key(unsigned type,
const std::string& name,
const std::string& optional) const std::string& optional)
: type_(type), name_(name), optional_(optional) { : type_(type),
} name_lowercase_(base::ToLowerASCII(name)),
optional_(optional) {}
MDnsCache::Key::Key(const MDnsCache::Key& other) = default; MDnsCache::Key::Key(const MDnsCache::Key& other) = default;
...@@ -39,12 +42,13 @@ MDnsCache::Key& MDnsCache::Key::operator=(const MDnsCache::Key& other) = ...@@ -39,12 +42,13 @@ MDnsCache::Key& MDnsCache::Key::operator=(const MDnsCache::Key& other) =
MDnsCache::Key::~Key() = default; MDnsCache::Key::~Key() = default;
bool MDnsCache::Key::operator<(const MDnsCache::Key& other) const { bool MDnsCache::Key::operator<(const MDnsCache::Key& other) const {
return std::tie(name_, type_, optional_) < return std::tie(name_lowercase_, type_, optional_) <
std::tie(other.name_, other.type_, other.optional_); std::tie(other.name_lowercase_, other.type_, other.optional_);
} }
bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const { bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const {
return type_ == key.type_ && name_ == key.name_ && optional_ == key.optional_; return type_ == key.type_ && name_lowercase_ == key.name_lowercase_ &&
optional_ == key.optional_;
} }
// static // static
...@@ -133,9 +137,10 @@ void MDnsCache::FindDnsRecords(unsigned type, ...@@ -133,9 +137,10 @@ void MDnsCache::FindDnsRecords(unsigned type,
DCHECK(results); DCHECK(results);
results->clear(); results->clear();
const std::string name_lowercase = base::ToLowerASCII(name);
auto i = mdns_cache_.lower_bound(Key(type, name, "")); auto i = mdns_cache_.lower_bound(Key(type, name, ""));
for (; i != mdns_cache_.end(); ++i) { for (; i != mdns_cache_.end(); ++i) {
if (i->first.name() != name || if (i->first.name_lowercase() != name_lowercase ||
(type != 0 && i->first.type() != type)) { (type != 0 && i->first.type() != type)) {
break; break;
} }
......
...@@ -39,14 +39,14 @@ class NET_EXPORT_PRIVATE MDnsCache { ...@@ -39,14 +39,14 @@ class NET_EXPORT_PRIVATE MDnsCache {
bool operator==(const Key& key) const; bool operator==(const Key& key) const;
unsigned type() const { return type_; } unsigned type() const { return type_; }
const std::string& name() const { return name_; } const std::string& name_lowercase() const { return name_lowercase_; }
const std::string& optional() const { return optional_; } const std::string& optional() const { return optional_; }
// Create the cache key corresponding to |record|. // Create the cache key corresponding to |record|.
static Key CreateFor(const RecordParsed* record); static Key CreateFor(const RecordParsed* record);
private: private:
unsigned type_; unsigned type_;
std::string name_; std::string name_lowercase_;
std::string optional_; std::string optional_;
}; };
......
...@@ -99,6 +99,26 @@ static const uint8_t kTestResponsesGoodbyePacket[] = { ...@@ -99,6 +99,26 @@ static const uint8_t kTestResponsesGoodbyePacket[] = {
74, 125, 95, 121, // RDATA is the IP: 74.125.95.121 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
}; };
static const uint8_t kTestResponsesDifferentCapitalization[] = {
// Answer 1
// GHS.l.google.com in DNS format.
3, 'G', 'H', 'S', 1, 'l', 6, 'g', 'o', 'o', 'g', 'l', 'e', 3, 'c', 'o', 'm',
0x00, 0x00, 0x01, // TYPE is A.
0x00, 0x01, // CLASS is IN.
0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
0, 4, // RDLENGTH is 4 bytes.
74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
// Answer 2
// ghs.l.GOOGLE.com in DNS format.
3, 'g', 'h', 's', 1, 'l', 6, 'G', 'O', 'O', 'G', 'L', 'E', 3, 'c', 'o', 'm',
0x00, 0x00, 0x01, // TYPE is A.
0x00, 0x01, // CLASS is IN.
0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
0, 4, // RDLENGTH is 4 bytes.
74, 125, 95, 122, // RDATA is the IP: 74.125.95.122
};
class RecordRemovalMock { class RecordRemovalMock {
public: public:
MOCK_METHOD1(OnRecordRemoved, void(const RecordParsed*)); MOCK_METHOD1(OnRecordRemoved, void(const RecordParsed*));
...@@ -402,4 +422,28 @@ TEST_F(MDnsCacheTest, ClearOnOverfilledCleanup) { ...@@ -402,4 +422,28 @@ TEST_F(MDnsCacheTest, ClearOnOverfilledCleanup) {
EXPECT_TRUE(results.empty()); EXPECT_TRUE(results.empty());
} }
TEST_F(MDnsCacheTest, CaseInsensitive) {
DnsRecordParser parser(kTestResponsesDifferentCapitalization,
sizeof(kTestResponsesDifferentCapitalization), 0);
std::unique_ptr<const RecordParsed> record1;
std::unique_ptr<const RecordParsed> record2;
std::vector<const RecordParsed*> results;
record1 = RecordParsed::CreateFrom(&parser, default_time_);
record2 = RecordParsed::CreateFrom(&parser, default_time_);
EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(std::move(record1)));
EXPECT_EQ(MDnsCache::RecordChanged,
cache_.UpdateDnsRecord(std::move(record2)));
cache_.FindDnsRecords(0, "ghs.l.google.com", &results, default_time_);
EXPECT_EQ(1u, results.size());
EXPECT_EQ("ghs.l.GOOGLE.com", results[0]->name());
std::vector<const RecordParsed*> results2;
cache_.FindDnsRecords(0, "GHS.L.google.COM", &results2, default_time_);
EXPECT_EQ(results, results2);
}
} // namespace net } // namespace net
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/location.h" #include "base/location.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/strings/string_util.h"
#include "base/threading/thread_task_runner_handle.h" #include "base/threading/thread_task_runner_handle.h"
#include "base/time/clock.h" #include "base/time/clock.h"
#include "base/time/default_clock.h" #include "base/time/default_clock.h"
...@@ -316,9 +317,12 @@ void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { ...@@ -316,9 +317,12 @@ void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) {
} }
// Alert all listeners waiting for the nonexistent RR types. // Alert all listeners waiting for the nonexistent RR types.
auto i = listeners_.upper_bound(ListenerKey(record->name(), 0)); ListenerKey key(record->name(), 0);
for (; i != listeners_.end() && i->first.first == record->name(); i++) { auto i = listeners_.upper_bound(key);
if (!rdata->GetBit(i->first.second)) { for (; i != listeners_.end() &&
i->first.name_lowercase() == key.name_lowercase();
i++) {
if (!rdata->GetBit(i->first.type())) {
for (auto& observer : *i->second) for (auto& observer : *i->second)
observer.AlertNsecRecord(); observer.AlertNsecRecord();
} }
...@@ -330,6 +334,17 @@ void MDnsClientImpl::Core::OnConnectionError(int error) { ...@@ -330,6 +334,17 @@ void MDnsClientImpl::Core::OnConnectionError(int error) {
VLOG(1) << "MDNS OnConnectionError (code: " << error << ")"; VLOG(1) << "MDNS OnConnectionError (code: " << error << ")";
} }
MDnsClientImpl::Core::ListenerKey::ListenerKey(const std::string& name,
uint16_t type)
: name_lowercase_(base::ToLowerASCII(name)), type_(type) {}
bool MDnsClientImpl::Core::ListenerKey::operator<(
const MDnsClientImpl::Core::ListenerKey& key) const {
if (name_lowercase_ == key.name_lowercase_)
return type_ < key.type_;
return name_lowercase_ < key.name_lowercase_;
}
void MDnsClientImpl::Core::AlertListeners( void MDnsClientImpl::Core::AlertListeners(
MDnsCache::UpdateType update_type, MDnsCache::UpdateType update_type,
const ListenerKey& key, const ListenerKey& key,
......
...@@ -154,7 +154,19 @@ class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient { ...@@ -154,7 +154,19 @@ class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient {
private: private:
FRIEND_TEST_ALL_PREFIXES(MDnsTest, CacheCleanupWithShortTTL); FRIEND_TEST_ALL_PREFIXES(MDnsTest, CacheCleanupWithShortTTL);
typedef std::pair<std::string, uint16_t> ListenerKey; class ListenerKey {
public:
ListenerKey(const std::string& name, uint16_t type);
ListenerKey(const ListenerKey&) = default;
ListenerKey(ListenerKey&&) = default;
bool operator<(const ListenerKey& key) const;
const std::string& name_lowercase() const { return name_lowercase_; }
uint16_t type() const { return type_; }
private:
std::string name_lowercase_;
uint16_t type_;
};
typedef base::ObserverList<MDnsListenerImpl>::Unchecked ObserverListType; typedef base::ObserverList<MDnsListenerImpl>::Unchecked ObserverListType;
typedef std::map<ListenerKey, std::unique_ptr<ObserverListType>> typedef std::map<ListenerKey, std::unique_ptr<ObserverListType>>
ListenerMap; ListenerMap;
......
...@@ -76,6 +76,32 @@ const uint8_t kSamplePacket1[] = { ...@@ -76,6 +76,32 @@ const uint8_t kSamplePacket1[] = {
0x24, 0x75, 0x00, 0x08, // RDLENGTH is 8 bytes. 0x24, 0x75, 0x00, 0x08, // RDLENGTH is 8 bytes.
0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x32}; 0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x32};
const uint8_t kSamplePacket1WithCapitalization[] = {
// Header
0x00, 0x00, // ID is zeroed out
0x81, 0x80, // Standard query response, RA, no error
0x00, 0x00, // No questions (for simplicity)
0x00, 0x02, // 2 RRs (answers)
0x00, 0x00, // 0 authority RRs
0x00, 0x00, // 0 additional RRs
// Answer 1
0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 'T', 'C', 'P', 0x05,
'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
0x00, 0x01, // CLASS is IN.
0x00, 0x00, // TTL (4 bytes) is 1 second;
0x00, 0x01, 0x00, 0x08, // RDLENGTH is 8 bytes.
0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
// Answer 2
0x08, '_', 'P', 'r', 'i', 'n', 't', 'e', 'R', 0xc0,
0x14, // Pointer to "._tcp.local"
0x00, 0x0c, // TYPE is PTR.
0x00, 0x01, // CLASS is IN.
0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
0x24, 0x75, 0x00, 0x08, // RDLENGTH is 8 bytes.
0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x32};
const uint8_t kCorruptedPacketBadQuestion[] = { const uint8_t kCorruptedPacketBadQuestion[] = {
// Header // Header
0x00, 0x00, // ID is zeroed out 0x00, 0x00, // ID is zeroed out
...@@ -248,6 +274,22 @@ const uint8_t kQueryPacketPrivet[] = { ...@@ -248,6 +274,22 @@ const uint8_t kQueryPacketPrivet[] = {
0x00, 0x01, // CLASS is IN. 0x00, 0x01, // CLASS is IN.
}; };
const uint8_t kQueryPacketPrivetWithCapitalization[] = {
// Header
0x00, 0x00, // ID is zeroed out
0x00, 0x00, // No flags.
0x00, 0x01, // One question.
0x00, 0x00, // 0 RRs (answers)
0x00, 0x00, // 0 authority RRs
0x00, 0x00, // 0 additional RRs
// Question
// This part is echoed back from the respective query.
0x07, '_', 'P', 'R', 'I', 'V', 'E', 'T', 0x04, '_', 't', 'c', 'p', 0x05,
'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
0x00, 0x01, // CLASS is IN.
};
const uint8_t kQueryPacketPrivetA[] = { const uint8_t kQueryPacketPrivetA[] = {
// Header // Header
0x00, 0x00, // ID is zeroed out 0x00, 0x00, // ID is zeroed out
...@@ -525,6 +567,48 @@ TEST_F(MDnsTest, PassiveListeners) { ...@@ -525,6 +567,48 @@ TEST_F(MDnsTest, PassiveListeners) {
listener_printer.reset(); listener_printer.reset();
} }
TEST_F(MDnsTest, PassiveListenersWithCapitalization) {
StrictMock<MockListenerDelegate> delegate_privet;
StrictMock<MockListenerDelegate> delegate_printer;
PtrRecordCopyContainer record_privet;
PtrRecordCopyContainer record_printer;
std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
dns_protocol::kTypePTR, "_privet._tcp.LOCAL", &delegate_privet);
std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
dns_protocol::kTypePTR, "_prinTER._Tcp.Local", &delegate_printer);
ASSERT_TRUE(listener_privet->Start());
ASSERT_TRUE(listener_printer->Start());
// Send the same packet twice to ensure no records are double-counted.
EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
.Times(Exactly(1))
.WillOnce(
Invoke(&record_privet, &PtrRecordCopyContainer::SaveWithDummyArg));
EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
.Times(Exactly(1))
.WillOnce(
Invoke(&record_printer, &PtrRecordCopyContainer::SaveWithDummyArg));
SimulatePacketReceive(kSamplePacket1WithCapitalization,
sizeof(kSamplePacket1WithCapitalization));
SimulatePacketReceive(kSamplePacket1WithCapitalization,
sizeof(kSamplePacket1WithCapitalization));
EXPECT_TRUE(record_privet.IsRecordWith("_privet._TCP.local",
"hello._privet._TCP.local"));
EXPECT_TRUE(record_printer.IsRecordWith("_PrinteR._TCP.local",
"hello._PrinteR._TCP.local"));
listener_privet.reset();
listener_printer.reset();
}
TEST_F(MDnsTest, PassiveListenersCacheCleanup) { TEST_F(MDnsTest, PassiveListenersCacheCleanup) {
StrictMock<MockListenerDelegate> delegate_privet; StrictMock<MockListenerDelegate> delegate_privet;
...@@ -709,6 +793,34 @@ TEST_F(MDnsTest, TransactionWithEmptyCache) { ...@@ -709,6 +793,34 @@ TEST_F(MDnsTest, TransactionWithEmptyCache) {
"hello._privet._tcp.local")); "hello._privet._tcp.local"));
} }
TEST_F(MDnsTest, TransactionWithEmptyCacheAndCapitalization) {
ExpectPacket(kQueryPacketPrivetWithCapitalization,
sizeof(kQueryPacketPrivetWithCapitalization));
std::unique_ptr<MDnsTransaction> transaction_privet =
test_client_->CreateTransaction(
dns_protocol::kTypePTR, "_PRIVET._tcp.local",
MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
MDnsTransaction::SINGLE_RESULT,
base::BindRepeating(&MDnsTest::MockableRecordCallback,
base::Unretained(this)));
ASSERT_TRUE(transaction_privet->Start());
PtrRecordCopyContainer record_privet;
EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
.Times(Exactly(1))
.WillOnce(
Invoke(&record_privet, &PtrRecordCopyContainer::SaveWithDummyArg));
SimulatePacketReceive(kSamplePacket1WithCapitalization,
sizeof(kSamplePacket1WithCapitalization));
EXPECT_TRUE(record_privet.IsRecordWith("_privet._TCP.local",
"hello._privet._TCP.local"));
}
TEST_F(MDnsTest, TransactionCacheOnlyNoResult) { TEST_F(MDnsTest, TransactionCacheOnlyNoResult) {
std::unique_ptr<MDnsTransaction> transaction_privet = std::unique_ptr<MDnsTransaction> transaction_privet =
test_client_->CreateTransaction( test_client_->CreateTransaction(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment