Commit 3421db38 authored by Eric Orth's avatar Eric Orth Committed by Commit Bot

Use a cancellable timer for delayed DoH probe start

Allows DnsClient::CancelProbesForContext() to cleanly cancel probes
when still waiting on initial startup delay. Nulling the
|url_request_context_for_probes_| was not sufficient for this case
because the context was already bound in the delayed task.

For testability, moved MockDnsTransactionFactory to be visible from
dns_test_util.h instead of just being a private internal class of
MockDnsClient.  Allows use with a real DnsClient for DnsClientTests.
Bunch of code churn from moving, but no real change to
MockDnsTransactionFactory other than adding a |doh_probes_running_| to
track when the probes have been started or canceled.

Fixed: 1015555
Change-Id: I09a23bc3e90267723d76f727be3bf5390ad6ca37
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1867443Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Commit-Queue: Eric Orth <ericorth@chromium.org>
Cr-Commit-Position: refs/heads/master@{#707664}
parent 6792b39a
......@@ -80,9 +80,6 @@ void UpdateConfigForDohUpgrade(DnsConfig* config) {
}
}
constexpr base::TimeDelta kInitialDoHTimeout =
base::TimeDelta::FromMilliseconds(5000);
class DnsClientImpl : public DnsClient,
public NetworkChangeNotifier::ConnectionTypeObserver {
public:
......@@ -96,7 +93,7 @@ class DnsClientImpl : public DnsClient,
rand_int_callback_(rand_int_callback) {
NetworkChangeNotifier::AddConnectionTypeObserver(this);
delayed_probes_allowed_timer_.Start(
FROM_HERE, kInitialDoHTimeout,
FROM_HERE, kInitialDohTimeout,
base::Bind(&DnsClientImpl::SetProbesAllowed, base::Unretained(this)));
}
......@@ -166,6 +163,10 @@ class DnsClientImpl : public DnsClient,
void SetRequestContextForProbes(
URLRequestContext* url_request_context) override {
DCHECK(url_request_context);
DCHECK(!url_request_context_for_probes_ ||
url_request_context == url_request_context_for_probes_);
url_request_context_for_probes_ = url_request_context;
}
......@@ -174,6 +175,7 @@ class DnsClientImpl : public DnsClient,
return;
factory_->CancelDohProbes();
delayed_probes_start_timer_.Stop();
url_request_context_for_probes_ = nullptr;
}
......@@ -203,6 +205,15 @@ class DnsClientImpl : public DnsClient,
session_->SetProbeSuccess(index, success);
}
void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) override {
factory_ = std::move(factory);
}
void StartDohProbesForTesting() override {
StartDohProbes(false /* network_change */);
}
private:
base::Optional<DnsConfig> BuildEffectiveConfig() const {
DnsConfig config;
......@@ -279,14 +290,14 @@ class DnsClientImpl : public DnsClient,
return;
if (probes_allowed_) {
delayed_probes_start_timer_.Stop();
factory_->StartDohProbes(url_request_context_for_probes_, network_change);
} else {
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
delayed_probes_start_timer_.Start(
FROM_HERE, delayed_probes_allowed_timer_.GetCurrentDelay(),
base::BindOnce(&DnsTransactionFactory::StartDohProbes,
factory_->weak_factory_.GetWeakPtr(),
url_request_context_for_probes_, network_change),
delayed_probes_allowed_timer_.GetCurrentDelay());
url_request_context_for_probes_, network_change));
}
}
......@@ -306,6 +317,7 @@ class DnsClientImpl : public DnsClient,
// prevent interference with startup tasks.
bool probes_allowed_;
base::OneShotTimer delayed_probes_allowed_timer_;
base::OneShotTimer delayed_probes_start_timer_;
URLRequestContext* url_request_context_for_probes_;
NetLog* net_log_;
......@@ -318,6 +330,10 @@ class DnsClientImpl : public DnsClient,
} // namespace
// static
const base::TimeDelta DnsClient::kInitialDohTimeout =
base::TimeDelta::FromSeconds(5);
// static
std::unique_ptr<DnsClient> DnsClient::CreateClient(NetLog* net_log) {
return std::make_unique<DnsClientImpl>(
......
......@@ -8,6 +8,7 @@
#include <memory>
#include "base/optional.h"
#include "base/time/time.h"
#include "net/base/net_export.h"
#include "net/base/rand_callback.h"
#include "net/dns/dns_config.h"
......@@ -28,6 +29,7 @@ class NetLog;
class NET_EXPORT DnsClient {
public:
static const int kMaxInsecureFallbackFailures = 16;
static const base::TimeDelta kInitialDohTimeout;
virtual ~DnsClient() {}
......@@ -83,6 +85,10 @@ class NET_EXPORT DnsClient {
virtual void SetProbeSuccessForTest(unsigned index, bool success) = 0;
virtual void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) = 0;
virtual void StartDohProbesForTesting() = 0;
// Creates default client.
static std::unique_ptr<DnsClient> CreateClient(NetLog* net_log);
......
......@@ -8,9 +8,11 @@
#include "base/bind.h"
#include "base/rand_util.h"
#include "base/test/task_environment.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/dns/dns_config.h"
#include "net/dns/dns_test_util.h"
#include "net/socket/socket_test_util.h"
#include "net/test/test_with_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h"
......@@ -18,6 +20,8 @@
namespace net {
class ClientSocketFactory;
namespace {
class AlwaysFailSocketFactory : public MockClientSocketFactory {
......@@ -32,6 +36,10 @@ class AlwaysFailSocketFactory : public MockClientSocketFactory {
class DnsClientTest : public TestWithTaskEnvironment {
protected:
DnsClientTest()
: TestWithTaskEnvironment(
base::test::TaskEnvironment::TimeSource::MOCK_TIME) {}
void SetUp() override {
client_ = DnsClient::CreateClientForTesting(
nullptr /* net_log */, &socket_factory_, base::Bind(&base::RandInt));
......@@ -58,8 +66,6 @@ class DnsClientTest : public TestWithTaskEnvironment {
std::unique_ptr<DnsClient> client_;
AlwaysFailSocketFactory socket_factory_;
private:
};
TEST_F(DnsClientTest, NoConfig) {
......@@ -240,6 +246,58 @@ TEST_F(DnsClientTest, OverrideToInvalid) {
EXPECT_FALSE(client_->GetEffectiveConfig());
}
TEST_F(DnsClientTest, DohProbes) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
FastForwardBy(DnsClient::kInitialDohTimeout);
EXPECT_TRUE(transaction_factory_ptr->doh_probes_running());
}
TEST_F(DnsClientTest, CancelDohProbesBeforeEnabled) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
client_->CancelProbesForContext(&context);
FastForwardUntilNoTasksRemain();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
}
TEST_F(DnsClientTest, CancelDohProbesAfterEnabled) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
FastForwardUntilNoTasksRemain();
EXPECT_TRUE(transaction_factory_ptr->doh_probes_running());
client_->CancelProbesForContext(&context);
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
}
} // namespace
} // namespace net
This diff is collapsed.
......@@ -13,10 +13,13 @@
#include <utility>
#include <vector>
#include "base/memory/weak_ptr.h"
#include "base/stl_util.h"
#include "base/time/time.h"
#include "net/dns/dns_client.h"
#include "net/dns/dns_config.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_transaction.h"
#include "net/dns/dns_util.h"
#include "net/dns/public/dns_protocol.h"
......@@ -259,7 +262,47 @@ struct MockDnsClientRule {
typedef std::vector<MockDnsClientRule> MockDnsClientRuleList;
// MockDnsClient provides MockTransactionFactory.
// A DnsTransactionFactory which creates MockTransaction.
class MockDnsTransactionFactory : public DnsTransactionFactory {
public:
explicit MockDnsTransactionFactory(MockDnsClientRuleList rules);
~MockDnsTransactionFactory() override;
std::unique_ptr<DnsTransaction> CreateTransaction(
const std::string& hostname,
uint16_t qtype,
DnsTransactionFactory::CallbackType callback,
const NetLogWithSource&,
bool secure,
DnsConfig::SecureDnsMode secure_dns_mode,
URLRequestContext* url_request_context) override;
void AddEDNSOption(const OptRecordRdata::Opt& opt) override;
base::TimeDelta GetDelayUntilNextProbeForTest(
unsigned doh_server_index) override;
void StartDohProbes(URLRequestContext* url_request_context,
bool network_change) override;
void CancelDohProbes() override;
DnsConfig::SecureDnsMode GetSecureDnsModeForTest() override;
void CompleteDelayedTransactions();
bool doh_probes_running() { return doh_probes_running_; }
private:
class MockTransaction;
using DelayedTransactionList = std::vector<base::WeakPtr<MockTransaction>>;
MockDnsClientRuleList rules_;
DelayedTransactionList delayed_transactions_;
bool doh_probes_running_ = false;
};
// MockDnsClient provides MockDnsTransactionFactory.
class MockDnsClient : public DnsClient {
public:
MockDnsClient(DnsConfig config, MockDnsClientRuleList rules);
......@@ -285,6 +328,9 @@ class MockDnsClient : public DnsClient {
base::Optional<DnsConfig> GetSystemConfigForTesting() const override;
DnsConfigOverrides GetConfigOverridesForTesting() const override;
void SetProbeSuccessForTest(unsigned index, bool success) override;
void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) override;
void StartDohProbesForTesting() override;
// Completes all DnsTransactions that were delayed by a rule.
void CompleteDelayedTransactions();
......@@ -302,8 +348,6 @@ class MockDnsClient : public DnsClient {
}
private:
class MockTransactionFactory;
base::Optional<DnsConfig> BuildEffectiveConfig();
bool insecure_enabled_ = false;
......@@ -315,7 +359,7 @@ class MockDnsClient : public DnsClient {
base::Optional<DnsConfig> config_;
DnsConfigOverrides overrides_;
base::Optional<DnsConfig> effective_config_;
std::unique_ptr<MockTransactionFactory> factory_;
std::unique_ptr<MockDnsTransactionFactory> factory_;
std::unique_ptr<AddressSorter> address_sorter_;
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment