Commit 9871aafa authored by Eric Orth's avatar Eric Orth Committed by Commit Bot

Implement MULTICAST_DNS source for new HostResolver API.

Only used when explicitly requested. The MDnsClient is created and
starts listening for MDNS messages at first use in HostResolverImpl.

For now supports only a basic feature set, eg we'll DCHECK if canonname
is requested (in contrast to the async resolver behavior that was
unnoticed until recently due to the lack of a DCHECK).

Bug: 846423
Cq-Include-Trybots: luci.chromium.try:linux_mojo
Change-Id: Ib296ac432d392baee2454587693b2537513f461c
Reviewed-on: https://chromium-review.googlesource.com/1211542Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Reviewed-by: default avatarTom Sepez <tsepez@chromium.org>
Commit-Queue: Eric Orth <ericorth@chromium.org>
Cr-Commit-Position: refs/heads/master@{#595956}
parent ee7c6129
...@@ -53,6 +53,8 @@ source_set("dns") { ...@@ -53,6 +53,8 @@ source_set("dns") {
"host_cache.cc", "host_cache.cc",
"host_resolver.cc", "host_resolver.cc",
"host_resolver_impl.cc", "host_resolver_impl.cc",
"host_resolver_mdns_task.cc",
"host_resolver_mdns_task.h",
"host_resolver_proc.cc", "host_resolver_proc.cc",
"host_resolver_proc.h", "host_resolver_proc.h",
"host_resolver_source.h", "host_resolver_source.h",
...@@ -459,8 +461,14 @@ source_set("test_support") { ...@@ -459,8 +461,14 @@ source_set("test_support") {
] ]
if (enable_mdns) { if (enable_mdns) {
sources += [ "mock_mdns_socket_factory.cc" ] sources += [
public += [ "mock_mdns_socket_factory.h" ] "mock_mdns_client.cc",
"mock_mdns_socket_factory.cc",
]
public += [
"mock_mdns_client.h",
"mock_mdns_socket_factory.h",
]
} }
deps = [ deps = [
......
...@@ -64,7 +64,9 @@ ...@@ -64,7 +64,9 @@
#include "net/dns/dns_response.h" #include "net/dns/dns_response.h"
#include "net/dns/dns_transaction.h" #include "net/dns/dns_transaction.h"
#include "net/dns/dns_util.h" #include "net/dns/dns_util.h"
#include "net/dns/host_resolver_mdns_task.h"
#include "net/dns/host_resolver_proc.h" #include "net/dns/host_resolver_proc.h"
#include "net/dns/mdns_client.h"
#include "net/log/net_log.h" #include "net/log/net_log.h"
#include "net/log/net_log_capture_mode.h" #include "net/log/net_log_capture_mode.h"
#include "net/log/net_log_event_type.h" #include "net/log/net_log_event_type.h"
...@@ -76,6 +78,10 @@ ...@@ -76,6 +78,10 @@
#include "net/socket/datagram_client_socket.h" #include "net/socket/datagram_client_socket.h"
#include "url/url_canon_ip.h" #include "url/url_canon_ip.h"
#if BUILDFLAG(ENABLE_MDNS)
#include "net/dns/mdns_client_impl.h"
#endif
#if defined(OS_WIN) #if defined(OS_WIN)
#include "net/base/winsock_init.h" #include "net/base/winsock_init.h"
#endif #endif
...@@ -1512,7 +1518,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1512,7 +1518,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
// This will destroy the Job. // This will destroy the Job.
CompleteRequests( CompleteRequests(
MakeCacheEntry(OK, addr_list, HostCache::Entry::SOURCE_HOSTS), MakeCacheEntry(OK, addr_list, HostCache::Entry::SOURCE_HOSTS),
base::TimeDelta()); base::TimeDelta(), true /* allow_cache */);
return true; return true;
} }
return false; return false;
...@@ -1525,7 +1531,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1525,7 +1531,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
} }
bool is_running() const { bool is_running() const {
return is_dns_running() || is_proc_running(); return is_dns_running() || is_mdns_running() || is_proc_running();
} }
private: private:
...@@ -1617,7 +1623,8 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1617,7 +1623,8 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
switch (key_.host_resolver_source) { switch (key_.host_resolver_source) {
case HostResolverSource::ANY: case HostResolverSource::ANY:
if (resolver_->HaveDnsConfig() && if (resolver_->HaveDnsConfig() &&
!ResemblesMulticastDNSName(key_.hostname)) { !ResemblesMulticastDNSName(key_.hostname) &&
!(key_.host_resolver_flags & HOST_RESOLVER_CANONNAME)) {
StartDnsTask(); StartDnsTask();
} else { } else {
StartProcTask(); StartProcTask();
...@@ -1633,6 +1640,9 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1633,6 +1640,9 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
StartDnsTask(); StartDnsTask();
break; break;
case HostResolverSource::MULTICAST_DNS:
StartMdnsTask();
break;
} }
// Caution: Job::Start must not complete synchronously. // Caution: Job::Start must not complete synchronously.
...@@ -1643,7 +1653,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1643,7 +1653,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
// TaskScheduler threads low, we will need to use an "inner" // TaskScheduler threads low, we will need to use an "inner"
// PrioritizedDispatcher with tighter limits. // PrioritizedDispatcher with tighter limits.
void StartProcTask() { void StartProcTask() {
DCHECK(!is_dns_running()); DCHECK(!is_running());
proc_task_ = std::make_unique<ProcTask>( proc_task_ = std::make_unique<ProcTask>(
key_, resolver_->proc_params_, key_, resolver_->proc_params_,
base::BindOnce(&Job::OnProcTaskComplete, base::Unretained(this), base::BindOnce(&Job::OnProcTaskComplete, base::Unretained(this),
...@@ -1693,7 +1703,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1693,7 +1703,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
// Don't store the |ttl| in cache since it's not obtained from the server. // Don't store the |ttl| in cache since it's not obtained from the server.
CompleteRequests( CompleteRequests(
MakeCacheEntry(net_error, addr_list, HostCache::Entry::SOURCE_UNKNOWN), MakeCacheEntry(net_error, addr_list, HostCache::Entry::SOURCE_UNKNOWN),
ttl); ttl, true /* allow_cache */);
} }
void StartDnsTask() { void StartDnsTask() {
...@@ -1751,7 +1761,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1751,7 +1761,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
CompleteRequests( CompleteRequests(
HostCache::Entry(net_error, AddressList(), HostCache::Entry(net_error, AddressList(),
HostCache::Entry::Source::SOURCE_UNKNOWN, ttl), HostCache::Entry::Source::SOURCE_UNKNOWN, ttl),
ttl); ttl, true /* allow_cache */);
} }
} }
...@@ -1784,7 +1794,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1784,7 +1794,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
} else { } else {
CompleteRequests(MakeCacheEntryWithTTL(net_error, addr_list, CompleteRequests(MakeCacheEntryWithTTL(net_error, addr_list,
HostCache::Entry::SOURCE_DNS, ttl), HostCache::Entry::SOURCE_DNS, ttl),
bounded_ttl); bounded_ttl, true /* allow_cache */);
} }
} }
...@@ -1801,6 +1811,50 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1801,6 +1811,50 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
dns_task_->StartSecondTransaction(); dns_task_->StartSecondTransaction();
} }
void StartMdnsTask() {
DCHECK(!is_running());
// No flags are supported for MDNS except
// HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6 (which is not actually an
// input flag).
DCHECK_EQ(0, key_.host_resolver_flags &
~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6);
std::vector<HostResolver::DnsQueryType> query_types;
switch (key_.address_family) {
case ADDRESS_FAMILY_UNSPECIFIED:
query_types.push_back(HostResolver::DnsQueryType::A);
query_types.push_back(HostResolver::DnsQueryType::AAAA);
break;
case ADDRESS_FAMILY_IPV4:
query_types.push_back(HostResolver::DnsQueryType::A);
break;
case ADDRESS_FAMILY_IPV6:
query_types.push_back(HostResolver::DnsQueryType::AAAA);
break;
}
mdns_task_ = std::make_unique<HostResolverMdnsTask>(
resolver_->GetOrCreateMdnsClient(), key_.hostname, query_types);
mdns_task_->Start(
base::BindOnce(&Job::OnMdnsTaskComplete, base::Unretained(this)));
}
void OnMdnsTaskComplete(int error) {
DCHECK(is_mdns_running());
// TODO(crbug.com/846423): Consider adding MDNS-specific logging.
if (error != OK) {
CompleteRequestsWithError(error);
} else if (ContainsIcannNameCollisionIp(mdns_task_->result_addresses())) {
CompleteRequestsWithError(ERR_ICANN_NAME_COLLISION);
} else {
// MDNS uses a separate cache, so skip saving result to cache.
// TODO(crbug.com/846423): Consider merging caches.
CompleteRequestsWithoutCache(error, mdns_task_->result_addresses());
}
}
URLRequestContext* url_request_context() override { URLRequestContext* url_request_context() override {
return resolver_->url_request_context_; return resolver_->url_request_context_;
} }
...@@ -1880,8 +1934,12 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1880,8 +1934,12 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
} }
// Performs Job's last rites. Completes all Requests. Deletes this. // Performs Job's last rites. Completes all Requests. Deletes this.
//
// If not |allow_cache|, result will not be stored in the host cache, even if
// result would otherwise allow doing so.
void CompleteRequests(const HostCache::Entry& entry, void CompleteRequests(const HostCache::Entry& entry,
base::TimeDelta ttl) { base::TimeDelta ttl,
bool allow_cache) {
CHECK(resolver_.get()); CHECK(resolver_.get());
// This job must be removed from resolver's |jobs_| now to make room for a // This job must be removed from resolver's |jobs_| now to make room for a
...@@ -1893,6 +1951,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1893,6 +1951,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
if (is_running()) { if (is_running()) {
proc_task_ = nullptr; proc_task_ = nullptr;
KillDnsTask(); KillDnsTask();
mdns_task_ = nullptr;
// Signal dispatcher that a slot has opened. // Signal dispatcher that a slot has opened.
resolver_->dispatcher_->OnJobFinished(); resolver_->dispatcher_->OnJobFinished();
...@@ -1922,7 +1981,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1922,7 +1981,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
bool did_complete = (entry.error() != ERR_NETWORK_CHANGED) && bool did_complete = (entry.error() != ERR_NETWORK_CHANGED) &&
(entry.error() != ERR_HOST_RESOLVER_QUEUE_TOO_LARGE); (entry.error() != ERR_HOST_RESOLVER_QUEUE_TOO_LARGE);
if (did_complete) if (did_complete && allow_cache)
resolver_->CacheResult(key_, entry, ttl); resolver_->CacheResult(key_, entry, ttl);
RecordJobHistograms(entry.error()); RecordJobHistograms(entry.error());
...@@ -1954,11 +2013,17 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1954,11 +2013,17 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
} }
} }
void CompleteRequestsWithoutCache(int error, const AddressList& addresses) {
CompleteRequests(
MakeCacheEntry(error, addresses, HostCache::Entry::SOURCE_UNKNOWN),
base::TimeDelta(), false /* allow_cache */);
}
// Convenience wrapper for CompleteRequests in case of failure. // Convenience wrapper for CompleteRequests in case of failure.
void CompleteRequestsWithError(int net_error) { void CompleteRequestsWithError(int net_error) {
CompleteRequests(HostCache::Entry(net_error, AddressList(), CompleteRequests(HostCache::Entry(net_error, AddressList(),
HostCache::Entry::SOURCE_UNKNOWN), HostCache::Entry::SOURCE_UNKNOWN),
base::TimeDelta()); base::TimeDelta(), true /* allow_cache */);
} }
RequestPriority priority() const override { RequestPriority priority() const override {
...@@ -1972,6 +2037,8 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -1972,6 +2037,8 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
bool is_dns_running() const { return !!dns_task_; } bool is_dns_running() const { return !!dns_task_; }
bool is_mdns_running() const { return !!mdns_task_; }
bool is_proc_running() const { return !!proc_task_; } bool is_proc_running() const { return !!proc_task_; }
base::WeakPtr<HostResolverImpl> resolver_; base::WeakPtr<HostResolverImpl> resolver_;
...@@ -2005,6 +2072,9 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job, ...@@ -2005,6 +2072,9 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job,
// Resolves the host using a DnsTransaction. // Resolves the host using a DnsTransaction.
std::unique_ptr<DnsTask> dns_task_; std::unique_ptr<DnsTask> dns_task_;
// Resolves the host using MDnsClient.
std::unique_ptr<HostResolverMdnsTask> mdns_task_;
// All Requests waiting for the result of this Job. Some can be canceled. // All Requests waiting for the result of this Job. Some can be canceled.
base::LinkedList<RequestImpl> requests_; base::LinkedList<RequestImpl> requests_;
...@@ -2311,6 +2381,17 @@ void HostResolverImpl::SetHaveOnlyLoopbackAddresses(bool result) { ...@@ -2311,6 +2381,17 @@ void HostResolverImpl::SetHaveOnlyLoopbackAddresses(bool result) {
} }
} }
void HostResolverImpl::SetMdnsSocketFactoryForTesting(
std::unique_ptr<MDnsSocketFactory> socket_factory) {
DCHECK(!mdns_client_);
mdns_socket_factory_ = std::move(socket_factory);
}
void HostResolverImpl::SetMdnsClientForTesting(
std::unique_ptr<MDnsClient> client) {
mdns_client_ = std::move(client);
}
void HostResolverImpl::SetTaskRunnerForTesting( void HostResolverImpl::SetTaskRunnerForTesting(
scoped_refptr<base::TaskRunner> task_runner) { scoped_refptr<base::TaskRunner> task_runner) {
proc_task_runner_ = std::move(task_runner); proc_task_runner_ = std::move(task_runner);
...@@ -2321,6 +2402,11 @@ int HostResolverImpl::Resolve(RequestImpl* request) { ...@@ -2321,6 +2402,11 @@ int HostResolverImpl::Resolve(RequestImpl* request) {
DCHECK(!request->job()); DCHECK(!request->job());
// Request may only be resolved once. // Request may only be resolved once.
DCHECK(!request->complete()); DCHECK(!request->complete());
// MDNS requests do not support skipping cache.
// TODO(crbug.com/846423): Either add support for skipping the MDNS cache, or
// merge to use the normal host cache for MDNS requests.
DCHECK(request->parameters().source != HostResolverSource::MULTICAST_DNS ||
request->parameters().allow_cached_response);
request->set_request_time(tick_clock_->NowTicks()); request->set_request_time(tick_clock_->NowTicks());
...@@ -2866,6 +2952,25 @@ void HostResolverImpl::OnDnsTaskResolve(int net_error) { ...@@ -2866,6 +2952,25 @@ void HostResolverImpl::OnDnsTaskResolve(int net_error) {
std::abs(net_error)); std::abs(net_error));
} }
MDnsClient* HostResolverImpl::GetOrCreateMdnsClient() {
#if BUILDFLAG(ENABLE_MDNS)
if (!mdns_client_) {
if (!mdns_socket_factory_)
mdns_socket_factory_ = std::make_unique<MDnsSocketFactoryImpl>(net_log_);
mdns_client_ = MDnsClient::CreateDefault();
mdns_client_->StartListening(mdns_socket_factory_.get());
}
DCHECK(mdns_client_->IsListening());
return mdns_client_.get();
#else
// Should not request MDNS resoltuion unless MDNS is enabled.
NOTREACHED();
return nullptr;
#endif
}
HostResolverImpl::RequestImpl::~RequestImpl() { HostResolverImpl::RequestImpl::~RequestImpl() {
if (job_) if (job_)
job_->CancelRequest(this); job_->CancelRequest(this);
......
...@@ -35,6 +35,8 @@ namespace net { ...@@ -35,6 +35,8 @@ namespace net {
class AddressList; class AddressList;
class DnsClient; class DnsClient;
class IPAddress; class IPAddress;
class MDnsClient;
class MDnsSocketFactory;
class NetLog; class NetLog;
class NetLogWithSource; class NetLogWithSource;
...@@ -187,6 +189,10 @@ class NET_EXPORT HostResolverImpl ...@@ -187,6 +189,10 @@ class NET_EXPORT HostResolverImpl
// Only allowed when the queue is empty. // Only allowed when the queue is empty.
void SetMaxQueuedJobsForTesting(size_t value); void SetMaxQueuedJobsForTesting(size_t value);
void SetMdnsSocketFactoryForTesting(
std::unique_ptr<MDnsSocketFactory> socket_factory);
void SetMdnsClientForTesting(std::unique_ptr<MDnsClient> client);
protected: protected:
// Callback from HaveOnlyLoopbackAddresses probe. // Callback from HaveOnlyLoopbackAddresses probe.
void SetHaveOnlyLoopbackAddresses(bool result); void SetHaveOnlyLoopbackAddresses(bool result);
...@@ -345,6 +351,8 @@ class NET_EXPORT HostResolverImpl ...@@ -345,6 +351,8 @@ class NET_EXPORT HostResolverImpl
// and resulted in |net_error|. // and resulted in |net_error|.
void OnDnsTaskResolve(int net_error); void OnDnsTaskResolve(int net_error);
MDnsClient* GetOrCreateMdnsClient();
// Allows the tests to catch slots leaking out of the dispatcher. One // Allows the tests to catch slots leaking out of the dispatcher. One
// HostResolverImpl::Job could occupy multiple PrioritizedDispatcher job // HostResolverImpl::Job could occupy multiple PrioritizedDispatcher job
// slots. // slots.
...@@ -355,6 +363,11 @@ class NET_EXPORT HostResolverImpl ...@@ -355,6 +363,11 @@ class NET_EXPORT HostResolverImpl
// Cache of host resolution results. // Cache of host resolution results.
std::unique_ptr<HostCache> cache_; std::unique_ptr<HostCache> cache_;
// Used for multicast DNS tasks. Created on first use using
// GetOrCreateMndsClient().
std::unique_ptr<MDnsSocketFactory> mdns_socket_factory_;
std::unique_ptr<MDnsClient> mdns_client_;
// Map from HostCache::Key to a Job. // Map from HostCache::Key to a Job.
JobMap jobs_; JobMap jobs_;
......
This diff is collapsed.
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/host_resolver_mdns_task.h"
#include <algorithm>
#include <utility>
#include "base/logging.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/dns/dns_protocol.h"
#include "net/dns/record_parsed.h"
#include "net/dns/record_rdata.h"
namespace net {
class HostResolverMdnsTask::Transaction {
public:
Transaction(HostResolver::DnsQueryType query_type, HostResolverMdnsTask* task)
: query_type_(query_type), result_(ERR_IO_PENDING), task_(task) {}
void Start() {
DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_);
// Should not be completed or running yet.
DCHECK_EQ(ERR_IO_PENDING, result_);
DCHECK(!async_transaction_);
uint16_t rrtype;
switch (query_type_) {
case net::HostResolver::DnsQueryType::A:
rrtype = net::dns_protocol::kTypeA;
break;
case net::HostResolver::DnsQueryType::AAAA:
rrtype = net::dns_protocol::kTypeAAAA;
break;
default:
// Type not supported for MDNS.
NOTREACHED();
return;
}
// TODO(crbug.com/846423): Use |allow_cached_response| to set the
// QUERY_CACHE flag or not.
int flags = MDnsTransaction::SINGLE_RESULT | MDnsTransaction::QUERY_CACHE |
MDnsTransaction::QUERY_NETWORK;
// If |this| is destroyed, destruction of |internal_transaction_| should
// cancel and prevent invocation of OnComplete.
std::unique_ptr<MDnsTransaction> inner_transaction =
task_->mdns_client_->CreateTransaction(
rrtype, task_->hostname_, flags,
base::BindRepeating(&HostResolverMdnsTask::Transaction::OnComplete,
base::Unretained(this)));
// Side effect warning: Start() may finish and invoke callbacks inline.
bool start_result = inner_transaction->Start();
if (!start_result)
task_->CompleteWithResult(ERR_FAILED, true /* post_needed */);
else if (result_ == ERR_IO_PENDING)
async_transaction_ = std::move(inner_transaction);
}
bool IsDone() const { return result_ != ERR_IO_PENDING; }
bool IsError() const {
return IsDone() && result_ != OK && result_ != ERR_NAME_NOT_RESOLVED;
}
int result() const { return result_; }
void Cancel() {
DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_);
DCHECK_EQ(ERR_IO_PENDING, result_);
result_ = ERR_FAILED;
async_transaction_ = nullptr;
}
private:
void OnComplete(MDnsTransaction::Result result, const RecordParsed* parsed) {
DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_);
DCHECK_EQ(ERR_IO_PENDING, result_);
switch (result) {
case MDnsTransaction::RESULT_RECORD:
result_ = OK;
break;
case MDnsTransaction::RESULT_NO_RESULTS:
case MDnsTransaction::RESULT_NSEC:
result_ = ERR_NAME_NOT_RESOLVED;
break;
default:
// No other results should be possible with the request flags used.
NOTREACHED();
}
if (result_ == net::OK) {
switch (query_type_) {
case net::HostResolver::DnsQueryType::A:
task_->result_addresses_.push_back(
IPEndPoint(parsed->rdata<net::ARecordRdata>()->address(), 0));
break;
case net::HostResolver::DnsQueryType::AAAA:
task_->result_addresses_.push_back(
IPEndPoint(parsed->rdata<net::AAAARecordRdata>()->address(), 0));
break;
default:
NOTREACHED();
}
}
// If we don't have a saved async_transaction, it means OnComplete was
// invoked inline in MDnsTransaction::Start. Callbacks will need to be
// invoked via post.
task_->CheckCompletion(!async_transaction_);
}
const HostResolver::DnsQueryType query_type_;
// ERR_IO_PENDING until transaction completes (or is cancelled).
int result_;
// Not saved until MDnsTransaction::Start completes to differentiate inline
// completion.
std::unique_ptr<MDnsTransaction> async_transaction_;
// Back pointer. Expected to destroy |this| before destroying itself.
HostResolverMdnsTask* const task_;
};
HostResolverMdnsTask::HostResolverMdnsTask(
MDnsClient* mdns_client,
const std::string& hostname,
const std::vector<HostResolver::DnsQueryType>& query_types)
: mdns_client_(mdns_client), hostname_(hostname), weak_ptr_factory_(this) {
DCHECK(!query_types.empty());
for (HostResolver::DnsQueryType query_type : query_types) {
transactions_.emplace_back(query_type, this);
}
}
HostResolverMdnsTask::~HostResolverMdnsTask() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
transactions_.clear();
}
void HostResolverMdnsTask::Start(CompletionOnceCallback completion_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(!completion_callback_);
completion_callback_ = std::move(completion_callback);
for (auto& transaction : transactions_) {
// Only start transaction if it is not already marked done. A transaction
// could be marked done before starting if it is preemptively canceled by
// a previously started transaction finishing with an error.
if (!transaction.IsDone())
transaction.Start();
}
}
void HostResolverMdnsTask::CheckCompletion(bool post_needed) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
// Finish immediately if any transactions completed with an error.
auto found_error =
std::find_if(transactions_.begin(), transactions_.end(),
[](const Transaction& t) { return t.IsError(); });
if (found_error != transactions_.end()) {
CompleteWithResult(found_error->result(), post_needed);
return;
}
if (std::all_of(transactions_.begin(), transactions_.end(),
[](const Transaction& t) { return t.IsDone(); })) {
// Task is overall successful if any of the transactions found results.
int result = result_addresses_.empty() ? ERR_NAME_NOT_RESOLVED : OK;
CompleteWithResult(result, post_needed);
return;
}
}
void HostResolverMdnsTask::CompleteWithResult(int result, bool post_needed) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
// Cancel any incomplete async transactions.
for (auto& transaction : transactions_) {
if (!transaction.IsDone())
transaction.Cancel();
}
if (post_needed) {
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(
[](base::WeakPtr<HostResolverMdnsTask> task, int result) {
if (task)
std::move(task->completion_callback_).Run(result);
},
weak_ptr_factory_.GetWeakPtr(), result));
} else {
std::move(completion_callback_).Run(result);
}
}
} // namespace net
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_DNS_HOST_RESOLVER_MDNS_TASK_H_
#define NET_DNS_HOST_RESOLVER_MDNS_TASK_H_
#include <memory>
#include <string>
#include <vector>
#include "base/containers/unique_ptr_adapters.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "net/base/completion_once_callback.h"
#include "net/dns/host_resolver.h"
#include "net/dns/mdns_client.h"
namespace net {
// Representation of a single HostResolverImpl::Job task to resolve the hostname
// using multicast DNS transactions. Destruction cancels the task and prevents
// any callbacks from being invoked.
class HostResolverMdnsTask {
public:
// |mdns_client| must outlive |this|.
HostResolverMdnsTask(
MDnsClient* mdns_client,
const std::string& hostname,
const std::vector<HostResolver::DnsQueryType>& query_types);
~HostResolverMdnsTask();
// Starts the task. |completion_callback| will be called asynchronously with
// results.
//
// Should only be called once.
void Start(CompletionOnceCallback completion_callback);
const AddressList& result_addresses() { return result_addresses_; }
private:
class Transaction;
void CheckCompletion(bool post_needed);
void CompleteWithResult(int result, bool post_needed);
MDnsClient* const mdns_client_;
const std::string hostname_;
AddressList result_addresses_;
std::vector<Transaction> transactions_;
CompletionOnceCallback completion_callback_;
SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<HostResolverMdnsTask> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(HostResolverMdnsTask);
};
} // namespace net
#endif // NET_DNS_HOST_RESOLVER_MDNS_TASK_H_
...@@ -21,7 +21,8 @@ enum class HostResolverSource { ...@@ -21,7 +21,8 @@ enum class HostResolverSource {
// Results will only come from DNS queries. // Results will only come from DNS queries.
DNS, DNS,
// TODO(crbug.com/846423): Add MDNS support. // Results will only come from Multicast DNS queries.
MULTICAST_DNS,
}; };
} // namespace net } // namespace net
......
...@@ -45,6 +45,9 @@ int Bind(const IPEndPoint& multicast_addr, ...@@ -45,6 +45,9 @@ int Bind(const IPEndPoint& multicast_addr,
} // namespace } // namespace
const base::TimeDelta MDnsTransaction::kTransactionTimeout =
base::TimeDelta::FromSeconds(3);
// static // static
std::unique_ptr<MDnsSocketFactory> MDnsSocketFactory::CreateDefault() { std::unique_ptr<MDnsSocketFactory> MDnsSocketFactory::CreateDefault() {
return std::unique_ptr<MDnsSocketFactory>(new MDnsSocketFactoryImpl); return std::unique_ptr<MDnsSocketFactory>(new MDnsSocketFactoryImpl);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <vector> #include <vector>
#include "base/callback.h" #include "base/callback.h"
#include "base/time/time.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/dns/dns_query.h" #include "net/dns/dns_query.h"
...@@ -32,6 +33,8 @@ class RecordParsed; ...@@ -32,6 +33,8 @@ class RecordParsed;
// time out after a reasonable number of seconds. // time out after a reasonable number of seconds.
class NET_EXPORT MDnsTransaction { class NET_EXPORT MDnsTransaction {
public: public:
static const base::TimeDelta kTransactionTimeout;
// Used to signify what type of result the transaction has received. // Used to signify what type of result the transaction has received.
enum Result { enum Result {
// Passed whenever a record is found. // Passed whenever a record is found.
......
...@@ -31,7 +31,6 @@ namespace net { ...@@ -31,7 +31,6 @@ namespace net {
namespace { namespace {
const unsigned MDnsTransactionTimeoutSeconds = 3;
// The fractions of the record's original TTL after which an active listener // The fractions of the record's original TTL after which an active listener
// (one that had |SetActiveRefresh(true)| called) will send a query to refresh // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
// its cache. This happens both at 85% of the original TTL and again at 95% of // its cache. This happens both at 85% of the original TTL and again at 95% of
...@@ -48,7 +47,7 @@ void MDnsSocketFactoryImpl::CreateSockets( ...@@ -48,7 +47,7 @@ void MDnsSocketFactoryImpl::CreateSockets(
DCHECK(interfaces[i].second == ADDRESS_FAMILY_IPV4 || DCHECK(interfaces[i].second == ADDRESS_FAMILY_IPV4 ||
interfaces[i].second == ADDRESS_FAMILY_IPV6); interfaces[i].second == ADDRESS_FAMILY_IPV6);
std::unique_ptr<DatagramServerSocket> socket(CreateAndBindMDnsSocket( std::unique_ptr<DatagramServerSocket> socket(CreateAndBindMDnsSocket(
interfaces[i].second, interfaces[i].first, nullptr)); interfaces[i].second, interfaces[i].first, net_log_));
if (socket) if (socket)
sockets->push_back(std::move(socket)); sockets->push_back(std::move(socket));
} }
...@@ -723,8 +722,7 @@ bool MDnsTransactionImpl::QueryAndListen() { ...@@ -723,8 +722,7 @@ bool MDnsTransactionImpl::QueryAndListen() {
timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver, timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver,
AsWeakPtr())); AsWeakPtr()));
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE, timeout_.callback(), FROM_HERE, timeout_.callback(), kTransactionTimeout);
base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds));
return true; return true;
} }
......
...@@ -34,15 +34,20 @@ class OneShotTimer; ...@@ -34,15 +34,20 @@ class OneShotTimer;
namespace net { namespace net {
class NetLog;
class MDnsSocketFactoryImpl : public MDnsSocketFactory { class MDnsSocketFactoryImpl : public MDnsSocketFactory {
public: public:
MDnsSocketFactoryImpl() {} MDnsSocketFactoryImpl() : net_log_(nullptr) {}
explicit MDnsSocketFactoryImpl(NetLog* net_log) : net_log_(net_log) {}
~MDnsSocketFactoryImpl() override {} ~MDnsSocketFactoryImpl() override {}
void CreateSockets( void CreateSockets(
std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override; std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override;
private: private:
NetLog* const net_log_;
DISALLOW_COPY_AND_ASSIGN(MDnsSocketFactoryImpl); DISALLOW_COPY_AND_ASSIGN(MDnsSocketFactoryImpl);
}; };
......
...@@ -331,6 +331,8 @@ MockHostResolverBase::MockHostResolverBase(bool use_caching) ...@@ -331,6 +331,8 @@ MockHostResolverBase::MockHostResolverBase(bool use_caching)
rules_map_[HostResolverSource::ANY] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::ANY] = CreateCatchAllHostResolverProc();
rules_map_[HostResolverSource::SYSTEM] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::SYSTEM] = CreateCatchAllHostResolverProc();
rules_map_[HostResolverSource::DNS] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::DNS] = CreateCatchAllHostResolverProc();
rules_map_[HostResolverSource::MULTICAST_DNS] =
CreateCatchAllHostResolverProc();
if (use_caching) { if (use_caching) {
cache_.reset(new HostCache(kMaxCacheEntries)); cache_.reset(new HostCache(kMaxCacheEntries));
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/mock_mdns_client.h"
namespace net {
MockMDnsTransaction::MockMDnsTransaction() = default;
MockMDnsTransaction::~MockMDnsTransaction() = default;
MockMDnsClient::MockMDnsClient() = default;
MockMDnsClient::~MockMDnsClient() = default;
} // namespace net
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_DNS_MOCK_MDNS_CLIENT_H_
#define NET_DNS_MOCK_MDNS_CLIENT_H_
#include <memory>
#include <string>
#include "net/dns/mdns_client.h"
#include "testing/gmock/include/gmock/gmock.h"
namespace net {
class MockMDnsTransaction : public MDnsTransaction {
public:
MockMDnsTransaction();
~MockMDnsTransaction();
MOCK_METHOD0(Start, bool());
MOCK_CONST_METHOD0(GetName, const std::string&());
MOCK_CONST_METHOD0(GetType, uint16_t());
};
class MockMDnsClient : public MDnsClient {
public:
MockMDnsClient();
~MockMDnsClient();
MOCK_METHOD3(CreateListener,
std::unique_ptr<MDnsListener>(uint16_t,
const std::string&,
MDnsListener::Delegate*));
MOCK_METHOD4(
CreateTransaction,
std::unique_ptr<MDnsTransaction>(uint16_t,
const std::string&,
int,
const MDnsTransaction::ResultCallback&));
MOCK_METHOD1(StartListening, bool(MDnsSocketFactory*));
MOCK_METHOD0(StopListening, void());
MOCK_CONST_METHOD0(IsListening, bool());
};
} // namespace net
#endif // NET_DNS_MOCK_MDNS_CLIENT_H_
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "net/base/host_port_pair.h" #include "net/base/host_port_pair.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/dns/host_resolver.h" #include "net/dns/host_resolver.h"
#include "net/dns/host_resolver_source.h"
#include "net/log/net_log.h" #include "net/log/net_log.h"
#include "services/network/resolve_host_request.h" #include "services/network/resolve_host_request.h"
...@@ -68,6 +69,13 @@ void HostResolver::ResolveHost( ...@@ -68,6 +69,13 @@ void HostResolver::ResolveHost(
const net::HostPortPair& host, const net::HostPortPair& host,
mojom::ResolveHostParametersPtr optional_parameters, mojom::ResolveHostParametersPtr optional_parameters,
mojom::ResolveHostClientPtr response_client) { mojom::ResolveHostClientPtr response_client) {
#if !BUILDFLAG(ENABLE_MDNS)
// TODO(crbug.com/821021): Handle without crashing if we create restricted
// HostResolvers for passing to untrusted processes.
DCHECK(!optional_parameters ||
optional_parameters->source != net::HostResolverSource::MULTICAST_DNS);
#endif // !BUILDFLAG(ENABLE_MDNS)
if (resolve_host_callback.Get()) if (resolve_host_callback.Get())
resolve_host_callback.Get().Run(host.host()); resolve_host_callback.Get().Run(host.host());
......
...@@ -206,6 +206,7 @@ TEST_F(HostResolverTest, Source) { ...@@ -206,6 +206,7 @@ TEST_F(HostResolverTest, Source) {
constexpr char kAnyResult[] = "1.2.3.4"; constexpr char kAnyResult[] = "1.2.3.4";
constexpr char kSystemResult[] = "127.0.0.1"; constexpr char kSystemResult[] = "127.0.0.1";
constexpr char kDnsResult[] = "168.100.12.23"; constexpr char kDnsResult[] = "168.100.12.23";
constexpr char kMdnsResult[] = "200.1.2.3";
auto inner_resolver = std::make_unique<net::MockHostResolver>(); auto inner_resolver = std::make_unique<net::MockHostResolver>();
inner_resolver->rules_map()[net::HostResolverSource::ANY]->AddRule( inner_resolver->rules_map()[net::HostResolverSource::ANY]->AddRule(
kDomain, kAnyResult); kDomain, kAnyResult);
...@@ -213,6 +214,8 @@ TEST_F(HostResolverTest, Source) { ...@@ -213,6 +214,8 @@ TEST_F(HostResolverTest, Source) {
kDomain, kSystemResult); kDomain, kSystemResult);
inner_resolver->rules_map()[net::HostResolverSource::DNS]->AddRule( inner_resolver->rules_map()[net::HostResolverSource::DNS]->AddRule(
kDomain, kDnsResult); kDomain, kDnsResult);
inner_resolver->rules_map()[net::HostResolverSource::MULTICAST_DNS]->AddRule(
kDomain, kMdnsResult);
net::NetLog net_log; net::NetLog net_log;
HostResolver resolver(inner_resolver.get(), &net_log); HostResolver resolver(inner_resolver.get(), &net_log);
...@@ -258,6 +261,23 @@ TEST_F(HostResolverTest, Source) { ...@@ -258,6 +261,23 @@ TEST_F(HostResolverTest, Source) {
EXPECT_EQ(net::OK, dns_client.result_error()); EXPECT_EQ(net::OK, dns_client.result_error());
EXPECT_THAT(dns_client.result_addresses().value().endpoints(), EXPECT_THAT(dns_client.result_addresses().value().endpoints(),
testing::ElementsAre(CreateExpectedEndPoint(kDnsResult, 80))); testing::ElementsAre(CreateExpectedEndPoint(kDnsResult, 80)));
#if BUILDFLAG(ENABLE_MDNS)
base::RunLoop mdns_run_loop;
mojom::ResolveHostClientPtr mdns_client_ptr;
TestResolveHostClient mdns_client(&mdns_client_ptr, &mdns_run_loop);
mojom::ResolveHostParametersPtr mdns_parameters =
mojom::ResolveHostParameters::New();
mdns_parameters->source = net::HostResolverSource::MULTICAST_DNS;
resolver.ResolveHost(net::HostPortPair(kDomain, 80),
std::move(mdns_parameters), std::move(mdns_client_ptr));
mdns_run_loop.Run();
EXPECT_EQ(net::OK, mdns_client.result_error());
EXPECT_THAT(mdns_client.result_addresses().value().endpoints(),
testing::ElementsAre(CreateExpectedEndPoint(kMdnsResult, 80)));
#endif // BUILDFLAG(ENABLE_MDNS)
} }
// Test that cached results are properly keyed by requested source. // Test that cached results are properly keyed by requested source.
......
...@@ -52,6 +52,8 @@ EnumTraits<ResolveHostParameters::Source, net::HostResolverSource>::ToMojom( ...@@ -52,6 +52,8 @@ EnumTraits<ResolveHostParameters::Source, net::HostResolverSource>::ToMojom(
return ResolveHostParameters::Source::SYSTEM; return ResolveHostParameters::Source::SYSTEM;
case net::HostResolverSource::DNS: case net::HostResolverSource::DNS:
return ResolveHostParameters::Source::DNS; return ResolveHostParameters::Source::DNS;
case net::HostResolverSource::MULTICAST_DNS:
return ResolveHostParameters::Source::MULTICAST_DNS;
} }
} }
...@@ -69,6 +71,9 @@ bool EnumTraits<ResolveHostParameters::Source, net::HostResolverSource>:: ...@@ -69,6 +71,9 @@ bool EnumTraits<ResolveHostParameters::Source, net::HostResolverSource>::
case ResolveHostParameters::Source::DNS: case ResolveHostParameters::Source::DNS:
*output = net::HostResolverSource::DNS; *output = net::HostResolverSource::DNS;
return true; return true;
case ResolveHostParameters::Source::MULTICAST_DNS:
*output = net::HostResolverSource::MULTICAST_DNS;
return true;
} }
} }
......
...@@ -65,7 +65,8 @@ struct ResolveHostParameters { ...@@ -65,7 +65,8 @@ struct ResolveHostParameters {
// Results will only come from DNS queries. // Results will only come from DNS queries.
DNS, DNS,
// TODO(crbug.com/846423): Add MDNS support. // Results will only come from Multicast DNS queries.
MULTICAST_DNS,
}; };
// The source to use for resolved addresses. Default allows the resolver to // The source to use for resolved addresses. Default allows the resolver to
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment