Commit ad24b182 authored by szym@chromium.org's avatar szym@chromium.org

Isolates generic DnsClient from AsyncHostResolver.

DnsClient provides a generic DNS client that allows fetching resource records. 
DnsClient is very lightweight and does not support aggregation, queuing or 
prioritization of requests. 

This is the first CL in a series to merge AsyncHostResolver into 
HostResolverImpl. 

Also introduces general-purpose BigEndianReader/Writer. 

BUG=90881
TEST=./net_unittests


Review URL: http://codereview.chromium.org/8762001

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@113282 0039d316-1c4b-4281-b951-d872f2087c98
parent d7de5787
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/base/big_endian.h"
#include "base/string_piece.h"
namespace net {
BigEndianReader::BigEndianReader(const void* buf, size_t len)
: ptr_(reinterpret_cast<const char*>(buf)), end_(ptr_ + len) {}
bool BigEndianReader::Skip(size_t len) {
if (ptr_ + len > end_)
return false;
ptr_ += len;
return true;
}
bool BigEndianReader::ReadBytes(void* out, size_t len) {
if (ptr_ + len > end_)
return false;
memcpy(out, ptr_, len);
ptr_ += len;
return true;
}
bool BigEndianReader::ReadPiece(base::StringPiece* out, size_t len) {
if (ptr_ + len > end_)
return false;
*out = base::StringPiece(ptr_, len);
ptr_ += len;
return true;
}
template<typename T>
bool BigEndianReader::Read(T* value) {
if (ptr_ + sizeof(T) > end_)
return false;
ReadBigEndian<T>(ptr_, value);
ptr_ += sizeof(T);
return true;
}
bool BigEndianReader::ReadU8(uint8* value) {
return Read(value);
}
bool BigEndianReader::ReadU16(uint16* value) {
return Read(value);
}
bool BigEndianReader::ReadU32(uint32* value) {
return Read(value);
}
BigEndianWriter::BigEndianWriter(void* buf, size_t len)
: ptr_(reinterpret_cast<char*>(buf)), end_(ptr_ + len) {}
bool BigEndianWriter::Skip(size_t len) {
if (ptr_ + len > end_)
return false;
ptr_ += len;
return true;
}
bool BigEndianWriter::WriteBytes(const void* buf, size_t len) {
if (ptr_ + len > end_)
return false;
memcpy(ptr_, buf, len);
ptr_ += len;
return true;
}
template<typename T>
bool BigEndianWriter::Write(T value) {
if (ptr_ + sizeof(T) > end_)
return false;
WriteBigEndian<T>(ptr_, value);
ptr_ += sizeof(T);
return true;
}
bool BigEndianWriter::WriteU8(uint8 value) {
return Write(value);
}
bool BigEndianWriter::WriteU16(uint16 value) {
return Write(value);
}
bool BigEndianWriter::WriteU32(uint32 value) {
return Write(value);
}
} // namespace net
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_BASE_BIG_ENDIAN_H_
#define NET_BASE_BIG_ENDIAN_H_
#pragma once
#include "base/basictypes.h"
#include "net/base/net_export.h"
namespace base {
class StringPiece;
}
namespace net {
// Read an integer (signed or unsigned) from |buf| in Big Endian order.
// Note: this loop is unrolled with -O1 and above.
// NOTE(szym): glibc dns-canon.c and SpdyFrameBuilder use
// ntohs(*(uint16_t*)ptr) which is potentially unaligned.
// This would cause SIGBUS on ARMv5 or earlier and ARMv6-M.
template<typename T>
inline void ReadBigEndian(const char buf[], T* out) {
*out = buf[0];
for (size_t i = 1; i < sizeof(T); ++i) {
*out <<= 8;
// Must cast to uint8 to avoid clobbering by sign extension.
*out |= static_cast<uint8>(buf[i]);
}
}
// Write an integer (signed or unsigned) |val| to |buf| in Big Endian order.
// Note: this loop is unrolled with -O1 and above.
template<typename T>
inline void WriteBigEndian(char buf[], T val) {
for (size_t i = 0; i < sizeof(T); ++i) {
buf[sizeof(T)-i-1] = static_cast<char>(val & 0xFF);
val >>= 8;
}
}
// Specializations to make clang happy about the (dead code) shifts above.
template<>
inline void ReadBigEndian<uint8>(const char buf[], uint8* out) {
*out = buf[0];
}
template<>
inline void WriteBigEndian<uint8>(char buf[], uint8 val) {
buf[0] = static_cast<char>(val);
}
// Allows reading integers in network order (big endian) while iterating over
// an underlying buffer. All the reading functions advance the internal pointer.
class NET_EXPORT BigEndianReader {
public:
BigEndianReader(const void* buf, size_t len);
const char* ptr() const { return ptr_; }
int remaining() const { return end_ - ptr_; }
bool Skip(size_t len);
bool ReadBytes(void* out, size_t len);
// Creates a StringPiece in |out| that points to the underlying buffer.
bool ReadPiece(base::StringPiece* out, size_t len);
bool ReadU8(uint8* value);
bool ReadU16(uint16* value);
bool ReadU32(uint32* value);
private:
// Hidden to promote type safety.
template<typename T>
bool Read(T* v);
const char* ptr_;
const char* end_;
};
// Allows writing integers in network order (big endian) while iterating over
// an underlying buffer. All the writing functions advance the internal pointer.
class NET_EXPORT BigEndianWriter {
public:
BigEndianWriter(void* buf, size_t len);
char* ptr() const { return ptr_; }
int remaining() const { return end_ - ptr_; }
bool Skip(size_t len);
bool WriteBytes(const void* buf, size_t len);
bool WriteU8(uint8 value);
bool WriteU16(uint16 value);
bool WriteU32(uint32 value);
private:
// Hidden to promote type safety.
template<typename T>
bool Write(T v);
char* ptr_;
char* end_;
};
} // namespace net
#endif // NET_BASE_BIG_ENDIAN_H_
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "base/string_piece.h"
#include "net/base/big_endian.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
TEST(BigEndianReaderTest, ReadsValues) {
char data[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xA, 0xB, 0xC };
char buf[2];
uint8 u8;
uint16 u16;
uint32 u32;
base::StringPiece piece;
BigEndianReader reader(data, sizeof(data));
EXPECT_TRUE(reader.Skip(2));
EXPECT_EQ(data + 2, reader.ptr());
EXPECT_EQ(reader.remaining(), static_cast<int>(sizeof(data)) - 2);
EXPECT_TRUE(reader.ReadBytes(buf, sizeof(buf)));
EXPECT_EQ(0x2, buf[0]);
EXPECT_EQ(0x3, buf[1]);
EXPECT_TRUE(reader.ReadU8(&u8));
EXPECT_EQ(0x4, u8);
EXPECT_TRUE(reader.ReadU16(&u16));
EXPECT_EQ(0x0506, u16);
EXPECT_TRUE(reader.ReadU32(&u32));
EXPECT_EQ(0x0708090Au, u32);
base::StringPiece expected(reader.ptr(), 2);
EXPECT_TRUE(reader.ReadPiece(&piece, 2));
EXPECT_EQ(2u, piece.size());
EXPECT_EQ(expected.data(), piece.data());
}
TEST(BigEndianReaderTest, RespectsLength) {
char data[4];
char buf[2];
uint8 u8;
uint16 u16;
uint32 u32;
base::StringPiece piece;
BigEndianReader reader(data, sizeof(data));
// 4 left
EXPECT_FALSE(reader.Skip(6));
EXPECT_TRUE(reader.Skip(1));
// 3 left
EXPECT_FALSE(reader.ReadU32(&u32));
EXPECT_FALSE(reader.ReadPiece(&piece, 4));
EXPECT_TRUE(reader.Skip(2));
// 1 left
EXPECT_FALSE(reader.ReadU16(&u16));
EXPECT_FALSE(reader.ReadBytes(buf, 2));
EXPECT_TRUE(reader.Skip(1));
// 0 left
EXPECT_FALSE(reader.ReadU8(&u8));
EXPECT_EQ(0, reader.remaining());
}
TEST(BigEndianWriterTest, WritesValues) {
char expected[] = { 0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 0xA };
char data[sizeof(expected)];
char buf[] = { 0x2, 0x3 };
memset(data, 0, sizeof(data));
BigEndianWriter writer(data, sizeof(data));
EXPECT_TRUE(writer.Skip(2));
EXPECT_TRUE(writer.WriteBytes(buf, sizeof(buf)));
EXPECT_TRUE(writer.WriteU8(0x4));
EXPECT_TRUE(writer.WriteU16(0x0506));
EXPECT_TRUE(writer.WriteU32(0x0708090A));
EXPECT_EQ(0, memcmp(expected, data, sizeof(expected)));
}
TEST(BigEndianWriterTest, RespectsLength) {
char data[4];
char buf[2];
uint8 u8 = 0;
uint16 u16 = 0;
uint32 u32 = 0;
BigEndianWriter writer(data, sizeof(data));
// 4 left
EXPECT_FALSE(writer.Skip(6));
EXPECT_TRUE(writer.Skip(1));
// 3 left
EXPECT_FALSE(writer.WriteU32(u32));
EXPECT_TRUE(writer.Skip(2));
// 1 left
EXPECT_FALSE(writer.WriteU16(u16));
EXPECT_FALSE(writer.WriteBytes(buf, 2));
EXPECT_TRUE(writer.Skip(1));
// 0 left
EXPECT_FALSE(writer.WriteU8(u8));
EXPECT_EQ(0, writer.remaining());
}
} // namespace net
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace net { namespace net {
// Based on DJB's public domain code. // Based on DJB's public domain code.
bool DNSDomainFromDot(const std::string& dotted, std::string* out) { bool DNSDomainFromDot(const base::StringPiece& dotted, std::string* out) {
const char* buf = dotted.data(); const char* buf = dotted.data();
unsigned n = dotted.size(); unsigned n = dotted.size();
char label[63]; char label[63];
...@@ -56,7 +56,7 @@ bool DNSDomainFromDot(const std::string& dotted, std::string* out) { ...@@ -56,7 +56,7 @@ bool DNSDomainFromDot(const std::string& dotted, std::string* out) {
return true; return true;
} }
std::string DNSDomainToString(const std::string& domain) { std::string DNSDomainToString(const base::StringPiece& domain) {
std::string ret; std::string ret;
for (unsigned i = 0; i < domain.size() && domain[i]; i += domain[i] + 1) { for (unsigned i = 0; i < domain.size() && domain[i]; i += domain[i] + 1) {
...@@ -73,7 +73,7 @@ std::string DNSDomainToString(const std::string& domain) { ...@@ -73,7 +73,7 @@ std::string DNSDomainToString(const std::string& domain) {
if (static_cast<unsigned>(domain[i]) + i + 1 > domain.size()) if (static_cast<unsigned>(domain[i]) + i + 1 > domain.size())
return ""; return "";
ret += domain.substr(i + 1, domain[i]); domain.substr(i + 1, domain[i]).AppendToString(&ret);
} }
return ret; return ret;
} }
...@@ -92,12 +92,13 @@ bool IsSTD3ASCIIValidCharacter(char c) { ...@@ -92,12 +92,13 @@ bool IsSTD3ASCIIValidCharacter(char c) {
return true; return true;
} }
std::string TrimEndingDot(const std::string& host) { std::string TrimEndingDot(const base::StringPiece& host) {
std::string host_trimmed = host; base::StringPiece host_trimmed = host;
size_t len = host_trimmed.length(); size_t len = host_trimmed.length();
if (len > 1 && host_trimmed[len - 1] == '.') if (len > 1 && host_trimmed[len - 1] == '.') {
host_trimmed.erase(len - 1); host_trimmed.remove_suffix(1);
return host_trimmed; }
return host_trimmed.as_string();
} }
bool DnsResponseBuffer::U8(uint8* v) { bool DnsResponseBuffer::U8(uint8* v) {
......
...@@ -19,18 +19,22 @@ namespace net { ...@@ -19,18 +19,22 @@ namespace net {
// //
// dotted: a string in dotted form: "www.google.com" // dotted: a string in dotted form: "www.google.com"
// out: a result in DNS form: "\x03www\x06google\x03com\x00" // out: a result in DNS form: "\x03www\x06google\x03com\x00"
NET_EXPORT_PRIVATE bool DNSDomainFromDot(const std::string& dotted, NET_EXPORT_PRIVATE bool DNSDomainFromDot(const base::StringPiece& dotted,
std::string* out); std::string* out);
// DNSDomainToString coverts a domain in DNS format to a dotted string. // DNSDomainToString coverts a domain in DNS format to a dotted string.
NET_EXPORT_PRIVATE std::string DNSDomainToString(const std::string& domain); NET_EXPORT_PRIVATE std::string DNSDomainToString(
const base::StringPiece& domain);
// Returns true iff the given character is in the set of valid DNS label // Returns true iff the given character is in the set of valid DNS label
// characters as given in RFC 3490, 4.1, 3(a) // characters as given in RFC 3490, 4.1, 3(a)
NET_EXPORT_PRIVATE bool IsSTD3ASCIIValidCharacter(char c); NET_EXPORT_PRIVATE bool IsSTD3ASCIIValidCharacter(char c);
// Returns the hostname by trimming the ending dot, if one exists. // Returns the hostname by trimming the ending dot, if one exists.
NET_EXPORT std::string TrimEndingDot(const std::string& host); NET_EXPORT std::string TrimEndingDot(const base::StringPiece& host);
// TODO(szym): remove all definitions below once DnsRRResolver migrates to
// DnsClient
// DNS class types. // DNS class types.
static const uint16 kClassIN = 1; static const uint16 kClassIN = 1;
......
...@@ -8,12 +8,16 @@ ...@@ -8,12 +8,16 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/message_loop.h"
#include "base/rand_util.h" #include "base/rand_util.h"
#include "base/stl_util.h" #include "base/stl_util.h"
#include "base/values.h" #include "base/values.h"
#include "net/base/address_list.h" #include "net/base/address_list.h"
#include "net/base/dns_util.h" #include "net/base/dns_util.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/dns/dns_protocol.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_session.h"
#include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_factory.h"
namespace net { namespace net {
...@@ -22,7 +26,7 @@ namespace { ...@@ -22,7 +26,7 @@ namespace {
// TODO(agayev): fix this when IPv6 support is added. // TODO(agayev): fix this when IPv6 support is added.
uint16 QueryTypeFromAddressFamily(AddressFamily address_family) { uint16 QueryTypeFromAddressFamily(AddressFamily address_family) {
return kDNS_A; return dns_protocol::kTypeA;
} }
class RequestParameters : public NetLog::EventParameters { class RequestParameters : public NetLog::EventParameters {
...@@ -56,17 +60,22 @@ class RequestParameters : public NetLog::EventParameters { ...@@ -56,17 +60,22 @@ class RequestParameters : public NetLog::EventParameters {
HostResolver* CreateAsyncHostResolver(size_t max_concurrent_resolves, HostResolver* CreateAsyncHostResolver(size_t max_concurrent_resolves,
const IPAddressNumber& dns_ip, const IPAddressNumber& dns_ip,
NetLog* net_log) { NetLog* net_log) {
size_t max_transactions = max_concurrent_resolves; size_t max_dns_requests = max_concurrent_resolves;
if (max_transactions == 0) if (max_dns_requests == 0)
max_transactions = 20; max_dns_requests = 20;
size_t max_pending_requests = max_transactions * 100; size_t max_pending_requests = max_dns_requests * 100;
DnsConfig config;
config.nameservers.push_back(IPEndPoint(dns_ip, 53));
DnsSession* session = new DnsSession(
config,
ClientSocketFactory::GetDefaultFactory(),
base::Bind(&base::RandInt),
net_log);
HostResolver* resolver = new AsyncHostResolver( HostResolver* resolver = new AsyncHostResolver(
IPEndPoint(dns_ip, 53), max_dns_requests,
max_transactions,
max_pending_requests, max_pending_requests,
base::Bind(&base::RandInt),
HostCache::CreateDefaultCache(), HostCache::CreateDefaultCache(),
NULL, DnsClient::CreateClient(session),
net_log); net_log);
return resolver; return resolver;
} }
...@@ -193,19 +202,15 @@ class AsyncHostResolver::Request { ...@@ -193,19 +202,15 @@ class AsyncHostResolver::Request {
}; };
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
AsyncHostResolver::AsyncHostResolver(const IPEndPoint& dns_server, AsyncHostResolver::AsyncHostResolver(size_t max_dns_requests,
size_t max_transactions,
size_t max_pending_requests, size_t max_pending_requests,
const RandIntCallback& rand_int_cb,
HostCache* cache, HostCache* cache,
ClientSocketFactory* factory, DnsClient* client,
NetLog* net_log) NetLog* net_log)
: max_transactions_(max_transactions), : max_dns_requests_(max_dns_requests),
max_pending_requests_(max_pending_requests), max_pending_requests_(max_pending_requests),
dns_server_(dns_server),
rand_int_cb_(rand_int_cb),
cache_(cache), cache_(cache),
factory_(factory), client_(client),
net_log_(net_log) { net_log_(net_log) {
} }
...@@ -215,8 +220,8 @@ AsyncHostResolver::~AsyncHostResolver() { ...@@ -215,8 +220,8 @@ AsyncHostResolver::~AsyncHostResolver() {
it != requestlist_map_.end(); ++it) it != requestlist_map_.end(); ++it)
STLDeleteElements(&it->second); STLDeleteElements(&it->second);
// Destroy transactions. // Destroy DNS requests.
STLDeleteElements(&transactions_); STLDeleteElements(&dns_requests_);
// Destroy pending requests. // Destroy pending requests.
for (size_t i = 0; i < arraysize(pending_requests_); ++i) for (size_t i = 0; i < arraysize(pending_requests_); ++i)
...@@ -240,8 +245,8 @@ int AsyncHostResolver::Resolve(const RequestInfo& info, ...@@ -240,8 +245,8 @@ int AsyncHostResolver::Resolve(const RequestInfo& info,
rv = request->result(); rv = request->result();
else if (AttachToRequestList(request.get())) else if (AttachToRequestList(request.get()))
rv = ERR_IO_PENDING; rv = ERR_IO_PENDING;
else if (transactions_.size() < max_transactions_) else if (dns_requests_.size() < max_dns_requests_)
rv = StartNewTransactionFor(request.get()); rv = StartNewDnsRequestFor(request.get());
else else
rv = Enqueue(request.get()); rv = Enqueue(request.get());
...@@ -327,39 +332,56 @@ HostCache* AsyncHostResolver::GetHostCache() { ...@@ -327,39 +332,56 @@ HostCache* AsyncHostResolver::GetHostCache() {
return cache_.get(); return cache_.get();
} }
void AsyncHostResolver::OnTransactionComplete( void AsyncHostResolver::OnDnsRequestComplete(
DnsClient::Request* dns_req,
int result, int result,
const DnsTransaction* transaction, const DnsResponse* response) {
const IPAddressList& ip_addresses) { DCHECK(std::find(dns_requests_.begin(), dns_requests_.end(), dns_req)
DCHECK(std::find(transactions_.begin(), transactions_.end(), transaction) != dns_requests_.end());
!= transactions_.end());
DCHECK(requestlist_map_.find(transaction->key()) != requestlist_map_.end());
// If by the time requests that caused |transaction| are cancelled, we do // If by the time requests that caused |dns_req| are cancelled, we do
// not have a port number to associate with the result, therefore, we // not have a port number to associate with the result, therefore, we
// assume the most common port, otherwise we use the port number of the // assume the most common port, otherwise we use the port number of the
// first request. // first request.
RequestList& requests = requestlist_map_[transaction->key()]; KeyRequestListMap::iterator rit = requestlist_map_.find(
std::make_pair(dns_req->qname(), dns_req->qtype()));
DCHECK(rit != requestlist_map_.end());
RequestList& requests = rit->second;
int port = requests.empty() ? 80 : requests.front()->info().port(); int port = requests.empty() ? 80 : requests.front()->info().port();
// Run callback of every request that was depending on this transaction, // Extract AddressList out of DnsResponse.
AddressList addr_list;
if (result == OK) {
IPAddressList ip_addresses;
DnsRecordParser parser = response->Parser();
DnsResourceRecord record;
// TODO(szym): Add stricter checking of names, aliases and address lengths.
while (parser.ParseRecord(&record)) {
if (record.type == dns_req->qtype() &&
(record.rdata.size() == kIPv4AddressSize ||
record.rdata.size() == kIPv6AddressSize)) {
ip_addresses.push_back(IPAddressNumber(record.rdata.begin(),
record.rdata.end()));
}
}
if (!ip_addresses.empty())
addr_list = AddressList::CreateFromIPAddressList(ip_addresses, port);
else
result = ERR_NAME_NOT_RESOLVED;
}
// Run callback of every request that was depending on this DNS request,
// also notify observers. // also notify observers.
AddressList addrlist; for (RequestList::iterator it = requests.begin(); it != requests.end(); ++it)
if (result == OK) (*it)->OnAsyncComplete(result, addr_list);
addrlist = AddressList::CreateFromIPAddressList(ip_addresses, port);
for (RequestList::iterator it = requests.begin(); it != requests.end(); // It is possible that the requests that caused |dns_req| to be
++it) // created are cancelled by the time |dns_req| completes. In that
(*it)->OnAsyncComplete(result, addrlist);
// It is possible that the requests that caused |transaction| to be
// created are cancelled by the time |transaction| completes. In that
// case |requests| would be empty. We are knowingly throwing away the // case |requests| would be empty. We are knowingly throwing away the
// result of a DNS resolution in that case, because (a) if there are no // result of a DNS resolution in that case, because (a) if there are no
// requests, we do not have info to obtain a key from, (b) DnsTransaction // requests, we do not have info to obtain a key from, (b) DnsTransaction
// does not have info(), adding one into it just temporarily doesn't make // does not have info(), adding one into it just temporarily doesn't make
// sense, since HostCache will be replaced with RR cache soon, (c) // sense, since HostCache will be replaced with RR cache soon.
// recreating info from DnsTransaction::Key adds a lot of temporary
// code/functions (like converting back from qtype to AddressFamily.)
// Also, we only cache positive results. All of this will change when RR // Also, we only cache positive results. All of this will change when RR
// cache is added. // cache is added.
if (result == OK && cache_.get() && !requests.empty()) { if (result == OK && cache_.get() && !requests.empty()) {
...@@ -367,16 +389,16 @@ void AsyncHostResolver::OnTransactionComplete( ...@@ -367,16 +389,16 @@ void AsyncHostResolver::OnTransactionComplete(
HostResolver::RequestInfo info = request->info(); HostResolver::RequestInfo info = request->info();
HostCache::Key key( HostCache::Key key(
info.hostname(), info.address_family(), info.host_resolver_flags()); info.hostname(), info.address_family(), info.host_resolver_flags());
cache_->Set(key, result, addrlist, base::TimeTicks::Now()); cache_->Set(key, result, addr_list, base::TimeTicks::Now());
} }
// Cleanup requests. // Cleanup requests.
STLDeleteElements(&requests); STLDeleteElements(&requests);
requestlist_map_.erase(transaction->key()); requestlist_map_.erase(rit);
// Cleanup transaction and start a new one if there are pending requests. // Cleanup |dns_req| and start a new one if there are pending requests.
delete transaction; delete dns_req;
transactions_.remove(transaction); dns_requests_.remove(dns_req);
ProcessPending(); ProcessPending();
} }
...@@ -399,25 +421,22 @@ bool AsyncHostResolver::AttachToRequestList(Request* request) { ...@@ -399,25 +421,22 @@ bool AsyncHostResolver::AttachToRequestList(Request* request) {
return true; return true;
} }
int AsyncHostResolver::StartNewTransactionFor(Request* request) { int AsyncHostResolver::StartNewDnsRequestFor(Request* request) {
DCHECK(requestlist_map_.find(request->key()) == requestlist_map_.end()); DCHECK(requestlist_map_.find(request->key()) == requestlist_map_.end());
DCHECK(transactions_.size() < max_transactions_); DCHECK(dns_requests_.size() < max_dns_requests_);
request->request_net_log().AddEvent( request->request_net_log().AddEvent(
NetLog::TYPE_ASYNC_HOST_RESOLVER_CREATE_DNS_TRANSACTION, NULL); NetLog::TYPE_ASYNC_HOST_RESOLVER_CREATE_DNS_TRANSACTION, NULL);
requestlist_map_[request->key()].push_back(request); requestlist_map_[request->key()].push_back(request);
DnsTransaction* transaction = new DnsTransaction( DnsClient::Request* dns_req = client_->CreateRequest(
dns_server_,
request->key().first, request->key().first,
request->key().second, request->key().second,
rand_int_cb_, base::Bind(&AsyncHostResolver::OnDnsRequestComplete,
factory_, base::Unretained(this)),
request->request_net_log(), request->request_net_log());
net_log_); dns_requests_.push_back(dns_req);
transaction->SetDelegate(this); return dns_req->Start();
transactions_.push_back(transaction);
return transaction->Start();
} }
int AsyncHostResolver::Enqueue(Request* request) { int AsyncHostResolver::Enqueue(Request* request) {
...@@ -490,7 +509,7 @@ void AsyncHostResolver::ProcessPending() { ...@@ -490,7 +509,7 @@ void AsyncHostResolver::ProcessPending() {
} }
} }
} }
StartNewTransactionFor(request); StartNewDnsRequestFor(request);
} }
} // namespace net } // namespace net
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <list> #include <list>
#include <map> #include <map>
#include <string>
#include <utility>
#include "base/threading/non_thread_safe.h" #include "base/threading/non_thread_safe.h"
#include "net/base/address_family.h" #include "net/base/address_family.h"
...@@ -15,24 +17,18 @@ ...@@ -15,24 +17,18 @@
#include "net/base/host_resolver.h" #include "net/base/host_resolver.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
#include "net/base/net_log.h" #include "net/base/net_log.h"
#include "net/base/rand_callback.h" #include "net/dns/dns_client.h"
#include "net/dns/dns_transaction.h"
namespace net { namespace net {
class ClientSocketFactory;
class NET_EXPORT AsyncHostResolver class NET_EXPORT AsyncHostResolver
: public HostResolver, : public HostResolver,
public DnsTransaction::Delegate,
NON_EXPORTED_BASE(public base::NonThreadSafe) { NON_EXPORTED_BASE(public base::NonThreadSafe) {
public: public:
AsyncHostResolver(const IPEndPoint& dns_server, AsyncHostResolver(size_t max_dns_requests,
size_t max_transactions, size_t max_pending_requests,
size_t max_pending_requests_,
const RandIntCallback& rand_int,
HostCache* cache, HostCache* cache,
ClientSocketFactory* factory, DnsClient* client,
NetLog* net_log); NetLog* net_log);
virtual ~AsyncHostResolver(); virtual ~AsyncHostResolver();
...@@ -50,11 +46,9 @@ class NET_EXPORT AsyncHostResolver ...@@ -50,11 +46,9 @@ class NET_EXPORT AsyncHostResolver
virtual AddressFamily GetDefaultAddressFamily() const OVERRIDE; virtual AddressFamily GetDefaultAddressFamily() const OVERRIDE;
virtual HostCache* GetHostCache() OVERRIDE; virtual HostCache* GetHostCache() OVERRIDE;
// DnsTransaction::Delegate interface void OnDnsRequestComplete(DnsClient::Request* request,
virtual void OnTransactionComplete( int result,
int result, const DnsResponse* transaction);
const DnsTransaction* transaction,
const IPAddressList& ip_addresses) OVERRIDE;
private: private:
FRIEND_TEST_ALL_PREFIXES(AsyncHostResolverTest, QueuedLookup); FRIEND_TEST_ALL_PREFIXES(AsyncHostResolverTest, QueuedLookup);
...@@ -68,9 +62,9 @@ class NET_EXPORT AsyncHostResolver ...@@ -68,9 +62,9 @@ class NET_EXPORT AsyncHostResolver
class Request; class Request;
typedef DnsTransaction::Key Key; typedef std::pair<std::string, uint16> Key;
typedef std::list<Request*> RequestList; typedef std::list<Request*> RequestList;
typedef std::list<const DnsTransaction*> TransactionList; typedef std::list<const DnsClient::Request*> DnsRequestList;
typedef std::map<Key, RequestList> KeyRequestListMap; typedef std::map<Key, RequestList> KeyRequestListMap;
// Create a new request for the incoming Resolve() call. // Create a new request for the incoming Resolve() call.
...@@ -92,9 +86,9 @@ class NET_EXPORT AsyncHostResolver ...@@ -92,9 +86,9 @@ class NET_EXPORT AsyncHostResolver
// attach |request| to the respective list. // attach |request| to the respective list.
bool AttachToRequestList(Request* request); bool AttachToRequestList(Request* request);
// Will start a new transaction for |request|, will insert a new key in // Will start a new DNS request for |request|, will insert a new key in
// |requestlist_map_| and append |request| to the respective list. // |requestlist_map_| and append |request| to the respective list.
int StartNewTransactionFor(Request* request); int StartNewDnsRequestFor(Request* request);
// Will enqueue |request| in |pending_requests_|. // Will enqueue |request| in |pending_requests_|.
int Enqueue(Request* request); int Enqueue(Request* request);
...@@ -114,11 +108,11 @@ class NET_EXPORT AsyncHostResolver ...@@ -114,11 +108,11 @@ class NET_EXPORT AsyncHostResolver
// there are pending requests. // there are pending requests.
void ProcessPending(); void ProcessPending();
// Maximum number of concurrent transactions. // Maximum number of concurrent DNS requests.
size_t max_transactions_; size_t max_dns_requests_;
// List of current transactions. // List of current DNS requests.
TransactionList transactions_; DnsRequestList dns_requests_;
// A map from Key to a list of requests waiting for the Key to resolve. // A map from Key to a list of requests waiting for the Key to resolve.
KeyRequestListMap requestlist_map_; KeyRequestListMap requestlist_map_;
...@@ -129,18 +123,10 @@ class NET_EXPORT AsyncHostResolver ...@@ -129,18 +123,10 @@ class NET_EXPORT AsyncHostResolver
// Queues based on priority for putting pending requests. // Queues based on priority for putting pending requests.
RequestList pending_requests_[NUM_PRIORITIES]; RequestList pending_requests_[NUM_PRIORITIES];
// DNS server to which queries will be setn.
IPEndPoint dns_server_;
// Callback to be passed to DnsTransaction for generating DNS query ids.
RandIntCallback rand_int_cb_;
// Cache of host resolution results. // Cache of host resolution results.
scoped_ptr<HostCache> cache_; scoped_ptr<HostCache> cache_;
// Also passed to DnsTransaction; it's a dependency injection to aid DnsClient* client_;
// testing, outside of unit tests, its value is always NULL.
ClientSocketFactory* factory_;
NetLog* net_log_; NetLog* net_log_;
......
...@@ -6,18 +6,27 @@ ...@@ -6,18 +6,27 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/memory/scoped_ptr.h" #include "base/memory/scoped_ptr.h"
#include "base/message_loop.h"
#include "base/stl_util.h"
#include "net/base/host_cache.h" #include "net/base/host_cache.h"
#include "net/base/net_errors.h"
#include "net/base/net_log.h" #include "net/base/net_log.h"
#include "net/base/rand_callback.h"
#include "net/base/sys_addrinfo.h" #include "net/base/sys_addrinfo.h"
#include "net/base/test_completion_callback.h"
#include "net/dns/dns_client.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_test_util.h" #include "net/dns/dns_test_util.h"
#include "net/socket/socket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace net { namespace net {
namespace { namespace {
const int kPortNum = 80;
const size_t kMaxTransactions = 2;
const size_t kMaxPendingRequests = 1;
void VerifyAddressList(const std::vector<const char*>& ip_addresses, void VerifyAddressList(const std::vector<const char*>& ip_addresses,
int port, int port,
const AddressList& addrlist) { const AddressList& addrlist) {
...@@ -39,12 +48,64 @@ void VerifyAddressList(const std::vector<const char*>& ip_addresses, ...@@ -39,12 +48,64 @@ void VerifyAddressList(const std::vector<const char*>& ip_addresses,
ASSERT_EQ(static_cast<addrinfo*>(NULL), ainfo); ASSERT_EQ(static_cast<addrinfo*>(NULL), ainfo);
} }
class MockDnsClient : public DnsClient,
public base::SupportsWeakPtr<MockDnsClient> {
public:
// Using WeakPtr to support cancellation.
// All MockRequests succeed unless canceled or MockDnsClient is destroyed.
class MockRequest : public DnsClient::Request,
public base::SupportsWeakPtr<MockRequest> {
public:
MockRequest(const base::StringPiece& qname,
uint16 qtype,
const RequestCallback& callback,
const base::WeakPtr<MockDnsClient>& client)
: Request(qname, qtype, callback), started_(false), client_(client) {
}
virtual int Start() OVERRIDE {
EXPECT_FALSE(started_);
started_ = true;
MessageLoop::current()->PostTask(
FROM_HERE,
base::Bind(&MockRequest::Finish, AsWeakPtr()));
return ERR_IO_PENDING;
}
private:
void Finish() {
if (!client_) {
DoCallback(ERR_DNS_SERVER_FAILED, NULL);
return;
}
DoCallback(OK, client_->responses[Key(qname(), qtype())]);
}
bool started_;
base::WeakPtr<MockDnsClient> client_;
};
typedef std::pair<std::string, uint16> Key;
MockDnsClient() : num_requests(0) {}
~MockDnsClient() {
STLDeleteValues(&responses);
}
Request* CreateRequest(const base::StringPiece& qname,
uint16 qtype,
const RequestCallback& callback,
const BoundNetLog&) {
++num_requests;
return new MockRequest(qname, qtype, callback, AsWeakPtr());
}
int num_requests;
std::map<Key, DnsResponse*> responses;
};
} // namespace } // namespace
static const int kPortNum = 80;
static const size_t kMaxTransactions = 2;
static const size_t kMaxPendingRequests = 1;
static int transaction_ids[] = {0, 1, 2, 3};
// The following fixture sets up an environment for four different lookups // The following fixture sets up an environment for four different lookups
// with their data defined in dns_test_util.h. All tests make use of these // with their data defined in dns_test_util.h. All tests make use of these
...@@ -69,84 +130,52 @@ class AsyncHostResolverTest : public testing::Test { ...@@ -69,84 +130,52 @@ class AsyncHostResolverTest : public testing::Test {
ip_addresses2_(kT2IpAddresses, ip_addresses2_(kT2IpAddresses,
kT2IpAddresses + arraysize(kT2IpAddresses)), kT2IpAddresses + arraysize(kT2IpAddresses)),
ip_addresses3_(kT3IpAddresses, ip_addresses3_(kT3IpAddresses,
kT3IpAddresses + arraysize(kT3IpAddresses)), kT3IpAddresses + arraysize(kT3IpAddresses)) {
test_prng_(std::deque<int>(
transaction_ids, transaction_ids + arraysize(transaction_ids))) {
rand_int_cb_ = base::Bind(&TestPrng::GetNext,
base::Unretained(&test_prng_));
// AF_INET only for now. // AF_INET only for now.
info0_.set_address_family(ADDRESS_FAMILY_IPV4); info0_.set_address_family(ADDRESS_FAMILY_IPV4);
info1_.set_address_family(ADDRESS_FAMILY_IPV4); info1_.set_address_family(ADDRESS_FAMILY_IPV4);
info2_.set_address_family(ADDRESS_FAMILY_IPV4); info2_.set_address_family(ADDRESS_FAMILY_IPV4);
info3_.set_address_family(ADDRESS_FAMILY_IPV4); info3_.set_address_family(ADDRESS_FAMILY_IPV4);
// Setup socket read/writes for transaction 0. client_.reset(new MockDnsClient());
writes0_.push_back(
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), AddResponse(std::string(kT0DnsName, arraysize(kT0DnsName)), kT0Qtype,
arraysize(kT0QueryDatagram))); new DnsResponse(reinterpret_cast<const char*>(kT0ResponseDatagram),
reads0_.push_back( arraysize(kT0ResponseDatagram),
MockRead(true, reinterpret_cast<const char*>(kT0ResponseDatagram), arraysize(kT0QueryDatagram)));
arraysize(kT0ResponseDatagram)));
data0_.reset(new StaticSocketDataProvider(&reads0_[0], reads0_.size(), AddResponse(std::string(kT1DnsName, arraysize(kT1DnsName)), kT1Qtype,
&writes0_[0], writes0_.size())); new DnsResponse(reinterpret_cast<const char*>(kT1ResponseDatagram),
arraysize(kT1ResponseDatagram),
// Setup socket read/writes for transaction 1. arraysize(kT1QueryDatagram)));
writes1_.push_back(
MockWrite(true, reinterpret_cast<const char*>(kT1QueryDatagram), AddResponse(std::string(kT2DnsName, arraysize(kT2DnsName)), kT2Qtype,
arraysize(kT1QueryDatagram))); new DnsResponse(reinterpret_cast<const char*>(kT2ResponseDatagram),
reads1_.push_back( arraysize(kT2ResponseDatagram),
MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram), arraysize(kT2QueryDatagram)));
arraysize(kT1ResponseDatagram)));
data1_.reset(new StaticSocketDataProvider(&reads1_[0], reads1_.size(), AddResponse(std::string(kT3DnsName, arraysize(kT3DnsName)), kT3Qtype,
&writes1_[0], writes1_.size())); new DnsResponse(reinterpret_cast<const char*>(kT3ResponseDatagram),
arraysize(kT3ResponseDatagram),
// Setup socket read/writes for transaction 2. arraysize(kT3QueryDatagram)));
writes2_.push_back(
MockWrite(true, reinterpret_cast<const char*>(kT2QueryDatagram),
arraysize(kT2QueryDatagram)));
reads2_.push_back(
MockRead(true, reinterpret_cast<const char*>(kT2ResponseDatagram),
arraysize(kT2ResponseDatagram)));
data2_.reset(new StaticSocketDataProvider(&reads2_[0], reads2_.size(),
&writes2_[0], writes2_.size()));
// Setup socket read/writes for transaction 3.
writes3_.push_back(
MockWrite(true, reinterpret_cast<const char*>(kT3QueryDatagram),
arraysize(kT3QueryDatagram)));
reads3_.push_back(
MockRead(true, reinterpret_cast<const char*>(kT3ResponseDatagram),
arraysize(kT3ResponseDatagram)));
data3_.reset(new StaticSocketDataProvider(&reads3_[0], reads3_.size(),
&writes3_[0], writes3_.size()));
factory_.AddSocketDataProvider(data0_.get());
factory_.AddSocketDataProvider(data1_.get());
factory_.AddSocketDataProvider(data2_.get());
factory_.AddSocketDataProvider(data3_.get());
IPEndPoint dns_server;
bool rv0 = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
DCHECK(rv0);
resolver_.reset( resolver_.reset(
new AsyncHostResolver( new AsyncHostResolver(kMaxTransactions, kMaxPendingRequests,
dns_server, kMaxTransactions, kMaxPendingRequests, rand_int_cb_, HostCache::CreateDefaultCache(),
HostCache::CreateDefaultCache(), &factory_, NULL)); client_.get(), NULL));
}
void AddResponse(const std::string& name, uint8 type, DnsResponse* response) {
client_->responses[MockDnsClient::Key(name, type)] = response;
} }
protected: protected:
AddressList addrlist0_, addrlist1_, addrlist2_, addrlist3_; AddressList addrlist0_, addrlist1_, addrlist2_, addrlist3_;
HostResolver::RequestInfo info0_, info1_, info2_, info3_; HostResolver::RequestInfo info0_, info1_, info2_, info3_;
std::vector<MockWrite> writes0_, writes1_, writes2_, writes3_;
std::vector<MockRead> reads0_, reads1_, reads2_, reads3_;
scoped_ptr<StaticSocketDataProvider> data0_, data1_, data2_, data3_;
std::vector<const char*> ip_addresses0_, ip_addresses1_, std::vector<const char*> ip_addresses0_, ip_addresses1_,
ip_addresses2_, ip_addresses3_; ip_addresses2_, ip_addresses3_;
MockClientSocketFactory factory_;
TestPrng test_prng_;
RandIntCallback rand_int_cb_;
scoped_ptr<HostResolver> resolver_; scoped_ptr<HostResolver> resolver_;
scoped_ptr<MockDnsClient> client_;
TestCompletionCallback callback0_, callback1_, callback2_, callback3_; TestCompletionCallback callback0_, callback1_, callback2_, callback3_;
}; };
...@@ -242,7 +271,7 @@ TEST_F(AsyncHostResolverTest, ConcurrentLookup) { ...@@ -242,7 +271,7 @@ TEST_F(AsyncHostResolverTest, ConcurrentLookup) {
EXPECT_EQ(OK, rv2); EXPECT_EQ(OK, rv2);
VerifyAddressList(ip_addresses2_, kPortNum, addrlist2_); VerifyAddressList(ip_addresses2_, kPortNum, addrlist2_);
EXPECT_EQ(3u, factory_.udp_client_sockets().size()); EXPECT_EQ(3, client_->num_requests);
} }
TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) { TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) {
...@@ -270,7 +299,7 @@ TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) { ...@@ -270,7 +299,7 @@ TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) {
VerifyAddressList(ip_addresses0_, kPortNum, addrlist2_); VerifyAddressList(ip_addresses0_, kPortNum, addrlist2_);
// Although we have three lookups, a single UDP socket was used. // Although we have three lookups, a single UDP socket was used.
EXPECT_EQ(1u, factory_.udp_client_sockets().size()); EXPECT_EQ(1, client_->num_requests);
} }
TEST_F(AsyncHostResolverTest, CancelLookup) { TEST_F(AsyncHostResolverTest, CancelLookup) {
...@@ -319,7 +348,7 @@ TEST_F(AsyncHostResolverTest, CancelSameHostLookup) { ...@@ -319,7 +348,7 @@ TEST_F(AsyncHostResolverTest, CancelSameHostLookup) {
EXPECT_EQ(OK, rv1); EXPECT_EQ(OK, rv1);
VerifyAddressList(ip_addresses0_, kPortNum, addrlist1_); VerifyAddressList(ip_addresses0_, kPortNum, addrlist1_);
EXPECT_EQ(1u, factory_.udp_client_sockets().size()); EXPECT_EQ(1, client_->num_requests);
} }
// Test that a queued lookup completes. // Test that a queued lookup completes.
......
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/dns_client.h"
#include "base/bind.h"
#include "base/string_piece.h"
#include "net/base/net_errors.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_session.h"
#include "net/dns/dns_transaction.h"
#include "net/socket/client_socket_factory.h"
namespace net {
DnsClient::Request::Request(const base::StringPiece& qname,
uint16 qtype,
const RequestCallback& callback)
: qname_(qname.data(), qname.size()),
qtype_(qtype),
callback_(callback) {
}
DnsClient::Request::~Request() {}
// Implementation of DnsClient that uses DnsTransaction to serve requests.
class DnsClientImpl : public DnsClient {
public:
class RequestImpl : public Request {
public:
RequestImpl(const base::StringPiece& qname,
uint16 qtype,
const RequestCallback& callback,
DnsSession* session,
const BoundNetLog& net_log)
: Request(qname, qtype, callback),
session_(session),
net_log_(net_log) {
}
virtual int Start() OVERRIDE {
transaction_.reset(new DnsTransaction(
session_,
qname(),
qtype(),
base::Bind(&RequestImpl::OnComplete, base::Unretained(this)),
net_log_));
return transaction_->Start();
}
void OnComplete(DnsTransaction* transaction, int rv) {
DCHECK_EQ(transaction_.get(), transaction);
// TODO(szym):
// - handle retransmissions here instead of DnsTransaction
// - handle rcode and flags here instead of DnsTransaction
// - update RTT in DnsSession
// - perform suffix search
// - handle DNS over TCP
DoCallback(rv, (rv == OK) ? transaction->response() : NULL);
}
private:
scoped_refptr<DnsSession> session_;
BoundNetLog net_log_;
scoped_ptr<DnsTransaction> transaction_;
};
explicit DnsClientImpl(DnsSession* session) {
session_ = session;
}
virtual Request* CreateRequest(
const base::StringPiece& qname,
uint16 qtype,
const RequestCallback& callback,
const BoundNetLog& source_net_log) OVERRIDE {
return new RequestImpl(qname, qtype, callback, session_, source_net_log);
}
private:
scoped_refptr<DnsSession> session_;
};
// static
DnsClient* DnsClient::CreateClient(DnsSession* session) {
return new DnsClientImpl(session);
}
} // namespace net
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_DNS_DNS_CLIENT_H_
#define NET_DNS_DNS_CLIENT_H_
#pragma once
#include <string>
#include "base/basictypes.h"
#include "base/callback.h"
#include "base/memory/weak_ptr.h"
#include "net/base/net_export.h"
namespace base {
class StringPiece;
}
namespace net {
class BoundNetLog;
class ClientSocketFactory;
class DnsResponse;
class DnsSession;
// DnsClient performs asynchronous DNS queries. DnsClient takes care of
// retransmissions, DNS server fallback (or round-robin), suffix search, and
// simple response validation ("does it match the query") to fight poisoning.
// It does NOT perform caching, aggregation or prioritization of requests.
//
// Destroying DnsClient does NOT affect any already created Requests.
//
// TODO(szym): consider adding flags to MakeRequest to indicate options:
// -- don't perform suffix search
// -- query both A and AAAA at once
// -- accept truncated response (and/or forbid TCP)
class NET_EXPORT_PRIVATE DnsClient {
public:
class Request;
// Callback for complete requests. Note that DnsResponse might be NULL if
// the DNS server(s) could not be reached.
typedef base::Callback<void(Request* req,
int result,
const DnsResponse* resp)> RequestCallback;
// A data-holder for a request made to the DnsClient.
// Destroying the request cancels the underlying network effort.
class NET_EXPORT_PRIVATE Request {
public:
Request(const base::StringPiece& qname,
uint16 qtype,
const RequestCallback& callback);
virtual ~Request();
const std::string& qname() const { return qname_; }
uint16 qtype() const { return qtype_; }
virtual int Start() = 0;
void DoCallback(int result, const DnsResponse* response) {
callback_.Run(this, result, response);
}
private:
std::string qname_;
uint16 qtype_;
RequestCallback callback_;
DISALLOW_COPY_AND_ASSIGN(Request);
};
virtual ~DnsClient() {}
// Makes asynchronous DNS query for the given |qname| and |qtype| (assuming
// QCLASS == IN). The caller is responsible for destroying the returned
// request whether to cancel it or after its completion.
// (Destroying DnsClient does not abort the requests.)
virtual Request* CreateRequest(
const base::StringPiece& qname,
uint16 qtype,
const RequestCallback& callback,
const BoundNetLog& source_net_log) WARN_UNUSED_RESULT = 0;
// Creates a socket-based DnsClient using the |session|.
static DnsClient* CreateClient(DnsSession* session) WARN_UNUSED_RESULT;
};
} // namespace net
#endif // NET_DNS_DNS_CLIENT_H_
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/dns_client.h"
#include "base/bind.h"
#include "base/memory/scoped_ptr.h"
#include "net/base/big_endian.h"
#include "net/base/net_log.h"
#include "net/base/sys_addrinfo.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_session.h"
#include "net/dns/dns_test_util.h"
#include "net/socket/socket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
// TODO(szym): test DnsClient::Request::Start with synchronous failure
// TODO(szym): test suffix search and server fallback once implemented
namespace net {
namespace {
class DnsClientTest : public testing::Test {
public:
class TestRequestHelper {
public:
// If |answer_count| < 0, it is the expected error code.
TestRequestHelper(const char* name,
uint16 type,
const MockWrite& write,
const MockRead& read,
int answer_count) {
// Must include the terminating \x00.
qname = std::string(name, strlen(name) + 1);
qtype = type;
expected_answer_count = answer_count;
completed = false;
writes.push_back(write);
reads.push_back(read);
ReadBigEndian<uint16>(write.data, &transaction_id);
data.reset(new StaticSocketDataProvider(&reads[0], reads.size(),
&writes[0], writes.size()));
}
void MakeRequest(DnsClient* client) {
EXPECT_EQ(NULL, request.get());
request.reset(client->CreateRequest(
qname,
qtype,
base::Bind(&TestRequestHelper::OnRequestComplete,
base::Unretained(this)),
BoundNetLog()));
EXPECT_EQ(qname, request->qname());
EXPECT_EQ(qtype, request->qtype());
EXPECT_EQ(ERR_IO_PENDING, request->Start());
}
void Cancel() {
ASSERT_TRUE(request.get() != NULL);
request.reset(NULL);
}
void OnRequestComplete(DnsClient::Request* req,
int rv,
const DnsResponse* response) {
EXPECT_FALSE(completed);
EXPECT_EQ(request.get(), req);
if (expected_answer_count >= 0) {
EXPECT_EQ(OK, rv);
EXPECT_EQ(expected_answer_count, response->answer_count());
DnsRecordParser parser = response->Parser();
DnsResourceRecord record;
for (int i = 0; i < expected_answer_count; ++i) {
EXPECT_TRUE(parser.ParseRecord(&record));
}
EXPECT_TRUE(parser.AtEnd());
} else {
EXPECT_EQ(expected_answer_count, rv);
EXPECT_EQ(NULL, response);
}
completed = true;
}
void CancelOnRequestComplete(DnsClient::Request* req,
int rv,
const DnsResponse* response) {
EXPECT_FALSE(completed);
Cancel();
}
std::string qname;
uint16 qtype;
std::vector<MockWrite> writes;
std::vector<MockRead> reads;
uint16 transaction_id; // Id from first write.
scoped_ptr<StaticSocketDataProvider> data;
scoped_ptr<DnsClient::Request> request;
int expected_answer_count;
bool completed;
};
virtual void SetUp() OVERRIDE {
helpers_.push_back(new TestRequestHelper(
kT0DnsName,
kT0Qtype,
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram)),
MockRead(true, reinterpret_cast<const char*>(kT0ResponseDatagram),
arraysize(kT0ResponseDatagram)),
arraysize(kT0IpAddresses) + 1)); // +1 for CNAME RR
helpers_.push_back(new TestRequestHelper(
kT1DnsName,
kT1Qtype,
MockWrite(true, reinterpret_cast<const char*>(kT1QueryDatagram),
arraysize(kT1QueryDatagram)),
MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram),
arraysize(kT1ResponseDatagram)),
arraysize(kT1IpAddresses) + 1)); // +1 for CNAME RR
helpers_.push_back(new TestRequestHelper(
kT2DnsName,
kT2Qtype,
MockWrite(true, reinterpret_cast<const char*>(kT2QueryDatagram),
arraysize(kT2QueryDatagram)),
MockRead(true, reinterpret_cast<const char*>(kT2ResponseDatagram),
arraysize(kT2ResponseDatagram)),
arraysize(kT2IpAddresses) + 1)); // +1 for CNAME RR
helpers_.push_back(new TestRequestHelper(
kT3DnsName,
kT3Qtype,
MockWrite(true, reinterpret_cast<const char*>(kT3QueryDatagram),
arraysize(kT3QueryDatagram)),
MockRead(true, reinterpret_cast<const char*>(kT3ResponseDatagram),
arraysize(kT3ResponseDatagram)),
arraysize(kT3IpAddresses) + 2)); // +2 for CNAME RR
CreateClient();
}
void CreateClient() {
MockClientSocketFactory* factory = new MockClientSocketFactory();
transaction_ids_.clear();
for (unsigned i = 0; i < helpers_.size(); ++i) {
factory->AddSocketDataProvider(helpers_[i]->data.get());
transaction_ids_.push_back(static_cast<int>(helpers_[i]->transaction_id));
}
DnsConfig config;
IPEndPoint dns_server;
{
bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
EXPECT_TRUE(rv);
}
config.nameservers.push_back(dns_server);
DnsSession* session = new DnsSession(
config,
factory,
base::Bind(&DnsClientTest::GetNextId, base::Unretained(this)),
NULL /* NetLog */);
client_.reset(DnsClient::CreateClient(session));
}
virtual void TearDown() OVERRIDE {
STLDeleteElements(&helpers_);
}
int GetNextId(int min, int max) {
EXPECT_FALSE(transaction_ids_.empty());
int id = transaction_ids_.front();
transaction_ids_.pop_front();
EXPECT_GE(id, min);
EXPECT_LE(id, max);
return id;
}
protected:
std::vector<TestRequestHelper*> helpers_;
std::deque<int> transaction_ids_;
scoped_ptr<DnsClient> client_;
};
TEST_F(DnsClientTest, Lookup) {
helpers_[0]->MakeRequest(client_.get());
// Wait until result.
MessageLoop::current()->RunAllPending();
EXPECT_TRUE(helpers_[0]->completed);
}
TEST_F(DnsClientTest, ConcurrentLookup) {
for (unsigned i = 0; i < helpers_.size(); ++i) {
helpers_[i]->MakeRequest(client_.get());
}
MessageLoop::current()->RunAllPending();
for (unsigned i = 0; i < helpers_.size(); ++i) {
EXPECT_TRUE(helpers_[i]->completed);
}
}
TEST_F(DnsClientTest, CancelLookup) {
for (unsigned i = 0; i < helpers_.size(); ++i) {
helpers_[i]->MakeRequest(client_.get());
}
helpers_[0]->Cancel();
helpers_[2]->Cancel();
MessageLoop::current()->RunAllPending();
EXPECT_FALSE(helpers_[0]->completed);
EXPECT_TRUE(helpers_[1]->completed);
EXPECT_FALSE(helpers_[2]->completed);
EXPECT_TRUE(helpers_[3]->completed);
}
TEST_F(DnsClientTest, DestroyClient) {
for (unsigned i = 0; i < helpers_.size(); ++i) {
helpers_[i]->MakeRequest(client_.get());
}
// Destroying the client does not affect running requests.
client_.reset(NULL);
MessageLoop::current()->RunAllPending();
for (unsigned i = 0; i < helpers_.size(); ++i) {
EXPECT_TRUE(helpers_[i]->completed);
}
}
TEST_F(DnsClientTest, DestroyRequestFromCallback) {
// Custom callback to cancel the completing request.
helpers_[0]->request.reset(client_->CreateRequest(
helpers_[0]->qname,
helpers_[0]->qtype,
base::Bind(&TestRequestHelper::CancelOnRequestComplete,
base::Unretained(helpers_[0])),
BoundNetLog()));
helpers_[0]->request->Start();
for (unsigned i = 1; i < helpers_.size(); ++i) {
helpers_[i]->MakeRequest(client_.get());
}
MessageLoop::current()->RunAllPending();
EXPECT_FALSE(helpers_[0]->completed);
for (unsigned i = 1; i < helpers_.size(); ++i) {
EXPECT_TRUE(helpers_[i]->completed);
}
}
TEST_F(DnsClientTest, HandleFailure) {
STLDeleteElements(&helpers_);
// Wrong question.
helpers_.push_back(new TestRequestHelper(
kT0DnsName,
kT0Qtype,
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram)),
MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram),
arraysize(kT1ResponseDatagram)),
ERR_DNS_MALFORMED_RESPONSE));
// Response with NXDOMAIN.
uint8 nxdomain_response[arraysize(kT0QueryDatagram)];
memcpy(nxdomain_response, kT0QueryDatagram, arraysize(nxdomain_response));
nxdomain_response[2] &= 0x80; // Response bit.
nxdomain_response[3] &= 0x03; // NXDOMAIN bit.
helpers_.push_back(new TestRequestHelper(
kT0DnsName,
kT0Qtype,
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram)),
MockRead(true, reinterpret_cast<const char*>(nxdomain_response),
arraysize(nxdomain_response)),
ERR_NAME_NOT_RESOLVED));
CreateClient();
for (unsigned i = 0; i < helpers_.size(); ++i) {
helpers_[i]->MakeRequest(client_.get());
}
MessageLoop::current()->RunAllPending();
for (unsigned i = 0; i < helpers_.size(); ++i) {
EXPECT_TRUE(helpers_[i]->completed);
}
}
} // namespace
} // namespace net
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_DNS_DNS_PROTOCOL_H_
#define NET_DNS_DNS_PROTOCOL_H_
#pragma once
#include "base/basictypes.h"
#include "net/base/net_export.h"
namespace net {
namespace dns_protocol {
// DNS packet consists of a header followed by questions and/or answers.
// For the meaning of specific fields, please see RFC 1035 and 2535
// Header format.
// 1 1 1 1 1 1
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ID |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |QR| Opcode |AA|TC|RD|RA| Z|AD|CD| RCODE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QDCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ANCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | NSCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ARCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// Question format.
// 1 1 1 1 1 1
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | |
// / QNAME /
// / /
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QTYPE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QCLASS |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// Answer format.
// 1 1 1 1 1 1
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | |
// / /
// / NAME /
// | |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | TYPE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | CLASS |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | TTL |
// | |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | RDLENGTH |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
// / RDATA /
// / /
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#pragma pack(push)
#pragma pack(1)
// On-the-wire header. All uint16 are in network order.
// Used internally in DnsQuery and DnsResponseParser.
struct NET_EXPORT_PRIVATE Header {
uint16 id;
uint8 flags[2];
uint16 qdcount;
uint16 ancount;
uint16 nscount;
uint16 arcount;
};
#pragma pack(pop)
static const uint8 kLabelMask = 0xc0;
static const uint8 kLabelPointer = 0xc0;
static const uint8 kLabelDirect = 0x0;
static const uint16 kOffsetMask = 0x3fff;
static const int kMaxNameLength = 255;
// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512
// bytes (not counting the IP nor UDP headers).
static const int kMaxUDPSize = 512;
// DNS class types.
static const uint16 kClassIN = 1;
// DNS resource record types. See
// http://www.iana.org/assignments/dns-parameters
static const uint16 kTypeA = 1;
static const uint16 kTypeCNAME = 5;
static const uint16 kTypeTXT = 16;
static const uint16 kTypeAAAA = 28;
// DNS rcode values.
static const uint8 kRcodeMask = 0xf;
static const uint8 kRcodeNOERROR = 0;
static const uint8 kRcodeFORMERR = 1;
static const uint8 kRcodeSERVFAIL = 2;
static const uint8 kRcodeNXDOMAIN = 3;
static const uint8 kRcodeNOTIMP = 4;
static const uint8 kRcodeREFUSED = 5;
} // namespace dns_protocol
} // namespace net
#endif // NET_DNS_DNS_PROTOCOL_H_
...@@ -6,91 +6,78 @@ ...@@ -6,91 +6,78 @@
#include <limits> #include <limits>
#include "net/base/big_endian.h"
#include "net/base/dns_util.h" #include "net/base/dns_util.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/sys_byteorder.h"
#include "net/dns/dns_protocol.h"
namespace net { namespace net {
namespace {
void PackUint16BE(char buf[2], uint16 v) {
buf[0] = v >> 8;
buf[1] = v & 0xff;
}
uint16 UnpackUint16BE(char buf[2]) {
return static_cast<uint8>(buf[0]) << 8 | static_cast<uint8>(buf[1]);
}
} // namespace
// DNS query consists of a 12-byte header followed by a question section. // DNS query consists of a 12-byte header followed by a question section.
// For details, see RFC 1035 section 4.1.1. This header template sets RD // For details, see RFC 1035 section 4.1.1. This header template sets RD
// bit, which directs the name server to pursue query recursively, and sets // bit, which directs the name server to pursue query recursively, and sets
// the QDCOUNT to 1, meaning the question section has a single entry. The // the QDCOUNT to 1, meaning the question section has a single entry.
// first two bytes of the header form a 16-bit random query ID to be copied DnsQuery::DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype)
// in the corresponding reply by the name server -- randomized during : qname_size_(qname.size()) {
// DnsQuery construction. DCHECK(!DNSDomainToString(qname).empty());
static const char kHeader[] = {0x00, 0x00, 0x01, 0x00, 0x00, 0x01, // QNAME + QTYPE + QCLASS
0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; size_t question_size = qname_size_ + sizeof(uint16) + sizeof(uint16);
static const size_t kHeaderSize = arraysize(kHeader); io_buffer_ = new IOBufferWithSize(sizeof(dns_protocol::Header) +
question_size);
DnsQuery::DnsQuery(const std::string& qname, dns_protocol::Header* header =
uint16 qtype, reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
const RandIntCallback& rand_int_cb) memset(header, 0, sizeof(dns_protocol::Header));
: qname_size_(qname.size()), header->id = htons(id);
rand_int_cb_(rand_int_cb) { header->flags[0] = 0x1; // RD bit
DCHECK(DnsResponseBuffer(reinterpret_cast<const uint8*>(qname.c_str()), header->qdcount = htons(1);
qname.size()).DNSName(NULL));
DCHECK(qtype == kDNS_A || qtype == kDNS_AAAA); // Write question section after the header.
BigEndianWriter writer(reinterpret_cast<char*>(header + 1), question_size);
io_buffer_ = new IOBufferWithSize(kHeaderSize + question_size()); writer.WriteBytes(qname.data(), qname.size());
writer.WriteU16(qtype);
int byte_offset = 0; writer.WriteU16(dns_protocol::kClassIN);
char* buffer_head = io_buffer_->data();
memcpy(&buffer_head[byte_offset], kHeader, kHeaderSize);
byte_offset += kHeaderSize;
memcpy(&buffer_head[byte_offset], &qname[0], qname_size_);
byte_offset += qname_size_;
PackUint16BE(&buffer_head[byte_offset], qtype);
byte_offset += sizeof(qtype);
PackUint16BE(&buffer_head[byte_offset], kClassIN);
RandomizeId();
} }
DnsQuery::~DnsQuery() { DnsQuery::~DnsQuery() {
} }
uint16 DnsQuery::id() const { DnsQuery* DnsQuery::CloneWithNewId(uint16 id) const {
return UnpackUint16BE(&io_buffer_->data()[0]); return new DnsQuery(*this, id);
} }
uint16 DnsQuery::qtype() const { uint16 DnsQuery::id() const {
return UnpackUint16BE(&io_buffer_->data()[kHeaderSize + qname_size_]); const dns_protocol::Header* header =
} reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data());
return ntohs(header->id);
DnsQuery* DnsQuery::CloneWithNewId() const {
return new DnsQuery(qname(), qtype(), rand_int_cb_);
} }
size_t DnsQuery::question_size() const { base::StringPiece DnsQuery::qname() const {
return qname_size_ // QNAME return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header),
+ sizeof(uint16) // QTYPE qname_size_);
+ sizeof(uint16); // QCLASS
} }
const char* DnsQuery::question_data() const { uint16 DnsQuery::qtype() const {
return &io_buffer_->data()[kHeaderSize]; uint16 type;
ReadBigEndian<uint16>(io_buffer_->data() +
sizeof(dns_protocol::Header) +
qname_size_, &type);
return type;
} }
const std::string DnsQuery::qname() const { base::StringPiece DnsQuery::question() const {
return std::string(question_data(), qname_size_); return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header),
qname_size_ + sizeof(uint16) + sizeof(uint16));
} }
void DnsQuery::RandomizeId() { DnsQuery::DnsQuery(const DnsQuery& orig, uint16 id) {
PackUint16BE(&io_buffer_->data()[0], rand_int_cb_.Run( qname_size_ = orig.qname_size_;
std::numeric_limits<uint16>::min(), io_buffer_ = new IOBufferWithSize(orig.io_buffer()->size());
std::numeric_limits<uint16>::max())); memcpy(io_buffer_.get()->data(), orig.io_buffer()->data(),
io_buffer_.get()->size());
dns_protocol::Header* header =
reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
header->id = htons(id);
} }
} // namespace net } // namespace net
...@@ -6,53 +6,41 @@ ...@@ -6,53 +6,41 @@
#define NET_DNS_DNS_QUERY_H_ #define NET_DNS_DNS_QUERY_H_
#pragma once #pragma once
#include <string> #include "base/basictypes.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/string_piece.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/base/rand_callback.h"
namespace net { namespace net {
class IOBufferWithSize; class IOBufferWithSize;
// Represents on-the-wire DNS query message as an object. // Represents on-the-wire DNS query message as an object.
// TODO(szym): add support for the OPT pseudo-RR (EDNS0/DNSSEC).
class NET_EXPORT_PRIVATE DnsQuery { class NET_EXPORT_PRIVATE DnsQuery {
public: public:
// Constructs a query message from |qname| which *MUST* be in a valid // Constructs a query message from |qname| which *MUST* be in a valid
// DNS name format, and |qtype| which must be either kDNS_A or kDNS_AAAA. // DNS name format, and |qtype|. The qclass is set to IN.
DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype);
// Every generated object has a random ID, hence two objects generated
// with the same set of constructor arguments are generally not equal;
// there is a 1/2^16 chance of them being equal due to size of |id_|.
DnsQuery(const std::string& qname,
uint16 qtype,
const RandIntCallback& rand_int_cb);
~DnsQuery(); ~DnsQuery();
// Clones |this| verbatim, with ID field of the header regenerated. // Clones |this| verbatim, with ID field of the header set to |id|.
DnsQuery* CloneWithNewId() const; DnsQuery* CloneWithNewId(uint16 id) const;
// DnsQuery field accessors. // DnsQuery field accessors.
uint16 id() const; uint16 id() const;
base::StringPiece qname() const;
uint16 qtype() const; uint16 qtype() const;
// Returns the size of the Question section of the query. Used when // Returns the Question section of the query. Used when matching the
// matching the response. // response.
size_t question_size() const; base::StringPiece question() const;
// Returns pointer to the Question section of the query. Used when
// matching the response.
const char* question_data() const;
// IOBuffer accessor to be used for writing out the query. // IOBuffer accessor to be used for writing out the query.
IOBufferWithSize* io_buffer() const { return io_buffer_; } IOBufferWithSize* io_buffer() const { return io_buffer_; }
private: private:
const std::string qname() const; DnsQuery(const DnsQuery& orig, uint16 id);
// Randomizes ID field of the query message.
void RandomizeId();
// Size of the DNS name (*NOT* hostname) we are trying to resolve; used // Size of the DNS name (*NOT* hostname) we are trying to resolve; used
// to calculate offsets. // to calculate offsets.
...@@ -61,9 +49,6 @@ class NET_EXPORT_PRIVATE DnsQuery { ...@@ -61,9 +49,6 @@ class NET_EXPORT_PRIVATE DnsQuery {
// Contains query bytes to be consumed by higher level Write() call. // Contains query bytes to be consumed by higher level Write() call.
scoped_refptr<IOBufferWithSize> io_buffer_; scoped_refptr<IOBufferWithSize> io_buffer_;
// PRNG function for generating IDs.
RandIntCallback rand_int_cb_;
DISALLOW_COPY_AND_ASSIGN(DnsQuery); DISALLOW_COPY_AND_ASSIGN(DnsQuery);
}; };
......
...@@ -5,58 +5,21 @@ ...@@ -5,58 +5,21 @@
#include "net/dns/dns_query.h" #include "net/dns/dns_query.h"
#include "base/bind.h" #include "base/bind.h"
#include "base/rand_util.h"
#include "net/base/dns_util.h" #include "net/base/dns_util.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/dns/dns_protocol.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace net { namespace net {
// DNS query consists of a header followed by a question. Header format namespace {
// and question format are described below. For the meaning of specific
// fields, please see RFC 1035.
// Header format. TEST(DnsQueryTest, Constructor) {
// 1 1 1 1 1 1 // This includes \0 at the end.
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 const char qname_data[] = "\x03""www""\x07""example""\x03""com";
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ID |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QDCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ANCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | NSCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | ARCOUNT |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// Question format.
// 1 1 1 1 1 1
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | |
// / QNAME /
// / /
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QTYPE |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | QCLASS |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
TEST(DnsQueryTest, ConstructorTest) {
std::string kQname("\003www\006google\003com", 16);
DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt));
EXPECT_EQ(kDNS_A, q1.qtype());
uint8 id_hi = q1.id() >> 8, id_lo = q1.id() & 0xff;
// See the top of the file for the description of a DNS query.
const uint8 query_data[] = { const uint8 query_data[] = {
// Header // Header
id_hi, id_lo, 0xbe, 0xef,
0x01, 0x00, // Flags -- set RD (recursion desired) bit. 0x01, 0x00, // Flags -- set RD (recursion desired) bit.
0x00, 0x01, // Set QDCOUNT (question count) to 1, all the 0x00, 0x01, // Set QDCOUNT (question count) to 1, all the
// rest are 0 for a query. // rest are 0 for a query.
...@@ -65,46 +28,42 @@ TEST(DnsQueryTest, ConstructorTest) { ...@@ -65,46 +28,42 @@ TEST(DnsQueryTest, ConstructorTest) {
0x00, 0x00, 0x00, 0x00,
// Question // Question
0x03, 0x77, 0x77, 0x77, // QNAME: www.google.com in DNS format. 0x03, 'w', 'w', 'w', // QNAME: www.example.com in DNS format.
0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
0x03, 0x63, 0x6f, 0x6d, 0x00, 0x03, 'c', 'o', 'm',
0x00,
0x00, 0x01, // QTYPE: A query. 0x00, 0x01, // QTYPE: A query.
0x00, 0x01, // QCLASS: IN class. 0x00, 0x01, // QCLASS: IN class.
}; };
int expected_size = arraysize(query_data); base::StringPiece qname(qname_data, sizeof(qname_data));
EXPECT_EQ(expected_size, q1.io_buffer()->size()); DnsQuery q1(0xbeef, qname, dns_protocol::kTypeA);
EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, expected_size)); EXPECT_EQ(dns_protocol::kTypeA, q1.qtype());
ASSERT_EQ(static_cast<int>(sizeof(query_data)), q1.io_buffer()->size());
EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, sizeof(query_data)));
EXPECT_EQ(qname, q1.qname());
base::StringPiece question(reinterpret_cast<const char*>(query_data) + 12,
21);
EXPECT_EQ(question, q1.question());
} }
TEST(DnsQueryTest, CloneTest) { TEST(DnsQueryTest, Clone) {
std::string kQname("\003www\006google\003com", 16); // This includes \0 at the end.
DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt)); const char qname_data[] = "\x03""www""\x07""example""\x03""com";
base::StringPiece qname(qname_data, sizeof(qname_data));
scoped_ptr<DnsQuery> q2(q1.CloneWithNewId()); DnsQuery q1(0, qname, dns_protocol::kTypeA);
EXPECT_EQ(0, q1.id());
scoped_ptr<DnsQuery> q2(q1.CloneWithNewId(42));
EXPECT_EQ(42, q2->id());
EXPECT_EQ(q1.io_buffer()->size(), q2->io_buffer()->size()); EXPECT_EQ(q1.io_buffer()->size(), q2->io_buffer()->size());
EXPECT_EQ(q1.qtype(), q2->qtype()); EXPECT_EQ(q1.qtype(), q2->qtype());
EXPECT_EQ(q1.question_size(), q2->question_size()); EXPECT_EQ(q1.question(), q2->question());
EXPECT_EQ(0, memcmp(q1.question_data(), q2->question_data(),
q1.question_size()));
} }
TEST(DnsQueryTest, RandomIdTest) { } // namespace
std::string kQname("\003www\006google\003com", 16);
// Since id fields are 16-bit values, we iterate to reduce the
// probability of collision, to avoid a flaky test.
bool ids_are_random = false;
for (int i = 0; i < 1000; ++i) {
DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt));
DnsQuery q2(kQname, kDNS_A, base::Bind(&base::RandInt));
scoped_ptr<DnsQuery> q3(q1.CloneWithNewId());
ids_are_random = q1.id () != q2.id() && q1.id() != q3->id();
if (ids_are_random)
break;
}
EXPECT_TRUE(ids_are_random);
}
} // namespace net } // namespace net
...@@ -4,97 +4,185 @@ ...@@ -4,97 +4,185 @@
#include "net/dns/dns_response.h" #include "net/dns/dns_response.h"
#include "net/base/dns_util.h" #include "net/base/big_endian.h"
#include "net/base/io_buffer.h" #include "net/base/io_buffer.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
#include "net/base/sys_byteorder.h"
#include "net/dns/dns_protocol.h"
#include "net/dns/dns_query.h" #include "net/dns/dns_query.h"
namespace net { namespace net {
// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512 DnsRecordParser::DnsRecordParser() : packet_(NULL), length_(0), cur_(0) {
// bytes (not counting the IP nor UDP headers). }
static const int kMaxResponseSize = 512;
DnsRecordParser::DnsRecordParser(const void* packet,
size_t length,
size_t offset)
: packet_(reinterpret_cast<const char*>(packet)),
length_(length),
cur_(packet_ + offset) {
DCHECK_LE(offset, length);
}
DnsResponse::DnsResponse(DnsQuery* query) int DnsRecordParser::ParseName(const void* const vpos, std::string* out) const {
: query_(query), const char* const pos = reinterpret_cast<const char*>(vpos);
io_buffer_(new IOBufferWithSize(kMaxResponseSize + 1)) { DCHECK(packet_);
DCHECK(query_); DCHECK_LE(packet_, pos);
DCHECK_LE(pos, packet_ + length_);
const char* p = pos;
const char* end = packet_ + length_;
// Count number of seen bytes to detect loops.
size_t seen = 0;
// Remember how many bytes were consumed before first jump.
size_t consumed = 0;
if (pos >= end)
return 0;
if (out) {
out->clear();
out->reserve(dns_protocol::kMaxNameLength);
}
for (;;) {
// The two couple of bits of the length give the type of the length. It's
// either a direct length or a pointer to the remainder of the name.
switch (*p & dns_protocol::kLabelMask) {
case dns_protocol::kLabelPointer: {
if (p + sizeof(uint16) > end)
return 0;
if (consumed == 0) {
consumed = p - pos + sizeof(uint16);
if (!out)
return consumed; // If name is not stored, that's all we need.
}
seen += sizeof(uint16);
// If seen the whole packet, then we must be in a loop.
if (seen > length_)
return 0;
uint16 offset;
ReadBigEndian<uint16>(p, &offset);
offset &= dns_protocol::kOffsetMask;
p = packet_ + offset;
if (p >= end)
return 0;
break;
}
case dns_protocol::kLabelDirect: {
uint8 label_len = *p;
++p;
// Note: root domain (".") is NOT included.
if (label_len == 0) {
if (consumed == 0) {
consumed = p - pos;
} // else we set |consumed| before first jump
return consumed;
}
if (p + label_len >= end)
return 0; // Truncated or missing label.
if (out) {
if (!out->empty())
out->append(".");
out->append(p, label_len);
}
p += label_len;
seen += 1 + label_len;
break;
}
default:
// unhandled label type
return 0;
}
}
}
bool DnsRecordParser::ParseRecord(DnsResourceRecord* out) {
DCHECK(packet_);
size_t consumed = ParseName(cur_, &out->name);
if (!consumed)
return false;
BigEndianReader reader(cur_ + consumed,
packet_ + length_ - (cur_ + consumed));
uint16 rdlen;
if (reader.ReadU16(&out->type) &&
reader.ReadU16(&out->klass) &&
reader.ReadU32(&out->ttl) &&
reader.ReadU16(&rdlen) &&
reader.ReadPiece(&out->rdata, rdlen)) {
cur_ = reader.ptr();
return true;
}
return false;
}
DnsResponse::DnsResponse()
: io_buffer_(new IOBufferWithSize(dns_protocol::kMaxUDPSize + 1)) {
}
DnsResponse::DnsResponse(const void* data,
size_t length,
size_t answer_offset)
: io_buffer_(new IOBufferWithSize(length)),
parser_(io_buffer_->data(), length, answer_offset) {
memcpy(io_buffer_->data(), data, length);
} }
DnsResponse::~DnsResponse() { DnsResponse::~DnsResponse() {
} }
int DnsResponse::Parse(int nbytes, IPAddressList* ip_addresses) { bool DnsResponse::InitParse(int nbytes, const DnsQuery& query) {
// Response includes query, it should be at least that size. // Response includes query, it should be at least that size.
if (nbytes < query_->io_buffer()->size() || nbytes > kMaxResponseSize) if (nbytes < query.io_buffer()->size() || nbytes > dns_protocol::kMaxUDPSize)
return ERR_DNS_MALFORMED_RESPONSE; return false;
DnsResponseBuffer response(reinterpret_cast<uint8*>(io_buffer_->data()), // Match the query id.
io_buffer_->size()); if (ntohs(header()->id) != query.id())
uint16 id; return false;
if (!response.U16(&id) || id != query_->id()) // Make sure IDs match.
return ERR_DNS_MALFORMED_RESPONSE; // Match question count.
if (ntohs(header()->qdcount) != 1)
uint8 flags, rcode; return false;
if (!response.U8(&flags) || !response.U8(&rcode))
return ERR_DNS_MALFORMED_RESPONSE; // Match the question section.
const size_t hdr_size = sizeof(dns_protocol::Header);
if (flags & 2) // TC is set -- server wants TCP, we don't support it (yet?). const base::StringPiece question = query.question();
return ERR_DNS_SERVER_REQUIRES_TCP; if (question != base::StringPiece(io_buffer_->data() + hdr_size,
question.size())) {
rcode &= 0x0f; // 3 means NXDOMAIN, the rest means server failed. return false;
if (rcode && (rcode != 3))
return ERR_DNS_SERVER_FAILED;
uint16 query_count, answer_count, authority_count, additional_count;
if (!response.U16(&query_count) ||
!response.U16(&answer_count) ||
!response.U16(&authority_count) ||
!response.U16(&additional_count)) {
return ERR_DNS_MALFORMED_RESPONSE;
} }
if (query_count != 1) // Sent a single question, shouldn't have changed. // Construct the parser.
return ERR_DNS_MALFORMED_RESPONSE; parser_ = DnsRecordParser(io_buffer_->data(),
nbytes,
hdr_size + question.size());
return true;
}
base::StringPiece question; // Make sure question section is echoed back. uint8 DnsResponse::flags0() const {
if (!response.Block(&question, query_->question_size()) || return header()->flags[0];
memcmp(question.data(), query_->question_data(), }
query_->question_size())) {
return ERR_DNS_MALFORMED_RESPONSE;
}
if (answer_count < 1) uint8 DnsResponse::flags1() const {
return ERR_NAME_NOT_RESOLVED; return header()->flags[1] & ~(dns_protocol::kRcodeMask);
}
IPAddressList rdatas;
while (answer_count--) { uint8 DnsResponse::rcode() const {
uint32 ttl; return header()->flags[1] & dns_protocol::kRcodeMask;
uint16 rdlength, qtype, qclass; }
if (!response.DNSName(NULL) ||
!response.U16(&qtype) ||
!response.U16(&qclass) ||
!response.U32(&ttl) ||
!response.U16(&rdlength)) {
return ERR_DNS_MALFORMED_RESPONSE;
}
if (qtype == query_->qtype() &&
qclass == kClassIN &&
(rdlength == kIPv4AddressSize || rdlength == kIPv6AddressSize)) {
base::StringPiece rdata;
if (!response.Block(&rdata, rdlength))
return ERR_DNS_MALFORMED_RESPONSE;
rdatas.push_back(IPAddressNumber(rdata.begin(), rdata.end()));
} else if (!response.Skip(rdlength))
return ERR_DNS_MALFORMED_RESPONSE;
}
if (rdatas.empty()) int DnsResponse::answer_count() const {
return ERR_NAME_NOT_RESOLVED; return ntohs(header()->ancount);
}
DnsRecordParser DnsResponse::Parser() const {
DCHECK(parser_.IsValid());
return parser_;
}
if (ip_addresses) const dns_protocol::Header* DnsResponse::header() const {
ip_addresses->swap(rdatas); return reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data());
return OK;
} }
} // namespace net } // namespace net
...@@ -6,43 +6,109 @@ ...@@ -6,43 +6,109 @@
#define NET_DNS_DNS_RESPONSE_H_ #define NET_DNS_DNS_RESPONSE_H_
#pragma once #pragma once
#include <string>
#include "base/basictypes.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/string_piece.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/base/net_util.h" #include "net/base/net_util.h"
namespace net{ namespace net {
class DnsQuery; class DnsQuery;
class IOBufferWithSize; class IOBufferWithSize;
// Represents on-the-wire DNS response as an object; allows extracting namespace dns_protocol {
// records. struct Header;
}
// Parsed resource record.
struct NET_EXPORT_PRIVATE DnsResourceRecord {
std::string name; // in dotted form
uint16 type;
uint16 klass;
uint32 ttl;
base::StringPiece rdata; // points to the original response buffer
};
// Iterator to walk over resource records of the DNS response packet.
class NET_EXPORT_PRIVATE DnsRecordParser {
public:
// Construct an uninitialized iterator.
DnsRecordParser();
// Construct an iterator to process the |packet| of given |length|.
// |offset| points to the beginning of the answer section.
DnsRecordParser(const void* packet, size_t length, size_t offset);
// Returns |true| if initialized.
bool IsValid() const { return packet_ != NULL; }
// Returns |true| if no more bytes remain in the packet.
bool AtEnd() const { return cur_ == packet_ + length_; }
// Parses a (possibly compressed) DNS name from the packet starting at
// |pos|. Stores output (even partial) in |out| unless |out| is NULL. |out|
// is stored in the dotted form, e.g., "example.com". Returns number of bytes
// consumed or 0 on failure.
// This is exposed to allow parsing compressed names within RRDATA for TYPEs
// such as NS, CNAME, PTR, MX, SOA.
// See RFC 1035 section 4.1.4.
int ParseName(const void* pos, std::string* out) const;
// Parses the next resource record. Returns true if succeeded.
bool ParseRecord(DnsResourceRecord* record);
private:
const char* packet_;
size_t length_;
// Current offset within the packet.
const char* cur_;
};
// Buffer-holder for the DNS response allowing easy access to the header fields
// and resource records. After reading into |io_buffer| must call InitParse to
// position the RR parser.
class NET_EXPORT_PRIVATE DnsResponse { class NET_EXPORT_PRIVATE DnsResponse {
public: public:
// Constructs an object with an IOBuffer large enough to read // Constructs an object with an IOBuffer large enough to read
// one byte more than largest possible response, to detect malformed // one byte more than largest possible response, to detect malformed
// responses; |query| is a pointer to the DnsQuery for which |this| // responses.
// is supposed to be a response. DnsResponse();
explicit DnsResponse(DnsQuery* query); // Constructs response from |data|. Used for testing purposes only!
DnsResponse(const void* data, size_t length, size_t answer_offset);
~DnsResponse(); ~DnsResponse();
// Internal buffer accessor into which actual bytes of response will be // Internal buffer accessor into which actual bytes of response will be
// read. // read.
IOBufferWithSize* io_buffer() { return io_buffer_.get(); } IOBufferWithSize* io_buffer() { return io_buffer_.get(); }
// Parses response of size nbytes and puts address into |ip_addresses|, // Returns false if the packet is shorter than the header or does not match
// returns net_error code in case of failure. // |query| id or question.
int Parse(int nbytes, IPAddressList* ip_addresses); bool InitParse(int nbytes, const DnsQuery& query);
// Accessors for the header.
uint8 flags0() const; // first byte of flags
uint8 flags1() const; // second byte of flags excluding rcode
uint8 rcode() const;
int answer_count() const;
// Returns an iterator to the resource records in the answer section. Must be
// called after InitParse. The iterator is valid only in the scope of the
// DnsResponse.
DnsRecordParser Parser() const;
private: private:
// The matching query; |this| is the response for |query_|. We do not // Convenience for header access.
// own it, lifetime of |this| should be within the limits of lifetime of const dns_protocol::Header* header() const;
// |query_|.
const DnsQuery* const query_;
// Buffer into which response bytes are read. // Buffer into which response bytes are read.
scoped_refptr<IOBufferWithSize> io_buffer_; scoped_refptr<IOBufferWithSize> io_buffer_;
// Iterator constructed after InitParse positioned at the answer section.
DnsRecordParser parser_;
DISALLOW_COPY_AND_ASSIGN(DnsResponse); DISALLOW_COPY_AND_ASSIGN(DnsResponse);
}; };
......
This diff is collapsed.
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/dns_session.h"
#include "base/basictypes.h"
#include "base/bind.h"
#include "base/time.h"
#include "net/base/ip_endpoint.h"
#include "net/dns/dns_config_service.h"
#include "net/socket/client_socket_factory.h"
namespace net {
DnsSession::DnsSession(const DnsConfig& config,
ClientSocketFactory* factory,
const RandIntCallback& rand_int_callback,
NetLog* net_log)
: config_(config),
socket_factory_(factory),
rand_callback_(base::Bind(rand_int_callback, 0, kuint16max)),
net_log_(net_log),
server_index_(0) {
}
int DnsSession::NextId() const {
return rand_callback_.Run();
}
const IPEndPoint& DnsSession::NextServer() {
// TODO(szym): Rotate servers on failures.
const IPEndPoint& ipe = config_.nameservers[server_index_];
if (config_.rotate)
server_index_ = (server_index_ + 1) % config_.nameservers.size();
return ipe;
}
base::TimeDelta DnsSession::NextTimeout(int attempt) {
// TODO(szym): Adapt timeout to observed RTT.
return config_.timeout * (attempt + 1);
}
DnsSession::~DnsSession() {}
} // namespace net
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_DNS_DNS_SESSION_H_
#define NET_DNS_DNS_SESSION_H_
#pragma once
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/time.h"
#include "net/base/net_export.h"
#include "net/base/rand_callback.h"
#include "net/dns/dns_config_service.h"
namespace net {
class ClientSocketFactory;
class NetLog;
// Session parameters and state shared between DNS transactions.
// Ref-counted so that DnsClient::Request can keep working in absence of
// DnsClient. A DnsSession must be recreated when DnsConfig changes.
class NET_EXPORT_PRIVATE DnsSession
: NON_EXPORTED_BASE(public base::RefCounted<DnsSession>) {
public:
typedef base::Callback<int()> RandCallback;
DnsSession(const DnsConfig& config,
ClientSocketFactory* factory,
const RandIntCallback& rand_int_callback,
NetLog* net_log);
ClientSocketFactory* socket_factory() const { return socket_factory_.get(); }
const DnsConfig& config() const { return config_; }
NetLog* net_log() const { return net_log_; }
// Return the next random query ID.
int NextId() const;
// Return the next server address.
const IPEndPoint& NextServer();
// Return the timeout for the next transaction.
base::TimeDelta NextTimeout(int attempt);
private:
friend class base::RefCounted<DnsSession>;
~DnsSession();
const DnsConfig config_;
scoped_ptr<ClientSocketFactory> socket_factory_;
RandCallback rand_callback_;
NetLog* net_log_;
// Current index into |config_.nameservers|.
int server_index_;
// TODO(szym): add current RTT estimate
// TODO(szym): add flag to indicate DNSSEC is supported
// TODO(szym): add TCP connection pool to support DNS over TCP
// TODO(szym): add UDP socket pool ?
DISALLOW_COPY_AND_ASSIGN(DnsSession);
};
} // namespace net
#endif // NET_DNS_DNS_SESSION_H_
...@@ -8,20 +8,6 @@ ...@@ -8,20 +8,6 @@
namespace net { namespace net {
TestPrng::TestPrng(const std::deque<int>& numbers) : numbers_(numbers) {
}
TestPrng::~TestPrng() {
}
int TestPrng::GetNext(int min, int max) {
DCHECK(!numbers_.empty());
int rv = numbers_.front();
numbers_.pop_front();
DCHECK(rv >= min && rv <= max);
return rv;
}
bool ConvertStringsToIPAddressList( bool ConvertStringsToIPAddressList(
const char* const ip_strings[], size_t size, IPAddressList* address_list) { const char* const ip_strings[], size_t size, IPAddressList* address_list) {
DCHECK(address_list); DCHECK(address_list);
......
...@@ -14,26 +14,10 @@ ...@@ -14,26 +14,10 @@
#include "net/base/host_resolver.h" #include "net/base/host_resolver.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
#include "net/base/net_util.h" #include "net/base/net_util.h"
#include "net/dns/dns_protocol.h"
namespace net { namespace net {
// DNS related classes make use of PRNG for various tasks. This class is
// used as a PRNG for unit testing those tasks. It takes a deque of
// integers |numbers| which should be returned by calls to GetNext.
class TestPrng {
public:
explicit TestPrng(const std::deque<int>& numbers);
~TestPrng();
// Pops and returns the next number from |numbers_| deque.
int GetNext(int min, int max);
private:
std::deque<int> numbers_;
DISALLOW_COPY_AND_ASSIGN(TestPrng);
};
// A utility function for tests that given an array of IP literals, // A utility function for tests that given an array of IP literals,
// converts it to an IPAddressList. // converts it to an IPAddressList.
bool ConvertStringsToIPAddressList( bool ConvertStringsToIPAddressList(
...@@ -49,7 +33,7 @@ static const uint16 kDnsPort = 53; ...@@ -49,7 +33,7 @@ static const uint16 kDnsPort = 53;
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// Query/response set for www.google.com, ID is fixed to 0. // Query/response set for www.google.com, ID is fixed to 0.
static const char kT0HostName[] = "www.google.com"; static const char kT0HostName[] = "www.google.com";
static const uint16 kT0Qtype = kDNS_A; static const uint16 kT0Qtype = dns_protocol::kTypeA;
static const char kT0DnsName[] = { static const char kT0DnsName[] = {
0x03, 'w', 'w', 'w', 0x03, 'w', 'w', 'w',
0x06, 'g', 'o', 'o', 'g', 'l', 'e', 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
...@@ -92,10 +76,10 @@ static const char* const kT0IpAddresses[] = { ...@@ -92,10 +76,10 @@ static const char* const kT0IpAddresses[] = {
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// Query/response set for codereview.chromium.org, ID is fixed to 1. // Query/response set for codereview.chromium.org, ID is fixed to 1.
static const char kT1HostName[] = "codereview.chromium.org"; static const char kT1HostName[] = "codereview.chromium.org";
static const uint16 kT1Qtype = kDNS_A; static const uint16 kT1Qtype = dns_protocol::kTypeA;
static const char kT1DnsName[] = { static const char kT1DnsName[] = {
0x12, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
0x10, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
0x03, 'o', 'r', 'g', 0x03, 'o', 'r', 'g',
0x00 0x00
}; };
...@@ -130,10 +114,10 @@ static const char* const kT1IpAddresses[] = { ...@@ -130,10 +114,10 @@ static const char* const kT1IpAddresses[] = {
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// Query/response set for www.ccs.neu.edu, ID is fixed to 2. // Query/response set for www.ccs.neu.edu, ID is fixed to 2.
static const char kT2HostName[] = "www.ccs.neu.edu"; static const char kT2HostName[] = "www.ccs.neu.edu";
static const uint16 kT2Qtype = kDNS_A; static const uint16 kT2Qtype = dns_protocol::kTypeA;
static const char kT2DnsName[] = { static const char kT2DnsName[] = {
0x03, 'w', 'w', 'w', 0x03, 'w', 'w', 'w',
0x03, 'c', 'c', 'c', 0x03, 'c', 'c', 's',
0x03, 'n', 'e', 'u', 0x03, 'n', 'e', 'u',
0x03, 'e', 'd', 'u', 0x03, 'e', 'd', 'u',
0x00 0x00
...@@ -166,7 +150,7 @@ static const char* const kT2IpAddresses[] = { ...@@ -166,7 +150,7 @@ static const char* const kT2IpAddresses[] = {
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// Query/response set for www.google.az, ID is fixed to 3. // Query/response set for www.google.az, ID is fixed to 3.
static const char kT3HostName[] = "www.google.az"; static const char kT3HostName[] = "www.google.az";
static const uint16 kT3Qtype = kDNS_A; static const uint16 kT3Qtype = dns_protocol::kTypeA;
static const char kT3DnsName[] = { static const char kT3DnsName[] = {
0x03, 'w', 'w', 'w', 0x03, 'w', 'w', 'w',
0x06, 'g', 'o', 'o', 'g', 'l', 'e', 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
......
This diff is collapsed.
...@@ -6,12 +6,10 @@ ...@@ -6,12 +6,10 @@
#define NET_DNS_DNS_TRANSACTION_H_ #define NET_DNS_DNS_TRANSACTION_H_
#pragma once #pragma once
#include <set>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "base/gtest_prod_util.h" #include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h" #include "base/memory/scoped_ptr.h"
#include "base/timer.h" #include "base/timer.h"
#include "base/threading/non_thread_safe.h" #include "base/threading/non_thread_safe.h"
...@@ -23,70 +21,39 @@ ...@@ -23,70 +21,39 @@
namespace net { namespace net {
class ClientSocketFactory;
class DatagramClientSocket; class DatagramClientSocket;
class DnsQuery; class DnsQuery;
class DnsResponse; class DnsResponse;
class DnsSession;
// Performs (with fixed retries) a single asynchronous DNS transaction, // Performs a single asynchronous DNS transaction over UDP,
// which consists of sending out a DNS query, waiting for a response, and // which consists of sending out a DNS query, waiting for a response, and
// parsing and returning the IP addresses that it matches. // returning the response that it matches.
class NET_EXPORT_PRIVATE DnsTransaction : class NET_EXPORT_PRIVATE DnsTransaction :
NON_EXPORTED_BASE(public base::NonThreadSafe) { NON_EXPORTED_BASE(public base::NonThreadSafe) {
public: public:
typedef std::pair<std::string, uint16> Key; typedef base::Callback<void(DnsTransaction*, int)> ResultCallback;
// Interface that should be implemented by DnsTransaction consumers and // Create new transaction using the parameters and state in |session|.
// passed to the |Start| method to be notified when the transaction has // Issues query for name |qname| (in DNS format) type |qtype| and class IN.
// completed. // Calls |callback| on completion or timeout.
class NET_EXPORT_PRIVATE Delegate { // TODO(szym): change dependency to (IPEndPoint, Socket, DnsQuery, callback)
public: DnsTransaction(DnsSession* session,
Delegate(); const base::StringPiece& qname,
virtual ~Delegate(); uint16 qtype,
const ResultCallback& callback,
// A consumer of DnsTransaction should override |OnTransactionComplete| const BoundNetLog& source_net_log);
// and call |set_delegate(this)|. The method will be called once the ~DnsTransaction();
// resolution has completed, results passed in as arguments.
virtual void OnTransactionComplete(
int result,
const DnsTransaction* transaction,
const IPAddressList& ip_addresses);
private:
friend class DnsTransaction;
void Attach(DnsTransaction* transaction);
void Detach(DnsTransaction* transaction);
std::set<DnsTransaction*> registered_transactions_;
DISALLOW_COPY_AND_ASSIGN(Delegate); const DnsQuery* query() const { return query_.get(); }
};
// |dns_server| is the address of the DNS server, |dns_name| is the const DnsResponse* response() const { return response_.get(); }
// hostname (in DNS format) to be resolved, |query_type| is the type of
// the query, either kDNS_A or kDNS_AAAA, |rand_int| is the PRNG used for
// generating DNS query.
DnsTransaction(const IPEndPoint& dns_server,
const std::string& dns_name,
uint16 query_type,
const RandIntCallback& rand_int,
ClientSocketFactory* socket_factory,
const BoundNetLog& source_net_log,
NetLog* net_log);
~DnsTransaction();
void SetDelegate(Delegate* delegate);
const Key& key() const { return key_; }
// Starts the resolution process. Will return ERR_IO_PENDING and will // Starts the resolution process. Will return ERR_IO_PENDING and will
// notify the caller via |delegate|. Should only be called once. // notify the caller via |delegate|. Should only be called once.
int Start(); int Start();
private: private:
FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, FirstTimeoutTest);
FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, SecondTimeoutTest);
FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, ThirdTimeoutTest);
enum State { enum State {
STATE_CONNECT, STATE_CONNECT,
STATE_CONNECT_COMPLETE, STATE_CONNECT_COMPLETE,
...@@ -114,26 +81,17 @@ class NET_EXPORT_PRIVATE DnsTransaction : ...@@ -114,26 +81,17 @@ class NET_EXPORT_PRIVATE DnsTransaction :
void RevokeTimer(); void RevokeTimer();
void OnTimeout(); void OnTimeout();
// This is to be used by unit tests only. scoped_refptr<DnsSession> session_;
void set_timeouts_ms(const std::vector<base::TimeDelta>& timeouts_ms); IPEndPoint dns_server_;
const IPEndPoint dns_server_;
Key key_;
IPAddressList ip_addresses_;
Delegate* delegate_;
scoped_ptr<DnsQuery> query_; scoped_ptr<DnsQuery> query_;
ResultCallback callback_;
scoped_ptr<DnsResponse> response_; scoped_ptr<DnsResponse> response_;
scoped_ptr<DatagramClientSocket> socket_; scoped_ptr<DatagramClientSocket> socket_;
// Number of retry attempts so far. // Number of retry attempts so far.
size_t attempts_; int attempts_;
// Timeouts in milliseconds.
std::vector<base::TimeDelta> timeouts_ms_;
State next_state_; State next_state_;
ClientSocketFactory* socket_factory_;
base::OneShotTimer<DnsTransaction> timer_; base::OneShotTimer<DnsTransaction> timer_;
OldCompletionCallbackImpl<DnsTransaction> io_callback_; OldCompletionCallbackImpl<DnsTransaction> io_callback_;
......
This diff is collapsed.
...@@ -55,6 +55,8 @@ ...@@ -55,6 +55,8 @@
'base/backoff_entry.h', 'base/backoff_entry.h',
'base/bandwidth_metrics.cc', 'base/bandwidth_metrics.cc',
'base/bandwidth_metrics.h', 'base/bandwidth_metrics.h',
'base/big_endian.cc',
'base/big_endian.h',
'base/cache_type.h', 'base/cache_type.h',
'base/capturing_net_log.cc', 'base/capturing_net_log.cc',
'base/capturing_net_log.h', 'base/capturing_net_log.h',
...@@ -323,6 +325,8 @@ ...@@ -323,6 +325,8 @@
'disk_cache/trace.h', 'disk_cache/trace.h',
'dns/async_host_resolver.cc', 'dns/async_host_resolver.cc',
'dns/async_host_resolver.h', 'dns/async_host_resolver.h',
'dns/dns_client.cc',
'dns/dns_client.h',
'dns/dns_config_service.cc', 'dns/dns_config_service.cc',
'dns/dns_config_service.h', 'dns/dns_config_service.h',
'dns/dns_config_service_posix.cc', 'dns/dns_config_service_posix.cc',
...@@ -335,6 +339,8 @@ ...@@ -335,6 +339,8 @@
'dns/dns_query.h', 'dns/dns_query.h',
'dns/dns_response.cc', 'dns/dns_response.cc',
'dns/dns_response.h', 'dns/dns_response.h',
'dns/dns_session.cc',
'dns/dns_session.h',
'dns/dns_transaction.cc', 'dns/dns_transaction.cc',
'dns/dns_transaction.h', 'dns/dns_transaction.h',
'dns/serial_worker.cc', 'dns/serial_worker.cc',
...@@ -991,6 +997,7 @@ ...@@ -991,6 +997,7 @@
'sources': [ 'sources': [
'base/address_list_unittest.cc', 'base/address_list_unittest.cc',
'base/backoff_entry_unittest.cc', 'base/backoff_entry_unittest.cc',
'base/big_endian_unittest.cc',
'base/cert_database_nss_unittest.cc', 'base/cert_database_nss_unittest.cc',
'base/cert_verifier_unittest.cc', 'base/cert_verifier_unittest.cc',
'base/cookie_monster_unittest.cc', 'base/cookie_monster_unittest.cc',
...@@ -1051,6 +1058,7 @@ ...@@ -1051,6 +1058,7 @@
'disk_cache/mapped_file_unittest.cc', 'disk_cache/mapped_file_unittest.cc',
'disk_cache/storage_block_unittest.cc', 'disk_cache/storage_block_unittest.cc',
'dns/async_host_resolver_unittest.cc', 'dns/async_host_resolver_unittest.cc',
'dns/dns_client_unittest.cc',
'dns/dns_config_service_posix_unittest.cc', 'dns/dns_config_service_posix_unittest.cc',
'dns/dns_config_service_unittest.cc', 'dns/dns_config_service_unittest.cc',
'dns/dns_config_service_win_unittest.cc', 'dns/dns_config_service_win_unittest.cc',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment