Commit 3421db38 authored by Eric Orth's avatar Eric Orth Committed by Commit Bot

Use a cancellable timer for delayed DoH probe start

Allows DnsClient::CancelProbesForContext() to cleanly cancel probes
when still waiting on initial startup delay. Nulling the
|url_request_context_for_probes_| was not sufficient for this case
because the context was already bound in the delayed task.

For testability, moved MockDnsTransactionFactory to be visible from
dns_test_util.h instead of just being a private internal class of
MockDnsClient.  Allows use with a real DnsClient for DnsClientTests.
Bunch of code churn from moving, but no real change to
MockDnsTransactionFactory other than adding a |doh_probes_running_| to
track when the probes have been started or canceled.

Fixed: 1015555
Change-Id: I09a23bc3e90267723d76f727be3bf5390ad6ca37
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1867443Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Commit-Queue: Eric Orth <ericorth@chromium.org>
Cr-Commit-Position: refs/heads/master@{#707664}
parent 6792b39a
...@@ -80,9 +80,6 @@ void UpdateConfigForDohUpgrade(DnsConfig* config) { ...@@ -80,9 +80,6 @@ void UpdateConfigForDohUpgrade(DnsConfig* config) {
} }
} }
constexpr base::TimeDelta kInitialDoHTimeout =
base::TimeDelta::FromMilliseconds(5000);
class DnsClientImpl : public DnsClient, class DnsClientImpl : public DnsClient,
public NetworkChangeNotifier::ConnectionTypeObserver { public NetworkChangeNotifier::ConnectionTypeObserver {
public: public:
...@@ -96,7 +93,7 @@ class DnsClientImpl : public DnsClient, ...@@ -96,7 +93,7 @@ class DnsClientImpl : public DnsClient,
rand_int_callback_(rand_int_callback) { rand_int_callback_(rand_int_callback) {
NetworkChangeNotifier::AddConnectionTypeObserver(this); NetworkChangeNotifier::AddConnectionTypeObserver(this);
delayed_probes_allowed_timer_.Start( delayed_probes_allowed_timer_.Start(
FROM_HERE, kInitialDoHTimeout, FROM_HERE, kInitialDohTimeout,
base::Bind(&DnsClientImpl::SetProbesAllowed, base::Unretained(this))); base::Bind(&DnsClientImpl::SetProbesAllowed, base::Unretained(this)));
} }
...@@ -166,6 +163,10 @@ class DnsClientImpl : public DnsClient, ...@@ -166,6 +163,10 @@ class DnsClientImpl : public DnsClient,
void SetRequestContextForProbes( void SetRequestContextForProbes(
URLRequestContext* url_request_context) override { URLRequestContext* url_request_context) override {
DCHECK(url_request_context);
DCHECK(!url_request_context_for_probes_ ||
url_request_context == url_request_context_for_probes_);
url_request_context_for_probes_ = url_request_context; url_request_context_for_probes_ = url_request_context;
} }
...@@ -174,6 +175,7 @@ class DnsClientImpl : public DnsClient, ...@@ -174,6 +175,7 @@ class DnsClientImpl : public DnsClient,
return; return;
factory_->CancelDohProbes(); factory_->CancelDohProbes();
delayed_probes_start_timer_.Stop();
url_request_context_for_probes_ = nullptr; url_request_context_for_probes_ = nullptr;
} }
...@@ -203,6 +205,15 @@ class DnsClientImpl : public DnsClient, ...@@ -203,6 +205,15 @@ class DnsClientImpl : public DnsClient,
session_->SetProbeSuccess(index, success); session_->SetProbeSuccess(index, success);
} }
void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) override {
factory_ = std::move(factory);
}
void StartDohProbesForTesting() override {
StartDohProbes(false /* network_change */);
}
private: private:
base::Optional<DnsConfig> BuildEffectiveConfig() const { base::Optional<DnsConfig> BuildEffectiveConfig() const {
DnsConfig config; DnsConfig config;
...@@ -279,14 +290,14 @@ class DnsClientImpl : public DnsClient, ...@@ -279,14 +290,14 @@ class DnsClientImpl : public DnsClient,
return; return;
if (probes_allowed_) { if (probes_allowed_) {
delayed_probes_start_timer_.Stop();
factory_->StartDohProbes(url_request_context_for_probes_, network_change); factory_->StartDohProbes(url_request_context_for_probes_, network_change);
} else { } else {
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask( delayed_probes_start_timer_.Start(
FROM_HERE, FROM_HERE, delayed_probes_allowed_timer_.GetCurrentDelay(),
base::BindOnce(&DnsTransactionFactory::StartDohProbes, base::BindOnce(&DnsTransactionFactory::StartDohProbes,
factory_->weak_factory_.GetWeakPtr(), factory_->weak_factory_.GetWeakPtr(),
url_request_context_for_probes_, network_change), url_request_context_for_probes_, network_change));
delayed_probes_allowed_timer_.GetCurrentDelay());
} }
} }
...@@ -306,6 +317,7 @@ class DnsClientImpl : public DnsClient, ...@@ -306,6 +317,7 @@ class DnsClientImpl : public DnsClient,
// prevent interference with startup tasks. // prevent interference with startup tasks.
bool probes_allowed_; bool probes_allowed_;
base::OneShotTimer delayed_probes_allowed_timer_; base::OneShotTimer delayed_probes_allowed_timer_;
base::OneShotTimer delayed_probes_start_timer_;
URLRequestContext* url_request_context_for_probes_; URLRequestContext* url_request_context_for_probes_;
NetLog* net_log_; NetLog* net_log_;
...@@ -318,6 +330,10 @@ class DnsClientImpl : public DnsClient, ...@@ -318,6 +330,10 @@ class DnsClientImpl : public DnsClient,
} // namespace } // namespace
// static
const base::TimeDelta DnsClient::kInitialDohTimeout =
base::TimeDelta::FromSeconds(5);
// static // static
std::unique_ptr<DnsClient> DnsClient::CreateClient(NetLog* net_log) { std::unique_ptr<DnsClient> DnsClient::CreateClient(NetLog* net_log) {
return std::make_unique<DnsClientImpl>( return std::make_unique<DnsClientImpl>(
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <memory> #include <memory>
#include "base/optional.h" #include "base/optional.h"
#include "base/time/time.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/base/rand_callback.h" #include "net/base/rand_callback.h"
#include "net/dns/dns_config.h" #include "net/dns/dns_config.h"
...@@ -28,6 +29,7 @@ class NetLog; ...@@ -28,6 +29,7 @@ class NetLog;
class NET_EXPORT DnsClient { class NET_EXPORT DnsClient {
public: public:
static const int kMaxInsecureFallbackFailures = 16; static const int kMaxInsecureFallbackFailures = 16;
static const base::TimeDelta kInitialDohTimeout;
virtual ~DnsClient() {} virtual ~DnsClient() {}
...@@ -83,6 +85,10 @@ class NET_EXPORT DnsClient { ...@@ -83,6 +85,10 @@ class NET_EXPORT DnsClient {
virtual void SetProbeSuccessForTest(unsigned index, bool success) = 0; virtual void SetProbeSuccessForTest(unsigned index, bool success) = 0;
virtual void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) = 0;
virtual void StartDohProbesForTesting() = 0;
// Creates default client. // Creates default client.
static std::unique_ptr<DnsClient> CreateClient(NetLog* net_log); static std::unique_ptr<DnsClient> CreateClient(NetLog* net_log);
......
...@@ -8,9 +8,11 @@ ...@@ -8,9 +8,11 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/rand_util.h" #include "base/rand_util.h"
#include "base/test/task_environment.h"
#include "net/base/ip_address.h" #include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
#include "net/dns/dns_config.h" #include "net/dns/dns_config.h"
#include "net/dns/dns_test_util.h"
#include "net/socket/socket_test_util.h" #include "net/socket/socket_test_util.h"
#include "net/test/test_with_task_environment.h" #include "net/test/test_with_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
...@@ -18,6 +20,8 @@ ...@@ -18,6 +20,8 @@
namespace net { namespace net {
class ClientSocketFactory;
namespace { namespace {
class AlwaysFailSocketFactory : public MockClientSocketFactory { class AlwaysFailSocketFactory : public MockClientSocketFactory {
...@@ -32,6 +36,10 @@ class AlwaysFailSocketFactory : public MockClientSocketFactory { ...@@ -32,6 +36,10 @@ class AlwaysFailSocketFactory : public MockClientSocketFactory {
class DnsClientTest : public TestWithTaskEnvironment { class DnsClientTest : public TestWithTaskEnvironment {
protected: protected:
DnsClientTest()
: TestWithTaskEnvironment(
base::test::TaskEnvironment::TimeSource::MOCK_TIME) {}
void SetUp() override { void SetUp() override {
client_ = DnsClient::CreateClientForTesting( client_ = DnsClient::CreateClientForTesting(
nullptr /* net_log */, &socket_factory_, base::Bind(&base::RandInt)); nullptr /* net_log */, &socket_factory_, base::Bind(&base::RandInt));
...@@ -58,8 +66,6 @@ class DnsClientTest : public TestWithTaskEnvironment { ...@@ -58,8 +66,6 @@ class DnsClientTest : public TestWithTaskEnvironment {
std::unique_ptr<DnsClient> client_; std::unique_ptr<DnsClient> client_;
AlwaysFailSocketFactory socket_factory_; AlwaysFailSocketFactory socket_factory_;
private:
}; };
TEST_F(DnsClientTest, NoConfig) { TEST_F(DnsClientTest, NoConfig) {
...@@ -240,6 +246,58 @@ TEST_F(DnsClientTest, OverrideToInvalid) { ...@@ -240,6 +246,58 @@ TEST_F(DnsClientTest, OverrideToInvalid) {
EXPECT_FALSE(client_->GetEffectiveConfig()); EXPECT_FALSE(client_->GetEffectiveConfig());
} }
TEST_F(DnsClientTest, DohProbes) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
FastForwardBy(DnsClient::kInitialDohTimeout);
EXPECT_TRUE(transaction_factory_ptr->doh_probes_running());
}
TEST_F(DnsClientTest, CancelDohProbesBeforeEnabled) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
client_->CancelProbesForContext(&context);
FastForwardUntilNoTasksRemain();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
}
TEST_F(DnsClientTest, CancelDohProbesAfterEnabled) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
FastForwardUntilNoTasksRemain();
EXPECT_TRUE(transaction_factory_ptr->doh_probes_running());
client_->CancelProbesForContext(&context);
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
}
} // namespace } // namespace
} // namespace net } // namespace net
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include "base/big_endian.h" #include "base/big_endian.h"
#include "base/bind.h" #include "base/bind.h"
#include "base/location.h" #include "base/location.h"
#include "base/memory/weak_ptr.h"
#include "base/numerics/safe_conversions.h" #include "base/numerics/safe_conversions.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/sys_byteorder.h" #include "base/sys_byteorder.h"
...@@ -18,7 +17,6 @@ ...@@ -18,7 +17,6 @@
#include "net/dns/address_sorter.h" #include "net/dns/address_sorter.h"
#include "net/dns/dns_hosts.h" #include "net/dns/dns_hosts.h"
#include "net/dns/dns_query.h" #include "net/dns/dns_query.h"
#include "net/dns/dns_transaction.h"
#include "net/dns/dns_util.h" #include "net/dns/dns_util.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -180,142 +178,6 @@ DnsResourceRecord BuildServiceRecord(std::string name, ...@@ -180,142 +178,6 @@ DnsResourceRecord BuildServiceRecord(std::string name,
return record; return record;
} }
// A DnsTransaction which uses MockDnsClientRuleList to determine the response.
class MockTransaction : public DnsTransaction,
public base::SupportsWeakPtr<MockTransaction> {
public:
MockTransaction(const MockDnsClientRuleList& rules,
const std::string& hostname,
uint16_t qtype,
bool secure,
DnsConfig::SecureDnsMode secure_dns_mode,
URLRequestContext* url_request_context,
DnsTransactionFactory::CallbackType callback)
: result_(MockDnsClientRule::FAIL),
hostname_(hostname),
qtype_(qtype),
callback_(std::move(callback)),
started_(false),
delayed_(false) {
// Find the relevant rule which matches |qtype|, |secure|, prefix of
// |hostname|, and |url_request_context| (iff the rule context is not
// null).
for (size_t i = 0; i < rules.size(); ++i) {
const std::string& prefix = rules[i].prefix;
if ((rules[i].qtype == qtype) && (rules[i].secure == secure) &&
(hostname.size() >= prefix.size()) &&
(hostname.compare(0, prefix.size(), prefix) == 0) &&
(!rules[i].context || rules[i].context == url_request_context)) {
const MockDnsClientRule::Result* result = &rules[i].result;
result_ = MockDnsClientRule::Result(result->type);
delayed_ = rules[i].delay;
// Generate a DnsResponse when not provided with the rule.
std::vector<DnsResourceRecord> authority_records;
std::string dns_name;
CHECK(DNSDomainFromDot(hostname_, &dns_name));
base::Optional<DnsQuery> query(base::in_place, 22 /* id */, dns_name,
qtype_);
switch (result->type) {
case MockDnsClientRule::NODOMAIN:
case MockDnsClientRule::EMPTY:
DCHECK(!result->response); // Not expected to be provided.
authority_records = {BuildSoaRecord(hostname_)};
result_.response = std::make_unique<DnsResponse>(
22 /* id */, false /* is_authoritative */,
std::vector<DnsResourceRecord>() /* answers */,
authority_records,
std::vector<DnsResourceRecord>() /* additional_records */,
query,
result->type == MockDnsClientRule::NODOMAIN
? dns_protocol::kRcodeNXDOMAIN
: 0);
break;
case MockDnsClientRule::FAIL:
case MockDnsClientRule::TIMEOUT:
DCHECK(!result->response); // Not expected to be provided.
break;
case MockDnsClientRule::OK:
if (result->response) {
// Copy response in case |rules| are destroyed before the
// transaction completes.
result_.response = std::make_unique<DnsResponse>(
result->response->io_buffer(),
result->response->io_buffer_size());
CHECK(result_.response->InitParseWithoutQuery(
result->response->io_buffer_size()));
} else {
// Generated response only available for address types.
DCHECK(qtype_ == dns_protocol::kTypeA ||
qtype_ == dns_protocol::kTypeAAAA);
result_.response = BuildTestDnsResponse(
hostname_, qtype_ == dns_protocol::kTypeA
? IPAddress::IPv4Localhost()
: IPAddress::IPv6Localhost());
}
break;
case MockDnsClientRule::MALFORMED:
DCHECK(!result->response); // Not expected to be provided.
result_.response = CreateMalformedResponse(hostname_, qtype_);
break;
}
break;
}
}
}
const std::string& GetHostname() const override { return hostname_; }
uint16_t GetType() const override { return qtype_; }
void Start() override {
EXPECT_FALSE(started_);
started_ = true;
if (delayed_)
return;
// Using WeakPtr to cleanly cancel when transaction is destroyed.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&MockTransaction::Finish, AsWeakPtr()));
}
void FinishDelayedTransaction() {
EXPECT_TRUE(delayed_);
delayed_ = false;
Finish();
}
bool delayed() const { return delayed_; }
private:
void Finish() {
switch (result_.type) {
case MockDnsClientRule::NODOMAIN:
case MockDnsClientRule::FAIL:
std::move(callback_).Run(this, ERR_NAME_NOT_RESOLVED,
result_.response.get());
break;
case MockDnsClientRule::EMPTY:
case MockDnsClientRule::OK:
case MockDnsClientRule::MALFORMED:
std::move(callback_).Run(this, OK, result_.response.get());
break;
case MockDnsClientRule::TIMEOUT:
std::move(callback_).Run(this, ERR_DNS_TIMED_OUT, nullptr);
break;
}
}
void SetRequestPriority(RequestPriority priority) override {}
MockDnsClientRule::Result result_;
const std::string hostname_;
const uint16_t qtype_;
DnsTransactionFactory::CallbackType callback_;
bool started_;
bool delayed_;
};
} // namespace } // namespace
std::unique_ptr<DnsResponse> BuildTestDnsResponse(std::string name, std::unique_ptr<DnsResponse> BuildTestDnsResponse(std::string name,
...@@ -452,68 +314,201 @@ MockDnsClientRule::MockDnsClientRule(const std::string& prefix, ...@@ -452,68 +314,201 @@ MockDnsClientRule::MockDnsClientRule(const std::string& prefix,
MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default; MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default;
// A DnsTransactionFactory which creates MockTransaction. // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
class MockDnsClient::MockTransactionFactory : public DnsTransactionFactory { class MockDnsTransactionFactory::MockTransaction
: public DnsTransaction,
public base::SupportsWeakPtr<MockTransaction> {
public: public:
explicit MockTransactionFactory(MockDnsClientRuleList rules) MockTransaction(const MockDnsClientRuleList& rules,
: rules_(std::move(rules)) {} const std::string& hostname,
uint16_t qtype,
~MockTransactionFactory() override = default; bool secure,
DnsConfig::SecureDnsMode secure_dns_mode,
std::unique_ptr<DnsTransaction> CreateTransaction( URLRequestContext* url_request_context,
const std::string& hostname, DnsTransactionFactory::CallbackType callback)
uint16_t qtype, : result_(MockDnsClientRule::FAIL),
DnsTransactionFactory::CallbackType callback, hostname_(hostname),
const NetLogWithSource&, qtype_(qtype),
bool secure, callback_(std::move(callback)),
DnsConfig::SecureDnsMode secure_dns_mode, started_(false),
URLRequestContext* url_request_context) override { delayed_(false) {
std::unique_ptr<MockTransaction> transaction = // Find the relevant rule which matches |qtype|, |secure|, prefix of
std::make_unique<MockTransaction>(rules_, hostname, qtype, secure, // |hostname|, and |url_request_context| (iff the rule context is not
secure_dns_mode, url_request_context, // null).
std::move(callback)); for (size_t i = 0; i < rules.size(); ++i) {
if (transaction->delayed()) const std::string& prefix = rules[i].prefix;
delayed_transactions_.push_back(transaction->AsWeakPtr()); if ((rules[i].qtype == qtype) && (rules[i].secure == secure) &&
return transaction; (hostname.size() >= prefix.size()) &&
} (hostname.compare(0, prefix.size(), prefix) == 0) &&
(!rules[i].context || rules[i].context == url_request_context)) {
const MockDnsClientRule::Result* result = &rules[i].result;
result_ = MockDnsClientRule::Result(result->type);
delayed_ = rules[i].delay;
void AddEDNSOption(const OptRecordRdata::Opt& opt) override {} // Generate a DnsResponse when not provided with the rule.
std::vector<DnsResourceRecord> authority_records;
std::string dns_name;
CHECK(DNSDomainFromDot(hostname_, &dns_name));
base::Optional<DnsQuery> query(base::in_place, 22 /* id */, dns_name,
qtype_);
switch (result->type) {
case MockDnsClientRule::NODOMAIN:
case MockDnsClientRule::EMPTY:
DCHECK(!result->response); // Not expected to be provided.
authority_records = {BuildSoaRecord(hostname_)};
result_.response = std::make_unique<DnsResponse>(
22 /* id */, false /* is_authoritative */,
std::vector<DnsResourceRecord>() /* answers */,
authority_records,
std::vector<DnsResourceRecord>() /* additional_records */,
query,
result->type == MockDnsClientRule::NODOMAIN
? dns_protocol::kRcodeNXDOMAIN
: 0);
break;
case MockDnsClientRule::FAIL:
case MockDnsClientRule::TIMEOUT:
DCHECK(!result->response); // Not expected to be provided.
break;
case MockDnsClientRule::OK:
if (result->response) {
// Copy response in case |rules| are destroyed before the
// transaction completes.
result_.response = std::make_unique<DnsResponse>(
result->response->io_buffer(),
result->response->io_buffer_size());
CHECK(result_.response->InitParseWithoutQuery(
result->response->io_buffer_size()));
} else {
// Generated response only available for address types.
DCHECK(qtype_ == dns_protocol::kTypeA ||
qtype_ == dns_protocol::kTypeAAAA);
result_.response = BuildTestDnsResponse(
hostname_, qtype_ == dns_protocol::kTypeA
? IPAddress::IPv4Localhost()
: IPAddress::IPv6Localhost());
}
break;
case MockDnsClientRule::MALFORMED:
DCHECK(!result->response); // Not expected to be provided.
result_.response = CreateMalformedResponse(hostname_, qtype_);
break;
}
base::TimeDelta GetDelayUntilNextProbeForTest( break;
unsigned doh_server_index) override { }
NOTREACHED(); }
return base::TimeDelta();
} }
void StartDohProbes(URLRequestContext* url_request_context, const std::string& GetHostname() const override { return hostname_; }
bool network_change) override {}
void CancelDohProbes() override {} uint16_t GetType() const override { return qtype_; }
DnsConfig::SecureDnsMode GetSecureDnsModeForTest() override { void Start() override {
return DnsConfig::SecureDnsMode::AUTOMATIC; EXPECT_FALSE(started_);
started_ = true;
if (delayed_)
return;
// Using WeakPtr to cleanly cancel when transaction is destroyed.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&MockTransaction::Finish, AsWeakPtr()));
} }
void CompleteDelayedTransactions() { void FinishDelayedTransaction() {
DelayedTransactionList old_delayed_transactions; EXPECT_TRUE(delayed_);
old_delayed_transactions.swap(delayed_transactions_); delayed_ = false;
for (auto it = old_delayed_transactions.begin(); Finish();
it != old_delayed_transactions.end(); ++it) {
if (it->get())
(*it)->FinishDelayedTransaction();
}
} }
bool delayed() const { return delayed_; }
private: private:
typedef std::vector<base::WeakPtr<MockTransaction>> DelayedTransactionList; void Finish() {
switch (result_.type) {
case MockDnsClientRule::NODOMAIN:
case MockDnsClientRule::FAIL:
std::move(callback_).Run(this, ERR_NAME_NOT_RESOLVED,
result_.response.get());
break;
case MockDnsClientRule::EMPTY:
case MockDnsClientRule::OK:
case MockDnsClientRule::MALFORMED:
std::move(callback_).Run(this, OK, result_.response.get());
break;
case MockDnsClientRule::TIMEOUT:
std::move(callback_).Run(this, ERR_DNS_TIMED_OUT, nullptr);
break;
}
}
MockDnsClientRuleList rules_; void SetRequestPriority(RequestPriority priority) override {}
DelayedTransactionList delayed_transactions_;
MockDnsClientRule::Result result_;
const std::string hostname_;
const uint16_t qtype_;
DnsTransactionFactory::CallbackType callback_;
bool started_;
bool delayed_;
}; };
MockDnsTransactionFactory::MockDnsTransactionFactory(
MockDnsClientRuleList rules)
: rules_(std::move(rules)) {}
MockDnsTransactionFactory::~MockDnsTransactionFactory() = default;
std::unique_ptr<DnsTransaction> MockDnsTransactionFactory::CreateTransaction(
const std::string& hostname,
uint16_t qtype,
DnsTransactionFactory::CallbackType callback,
const NetLogWithSource&,
bool secure,
DnsConfig::SecureDnsMode secure_dns_mode,
URLRequestContext* url_request_context) {
std::unique_ptr<MockTransaction> transaction =
std::make_unique<MockTransaction>(rules_, hostname, qtype, secure,
secure_dns_mode, url_request_context,
std::move(callback));
if (transaction->delayed())
delayed_transactions_.push_back(transaction->AsWeakPtr());
return transaction;
}
void MockDnsTransactionFactory::AddEDNSOption(const OptRecordRdata::Opt& opt) {}
base::TimeDelta MockDnsTransactionFactory::GetDelayUntilNextProbeForTest(
unsigned doh_server_index) {
NOTREACHED();
return base::TimeDelta();
}
void MockDnsTransactionFactory::StartDohProbes(
URLRequestContext* url_request_context,
bool network_change) {
doh_probes_running_ = true;
}
void MockDnsTransactionFactory::CancelDohProbes() {
doh_probes_running_ = false;
}
DnsConfig::SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() {
return DnsConfig::SecureDnsMode::AUTOMATIC;
}
void MockDnsTransactionFactory::CompleteDelayedTransactions() {
DelayedTransactionList old_delayed_transactions;
old_delayed_transactions.swap(delayed_transactions_);
for (auto it = old_delayed_transactions.begin();
it != old_delayed_transactions.end(); ++it) {
if (it->get())
(*it)->FinishDelayedTransaction();
}
}
MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules) MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules)
: config_(std::move(config)), : config_(std::move(config)),
factory_(new MockTransactionFactory(std::move(rules))), factory_(new MockDnsTransactionFactory(std::move(rules))),
address_sorter_(new MockAddressSorter()) { address_sorter_(new MockAddressSorter()) {
effective_config_ = BuildEffectiveConfig(); effective_config_ = BuildEffectiveConfig();
} }
...@@ -605,6 +600,16 @@ DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const { ...@@ -605,6 +600,16 @@ DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const {
void MockDnsClient::SetProbeSuccessForTest(unsigned index, bool success) {} void MockDnsClient::SetProbeSuccessForTest(unsigned index, bool success) {}
void MockDnsClient::SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) {
NOTREACHED();
}
void MockDnsClient::StartDohProbesForTesting() {
factory_->StartDohProbes(nullptr /* url_request_context */,
false /* network_change */);
}
void MockDnsClient::CompleteDelayedTransactions() { void MockDnsClient::CompleteDelayedTransactions() {
factory_->CompleteDelayedTransactions(); factory_->CompleteDelayedTransactions();
} }
......
...@@ -13,10 +13,13 @@ ...@@ -13,10 +13,13 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "base/memory/weak_ptr.h"
#include "base/stl_util.h" #include "base/stl_util.h"
#include "base/time/time.h"
#include "net/dns/dns_client.h" #include "net/dns/dns_client.h"
#include "net/dns/dns_config.h" #include "net/dns/dns_config.h"
#include "net/dns/dns_response.h" #include "net/dns/dns_response.h"
#include "net/dns/dns_transaction.h"
#include "net/dns/dns_util.h" #include "net/dns/dns_util.h"
#include "net/dns/public/dns_protocol.h" #include "net/dns/public/dns_protocol.h"
...@@ -259,7 +262,47 @@ struct MockDnsClientRule { ...@@ -259,7 +262,47 @@ struct MockDnsClientRule {
typedef std::vector<MockDnsClientRule> MockDnsClientRuleList; typedef std::vector<MockDnsClientRule> MockDnsClientRuleList;
// MockDnsClient provides MockTransactionFactory. // A DnsTransactionFactory which creates MockTransaction.
class MockDnsTransactionFactory : public DnsTransactionFactory {
public:
explicit MockDnsTransactionFactory(MockDnsClientRuleList rules);
~MockDnsTransactionFactory() override;
std::unique_ptr<DnsTransaction> CreateTransaction(
const std::string& hostname,
uint16_t qtype,
DnsTransactionFactory::CallbackType callback,
const NetLogWithSource&,
bool secure,
DnsConfig::SecureDnsMode secure_dns_mode,
URLRequestContext* url_request_context) override;
void AddEDNSOption(const OptRecordRdata::Opt& opt) override;
base::TimeDelta GetDelayUntilNextProbeForTest(
unsigned doh_server_index) override;
void StartDohProbes(URLRequestContext* url_request_context,
bool network_change) override;
void CancelDohProbes() override;
DnsConfig::SecureDnsMode GetSecureDnsModeForTest() override;
void CompleteDelayedTransactions();
bool doh_probes_running() { return doh_probes_running_; }
private:
class MockTransaction;
using DelayedTransactionList = std::vector<base::WeakPtr<MockTransaction>>;
MockDnsClientRuleList rules_;
DelayedTransactionList delayed_transactions_;
bool doh_probes_running_ = false;
};
// MockDnsClient provides MockDnsTransactionFactory.
class MockDnsClient : public DnsClient { class MockDnsClient : public DnsClient {
public: public:
MockDnsClient(DnsConfig config, MockDnsClientRuleList rules); MockDnsClient(DnsConfig config, MockDnsClientRuleList rules);
...@@ -285,6 +328,9 @@ class MockDnsClient : public DnsClient { ...@@ -285,6 +328,9 @@ class MockDnsClient : public DnsClient {
base::Optional<DnsConfig> GetSystemConfigForTesting() const override; base::Optional<DnsConfig> GetSystemConfigForTesting() const override;
DnsConfigOverrides GetConfigOverridesForTesting() const override; DnsConfigOverrides GetConfigOverridesForTesting() const override;
void SetProbeSuccessForTest(unsigned index, bool success) override; void SetProbeSuccessForTest(unsigned index, bool success) override;
void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) override;
void StartDohProbesForTesting() override;
// Completes all DnsTransactions that were delayed by a rule. // Completes all DnsTransactions that were delayed by a rule.
void CompleteDelayedTransactions(); void CompleteDelayedTransactions();
...@@ -302,8 +348,6 @@ class MockDnsClient : public DnsClient { ...@@ -302,8 +348,6 @@ class MockDnsClient : public DnsClient {
} }
private: private:
class MockTransactionFactory;
base::Optional<DnsConfig> BuildEffectiveConfig(); base::Optional<DnsConfig> BuildEffectiveConfig();
bool insecure_enabled_ = false; bool insecure_enabled_ = false;
...@@ -315,7 +359,7 @@ class MockDnsClient : public DnsClient { ...@@ -315,7 +359,7 @@ class MockDnsClient : public DnsClient {
base::Optional<DnsConfig> config_; base::Optional<DnsConfig> config_;
DnsConfigOverrides overrides_; DnsConfigOverrides overrides_;
base::Optional<DnsConfig> effective_config_; base::Optional<DnsConfig> effective_config_;
std::unique_ptr<MockTransactionFactory> factory_; std::unique_ptr<MockDnsTransactionFactory> factory_;
std::unique_ptr<AddressSorter> address_sorter_; std::unique_ptr<AddressSorter> address_sorter_;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment