Commit ed0b9e03 authored by Bence Béky's avatar Bence Béky Committed by Commit Bot

Use OnceCallback in ChannelIDService.

Use OnceCallback instead of Callback, and CompletionOnceCallback instead
of CompletionCallback, in ChannelIDService.  Also restructure

ChannelIDService::JoinToInFlightRequest() to use early return if in
flight request in not found.

Bug: 807724
Change-Id: I45b55b0713ca9460fda8480d4f8ac14d5ead4435
Reviewed-on: https://chromium-review.googlesource.com/1117029Reviewed-by: default avatarMaks Orlovich <morlovich@chromium.org>
Commit-Queue: Bence Béky <bnc@chromium.org>
Cr-Commit-Position: refs/heads/master@{#570844}
parent 319c61a3
...@@ -67,29 +67,30 @@ std::unique_ptr<ChannelIDStore::ChannelID> GenerateChannelID( ...@@ -67,29 +67,30 @@ std::unique_ptr<ChannelIDStore::ChannelID> GenerateChannelID(
// generation. Will take care of deleting itself once Start() is called. // generation. Will take care of deleting itself once Start() is called.
class ChannelIDServiceWorker { class ChannelIDServiceWorker {
public: public:
typedef base::Callback< typedef base::OnceCallback<
void(const std::string&, int, std::unique_ptr<ChannelIDStore::ChannelID>)> void(const std::string&, int, std::unique_ptr<ChannelIDStore::ChannelID>)>
WorkerDoneCallback; WorkerDoneCallback;
ChannelIDServiceWorker(const std::string& server_identifier, ChannelIDServiceWorker(const std::string& server_identifier,
const WorkerDoneCallback& callback) WorkerDoneCallback callback)
: server_identifier_(server_identifier), : server_identifier_(server_identifier),
origin_task_runner_(base::ThreadTaskRunnerHandle::Get()), origin_task_runner_(base::ThreadTaskRunnerHandle::Get()),
callback_(callback) {} callback_(std::move(callback)) {}
// Starts the worker asynchronously. // Starts the worker asynchronously.
void Start(const scoped_refptr<base::TaskRunner>& task_runner) { void Start(const scoped_refptr<base::TaskRunner>& task_runner) {
DCHECK(origin_task_runner_->RunsTasksInCurrentSequence()); DCHECK(origin_task_runner_->RunsTasksInCurrentSequence());
auto callback = base::Bind(&ChannelIDServiceWorker::Run, base::Owned(this)); auto callback =
base::BindOnce(&ChannelIDServiceWorker::Run, base::Owned(this));
if (task_runner) { if (task_runner) {
task_runner->PostTask(FROM_HERE, callback); task_runner->PostTask(FROM_HERE, std::move(callback));
} else { } else {
base::PostTaskWithTraits( base::PostTaskWithTraits(
FROM_HERE, FROM_HERE,
{base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}, {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
callback); std::move(callback));
} }
} }
...@@ -100,8 +101,8 @@ class ChannelIDServiceWorker { ...@@ -100,8 +101,8 @@ class ChannelIDServiceWorker {
std::unique_ptr<ChannelIDStore::ChannelID> channel_id = std::unique_ptr<ChannelIDStore::ChannelID> channel_id =
GenerateChannelID(server_identifier_, &error); GenerateChannelID(server_identifier_, &error);
origin_task_runner_->PostTask( origin_task_runner_->PostTask(
FROM_HERE, base::Bind(callback_, server_identifier_, error, FROM_HERE, base::BindOnce(std::move(callback_), server_identifier_,
base::Passed(&channel_id))); error, base::Passed(&channel_id)));
} }
const std::string server_identifier_; const std::string server_identifier_;
...@@ -176,12 +177,12 @@ void ChannelIDService::Request::Cancel() { ...@@ -176,12 +177,12 @@ void ChannelIDService::Request::Cancel() {
void ChannelIDService::Request::RequestStarted( void ChannelIDService::Request::RequestStarted(
ChannelIDService* service, ChannelIDService* service,
const CompletionCallback& callback, CompletionOnceCallback callback,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
ChannelIDServiceJob* job) { ChannelIDServiceJob* job) {
DCHECK(service_ == NULL); DCHECK(service_ == NULL);
service_ = service; service_ = service;
callback_ = callback; callback_ = std::move(callback);
key_ = key; key_ = key;
job_ = job; job_ = job;
} }
...@@ -225,7 +226,7 @@ std::string ChannelIDService::GetDomainForHost(const std::string& host) { ...@@ -225,7 +226,7 @@ std::string ChannelIDService::GetDomainForHost(const std::string& host) {
int ChannelIDService::GetOrCreateChannelID( int ChannelIDService::GetOrCreateChannelID(
const std::string& host, const std::string& host,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
const CompletionCallback& callback, CompletionOnceCallback callback,
Request* out_req) { Request* out_req) {
DVLOG(1) << __func__ << " " << host; DVLOG(1) << __func__ << " " << host;
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
...@@ -243,19 +244,18 @@ int ChannelIDService::GetOrCreateChannelID( ...@@ -243,19 +244,18 @@ int ChannelIDService::GetOrCreateChannelID(
// See if a request for the same domain is currently in flight. // See if a request for the same domain is currently in flight.
bool create_if_missing = true; bool create_if_missing = true;
if (JoinToInFlightRequest(domain, key, create_if_missing, callback, if (JoinToInFlightRequest(domain, key, create_if_missing, &callback,
out_req)) { out_req)) {
return ERR_IO_PENDING; return ERR_IO_PENDING;
} }
int err = LookupChannelID(domain, key, create_if_missing, callback, out_req); int err = LookupChannelID(domain, key, create_if_missing, &callback, out_req);
if (err == ERR_FILE_NOT_FOUND) { if (err == ERR_FILE_NOT_FOUND) {
// Sync lookup did not find a valid channel ID. Start generating a new one. // Sync lookup did not find a valid channel ID. Start generating a new one.
workers_created_++; workers_created_++;
ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( ChannelIDServiceWorker* worker = new ChannelIDServiceWorker(
domain, domain, base::BindOnce(&ChannelIDService::GeneratedChannelID,
base::Bind(&ChannelIDService::GeneratedChannelID, weak_ptr_factory_.GetWeakPtr()));
weak_ptr_factory_.GetWeakPtr()));
worker->Start(task_runner_); worker->Start(task_runner_);
// We are waiting for key generation. Create a job & request to track it. // We are waiting for key generation. Create a job & request to track it.
...@@ -263,7 +263,7 @@ int ChannelIDService::GetOrCreateChannelID( ...@@ -263,7 +263,7 @@ int ChannelIDService::GetOrCreateChannelID(
inflight_[domain] = base::WrapUnique(job); inflight_[domain] = base::WrapUnique(job);
job->AddRequest(out_req); job->AddRequest(out_req);
out_req->RequestStarted(this, callback, key, job); out_req->RequestStarted(this, std::move(callback), key, job);
return ERR_IO_PENDING; return ERR_IO_PENDING;
} }
...@@ -272,7 +272,7 @@ int ChannelIDService::GetOrCreateChannelID( ...@@ -272,7 +272,7 @@ int ChannelIDService::GetOrCreateChannelID(
int ChannelIDService::GetChannelID(const std::string& host, int ChannelIDService::GetChannelID(const std::string& host,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
const CompletionCallback& callback, CompletionOnceCallback callback,
Request* out_req) { Request* out_req) {
DVLOG(1) << __func__ << " " << host; DVLOG(1) << __func__ << " " << host;
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
...@@ -290,12 +290,12 @@ int ChannelIDService::GetChannelID(const std::string& host, ...@@ -290,12 +290,12 @@ int ChannelIDService::GetChannelID(const std::string& host,
// See if a request for the same domain currently in flight. // See if a request for the same domain currently in flight.
bool create_if_missing = false; bool create_if_missing = false;
if (JoinToInFlightRequest(domain, key, create_if_missing, callback, if (JoinToInFlightRequest(domain, key, create_if_missing, &callback,
out_req)) { out_req)) {
return ERR_IO_PENDING; return ERR_IO_PENDING;
} }
int err = LookupChannelID(domain, key, create_if_missing, callback, out_req); int err = LookupChannelID(domain, key, create_if_missing, &callback, out_req);
return err; return err;
} }
...@@ -328,9 +328,8 @@ void ChannelIDService::GotChannelID(int err, ...@@ -328,9 +328,8 @@ void ChannelIDService::GotChannelID(int err,
// one. // one.
workers_created_++; workers_created_++;
ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( ChannelIDServiceWorker* worker = new ChannelIDServiceWorker(
server_identifier, server_identifier, base::BindOnce(&ChannelIDService::GeneratedChannelID,
base::Bind(&ChannelIDService::GeneratedChannelID, weak_ptr_factory_.GetWeakPtr()));
weak_ptr_factory_.GetWeakPtr()));
worker->Start(task_runner_); worker->Start(task_runner_);
} }
...@@ -372,29 +371,27 @@ bool ChannelIDService::JoinToInFlightRequest( ...@@ -372,29 +371,27 @@ bool ChannelIDService::JoinToInFlightRequest(
const std::string& domain, const std::string& domain,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
bool create_if_missing, bool create_if_missing,
const CompletionCallback& callback, CompletionOnceCallback* callback,
Request* out_req) { Request* out_req) {
ChannelIDServiceJob* job = NULL;
auto j = inflight_.find(domain); auto j = inflight_.find(domain);
if (j != inflight_.end()) { if (j == inflight_.end())
// A request for the same domain is in flight already. We'll attach our return false;
// callback, but we'll also mark it as requiring a channel ID if one's
// mising. // A request for the same domain is in flight already. We'll attach our
job = j->second.get(); // callback, but we'll also mark it as requiring a channel ID if one's mising.
inflight_joins_++; ChannelIDServiceJob* job = j->second.get();
inflight_joins_++;
job->AddRequest(out_req, create_if_missing);
out_req->RequestStarted(this, callback, key, job); job->AddRequest(out_req, create_if_missing);
return true; out_req->RequestStarted(this, std::move(*callback), key, job);
} return true;
return false;
} }
int ChannelIDService::LookupChannelID( int ChannelIDService::LookupChannelID(
const std::string& domain, const std::string& domain,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
bool create_if_missing, bool create_if_missing,
const CompletionCallback& callback, CompletionOnceCallback* callback,
Request* out_req) { Request* out_req) {
// Check if a channel ID key already exists for this domain. // Check if a channel ID key already exists for this domain.
int err = channel_id_store_->GetChannelID( int err = channel_id_store_->GetChannelID(
...@@ -414,7 +411,7 @@ int ChannelIDService::LookupChannelID( ...@@ -414,7 +411,7 @@ int ChannelIDService::LookupChannelID(
inflight_[domain] = base::WrapUnique(job); inflight_[domain] = base::WrapUnique(job);
job->AddRequest(out_req); job->AddRequest(out_req);
out_req->RequestStarted(this, callback, key, job); out_req->RequestStarted(this, std::move(*callback), key, job);
return ERR_IO_PENDING; return ERR_IO_PENDING;
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/task_runner.h" #include "base/task_runner.h"
#include "base/threading/thread_checker.h" #include "base/threading/thread_checker.h"
#include "net/base/completion_callback.h" #include "net/base/completion_once_callback.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/ssl/channel_id_store.h" #include "net/ssl/channel_id_store.h"
...@@ -49,14 +49,14 @@ class NET_EXPORT ChannelIDService { ...@@ -49,14 +49,14 @@ class NET_EXPORT ChannelIDService {
friend class ChannelIDServiceJob; friend class ChannelIDServiceJob;
void RequestStarted(ChannelIDService* service, void RequestStarted(ChannelIDService* service,
const CompletionCallback& callback, CompletionOnceCallback callback,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
ChannelIDServiceJob* job); ChannelIDServiceJob* job);
void Post(int error, std::unique_ptr<crypto::ECPrivateKey> key); void Post(int error, std::unique_ptr<crypto::ECPrivateKey> key);
ChannelIDService* service_; ChannelIDService* service_;
CompletionCallback callback_; CompletionOnceCallback callback_;
std::unique_ptr<crypto::ECPrivateKey>* key_; std::unique_ptr<crypto::ECPrivateKey>* key_;
ChannelIDServiceJob* job_; ChannelIDServiceJob* job_;
}; };
...@@ -91,7 +91,7 @@ class NET_EXPORT ChannelIDService { ...@@ -91,7 +91,7 @@ class NET_EXPORT ChannelIDService {
// |*out_req| will be initialized with a handle to the async request. // |*out_req| will be initialized with a handle to the async request.
int GetOrCreateChannelID(const std::string& host, int GetOrCreateChannelID(const std::string& host,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
const CompletionCallback& callback, CompletionOnceCallback callback,
Request* out_req); Request* out_req);
// Fetches the channel ID for the specified host if one exists. // Fetches the channel ID for the specified host if one exists.
...@@ -111,7 +111,7 @@ class NET_EXPORT ChannelIDService { ...@@ -111,7 +111,7 @@ class NET_EXPORT ChannelIDService {
// |*out_req| will be initialized with a handle to the async request. // |*out_req| will be initialized with a handle to the async request.
int GetChannelID(const std::string& host, int GetChannelID(const std::string& host,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
const CompletionCallback& callback, CompletionOnceCallback callback,
Request* out_req); Request* out_req);
// Returns the backing ChannelIDStore. // Returns the backing ChannelIDStore.
...@@ -140,23 +140,24 @@ class NET_EXPORT ChannelIDService { ...@@ -140,23 +140,24 @@ class NET_EXPORT ChannelIDService {
const std::string& server_identifier, const std::string& server_identifier,
std::unique_ptr<crypto::ECPrivateKey> key); std::unique_ptr<crypto::ECPrivateKey> key);
// Searches for an in-flight request for the same domain. If found, // Searches for an in-flight request for the same domain. If found, attaches
// attaches to the request and returns true. Returns false if no in-flight // to the request, consumes |*callback|, and returns true. Otherwise does not
// request is found. // consume |*callback| and returns false.
bool JoinToInFlightRequest(const std::string& domain, bool JoinToInFlightRequest(const std::string& domain,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
bool create_if_missing, bool create_if_missing,
const CompletionCallback& callback, CompletionOnceCallback* callback,
Request* out_req); Request* out_req);
// Looks for the channel ID for |domain| in this service's store. // Looks for the channel ID for |domain| in this service's store. Returns OK
// Returns OK if it can be found synchronously, ERR_IO_PENDING if the // if it can be found synchronously, ERR_IO_PENDING if the result cannot be
// result cannot be obtained synchronously, or a network error code on // obtained synchronously, or a different network error code on failure
// failure (including failure to find a channel ID of |domain|). // (including failure to find a channel ID of |domain|). Consumes |*callback|
// if and only if ERR_IO_PENDING is returned.
int LookupChannelID(const std::string& domain, int LookupChannelID(const std::string& domain,
std::unique_ptr<crypto::ECPrivateKey>* key, std::unique_ptr<crypto::ECPrivateKey>* key,
bool create_if_missing, bool create_if_missing,
const CompletionCallback& callback, CompletionOnceCallback* callback,
Request* out_req); Request* out_req);
std::unique_ptr<ChannelIDStore> channel_id_store_; std::unique_ptr<ChannelIDStore> channel_id_store_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment