Commit 3421db38 authored by Eric Orth's avatar Eric Orth Committed by Commit Bot

Use a cancellable timer for delayed DoH probe start

Allows DnsClient::CancelProbesForContext() to cleanly cancel probes
when still waiting on initial startup delay. Nulling the
|url_request_context_for_probes_| was not sufficient for this case
because the context was already bound in the delayed task.

For testability, moved MockDnsTransactionFactory to be visible from
dns_test_util.h instead of just being a private internal class of
MockDnsClient.  Allows use with a real DnsClient for DnsClientTests.
Bunch of code churn from moving, but no real change to
MockDnsTransactionFactory other than adding a |doh_probes_running_| to
track when the probes have been started or canceled.

Fixed: 1015555
Change-Id: I09a23bc3e90267723d76f727be3bf5390ad6ca37
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1867443Reviewed-by: default avatarMatt Menke <mmenke@chromium.org>
Commit-Queue: Eric Orth <ericorth@chromium.org>
Cr-Commit-Position: refs/heads/master@{#707664}
parent 6792b39a
...@@ -80,9 +80,6 @@ void UpdateConfigForDohUpgrade(DnsConfig* config) { ...@@ -80,9 +80,6 @@ void UpdateConfigForDohUpgrade(DnsConfig* config) {
} }
} }
constexpr base::TimeDelta kInitialDoHTimeout =
base::TimeDelta::FromMilliseconds(5000);
class DnsClientImpl : public DnsClient, class DnsClientImpl : public DnsClient,
public NetworkChangeNotifier::ConnectionTypeObserver { public NetworkChangeNotifier::ConnectionTypeObserver {
public: public:
...@@ -96,7 +93,7 @@ class DnsClientImpl : public DnsClient, ...@@ -96,7 +93,7 @@ class DnsClientImpl : public DnsClient,
rand_int_callback_(rand_int_callback) { rand_int_callback_(rand_int_callback) {
NetworkChangeNotifier::AddConnectionTypeObserver(this); NetworkChangeNotifier::AddConnectionTypeObserver(this);
delayed_probes_allowed_timer_.Start( delayed_probes_allowed_timer_.Start(
FROM_HERE, kInitialDoHTimeout, FROM_HERE, kInitialDohTimeout,
base::Bind(&DnsClientImpl::SetProbesAllowed, base::Unretained(this))); base::Bind(&DnsClientImpl::SetProbesAllowed, base::Unretained(this)));
} }
...@@ -166,6 +163,10 @@ class DnsClientImpl : public DnsClient, ...@@ -166,6 +163,10 @@ class DnsClientImpl : public DnsClient,
void SetRequestContextForProbes( void SetRequestContextForProbes(
URLRequestContext* url_request_context) override { URLRequestContext* url_request_context) override {
DCHECK(url_request_context);
DCHECK(!url_request_context_for_probes_ ||
url_request_context == url_request_context_for_probes_);
url_request_context_for_probes_ = url_request_context; url_request_context_for_probes_ = url_request_context;
} }
...@@ -174,6 +175,7 @@ class DnsClientImpl : public DnsClient, ...@@ -174,6 +175,7 @@ class DnsClientImpl : public DnsClient,
return; return;
factory_->CancelDohProbes(); factory_->CancelDohProbes();
delayed_probes_start_timer_.Stop();
url_request_context_for_probes_ = nullptr; url_request_context_for_probes_ = nullptr;
} }
...@@ -203,6 +205,15 @@ class DnsClientImpl : public DnsClient, ...@@ -203,6 +205,15 @@ class DnsClientImpl : public DnsClient,
session_->SetProbeSuccess(index, success); session_->SetProbeSuccess(index, success);
} }
void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) override {
factory_ = std::move(factory);
}
void StartDohProbesForTesting() override {
StartDohProbes(false /* network_change */);
}
private: private:
base::Optional<DnsConfig> BuildEffectiveConfig() const { base::Optional<DnsConfig> BuildEffectiveConfig() const {
DnsConfig config; DnsConfig config;
...@@ -279,14 +290,14 @@ class DnsClientImpl : public DnsClient, ...@@ -279,14 +290,14 @@ class DnsClientImpl : public DnsClient,
return; return;
if (probes_allowed_) { if (probes_allowed_) {
delayed_probes_start_timer_.Stop();
factory_->StartDohProbes(url_request_context_for_probes_, network_change); factory_->StartDohProbes(url_request_context_for_probes_, network_change);
} else { } else {
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask( delayed_probes_start_timer_.Start(
FROM_HERE, FROM_HERE, delayed_probes_allowed_timer_.GetCurrentDelay(),
base::BindOnce(&DnsTransactionFactory::StartDohProbes, base::BindOnce(&DnsTransactionFactory::StartDohProbes,
factory_->weak_factory_.GetWeakPtr(), factory_->weak_factory_.GetWeakPtr(),
url_request_context_for_probes_, network_change), url_request_context_for_probes_, network_change));
delayed_probes_allowed_timer_.GetCurrentDelay());
} }
} }
...@@ -306,6 +317,7 @@ class DnsClientImpl : public DnsClient, ...@@ -306,6 +317,7 @@ class DnsClientImpl : public DnsClient,
// prevent interference with startup tasks. // prevent interference with startup tasks.
bool probes_allowed_; bool probes_allowed_;
base::OneShotTimer delayed_probes_allowed_timer_; base::OneShotTimer delayed_probes_allowed_timer_;
base::OneShotTimer delayed_probes_start_timer_;
URLRequestContext* url_request_context_for_probes_; URLRequestContext* url_request_context_for_probes_;
NetLog* net_log_; NetLog* net_log_;
...@@ -318,6 +330,10 @@ class DnsClientImpl : public DnsClient, ...@@ -318,6 +330,10 @@ class DnsClientImpl : public DnsClient,
} // namespace } // namespace
// static
const base::TimeDelta DnsClient::kInitialDohTimeout =
base::TimeDelta::FromSeconds(5);
// static // static
std::unique_ptr<DnsClient> DnsClient::CreateClient(NetLog* net_log) { std::unique_ptr<DnsClient> DnsClient::CreateClient(NetLog* net_log) {
return std::make_unique<DnsClientImpl>( return std::make_unique<DnsClientImpl>(
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <memory> #include <memory>
#include "base/optional.h" #include "base/optional.h"
#include "base/time/time.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/base/rand_callback.h" #include "net/base/rand_callback.h"
#include "net/dns/dns_config.h" #include "net/dns/dns_config.h"
...@@ -28,6 +29,7 @@ class NetLog; ...@@ -28,6 +29,7 @@ class NetLog;
class NET_EXPORT DnsClient { class NET_EXPORT DnsClient {
public: public:
static const int kMaxInsecureFallbackFailures = 16; static const int kMaxInsecureFallbackFailures = 16;
static const base::TimeDelta kInitialDohTimeout;
virtual ~DnsClient() {} virtual ~DnsClient() {}
...@@ -83,6 +85,10 @@ class NET_EXPORT DnsClient { ...@@ -83,6 +85,10 @@ class NET_EXPORT DnsClient {
virtual void SetProbeSuccessForTest(unsigned index, bool success) = 0; virtual void SetProbeSuccessForTest(unsigned index, bool success) = 0;
virtual void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) = 0;
virtual void StartDohProbesForTesting() = 0;
// Creates default client. // Creates default client.
static std::unique_ptr<DnsClient> CreateClient(NetLog* net_log); static std::unique_ptr<DnsClient> CreateClient(NetLog* net_log);
......
...@@ -8,9 +8,11 @@ ...@@ -8,9 +8,11 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/rand_util.h" #include "base/rand_util.h"
#include "base/test/task_environment.h"
#include "net/base/ip_address.h" #include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h" #include "net/base/ip_endpoint.h"
#include "net/dns/dns_config.h" #include "net/dns/dns_config.h"
#include "net/dns/dns_test_util.h"
#include "net/socket/socket_test_util.h" #include "net/socket/socket_test_util.h"
#include "net/test/test_with_task_environment.h" #include "net/test/test_with_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
...@@ -18,6 +20,8 @@ ...@@ -18,6 +20,8 @@
namespace net { namespace net {
class ClientSocketFactory;
namespace { namespace {
class AlwaysFailSocketFactory : public MockClientSocketFactory { class AlwaysFailSocketFactory : public MockClientSocketFactory {
...@@ -32,6 +36,10 @@ class AlwaysFailSocketFactory : public MockClientSocketFactory { ...@@ -32,6 +36,10 @@ class AlwaysFailSocketFactory : public MockClientSocketFactory {
class DnsClientTest : public TestWithTaskEnvironment { class DnsClientTest : public TestWithTaskEnvironment {
protected: protected:
DnsClientTest()
: TestWithTaskEnvironment(
base::test::TaskEnvironment::TimeSource::MOCK_TIME) {}
void SetUp() override { void SetUp() override {
client_ = DnsClient::CreateClientForTesting( client_ = DnsClient::CreateClientForTesting(
nullptr /* net_log */, &socket_factory_, base::Bind(&base::RandInt)); nullptr /* net_log */, &socket_factory_, base::Bind(&base::RandInt));
...@@ -58,8 +66,6 @@ class DnsClientTest : public TestWithTaskEnvironment { ...@@ -58,8 +66,6 @@ class DnsClientTest : public TestWithTaskEnvironment {
std::unique_ptr<DnsClient> client_; std::unique_ptr<DnsClient> client_;
AlwaysFailSocketFactory socket_factory_; AlwaysFailSocketFactory socket_factory_;
private:
}; };
TEST_F(DnsClientTest, NoConfig) { TEST_F(DnsClientTest, NoConfig) {
...@@ -240,6 +246,58 @@ TEST_F(DnsClientTest, OverrideToInvalid) { ...@@ -240,6 +246,58 @@ TEST_F(DnsClientTest, OverrideToInvalid) {
EXPECT_FALSE(client_->GetEffectiveConfig()); EXPECT_FALSE(client_->GetEffectiveConfig());
} }
TEST_F(DnsClientTest, DohProbes) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
FastForwardBy(DnsClient::kInitialDohTimeout);
EXPECT_TRUE(transaction_factory_ptr->doh_probes_running());
}
TEST_F(DnsClientTest, CancelDohProbesBeforeEnabled) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
client_->CancelProbesForContext(&context);
FastForwardUntilNoTasksRemain();
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
}
TEST_F(DnsClientTest, CancelDohProbesAfterEnabled) {
URLRequestContext context;
client_->SetRequestContextForProbes(&context);
client_->SetSystemConfig(ValidConfigWithDoh());
auto transaction_factory =
std::make_unique<MockDnsTransactionFactory>(MockDnsClientRuleList());
auto* transaction_factory_ptr = transaction_factory.get();
client_->SetTransactionFactoryForTesting(std::move(transaction_factory));
client_->StartDohProbesForTesting();
FastForwardUntilNoTasksRemain();
EXPECT_TRUE(transaction_factory_ptr->doh_probes_running());
client_->CancelProbesForContext(&context);
EXPECT_FALSE(transaction_factory_ptr->doh_probes_running());
}
} // namespace } // namespace
} // namespace net } // namespace net
This diff is collapsed.
...@@ -13,10 +13,13 @@ ...@@ -13,10 +13,13 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "base/memory/weak_ptr.h"
#include "base/stl_util.h" #include "base/stl_util.h"
#include "base/time/time.h"
#include "net/dns/dns_client.h" #include "net/dns/dns_client.h"
#include "net/dns/dns_config.h" #include "net/dns/dns_config.h"
#include "net/dns/dns_response.h" #include "net/dns/dns_response.h"
#include "net/dns/dns_transaction.h"
#include "net/dns/dns_util.h" #include "net/dns/dns_util.h"
#include "net/dns/public/dns_protocol.h" #include "net/dns/public/dns_protocol.h"
...@@ -259,7 +262,47 @@ struct MockDnsClientRule { ...@@ -259,7 +262,47 @@ struct MockDnsClientRule {
typedef std::vector<MockDnsClientRule> MockDnsClientRuleList; typedef std::vector<MockDnsClientRule> MockDnsClientRuleList;
// MockDnsClient provides MockTransactionFactory. // A DnsTransactionFactory which creates MockTransaction.
class MockDnsTransactionFactory : public DnsTransactionFactory {
public:
explicit MockDnsTransactionFactory(MockDnsClientRuleList rules);
~MockDnsTransactionFactory() override;
std::unique_ptr<DnsTransaction> CreateTransaction(
const std::string& hostname,
uint16_t qtype,
DnsTransactionFactory::CallbackType callback,
const NetLogWithSource&,
bool secure,
DnsConfig::SecureDnsMode secure_dns_mode,
URLRequestContext* url_request_context) override;
void AddEDNSOption(const OptRecordRdata::Opt& opt) override;
base::TimeDelta GetDelayUntilNextProbeForTest(
unsigned doh_server_index) override;
void StartDohProbes(URLRequestContext* url_request_context,
bool network_change) override;
void CancelDohProbes() override;
DnsConfig::SecureDnsMode GetSecureDnsModeForTest() override;
void CompleteDelayedTransactions();
bool doh_probes_running() { return doh_probes_running_; }
private:
class MockTransaction;
using DelayedTransactionList = std::vector<base::WeakPtr<MockTransaction>>;
MockDnsClientRuleList rules_;
DelayedTransactionList delayed_transactions_;
bool doh_probes_running_ = false;
};
// MockDnsClient provides MockDnsTransactionFactory.
class MockDnsClient : public DnsClient { class MockDnsClient : public DnsClient {
public: public:
MockDnsClient(DnsConfig config, MockDnsClientRuleList rules); MockDnsClient(DnsConfig config, MockDnsClientRuleList rules);
...@@ -285,6 +328,9 @@ class MockDnsClient : public DnsClient { ...@@ -285,6 +328,9 @@ class MockDnsClient : public DnsClient {
base::Optional<DnsConfig> GetSystemConfigForTesting() const override; base::Optional<DnsConfig> GetSystemConfigForTesting() const override;
DnsConfigOverrides GetConfigOverridesForTesting() const override; DnsConfigOverrides GetConfigOverridesForTesting() const override;
void SetProbeSuccessForTest(unsigned index, bool success) override; void SetProbeSuccessForTest(unsigned index, bool success) override;
void SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) override;
void StartDohProbesForTesting() override;
// Completes all DnsTransactions that were delayed by a rule. // Completes all DnsTransactions that were delayed by a rule.
void CompleteDelayedTransactions(); void CompleteDelayedTransactions();
...@@ -302,8 +348,6 @@ class MockDnsClient : public DnsClient { ...@@ -302,8 +348,6 @@ class MockDnsClient : public DnsClient {
} }
private: private:
class MockTransactionFactory;
base::Optional<DnsConfig> BuildEffectiveConfig(); base::Optional<DnsConfig> BuildEffectiveConfig();
bool insecure_enabled_ = false; bool insecure_enabled_ = false;
...@@ -315,7 +359,7 @@ class MockDnsClient : public DnsClient { ...@@ -315,7 +359,7 @@ class MockDnsClient : public DnsClient {
base::Optional<DnsConfig> config_; base::Optional<DnsConfig> config_;
DnsConfigOverrides overrides_; DnsConfigOverrides overrides_;
base::Optional<DnsConfig> effective_config_; base::Optional<DnsConfig> effective_config_;
std::unique_ptr<MockTransactionFactory> factory_; std::unique_ptr<MockDnsTransactionFactory> factory_;
std::unique_ptr<AddressSorter> address_sorter_; std::unique_ptr<AddressSorter> address_sorter_;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment