Commit 5f3b3230 authored by Rob Percival's avatar Rob Percival Committed by Commit Bot

Fix use-after-free bug triggered when memory pressure reaches critical

When memory pressure reaches critical, SingleTreeTracker clears the
pending_entries_ map. However, if an inclusion check is in progress for
one or more of those pending entries, LogDnsClient will have a pointer to
a MerkleAuditProof held in that map. This results in it trying to access
freed memory.

The fix is to cancel all inclusion checks when this happens. This is done
by changing LogDnsClient to provide a "resource handle" when it starts a
query, which can be deleted in order to abort the query. Storing this
in pending_entries_ assures that all inclusion checks will be aborted
when pending_entries_ is cleared.

Bug: 811566
Change-Id: I86b7ff880c050b790d219fa0cd50b42839bc0d3e
Reviewed-on: https://chromium-review.googlesource.com/939627Reviewed-by: default avatarRyan Sleevi <rsleevi@chromium.org>
Reviewed-by: default avatarMatt Mueller <mattm@chromium.org>
Commit-Queue: Rob Percival <robpercival@chromium.org>
Cr-Commit-Position: refs/heads/master@{#546183}
parent a91d4ae3
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include "base/big_endian.h" #include "base/big_endian.h"
#include "base/bind.h" #include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/format_macros.h" #include "base/format_macros.h"
#include "base/location.h" #include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
...@@ -19,6 +18,7 @@ ...@@ -19,6 +18,7 @@
#include "base/time/time.h" #include "base/time/time.h"
#include "components/base32/base32.h" #include "components/base32/base32.h"
#include "crypto/sha2.h" #include "crypto/sha2.h"
#include "net/base/completion_once_callback.h"
#include "net/base/sys_addrinfo.h" #include "net/base/sys_addrinfo.h"
#include "net/cert/merkle_audit_proof.h" #include "net/cert/merkle_audit_proof.h"
#include "net/dns/dns_client.h" #include "net/dns/dns_client.h"
...@@ -150,24 +150,32 @@ bool ParseAuditPath(const net::DnsResponse& response, ...@@ -150,24 +150,32 @@ bool ParseAuditPath(const net::DnsResponse& response,
// Encapsulates the state machine required to get an audit proof from a Merkle // Encapsulates the state machine required to get an audit proof from a Merkle
// leaf hash. This requires a DNS request to obtain the leaf index, then a // leaf hash. This requires a DNS request to obtain the leaf index, then a
// series of DNS requests to get the nodes of the proof. // series of DNS requests to get the nodes of the proof.
class LogDnsClient::AuditProofQuery { class AuditProofQueryImpl : public LogDnsClient::AuditProofQuery {
public: public:
// The LogDnsClient is guaranteed to outlive the AuditProofQuery, so it's safe // The API contract of LogDnsClient requires that callers make sure the
// to leave ownership of |dns_client| with LogDnsClient. // AuditProofQuery does not outlive LogDnsClient, so it's safe to leave
AuditProofQuery(net::DnsClient* dns_client, // ownership of |dns_client| with LogDnsClient.
AuditProofQueryImpl(net::DnsClient* dns_client,
const std::string& domain_for_log, const std::string& domain_for_log,
const net::NetLogWithSource& net_log); const net::NetLogWithSource& net_log);
~AuditProofQueryImpl() override;
// Begins the process of getting an audit proof for the CT log entry with a // Begins the process of getting an audit proof for the CT log entry with a
// leaf hash of |leaf_hash|. The proof will be for a tree of size |tree_size|. // leaf hash of |leaf_hash|. The proof will be for a tree of size |tree_size|.
// If it cannot be obtained synchronously, net::ERR_IO_PENDING will be // If it cannot be obtained synchronously, net::ERR_IO_PENDING will be
// returned and |callback| will be invoked when the operation has completed // returned and |callback| will be invoked when the operation has completed
// asynchronously. Ownership of |proof| remains with the caller, and it must // asynchronously. If the operation is cancelled (by deleting the
// not be deleted until the operation is complete. // AuditProofQueryImpl), |cancellation_callback| will be invoked.
net::Error Start(std::string leaf_hash, net::Error Start(std::string leaf_hash,
uint64_t tree_size, uint64_t tree_size,
const net::CompletionCallback& callback, net::CompletionOnceCallback callback,
net::ct::MerkleAuditProof* out_proof); base::OnceClosure cancellation_callback);
// Returns the proof that is being obtained by this query.
// It is only guaranteed to be populated once either Start() returns net::OK
// or the completion callback is invoked with net::OK.
const net::ct::MerkleAuditProof& GetProof() const override;
private: private:
enum class State { enum class State {
...@@ -228,9 +236,11 @@ class LogDnsClient::AuditProofQuery { ...@@ -228,9 +236,11 @@ class LogDnsClient::AuditProofQuery {
// The Merkle leaf hash of the CT log entry an audit proof is required for. // The Merkle leaf hash of the CT log entry an audit proof is required for.
std::string leaf_hash_; std::string leaf_hash_;
// The audit proof to populate. // The audit proof to populate.
net::ct::MerkleAuditProof* proof_; net::ct::MerkleAuditProof proof_;
// The callback to invoke when the query is complete. // The callback to invoke when the query is complete.
net::CompletionCallback callback_; net::CompletionOnceCallback callback_;
// The callback to invoke when the query is cancelled.
base::OnceClosure cancellation_callback_;
// The DnsClient to use for sending DNS requests to the CT log. // The DnsClient to use for sending DNS requests to the CT log.
net::DnsClient* dns_client_; net::DnsClient* dns_client_;
// The most recent DNS request. Null if no DNS requests have been made. // The most recent DNS request. Null if no DNS requests have been made.
...@@ -243,11 +253,10 @@ class LogDnsClient::AuditProofQuery { ...@@ -243,11 +253,10 @@ class LogDnsClient::AuditProofQuery {
// The time that Start() was last called. Used to measure query duration. // The time that Start() was last called. Used to measure query duration.
base::TimeTicks start_time_; base::TimeTicks start_time_;
// Produces WeakPtrs to |this| for binding callbacks. // Produces WeakPtrs to |this| for binding callbacks.
base::WeakPtrFactory<AuditProofQuery> weak_ptr_factory_; base::WeakPtrFactory<AuditProofQueryImpl> weak_ptr_factory_;
}; };
LogDnsClient::AuditProofQuery::AuditProofQuery( AuditProofQueryImpl::AuditProofQueryImpl(net::DnsClient* dns_client,
net::DnsClient* dns_client,
const std::string& domain_for_log, const std::string& domain_for_log,
const net::NetLogWithSource& net_log) const net::NetLogWithSource& net_log)
: next_state_(State::NONE), : next_state_(State::NONE),
...@@ -259,20 +268,24 @@ LogDnsClient::AuditProofQuery::AuditProofQuery( ...@@ -259,20 +268,24 @@ LogDnsClient::AuditProofQuery::AuditProofQuery(
DCHECK(!domain_for_log_.empty()); DCHECK(!domain_for_log_.empty());
} }
AuditProofQueryImpl::~AuditProofQueryImpl() {
if (next_state_ != State::NONE)
std::move(cancellation_callback_).Run();
}
// |leaf_hash| is not a const-ref to allow callers to std::move that string into // |leaf_hash| is not a const-ref to allow callers to std::move that string into
// the method, avoiding the need to make a copy. // the method, avoiding the need to make a copy.
net::Error LogDnsClient::AuditProofQuery::Start( net::Error AuditProofQueryImpl::Start(std::string leaf_hash,
std::string leaf_hash,
uint64_t tree_size, uint64_t tree_size,
const net::CompletionCallback& callback, net::CompletionOnceCallback callback,
net::ct::MerkleAuditProof* proof) { base::OnceClosure cancellation_callback) {
// It should not already be in progress. // It should not already be in progress.
DCHECK_EQ(State::NONE, next_state_); DCHECK_EQ(State::NONE, next_state_);
start_time_ = base::TimeTicks::Now(); start_time_ = base::TimeTicks::Now();
proof_ = proof; proof_.tree_size = tree_size;
proof_->tree_size = tree_size;
leaf_hash_ = std::move(leaf_hash); leaf_hash_ = std::move(leaf_hash);
callback_ = callback; callback_ = std::move(callback);
cancellation_callback_ = std::move(cancellation_callback);
// The first step in the query is to request the leaf index corresponding to // The first step in the query is to request the leaf index corresponding to
// |leaf_hash| from the CT log. // |leaf_hash| from the CT log.
next_state_ = State::REQUEST_LEAF_INDEX; next_state_ = State::REQUEST_LEAF_INDEX;
...@@ -280,7 +293,11 @@ net::Error LogDnsClient::AuditProofQuery::Start( ...@@ -280,7 +293,11 @@ net::Error LogDnsClient::AuditProofQuery::Start(
return DoLoop(net::OK); return DoLoop(net::OK);
} }
net::Error LogDnsClient::AuditProofQuery::DoLoop(net::Error result) { const net::ct::MerkleAuditProof& AuditProofQueryImpl::GetProof() const {
return proof_;
}
net::Error AuditProofQueryImpl::DoLoop(net::Error result) {
CHECK_NE(State::NONE, next_state_); CHECK_NE(State::NONE, next_state_);
State state; State state;
do { do {
...@@ -339,7 +356,7 @@ net::Error LogDnsClient::AuditProofQuery::DoLoop(net::Error result) { ...@@ -339,7 +356,7 @@ net::Error LogDnsClient::AuditProofQuery::DoLoop(net::Error result) {
return result; return result;
} }
void LogDnsClient::AuditProofQuery::OnDnsTransactionComplete( void AuditProofQueryImpl::OnDnsTransactionComplete(
net::DnsTransaction* transaction, net::DnsTransaction* transaction,
int net_error, int net_error,
const net::DnsResponse* response) { const net::DnsResponse* response) {
...@@ -351,14 +368,14 @@ void LogDnsClient::AuditProofQuery::OnDnsTransactionComplete( ...@@ -351,14 +368,14 @@ void LogDnsClient::AuditProofQuery::OnDnsTransactionComplete(
// callback. OnDnsTransactionComplete() will be invoked again once the I/O // callback. OnDnsTransactionComplete() will be invoked again once the I/O
// is complete, and can invoke the completion callback then if appropriate. // is complete, and can invoke the completion callback then if appropriate.
if (result != net::ERR_IO_PENDING) { if (result != net::ERR_IO_PENDING) {
// The callback will delete this query (now that it has finished), so copy // The callback may delete this query (now that it has finished), so copy
// |callback_| before running it so that it is not deleted along with the // |callback_| before running it so that it is not deleted along with the
// query, mid-callback-execution (which would result in a crash). // query, mid-callback-execution (which would result in a crash).
base::ResetAndReturn(&callback_).Run(result); std::move(callback_).Run(result);
} }
} }
net::Error LogDnsClient::AuditProofQuery::RequestLeafIndex() { net::Error AuditProofQueryImpl::RequestLeafIndex() {
std::string encoded_leaf_hash = base32::Base32Encode( std::string encoded_leaf_hash = base32::Base32Encode(
leaf_hash_, base32::Base32EncodePolicy::OMIT_PADDING); leaf_hash_, base32::Base32EncodePolicy::OMIT_PADDING);
DCHECK_EQ(encoded_leaf_hash.size(), 52u); DCHECK_EQ(encoded_leaf_hash.size(), 52u);
...@@ -376,14 +393,13 @@ net::Error LogDnsClient::AuditProofQuery::RequestLeafIndex() { ...@@ -376,14 +393,13 @@ net::Error LogDnsClient::AuditProofQuery::RequestLeafIndex() {
// Stores the received leaf index in |proof_->leaf_index|. // Stores the received leaf index in |proof_->leaf_index|.
// If successful, the audit proof nodes will be requested next. // If successful, the audit proof nodes will be requested next.
net::Error LogDnsClient::AuditProofQuery::RequestLeafIndexComplete( net::Error AuditProofQueryImpl::RequestLeafIndexComplete(net::Error result) {
net::Error result) {
if (result != net::OK) { if (result != net::OK) {
return result; return result;
} }
DCHECK(last_dns_response_); DCHECK(last_dns_response_);
if (!ParseLeafIndex(*last_dns_response_, &proof_->leaf_index)) { if (!ParseLeafIndex(*last_dns_response_, &proof_.leaf_index)) {
return net::ERR_DNS_MALFORMED_RESPONSE; return net::ERR_DNS_MALFORMED_RESPONSE;
} }
...@@ -393,7 +409,7 @@ net::Error LogDnsClient::AuditProofQuery::RequestLeafIndexComplete( ...@@ -393,7 +409,7 @@ net::Error LogDnsClient::AuditProofQuery::RequestLeafIndexComplete(
// b) the wrong leaf hash was provided. // b) the wrong leaf hash was provided.
// c) there is a bug server-side. // c) there is a bug server-side.
// The first two are more likely, so return ERR_INVALID_ARGUMENT. // The first two are more likely, so return ERR_INVALID_ARGUMENT.
if (proof_->leaf_index >= proof_->tree_size) { if (proof_.leaf_index >= proof_.tree_size) {
return net::ERR_INVALID_ARGUMENT; return net::ERR_INVALID_ARGUMENT;
} }
...@@ -401,17 +417,17 @@ net::Error LogDnsClient::AuditProofQuery::RequestLeafIndexComplete( ...@@ -401,17 +417,17 @@ net::Error LogDnsClient::AuditProofQuery::RequestLeafIndexComplete(
return net::OK; return net::OK;
} }
net::Error LogDnsClient::AuditProofQuery::RequestAuditProofNodes() { net::Error AuditProofQueryImpl::RequestAuditProofNodes() {
// Test pre-conditions (should be guaranteed by DNS response validation). // Test pre-conditions (should be guaranteed by DNS response validation).
if (proof_->leaf_index >= proof_->tree_size || if (proof_.leaf_index >= proof_.tree_size ||
proof_->nodes.size() >= net::ct::CalculateAuditPathLength( proof_.nodes.size() >= net::ct::CalculateAuditPathLength(
proof_->leaf_index, proof_->tree_size)) { proof_.leaf_index, proof_.tree_size)) {
return net::ERR_UNEXPECTED; return net::ERR_UNEXPECTED;
} }
std::string qname = base::StringPrintf( std::string qname = base::StringPrintf(
"%zu.%" PRIu64 ".%" PRIu64 ".tree.%s.", proof_->nodes.size(), "%zu.%" PRIu64 ".%" PRIu64 ".tree.%s.", proof_.nodes.size(),
proof_->leaf_index, proof_->tree_size, domain_for_log_.c_str()); proof_.leaf_index, proof_.tree_size, domain_for_log_.c_str());
if (!StartDnsTransaction(qname)) { if (!StartDnsTransaction(qname)) {
return net::ERR_NAME_RESOLUTION_FAILED; return net::ERR_NAME_RESOLUTION_FAILED;
...@@ -421,35 +437,34 @@ net::Error LogDnsClient::AuditProofQuery::RequestAuditProofNodes() { ...@@ -421,35 +437,34 @@ net::Error LogDnsClient::AuditProofQuery::RequestAuditProofNodes() {
return net::ERR_IO_PENDING; return net::ERR_IO_PENDING;
} }
net::Error LogDnsClient::AuditProofQuery::RequestAuditProofNodesComplete( net::Error AuditProofQueryImpl::RequestAuditProofNodesComplete(
net::Error result) { net::Error result) {
if (result != net::OK) { if (result != net::OK) {
return result; return result;
} }
const uint64_t audit_path_length = const uint64_t audit_path_length =
net::ct::CalculateAuditPathLength(proof_->leaf_index, proof_->tree_size); net::ct::CalculateAuditPathLength(proof_.leaf_index, proof_.tree_size);
// The calculated |audit_path_length| can't ever be greater than 64, so // The calculated |audit_path_length| can't ever be greater than 64, so
// deriving the amount of memory to reserve from the untrusted |leaf_index| // deriving the amount of memory to reserve from the untrusted |leaf_index|
// is safe. // is safe.
proof_->nodes.reserve(audit_path_length); proof_.nodes.reserve(audit_path_length);
DCHECK(last_dns_response_); DCHECK(last_dns_response_);
if (!ParseAuditPath(*last_dns_response_, proof_)) { if (!ParseAuditPath(*last_dns_response_, &proof_)) {
return net::ERR_DNS_MALFORMED_RESPONSE; return net::ERR_DNS_MALFORMED_RESPONSE;
} }
// Keep requesting more proof nodes until all of them are received. // Keep requesting more proof nodes until all of them are received.
if (proof_->nodes.size() < audit_path_length) { if (proof_.nodes.size() < audit_path_length) {
next_state_ = State::REQUEST_AUDIT_PROOF_NODES; next_state_ = State::REQUEST_AUDIT_PROOF_NODES;
} }
return net::OK; return net::OK;
} }
bool LogDnsClient::AuditProofQuery::StartDnsTransaction( bool AuditProofQueryImpl::StartDnsTransaction(const std::string& qname) {
const std::string& qname) {
net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory();
if (!factory) { if (!factory) {
return false; return false;
...@@ -457,7 +472,7 @@ bool LogDnsClient::AuditProofQuery::StartDnsTransaction( ...@@ -457,7 +472,7 @@ bool LogDnsClient::AuditProofQuery::StartDnsTransaction(
current_dns_transaction_ = factory->CreateTransaction( current_dns_transaction_ = factory->CreateTransaction(
qname, net::dns_protocol::kTypeTXT, qname, net::dns_protocol::kTypeTXT,
base::Bind(&LogDnsClient::AuditProofQuery::OnDnsTransactionComplete, base::BindOnce(&AuditProofQueryImpl::OnDnsTransactionComplete,
weak_ptr_factory_.GetWeakPtr()), weak_ptr_factory_.GetWeakPtr()),
net_log_); net_log_);
...@@ -467,11 +482,11 @@ bool LogDnsClient::AuditProofQuery::StartDnsTransaction( ...@@ -467,11 +482,11 @@ bool LogDnsClient::AuditProofQuery::StartDnsTransaction(
LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client, LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client,
const net::NetLogWithSource& net_log, const net::NetLogWithSource& net_log,
size_t max_concurrent_queries) size_t max_in_flight_queries)
: dns_client_(std::move(dns_client)), : dns_client_(std::move(dns_client)),
net_log_(net_log), net_log_(net_log),
max_concurrent_queries_(max_concurrent_queries), in_flight_queries_(0),
weak_ptr_factory_(this) { max_in_flight_queries_(max_in_flight_queries) {
CHECK(dns_client_); CHECK(dns_client_);
net::NetworkChangeNotifier::AddDNSObserver(this); net::NetworkChangeNotifier::AddDNSObserver(this);
UpdateDnsConfig(); UpdateDnsConfig();
...@@ -489,9 +504,9 @@ void LogDnsClient::OnInitialDNSConfigRead() { ...@@ -489,9 +504,9 @@ void LogDnsClient::OnInitialDNSConfigRead() {
UpdateDnsConfig(); UpdateDnsConfig();
} }
void LogDnsClient::NotifyWhenNotThrottled(const base::Closure& callback) { void LogDnsClient::NotifyWhenNotThrottled(base::OnceClosure callback) {
DCHECK(HasMaxConcurrentQueriesInProgress()); DCHECK(HasMaxQueriesInFlight());
not_throttled_callbacks_.push_back(callback); not_throttled_callbacks_.emplace_back(std::move(callback));
} }
// |leaf_hash| is not a const-ref to allow callers to std::move that string into // |leaf_hash| is not a const-ref to allow callers to std::move that string into
...@@ -500,57 +515,66 @@ net::Error LogDnsClient::QueryAuditProof( ...@@ -500,57 +515,66 @@ net::Error LogDnsClient::QueryAuditProof(
base::StringPiece domain_for_log, base::StringPiece domain_for_log,
std::string leaf_hash, std::string leaf_hash,
uint64_t tree_size, uint64_t tree_size,
net::ct::MerkleAuditProof* proof, std::unique_ptr<AuditProofQuery>* out_query,
const net::CompletionCallback& callback) { const net::CompletionCallback& callback) {
DCHECK(proof); DCHECK(out_query);
if (domain_for_log.empty() || leaf_hash.size() != crypto::kSHA256Length) { if (domain_for_log.empty() || leaf_hash.size() != crypto::kSHA256Length) {
return net::ERR_INVALID_ARGUMENT; return net::ERR_INVALID_ARGUMENT;
} }
if (HasMaxConcurrentQueriesInProgress()) { if (HasMaxQueriesInFlight()) {
return net::ERR_TEMPORARILY_THROTTLED; return net::ERR_TEMPORARILY_THROTTLED;
} }
AuditProofQuery* query = new AuditProofQuery( auto* query = new AuditProofQueryImpl(dns_client_.get(),
dns_client_.get(), domain_for_log.as_string(), net_log_); domain_for_log.as_string(), net_log_);
// Transfers ownership of |query| to |audit_proof_queries_|. out_query->reset(query);
audit_proof_queries_.emplace_back(query);
++in_flight_queries_;
return query->Start(std::move(leaf_hash), tree_size, return query->Start(std::move(leaf_hash), tree_size,
base::Bind(&LogDnsClient::QueryAuditProofComplete, base::BindOnce(&LogDnsClient::QueryAuditProofComplete,
weak_ptr_factory_.GetWeakPtr(), base::Unretained(this), callback),
base::Unretained(query), callback), base::BindOnce(&LogDnsClient::QueryAuditProofCancelled,
proof); base::Unretained(this)));
} }
void LogDnsClient::QueryAuditProofComplete( void LogDnsClient::QueryAuditProofComplete(
AuditProofQuery* query, const net::CompletionCallback& completion_callback,
const net::CompletionCallback& callback,
int net_error) { int net_error) {
DCHECK(query); --in_flight_queries_;
// Move the "not throttled" callbacks to a local variable, just in case one of
// the callbacks deletes this LogDnsClient.
std::list<base::OnceClosure> not_throttled_callbacks =
std::move(not_throttled_callbacks_);
completion_callback.Run(net_error);
// Notify interested parties that the next query will not be throttled.
for (auto& callback : not_throttled_callbacks) {
std::move(callback).Run();
}
}
// Finished with the query - destroy it. void LogDnsClient::QueryAuditProofCancelled() {
auto query_iterator = --in_flight_queries_;
std::find_if(audit_proof_queries_.begin(), audit_proof_queries_.end(),
[query](const std::unique_ptr<AuditProofQuery>& p) {
return p.get() == query;
});
DCHECK(query_iterator != audit_proof_queries_.end());
audit_proof_queries_.erase(query_iterator);
callback.Run(net_error); // Move not_throttled_callbacks_ to a local variable, just in case one of the
// callbacks deletes this LogDnsClient.
std::list<base::OnceClosure> not_throttled_callbacks =
std::move(not_throttled_callbacks_);
// Notify interested parties that the next query will not be throttled. // Notify interested parties that the next query will not be throttled.
std::list<base::Closure> callbacks = std::move(not_throttled_callbacks_); for (auto& callback : not_throttled_callbacks) {
for (const base::Closure& callback : callbacks) { std::move(callback).Run();
callback.Run();
} }
} }
bool LogDnsClient::HasMaxConcurrentQueriesInProgress() const { bool LogDnsClient::HasMaxQueriesInFlight() const {
return max_concurrent_queries_ != 0 && return max_in_flight_queries_ != 0 &&
audit_proof_queries_.size() >= max_concurrent_queries_; in_flight_queries_ >= max_in_flight_queries_;
} }
void LogDnsClient::UpdateDnsConfig() { void LogDnsClient::UpdateDnsConfig() {
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/single_thread_task_runner.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
#include "net/base/completion_callback.h" #include "net/base/completion_callback.h"
#include "net/base/net_errors.h" #include "net/base/net_errors.h"
...@@ -33,6 +34,12 @@ namespace certificate_transparency { ...@@ -33,6 +34,12 @@ namespace certificate_transparency {
// It must be created and deleted on the same thread. It is not thread-safe. // It must be created and deleted on the same thread. It is not thread-safe.
class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver { class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver {
public: public:
class AuditProofQuery {
public:
virtual ~AuditProofQuery() = default;
virtual const net::ct::MerkleAuditProof& GetProof() const = 0;
};
// Creates a log client that will take ownership of |dns_client| and use it // Creates a log client that will take ownership of |dns_client| and use it
// to perform DNS queries. Queries will be logged to |net_log|. // to perform DNS queries. Queries will be logged to |net_log|.
// The |dns_client| does not need to be configured first - this will be done // The |dns_client| does not need to be configured first - this will be done
...@@ -60,7 +67,8 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver { ...@@ -60,7 +67,8 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver {
// constructor of LogDnsClient). This callback will fire once and then be // constructor of LogDnsClient). This callback will fire once and then be
// unregistered. Should only be used if QueryAuditProof() returns // unregistered. Should only be used if QueryAuditProof() returns
// net::ERR_TEMPORARILY_THROTTLED. // net::ERR_TEMPORARILY_THROTTLED.
void NotifyWhenNotThrottled(const base::Closure& callback); // The callback will be run on the same thread that created the LogDnsClient.
void NotifyWhenNotThrottled(base::OnceClosure callback);
// Queries a CT log to retrieve an audit proof for the leaf with |leaf_hash|. // Queries a CT log to retrieve an audit proof for the leaf with |leaf_hash|.
// The log is identified by |domain_for_log|, which is the DNS name used as a // The log is identified by |domain_for_log|, which is the DNS name used as a
...@@ -68,10 +76,12 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver { ...@@ -68,10 +76,12 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver {
// The |leaf_hash| is the SHA-256 Merkle leaf hash (see RFC6962, section 2.1). // The |leaf_hash| is the SHA-256 Merkle leaf hash (see RFC6962, section 2.1).
// The size of the CT log tree, for which the proof is requested, must be // The size of the CT log tree, for which the proof is requested, must be
// provided in |tree_size|. // provided in |tree_size|.
// The leaf index and audit proof obtained from the CT log will be placed in // A handle to the query will be placed in |out_query|. The audit proof can be
// |out_proof|. // obtained from that once the query completes. Deleting this handle before
// the query completes will cancel it. It must not outlive the LogDnsClient.
// If the proof cannot be obtained synchronously, this method will return // If the proof cannot be obtained synchronously, this method will return
// net::ERR_IO_PENDING and invoke |callback| once the query is complete. // net::ERR_IO_PENDING and invoke |callback| once the query is complete.
// The callback will be run on the same thread that created the LogDnsClient.
// Returns: // Returns:
// - net::OK if the query was successful. // - net::OK if the query was successful.
// - net::ERR_IO_PENDING if the query was successfully started and is // - net::ERR_IO_PENDING if the query was successfully started and is
...@@ -85,24 +95,23 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver { ...@@ -85,24 +95,23 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver {
net::Error QueryAuditProof(base::StringPiece domain_for_log, net::Error QueryAuditProof(base::StringPiece domain_for_log,
std::string leaf_hash, std::string leaf_hash,
uint64_t tree_size, uint64_t tree_size,
net::ct::MerkleAuditProof* out_proof, std::unique_ptr<AuditProofQuery>* out_query,
const net::CompletionCallback& callback); const net::CompletionCallback& callback);
private: private:
class AuditProofQuery;
// Invoked when an audit proof query completes. // Invoked when an audit proof query completes.
// |query| is the query that has completed.
// |callback| is the user-provided callback that should be notified. // |callback| is the user-provided callback that should be notified.
// |net_error| is a net::Error indicating success or failure. // |net_error| is a net::Error indicating success or failure.
void QueryAuditProofComplete(AuditProofQuery* query, void QueryAuditProofComplete(const net::CompletionCallback& callback,
const net::CompletionCallback& callback,
int net_error); int net_error);
// Returns true if the maximum number of queries are currently in flight. // Invoked when an audit proof query is cancelled.
// If the maximum number of concurrency queries is set to 0, this will always void QueryAuditProofCancelled();
// Returns true if the maximum number of queries are currently in-flight.
// If the maximum number of in-flight queries is set to 0, this will always
// return false. // return false.
bool HasMaxConcurrentQueriesInProgress() const; bool HasMaxQueriesInFlight() const;
// Updates the |dns_client_| config using NetworkChangeNotifier. // Updates the |dns_client_| config using NetworkChangeNotifier.
void UpdateDnsConfig(); void UpdateDnsConfig();
...@@ -111,16 +120,12 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver { ...@@ -111,16 +120,12 @@ class LogDnsClient : public net::NetworkChangeNotifier::DNSObserver {
std::unique_ptr<net::DnsClient> dns_client_; std::unique_ptr<net::DnsClient> dns_client_;
// Passed to the DNS client for logging. // Passed to the DNS client for logging.
net::NetLogWithSource net_log_; net::NetLogWithSource net_log_;
// A FIFO queue of ongoing queries. Since entries will always be appended to // The number of queries that are currently in-flight.
// the end and lookups will typically yield entries at the beginning, size_t in_flight_queries_;
// std::list is an efficient choice. // The maximum number of queries that can be in-flight at one time.
std::list<std::unique_ptr<AuditProofQuery>> audit_proof_queries_; size_t max_in_flight_queries_;
// The maximum number of queries that can be in flight at one time. // Callbacks to invoke when the number of in-flight queries is at its limit.
size_t max_concurrent_queries_; std::list<base::OnceClosure> not_throttled_callbacks_;
// Callbacks to invoke when the number of concurrent queries is at its limit.
std::list<base::Closure> not_throttled_callbacks_;
// Creates weak_ptrs to this, for callback purposes.
base::WeakPtrFactory<LogDnsClient> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(LogDnsClient); DISALLOW_COPY_AND_ASSIGN(LogDnsClient);
}; };
......
...@@ -133,8 +133,8 @@ TEST_P(LogDnsClientTest, QueryAuditProofReportsThatLogDomainDoesNotExist) { ...@@ -133,8 +133,8 @@ TEST_P(LogDnsClientTest, QueryAuditProofReportsThatLogDomainDoesNotExist) {
ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse(
kLeafIndexQnames[0], net::dns_protocol::kRcodeNXDOMAIN)); kLeafIndexQnames[0], net::dns_protocol::kRcodeNXDOMAIN));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_NAME_NOT_RESOLVED)); IsError(net::ERR_NAME_NOT_RESOLVED));
} }
...@@ -143,8 +143,8 @@ TEST_P(LogDnsClientTest, ...@@ -143,8 +143,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse(
kLeafIndexQnames[0], net::dns_protocol::kRcodeSERVFAIL)); kLeafIndexQnames[0], net::dns_protocol::kRcodeSERVFAIL));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_SERVER_FAILED)); IsError(net::ERR_DNS_SERVER_FAILED));
} }
...@@ -153,8 +153,8 @@ TEST_P(LogDnsClientTest, ...@@ -153,8 +153,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse(
kLeafIndexQnames[0], net::dns_protocol::kRcodeREFUSED)); kLeafIndexQnames[0], net::dns_protocol::kRcodeREFUSED));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_SERVER_FAILED)); IsError(net::ERR_DNS_SERVER_FAILED));
} }
...@@ -164,8 +164,8 @@ TEST_P( ...@@ -164,8 +164,8 @@ TEST_P(
ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse(
kLeafIndexQnames[0], std::vector<base::StringPiece>())); kLeafIndexQnames[0], std::vector<base::StringPiece>()));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -175,8 +175,8 @@ TEST_P( ...@@ -175,8 +175,8 @@ TEST_P(
ASSERT_TRUE( ASSERT_TRUE(
mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"123456", "7"})); mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"123456", "7"}));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -184,8 +184,8 @@ TEST_P(LogDnsClientTest, ...@@ -184,8 +184,8 @@ TEST_P(LogDnsClientTest,
QueryAuditProofReportsMalformedResponseIfLeafIndexIsNotNumeric) { QueryAuditProofReportsMalformedResponseIfLeafIndexIsNotNumeric) {
ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"foo"})); ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"foo"}));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -194,8 +194,8 @@ TEST_P(LogDnsClientTest, ...@@ -194,8 +194,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE( ASSERT_TRUE(
mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"123456.0"})); mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"123456.0"}));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -203,8 +203,8 @@ TEST_P(LogDnsClientTest, ...@@ -203,8 +203,8 @@ TEST_P(LogDnsClientTest,
QueryAuditProofReportsMalformedResponseIfLeafIndexIsEmpty) { QueryAuditProofReportsMalformedResponseIfLeafIndexIsEmpty) {
ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {""})); ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {""}));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -213,8 +213,8 @@ TEST_P(LogDnsClientTest, ...@@ -213,8 +213,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE( ASSERT_TRUE(
mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"foo123456"})); mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"foo123456"}));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -223,26 +223,26 @@ TEST_P(LogDnsClientTest, ...@@ -223,26 +223,26 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE( ASSERT_TRUE(
mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"123456foo"})); mock_dns_.ExpectRequestAndResponse(kLeafIndexQnames[0], {"123456foo"}));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLogDomainIsEmpty) { TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLogDomainIsEmpty) {
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_INVALID_ARGUMENT)); IsError(net::ERR_INVALID_ARGUMENT));
} }
TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLeafHashIsInvalid) { TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLeafHashIsInvalid) {
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", "foo", kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", "foo", kTreeSizes[0], &query),
IsError(net::ERR_INVALID_ARGUMENT)); IsError(net::ERR_INVALID_ARGUMENT));
} }
TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLeafHashIsEmpty) { TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLeafHashIsEmpty) {
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", "", kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", "", kTreeSizes[0], &query),
IsError(net::ERR_INVALID_ARGUMENT)); IsError(net::ERR_INVALID_ARGUMENT));
} }
...@@ -251,8 +251,8 @@ TEST_P(LogDnsClientTest, ...@@ -251,8 +251,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectRequestAndSocketError( ASSERT_TRUE(mock_dns_.ExpectRequestAndSocketError(
kLeafIndexQnames[0], net::ERR_CONNECTION_REFUSED)); kLeafIndexQnames[0], net::ERR_CONNECTION_REFUSED));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_CONNECTION_REFUSED)); IsError(net::ERR_CONNECTION_REFUSED));
} }
...@@ -260,8 +260,8 @@ TEST_P(LogDnsClientTest, ...@@ -260,8 +260,8 @@ TEST_P(LogDnsClientTest,
QueryAuditProofReportsTimeoutsDuringLeafIndexRequests) { QueryAuditProofReportsTimeoutsDuringLeafIndexRequests) {
ASSERT_TRUE(mock_dns_.ExpectRequestAndTimeout(kLeafIndexQnames[0])); ASSERT_TRUE(mock_dns_.ExpectRequestAndTimeout(kLeafIndexQnames[0]));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], kTreeSizes[0], &query),
IsError(net::ERR_DNS_TIMED_OUT)); IsError(net::ERR_DNS_TIMED_OUT));
} }
...@@ -284,9 +284,10 @@ TEST_P(LogDnsClientTest, QueryAuditProof) { ...@@ -284,9 +284,10 @@ TEST_P(LogDnsClientTest, QueryAuditProof) {
audit_proof.begin() + nodes_begin, audit_proof.begin() + nodes_end)); audit_proof.begin() + nodes_begin, audit_proof.begin() + nodes_end));
} }
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsOk()); IsOk());
const net::ct::MerkleAuditProof& proof = query->GetProof();
EXPECT_THAT(proof.leaf_index, Eq(123456u)); EXPECT_THAT(proof.leaf_index, Eq(123456u));
EXPECT_THAT(proof.tree_size, Eq(999999u)); EXPECT_THAT(proof.tree_size, Eq(999999u));
EXPECT_THAT(proof.nodes, Eq(audit_proof)); EXPECT_THAT(proof.nodes, Eq(audit_proof));
...@@ -319,9 +320,10 @@ TEST_P(LogDnsClientTest, QueryAuditProofHandlesResponsesWithShortAuditPaths) { ...@@ -319,9 +320,10 @@ TEST_P(LogDnsClientTest, QueryAuditProofHandlesResponsesWithShortAuditPaths) {
"13.123456.999999.tree.ct.test.", audit_proof.begin() + 13, "13.123456.999999.tree.ct.test.", audit_proof.begin() + 13,
audit_proof.end())); audit_proof.end()));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsOk()); IsOk());
const net::ct::MerkleAuditProof& proof = query->GetProof();
EXPECT_THAT(proof.leaf_index, Eq(123456u)); EXPECT_THAT(proof.leaf_index, Eq(123456u));
EXPECT_THAT(proof.tree_size, Eq(999999u)); EXPECT_THAT(proof.tree_size, Eq(999999u));
EXPECT_THAT(proof.nodes, Eq(audit_proof)); EXPECT_THAT(proof.nodes, Eq(audit_proof));
...@@ -334,8 +336,8 @@ TEST_P(LogDnsClientTest, ...@@ -334,8 +336,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse(
"0.123456.999999.tree.ct.test.", net::dns_protocol::kRcodeNXDOMAIN)); "0.123456.999999.tree.ct.test.", net::dns_protocol::kRcodeNXDOMAIN));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_NAME_NOT_RESOLVED)); IsError(net::ERR_NAME_NOT_RESOLVED));
} }
...@@ -346,8 +348,8 @@ TEST_P(LogDnsClientTest, ...@@ -346,8 +348,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse(
"0.123456.999999.tree.ct.test.", net::dns_protocol::kRcodeSERVFAIL)); "0.123456.999999.tree.ct.test.", net::dns_protocol::kRcodeSERVFAIL));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_SERVER_FAILED)); IsError(net::ERR_DNS_SERVER_FAILED));
} }
...@@ -358,8 +360,8 @@ TEST_P(LogDnsClientTest, ...@@ -358,8 +360,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndErrorResponse(
"0.123456.999999.tree.ct.test.", net::dns_protocol::kRcodeREFUSED)); "0.123456.999999.tree.ct.test.", net::dns_protocol::kRcodeREFUSED));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_SERVER_FAILED)); IsError(net::ERR_DNS_SERVER_FAILED));
} }
...@@ -373,8 +375,8 @@ TEST_P( ...@@ -373,8 +375,8 @@ TEST_P(
ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse( ASSERT_TRUE(mock_dns_.ExpectRequestAndResponse(
"0.123456.999999.tree.ct.test.", std::vector<base::StringPiece>())); "0.123456.999999.tree.ct.test.", std::vector<base::StringPiece>()));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -398,8 +400,8 @@ TEST_P( ...@@ -398,8 +400,8 @@ TEST_P(
"0.123456.999999.tree.ct.test.", "0.123456.999999.tree.ct.test.",
{first_chunk_of_proof, second_chunk_of_proof})); {first_chunk_of_proof, second_chunk_of_proof}));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -413,8 +415,8 @@ TEST_P(LogDnsClientTest, ...@@ -413,8 +415,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectAuditProofRequestAndResponse( ASSERT_TRUE(mock_dns_.ExpectAuditProofRequestAndResponse(
"0.123456.999999.tree.ct.test.", audit_proof.begin(), audit_proof.end())); "0.123456.999999.tree.ct.test.", audit_proof.begin(), audit_proof.end()));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -427,8 +429,8 @@ TEST_P(LogDnsClientTest, QueryAuditProofReportsResponseMalformedIfNodeTooLong) { ...@@ -427,8 +429,8 @@ TEST_P(LogDnsClientTest, QueryAuditProofReportsResponseMalformedIfNodeTooLong) {
ASSERT_TRUE(mock_dns_.ExpectAuditProofRequestAndResponse( ASSERT_TRUE(mock_dns_.ExpectAuditProofRequestAndResponse(
"0.123456.999999.tree.ct.test.", audit_proof.begin(), audit_proof.end())); "0.123456.999999.tree.ct.test.", audit_proof.begin(), audit_proof.end()));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -440,8 +442,8 @@ TEST_P(LogDnsClientTest, QueryAuditProofReportsResponseMalformedIfEmpty) { ...@@ -440,8 +442,8 @@ TEST_P(LogDnsClientTest, QueryAuditProofReportsResponseMalformedIfEmpty) {
ASSERT_TRUE(mock_dns_.ExpectAuditProofRequestAndResponse( ASSERT_TRUE(mock_dns_.ExpectAuditProofRequestAndResponse(
"0.123456.999999.tree.ct.test.", audit_proof.begin(), audit_proof.end())); "0.123456.999999.tree.ct.test.", audit_proof.begin(), audit_proof.end()));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_MALFORMED_RESPONSE)); IsError(net::ERR_DNS_MALFORMED_RESPONSE));
} }
...@@ -450,8 +452,8 @@ TEST_P(LogDnsClientTest, ...@@ -450,8 +452,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE( ASSERT_TRUE(
mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[0], 123456)); mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[0], 123456));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 123456, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 123456, &query),
IsError(net::ERR_INVALID_ARGUMENT)); IsError(net::ERR_INVALID_ARGUMENT));
} }
...@@ -460,8 +462,8 @@ TEST_P(LogDnsClientTest, ...@@ -460,8 +462,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE( ASSERT_TRUE(
mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[0], 999999)); mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[0], 999999));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 123456, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 123456, &query),
IsError(net::ERR_INVALID_ARGUMENT)); IsError(net::ERR_INVALID_ARGUMENT));
} }
...@@ -472,8 +474,8 @@ TEST_P(LogDnsClientTest, ...@@ -472,8 +474,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE(mock_dns_.ExpectRequestAndSocketError( ASSERT_TRUE(mock_dns_.ExpectRequestAndSocketError(
"0.123456.999999.tree.ct.test.", net::ERR_CONNECTION_REFUSED)); "0.123456.999999.tree.ct.test.", net::ERR_CONNECTION_REFUSED));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_CONNECTION_REFUSED)); IsError(net::ERR_CONNECTION_REFUSED));
} }
...@@ -484,8 +486,8 @@ TEST_P(LogDnsClientTest, ...@@ -484,8 +486,8 @@ TEST_P(LogDnsClientTest,
ASSERT_TRUE( ASSERT_TRUE(
mock_dns_.ExpectRequestAndTimeout("0.123456.999999.tree.ct.test.")); mock_dns_.ExpectRequestAndTimeout("0.123456.999999.tree.ct.test."));
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &proof), ASSERT_THAT(QueryAuditProof("ct.test", kLeafHashes[0], 999999, &query),
IsError(net::ERR_DNS_TIMED_OUT)); IsError(net::ERR_DNS_TIMED_OUT));
} }
...@@ -547,10 +549,10 @@ TEST_P(LogDnsClientTest, AdoptsLatestDnsConfigMidQuery) { ...@@ -547,10 +549,10 @@ TEST_P(LogDnsClientTest, AdoptsLatestDnsConfigMidQuery) {
LogDnsClient log_client(std::move(tmp), net::NetLogWithSource(), 0); LogDnsClient log_client(std::move(tmp), net::NetLogWithSource(), 0);
// Start query. // Start query.
net::ct::MerkleAuditProof proof; std::unique_ptr<LogDnsClient::AuditProofQuery> query;
net::TestCompletionCallback callback; net::TestCompletionCallback callback;
ASSERT_THAT(log_client.QueryAuditProof("ct.test", kLeafHashes[0], 999999, ASSERT_THAT(log_client.QueryAuditProof("ct.test", kLeafHashes[0], 999999,
&proof, callback.callback()), &query, callback.callback()),
IsError(net::ERR_IO_PENDING)); IsError(net::ERR_IO_PENDING));
// Get the current DNS config, modify it and publish the update. // Get the current DNS config, modify it and publish the update.
...@@ -566,6 +568,7 @@ TEST_P(LogDnsClientTest, AdoptsLatestDnsConfigMidQuery) { ...@@ -566,6 +568,7 @@ TEST_P(LogDnsClientTest, AdoptsLatestDnsConfigMidQuery) {
// Wait for the query to complete, then check that it was successful. // Wait for the query to complete, then check that it was successful.
// The DNS config should be updated during this time. // The DNS config should be updated during this time.
ASSERT_THAT(callback.WaitForResult(), IsOk()); ASSERT_THAT(callback.WaitForResult(), IsOk());
const net::ct::MerkleAuditProof& proof = query->GetProof();
EXPECT_THAT(proof.leaf_index, Eq(123456u)); EXPECT_THAT(proof.leaf_index, Eq(123456u));
EXPECT_THAT(proof.tree_size, Eq(999999u)); EXPECT_THAT(proof.tree_size, Eq(999999u));
EXPECT_THAT(proof.nodes, Eq(audit_proof)); EXPECT_THAT(proof.nodes, Eq(audit_proof));
...@@ -632,13 +635,13 @@ TEST_P(LogDnsClientTest, CanPerformQueriesInParallel) { ...@@ -632,13 +635,13 @@ TEST_P(LogDnsClientTest, CanPerformQueriesInParallel) {
} }
} }
net::ct::MerkleAuditProof proofs[kNumOfParallelQueries]; std::unique_ptr<LogDnsClient::AuditProofQuery> queries[kNumOfParallelQueries];
// Start the queries. // Start the queries.
for (size_t i = 0; i < kNumOfParallelQueries; ++i) { for (size_t i = 0; i < kNumOfParallelQueries; ++i) {
ASSERT_THAT( ASSERT_THAT(
log_client->QueryAuditProof("ct.test", kLeafHashes[i], kTreeSizes[i], log_client->QueryAuditProof("ct.test", kLeafHashes[i], kTreeSizes[i],
&proofs[i], callbacks[i].callback()), &queries[i], callbacks[i].callback()),
IsError(net::ERR_IO_PENDING)) IsError(net::ERR_IO_PENDING))
<< "query #" << i; << "query #" << i;
} }
...@@ -649,9 +652,10 @@ TEST_P(LogDnsClientTest, CanPerformQueriesInParallel) { ...@@ -649,9 +652,10 @@ TEST_P(LogDnsClientTest, CanPerformQueriesInParallel) {
SCOPED_TRACE(testing::Message() << "callbacks[" << i << "]"); SCOPED_TRACE(testing::Message() << "callbacks[" << i << "]");
EXPECT_THAT(callback.WaitForResult(), IsOk()); EXPECT_THAT(callback.WaitForResult(), IsOk());
EXPECT_THAT(proofs[i].leaf_index, Eq(kLeafIndices[i])); const net::ct::MerkleAuditProof& proof = queries[i]->GetProof();
EXPECT_THAT(proofs[i].tree_size, Eq(kTreeSizes[i])); EXPECT_THAT(proof.leaf_index, Eq(kLeafIndices[i]));
EXPECT_THAT(proofs[i].nodes, Eq(audit_proofs[i])); EXPECT_THAT(proof.tree_size, Eq(kTreeSizes[i]));
EXPECT_THAT(proof.nodes, Eq(audit_proofs[i]));
} }
} }
...@@ -684,20 +688,21 @@ TEST_P(LogDnsClientTest, CanBeThrottledToOneQueryAtATime) { ...@@ -684,20 +688,21 @@ TEST_P(LogDnsClientTest, CanBeThrottledToOneQueryAtATime) {
CreateLogDnsClient(kMaxConcurrentQueries); CreateLogDnsClient(kMaxConcurrentQueries);
// Try to start the queries. // Try to start the queries.
net::ct::MerkleAuditProof proof1; std::unique_ptr<LogDnsClient::AuditProofQuery> query1;
net::TestCompletionCallback callback1; net::TestCompletionCallback callback1;
ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[0], 999999, ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[0], 999999,
&proof1, callback1.callback()), &query1, callback1.callback()),
IsError(net::ERR_IO_PENDING)); IsError(net::ERR_IO_PENDING));
net::ct::MerkleAuditProof proof2; std::unique_ptr<LogDnsClient::AuditProofQuery> query2;
net::TestCompletionCallback callback2; net::TestCompletionCallback callback2;
ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[1], 999999, ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[1], 999999,
&proof2, callback2.callback()), &query2, callback2.callback()),
IsError(net::ERR_TEMPORARILY_THROTTLED)); IsError(net::ERR_TEMPORARILY_THROTTLED));
// Check that the first query succeeded. // Check that the first query succeeded.
EXPECT_THAT(callback1.WaitForResult(), IsOk()); EXPECT_THAT(callback1.WaitForResult(), IsOk());
const net::ct::MerkleAuditProof& proof1 = query1->GetProof();
EXPECT_THAT(proof1.leaf_index, Eq(123456u)); EXPECT_THAT(proof1.leaf_index, Eq(123456u));
EXPECT_THAT(proof1.tree_size, Eq(999999u)); EXPECT_THAT(proof1.tree_size, Eq(999999u));
EXPECT_THAT(proof1.nodes, Eq(audit_proof)); EXPECT_THAT(proof1.nodes, Eq(audit_proof));
...@@ -715,14 +720,15 @@ TEST_P(LogDnsClientTest, CanBeThrottledToOneQueryAtATime) { ...@@ -715,14 +720,15 @@ TEST_P(LogDnsClientTest, CanBeThrottledToOneQueryAtATime) {
"14.666.999999.tree.ct.test.", audit_proof.begin() + 14, "14.666.999999.tree.ct.test.", audit_proof.begin() + 14,
audit_proof.end())); audit_proof.end()));
net::ct::MerkleAuditProof proof3; std::unique_ptr<LogDnsClient::AuditProofQuery> query3;
net::TestCompletionCallback callback3; net::TestCompletionCallback callback3;
ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[2], 999999, ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[2], 999999,
&proof3, callback3.callback()), &query3, callback3.callback()),
IsError(net::ERR_IO_PENDING)); IsError(net::ERR_IO_PENDING));
// Check that the third query succeeded. // Check that the third query succeeded.
EXPECT_THAT(callback3.WaitForResult(), IsOk()); EXPECT_THAT(callback3.WaitForResult(), IsOk());
const net::ct::MerkleAuditProof& proof3 = query3->GetProof();
EXPECT_THAT(proof3.leaf_index, Eq(666u)); EXPECT_THAT(proof3.leaf_index, Eq(666u));
EXPECT_THAT(proof3.tree_size, Eq(999999u)); EXPECT_THAT(proof3.tree_size, Eq(999999u));
EXPECT_THAT(proof3.nodes, Eq(audit_proof)); EXPECT_THAT(proof3.nodes, Eq(audit_proof));
...@@ -748,16 +754,16 @@ TEST_P(LogDnsClientTest, NotifiesWhenNoLongerThrottled) { ...@@ -748,16 +754,16 @@ TEST_P(LogDnsClientTest, NotifiesWhenNoLongerThrottled) {
CreateLogDnsClient(kMaxConcurrentQueries); CreateLogDnsClient(kMaxConcurrentQueries);
// Start a query. // Start a query.
net::ct::MerkleAuditProof proof1; std::unique_ptr<LogDnsClient::AuditProofQuery> query1;
net::TestCompletionCallback proof_callback1; net::TestCompletionCallback query_callback1;
ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[0], 999999, ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[0], 999999,
&proof1, proof_callback1.callback()), &query1, query_callback1.callback()),
IsError(net::ERR_IO_PENDING)); IsError(net::ERR_IO_PENDING));
net::TestClosure not_throttled_callback; net::TestClosure not_throttled_callback;
log_client->NotifyWhenNotThrottled(not_throttled_callback.closure()); log_client->NotifyWhenNotThrottled(not_throttled_callback.closure());
ASSERT_THAT(proof_callback1.WaitForResult(), IsOk()); ASSERT_THAT(query_callback1.WaitForResult(), IsOk());
not_throttled_callback.WaitForResult(); not_throttled_callback.WaitForResult();
// Start another query to check |not_throttled_callback| doesn't fire again. // Start another query to check |not_throttled_callback| doesn't fire again.
...@@ -773,19 +779,44 @@ TEST_P(LogDnsClientTest, NotifiesWhenNoLongerThrottled) { ...@@ -773,19 +779,44 @@ TEST_P(LogDnsClientTest, NotifiesWhenNoLongerThrottled) {
"14.666.999999.tree.ct.test.", audit_proof.begin() + 14, "14.666.999999.tree.ct.test.", audit_proof.begin() + 14,
audit_proof.end())); audit_proof.end()));
net::ct::MerkleAuditProof proof2; std::unique_ptr<LogDnsClient::AuditProofQuery> query2;
net::TestCompletionCallback proof_callback2; net::TestCompletionCallback query_callback2;
ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[1], 999999, ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[1], 999999,
&proof2, proof_callback2.callback()), &query2, query_callback2.callback()),
IsError(net::ERR_IO_PENDING)); IsError(net::ERR_IO_PENDING));
// Give the query a chance to run. // Give the query a chance to run.
ASSERT_THAT(proof_callback2.WaitForResult(), IsOk()); ASSERT_THAT(query_callback2.WaitForResult(), IsOk());
// Give |not_throttled_callback| a chance to run - it shouldn't though. // Give |not_throttled_callback| a chance to run - it shouldn't though.
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
ASSERT_FALSE(not_throttled_callback.have_result()); ASSERT_FALSE(not_throttled_callback.have_result());
} }
TEST_P(LogDnsClientTest, CanCancelQueries) {
const size_t kMaxConcurrentQueries = 1;
std::unique_ptr<LogDnsClient> log_client =
CreateLogDnsClient(kMaxConcurrentQueries);
// Expect the first request of the query to be sent, but not the rest because
// it'll be cancelled before it gets that far.
ASSERT_TRUE(
mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[0], 123456));
// Start query.
std::unique_ptr<LogDnsClient::AuditProofQuery> query;
net::TestCompletionCallback callback;
ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[0], 999999,
&query, callback.callback()),
IsError(net::ERR_IO_PENDING));
// Cancel the query.
query.reset();
// Give |callback| a chance to run - it shouldn't though.
base::RunLoop().RunUntilIdle();
ASSERT_FALSE(callback.have_result());
}
INSTANTIATE_TEST_CASE_P(ReadMode, INSTANTIATE_TEST_CASE_P(ReadMode,
LogDnsClientTest, LogDnsClientTest,
::testing::Values(net::IoMode::ASYNC, ::testing::Values(net::IoMode::ASYNC,
......
...@@ -202,8 +202,9 @@ struct SingleTreeTracker::EntryAuditState { ...@@ -202,8 +202,9 @@ struct SingleTreeTracker::EntryAuditState {
// Current phase of inclusion check. // Current phase of inclusion check.
AuditState state; AuditState state;
// The proof to be filled in by the LogDnsClient // The audit proof query performed by LogDnsClient.
MerkleAuditProof proof; // It is null unless a query has been started.
std::unique_ptr<LogDnsClient::AuditProofQuery> audit_proof_query;
// The root hash of the tree for which an inclusion proof was requested. // The root hash of the tree for which an inclusion proof was requested.
// The root hash is needed after the inclusion proof is fetched for validating // The root hash is needed after the inclusion proof is fetched for validating
...@@ -414,9 +415,9 @@ void SingleTreeTracker::ProcessPendingEntries() { ...@@ -414,9 +415,9 @@ void SingleTreeTracker::ProcessPendingEntries() {
crypto::kSHA256Length); crypto::kSHA256Length);
net::Error result = dns_client_->QueryAuditProof( net::Error result = dns_client_->QueryAuditProof(
ct_log_->dns_domain(), leaf_hash, verified_sth_.tree_size, ct_log_->dns_domain(), leaf_hash, verified_sth_.tree_size,
&(it->second.proof), &(it->second.audit_proof_query),
base::Bind(&SingleTreeTracker::OnAuditProofObtained, base::Bind(&SingleTreeTracker::OnAuditProofObtained,
weak_factory_.GetWeakPtr(), it->first)); base::Unretained(this), it->first));
// Handling proofs returned synchronously is not implemeted. // Handling proofs returned synchronously is not implemeted.
DCHECK_NE(result, net::OK); DCHECK_NE(result, net::OK);
if (result == net::ERR_IO_PENDING) { if (result == net::ERR_IO_PENDING) {
...@@ -424,8 +425,11 @@ void SingleTreeTracker::ProcessPendingEntries() { ...@@ -424,8 +425,11 @@ void SingleTreeTracker::ProcessPendingEntries() {
// and continue to the next one. // and continue to the next one.
it->second.state = INCLUSION_PROOF_REQUESTED; it->second.state = INCLUSION_PROOF_REQUESTED;
} else if (result == net::ERR_TEMPORARILY_THROTTLED) { } else if (result == net::ERR_TEMPORARILY_THROTTLED) {
// Need to use a weak pointer here, as this callback could be triggered
// when the SingleTreeTracker is deleted (and pending queries are
// cancelled).
dns_client_->NotifyWhenNotThrottled( dns_client_->NotifyWhenNotThrottled(
base::Bind(&SingleTreeTracker::ProcessPendingEntries, base::BindOnce(&SingleTreeTracker::ProcessPendingEntries,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
// Exit the loop since all subsequent calls to QueryAuditProof // Exit the loop since all subsequent calls to QueryAuditProof
// will be throttled. // will be throttled.
...@@ -494,7 +498,8 @@ void SingleTreeTracker::OnAuditProofObtained(const EntryToAudit& entry, ...@@ -494,7 +498,8 @@ void SingleTreeTracker::OnAuditProofObtained(const EntryToAudit& entry,
std::string leaf_hash(reinterpret_cast<const char*>(entry.leaf_hash.data), std::string leaf_hash(reinterpret_cast<const char*>(entry.leaf_hash.data),
crypto::kSHA256Length); crypto::kSHA256Length);
bool verified = ct_log_->VerifyAuditProof(it->second.proof, bool verified =
ct_log_->VerifyAuditProof(it->second.audit_proof_query->GetProof(),
it->second.root_hash, leaf_hash); it->second.root_hash, leaf_hash);
LogAuditResultToNetLog(entry, verified); LogAuditResultToNetLog(entry, verified);
...@@ -514,7 +519,7 @@ void SingleTreeTracker::OnMemoryPressure( ...@@ -514,7 +519,7 @@ void SingleTreeTracker::OnMemoryPressure(
case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_NONE: case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_NONE:
break; break;
case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_CRITICAL: case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_CRITICAL:
pending_entries_.clear(); ResetPendingQueue();
// Fall through to clearing the other cache. // Fall through to clearing the other cache.
FALLTHROUGH; FALLTHROUGH;
case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_MODERATE: case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_MODERATE:
......
...@@ -655,6 +655,45 @@ TEST_F(SingleTreeTrackerTest, TestEntryIncludedAfterInclusionCheckSuccess) { ...@@ -655,6 +655,45 @@ TEST_F(SingleTreeTrackerTest, TestEntryIncludedAfterInclusionCheckSuccess) {
net_log_, LeafHash(chain_.get(), cert_sct_.get()), true)); net_log_, LeafHash(chain_.get(), cert_sct_.get()), true));
} }
// Tests that inclusion checks are aborted and SCTs discarded if under critical
// memory pressure.
TEST_F(SingleTreeTrackerTest,
TestInclusionCheckCancelledIfUnderMemoryPressure) {
CreateTreeTracker();
AddCacheEntry(host_resolver_.GetHostCache(), kHostname,
net::HostCache::Entry::SOURCE_DNS, kZeroTTL);
tree_tracker_->OnSCTVerified(kHostname, chain_.get(), cert_sct_.get());
EXPECT_EQ(
SingleTreeTracker::SCT_PENDING_NEWER_STH,
tree_tracker_->GetLogEntryInclusionStatus(chain_.get(), cert_sct_.get()));
// Provide with a fresh STH, which is for a tree of size 2.
SignedTreeHead sth;
ASSERT_TRUE(GetSignedTreeHeadForTreeOfSize2(&sth));
ASSERT_TRUE(log_->VerifySignedTreeHead(sth));
// Make the first event that is processed a critical memory pressure
// notification. This should be handled before the response to the first DNS
// request, so no requests after the first one should be sent (the leaf index
// request).
base::MemoryPressureListener::NotifyMemoryPressure(
base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_CRITICAL);
ASSERT_TRUE(mock_dns_.ExpectLeafIndexRequestAndResponse(
Base32LeafHash(chain_.get(), cert_sct_.get()) + ".hash." +
kDNSRequestSuffix,
0));
tree_tracker_->NewSTHObserved(sth);
base::RunLoop().RunUntilIdle();
// Expect the SCT to have been discarded.
EXPECT_EQ(
SingleTreeTracker::SCT_NOT_OBSERVED,
tree_tracker_->GetLogEntryInclusionStatus(chain_.get(), cert_sct_.get()));
}
// Test that pending entries transition states correctly according to the // Test that pending entries transition states correctly according to the
// STHs provided: // STHs provided:
// * Start without an STH. // * Start without an STH.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment