Commit 075d9579 authored by Michael Ershov's avatar Michael Ershov Committed by Commit Bot

Cert Provisioning: Check existing certs in UpdateOneCert

Check existing certs in UpdateOneCert to avoid creating a worker
for certificate profile that already has a provisioned certificate.
Also add a unit test for UpdateOneCert method.

Bug: 1045895
Test: CertProvisioning*
Change-Id: I284bec0c90766673b6ba3a2c1ac69f28303258ce
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2245150
Commit-Queue: Michael Ershov <miersh@google.com>
Reviewed-by: default avatarPavol Marko <pmarko@chromium.org>
Cr-Commit-Position: refs/heads/master@{#781407}
parent 3530c216
...@@ -88,6 +88,14 @@ bool CertProfile::operator!=(const CertProfile& other) const { ...@@ -88,6 +88,14 @@ bool CertProfile::operator!=(const CertProfile& other) const {
return !(*this == other); return !(*this == other);
} }
bool CertProfileComparator::operator()(const CertProfile& a,
const CertProfile& b) const {
static_assert(CertProfile::kVersion == 3, "This function should be updated");
return ((a.profile_id < b.profile_id) ||
(a.policy_version < b.policy_version) ||
(a.is_va_enabled < b.is_va_enabled));
}
//============================================================================== //==============================================================================
void RegisterProfilePrefs(PrefRegistrySimple* registry) { void RegisterProfilePrefs(PrefRegistrySimple* registry) {
......
...@@ -83,6 +83,10 @@ struct CertProfile { ...@@ -83,6 +83,10 @@ struct CertProfile {
bool operator!=(const CertProfile& other) const; bool operator!=(const CertProfile& other) const;
}; };
struct CertProfileComparator {
bool operator()(const CertProfile& a, const CertProfile& b) const;
};
void RegisterProfilePrefs(PrefRegistrySimple* registry); void RegisterProfilePrefs(PrefRegistrySimple* registry);
void RegisterLocalStatePrefs(PrefRegistrySimple* registry); void RegisterLocalStatePrefs(PrefRegistrySimple* registry);
const char* GetPrefNameForSerialization(CertScope scope); const char* GetPrefNameForSerialization(CertScope scope);
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_platform_keys_helpers.h" #include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_platform_keys_helpers.h"
#include <memory> #include <memory>
#include "base/bind.h" #include "base/bind.h"
#include "base/check.h" #include "base/check.h"
#include "base/containers/flat_set.h"
#include "base/stl_util.h" #include "base/stl_util.h"
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h" #include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h" #include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
...@@ -74,7 +76,7 @@ void CertProvisioningCertsWithIdsGetter::OnGetCertificatesDone( ...@@ -74,7 +76,7 @@ void CertProvisioningCertsWithIdsGetter::OnGetCertificatesDone(
void CertProvisioningCertsWithIdsGetter::CollectOneResult( void CertProvisioningCertsWithIdsGetter::CollectOneResult(
scoped_refptr<net::X509Certificate> cert, scoped_refptr<net::X509Certificate> cert,
const std::string& cert_id, const CertProfileId& cert_id,
const std::string& error_message) { const std::string& error_message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(wait_counter_ > 0); DCHECK(wait_counter_ > 0);
...@@ -110,7 +112,7 @@ CertProvisioningCertDeleter::~CertProvisioningCertDeleter() = default; ...@@ -110,7 +112,7 @@ CertProvisioningCertDeleter::~CertProvisioningCertDeleter() = default;
void CertProvisioningCertDeleter::DeleteCerts( void CertProvisioningCertDeleter::DeleteCerts(
CertScope cert_scope, CertScope cert_scope,
platform_keys::PlatformKeysService* platform_keys_service, platform_keys::PlatformKeysService* platform_keys_service,
const std::set<std::string>& cert_profile_ids_to_keep, base::flat_set<CertProfileId> cert_profile_ids_to_keep,
DeleteCertsCallback callback) { DeleteCertsCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(platform_keys_service); DCHECK(platform_keys_service);
...@@ -119,7 +121,7 @@ void CertProvisioningCertDeleter::DeleteCerts( ...@@ -119,7 +121,7 @@ void CertProvisioningCertDeleter::DeleteCerts(
cert_scope_ = cert_scope; cert_scope_ = cert_scope;
platform_keys_service_ = platform_keys_service; platform_keys_service_ = platform_keys_service;
callback_ = std::move(callback); callback_ = std::move(callback);
cert_profile_ids_to_keep_ = cert_profile_ids_to_keep; cert_profile_ids_to_keep_ = std::move(cert_profile_ids_to_keep);
cert_getter_ = std::make_unique<CertProvisioningCertsWithIdsGetter>(); cert_getter_ = std::make_unique<CertProvisioningCertsWithIdsGetter>();
cert_getter_->GetCertsWithIds( cert_getter_->GetCertsWithIds(
...@@ -129,7 +131,8 @@ void CertProvisioningCertDeleter::DeleteCerts( ...@@ -129,7 +131,8 @@ void CertProvisioningCertDeleter::DeleteCerts(
} }
void CertProvisioningCertDeleter::OnGetCertsWithIdsDone( void CertProvisioningCertDeleter::OnGetCertsWithIdsDone(
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids, base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids,
const std::string& error_message) { const std::string& error_message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
...@@ -146,7 +149,7 @@ void CertProvisioningCertDeleter::OnGetCertsWithIdsDone( ...@@ -146,7 +149,7 @@ void CertProvisioningCertDeleter::OnGetCertsWithIdsDone(
wait_counter_ = certs_with_ids.size(); wait_counter_ = certs_with_ids.size();
for (const auto& kv : certs_with_ids) { for (const auto& kv : certs_with_ids) {
const std::string& cert_id = kv.first; const CertProfileId& cert_id = kv.first;
if (base::Contains(cert_profile_ids_to_keep_, cert_id)) { if (base::Contains(cert_profile_ids_to_keep_, cert_id)) {
AccountOneResult(); AccountOneResult();
continue; continue;
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_PLATFORM_KEYS_HELPERS_H_ #define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_PLATFORM_KEYS_HELPERS_H_
#include "base/callback.h" #include "base/callback.h"
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h" #include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h"
#include "net/cert/x509_certificate.h" #include "net/cert/x509_certificate.h"
...@@ -22,7 +24,8 @@ namespace cert_provisioning { ...@@ -22,7 +24,8 @@ namespace cert_provisioning {
// ========= CertProvisioningCertsWithIdsGetter ================================ // ========= CertProvisioningCertsWithIdsGetter ================================
using GetCertsWithIdsCallback = base::OnceCallback<void( using GetCertsWithIdsCallback = base::OnceCallback<void(
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids, base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids,
const std::string& error_message)>; const std::string& error_message)>;
// Helper class that retrieves list of all certificates in a given scope with // Helper class that retrieves list of all certificates in a given scope with
...@@ -49,14 +52,15 @@ class CertProvisioningCertsWithIdsGetter { ...@@ -49,14 +52,15 @@ class CertProvisioningCertsWithIdsGetter {
const std::string& error_message); const std::string& error_message);
void CollectOneResult(scoped_refptr<net::X509Certificate> cert, void CollectOneResult(scoped_refptr<net::X509Certificate> cert,
const std::string& cert_id, const CertProfileId& cert_id,
const std::string& error_message); const std::string& error_message);
CertScope cert_scope_ = CertScope::kDevice; CertScope cert_scope_ = CertScope::kDevice;
platform_keys::PlatformKeysService* platform_keys_service_ = nullptr; platform_keys::PlatformKeysService* platform_keys_service_ = nullptr;
size_t wait_counter_ = 0; size_t wait_counter_ = 0;
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids_; base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids_;
GetCertsWithIdsCallback callback_; GetCertsWithIdsCallback callback_;
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
...@@ -81,12 +85,13 @@ class CertProvisioningCertDeleter { ...@@ -81,12 +85,13 @@ class CertProvisioningCertDeleter {
void DeleteCerts(CertScope cert_scope, void DeleteCerts(CertScope cert_scope,
platform_keys::PlatformKeysService* platform_keys_service, platform_keys::PlatformKeysService* platform_keys_service,
const std::set<std::string>& cert_profile_ids_to_keep, base::flat_set<CertProfileId> cert_profile_ids_to_keep,
DeleteCertsCallback callback); DeleteCertsCallback callback);
private: private:
void OnGetCertsWithIdsDone( void OnGetCertsWithIdsDone(
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids, base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids,
const std::string& error_message); const std::string& error_message);
void OnRemoveCertificateDone(const std::string& error_message); void OnRemoveCertificateDone(const std::string& error_message);
...@@ -99,7 +104,7 @@ class CertProvisioningCertDeleter { ...@@ -99,7 +104,7 @@ class CertProvisioningCertDeleter {
platform_keys::PlatformKeysService* platform_keys_service_ = nullptr; platform_keys::PlatformKeysService* platform_keys_service_ = nullptr;
size_t wait_counter_ = 0; size_t wait_counter_ = 0;
std::set<std::string> cert_profile_ids_to_keep_; base::flat_set<CertProfileId> cert_profile_ids_to_keep_;
DeleteCertsCallback callback_; DeleteCertsCallback callback_;
std::unique_ptr<CertProvisioningCertsWithIdsGetter> cert_getter_; std::unique_ptr<CertProvisioningCertsWithIdsGetter> cert_getter_;
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
#ifndef CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_SCHEDULER_H_ #ifndef CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_SCHEDULER_H_
#define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_SCHEDULER_H_ #define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_SCHEDULER_H_
#include <map> #include <vector>
#include <set>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h" #include "base/sequence_checker.h"
#include "base/time/time.h" #include "base/time/time.h"
...@@ -39,6 +40,8 @@ class CertProvisioningWorker; ...@@ -39,6 +40,8 @@ class CertProvisioningWorker;
using WorkerMap = using WorkerMap =
std::map<CertProfileId, std::unique_ptr<CertProvisioningWorker>>; std::map<CertProfileId, std::unique_ptr<CertProvisioningWorker>>;
using CertProfileSet = base::flat_set<CertProfile, CertProfileComparator>;
struct FailedWorkerInfo { struct FailedWorkerInfo {
CertProvisioningWorkerState state = CertProvisioningWorkerState::kInitState; CertProvisioningWorkerState state = CertProvisioningWorkerState::kInitState;
std::string public_key; std::string public_key;
...@@ -75,14 +78,15 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -75,14 +78,15 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
delete; delete;
// Intended to be called when a user presses a button in certificate manager // Intended to be called when a user presses a button in certificate manager
// UI. Retries provisioning of a specific certificate. // UI. Retries provisioning of a specific certificate.
void UpdateOneCert(const std::string& cert_profile_id); void UpdateOneCert(const CertProfileId& cert_profile_id);
void UpdateCerts(); void UpdateAllCerts();
void OnProfileFinished(const CertProfile& profile, void OnProfileFinished(const CertProfile& profile,
CertProvisioningWorkerState state); CertProvisioningWorkerState state);
const WorkerMap& GetWorkers() const; const WorkerMap& GetWorkers() const;
const std::map<std::string, FailedWorkerInfo>& GetFailedCertProfileIds()
const; const base::flat_map<CertProfileId, FailedWorkerInfo>&
GetFailedCertProfileIds() const;
private: private:
void ScheduleInitialUpdate(); void ScheduleInitialUpdate();
...@@ -98,30 +102,33 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -98,30 +102,33 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
void OnCleanVaKeysIfIdleDone(base::Optional<bool> delete_result); void OnCleanVaKeysIfIdleDone(base::Optional<bool> delete_result);
void RegisterForPrefsChanges(); void RegisterForPrefsChanges();
void UpdateOneCertImpl(const std::string& cert_profile_id); void UpdateOneCertImpl(const CertProfileId& cert_profile_id);
void UpdateCertList(std::vector<CertProfile> profiles);
void UpdateCertListWithExistingCerts(
std::vector<CertProfile> profiles,
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
existing_certs_with_ids,
const std::string& error_message);
void OnPrefsChange(); void OnPrefsChange();
void DailyUpdateCerts(); void DailyUpdateCerts();
void DeserializeWorkers(); void DeserializeWorkers();
void OnGetCertsWithIdsDone(
std::map<std::string, scoped_refptr<net::X509Certificate>>
existing_certs_with_ids,
const std::string& error_message);
// Creates a new worker for |profile| if there is no at the moment. // Creates a new worker for |profile| if there is no at the moment.
// Recreates a worker if existing one has a different version of the profile. // Recreates a worker if existing one has a different version of the profile.
// Continues an existing worker if it is in a waiting state. // Continues an existing worker if it is in a waiting state.
void ProcessProfile(const CertProfile& profile); void ProcessProfile(const CertProfile& profile);
base::Optional<CertProfile> GetOneCertProfile( base::Optional<CertProfile> GetOneCertProfile(
const std::string& cert_profile_id); const CertProfileId& cert_profile_id);
std::vector<CertProfile> GetCertProfiles(); std::vector<CertProfile> GetCertProfiles();
void CreateCertProvisioningWorker(CertProfile profile); void CreateCertProvisioningWorker(CertProfile profile);
CertProvisioningWorker* FindWorker(CertProfileId profile_id); CertProvisioningWorker* FindWorker(CertProfileId profile_id);
bool CheckInternetConnection(); // Returns true if the process can be continued (if it's not required to
// wait).
bool MaybeWaitForInternetConnection();
void WaitForInternetConnection(); void WaitForInternetConnection();
void OnNetworkChange(const NetworkState* network); void OnNetworkChange(const NetworkState* network);
// NetworkStateHandlerObserver // NetworkStateHandlerObserver
...@@ -143,12 +150,15 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -143,12 +150,15 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
// retried until next |DailyUpdateCerts|. FailedWorkerInfo contains some extra // retried until next |DailyUpdateCerts|. FailedWorkerInfo contains some extra
// information about the failure. Profiles that failed with // information about the failure. Profiles that failed with
// kInconsistentDataError will not be stored into this collection. // kInconsistentDataError will not be stored into this collection.
std::map<std::string /*cert_profile_id*/, FailedWorkerInfo> base::flat_map<CertProfileId, FailedWorkerInfo> failed_cert_profiles_;
failed_cert_profiles_;
// Equals true if the last attempt to update certificates failed because there // Equals true if the last attempt to update certificates failed because there
// was no internet connection. // was no internet connection.
bool is_waiting_for_online_ = false; bool is_waiting_for_online_ = false;
// Contains profiles that should be updated after the current update batch
// run, because an update for them was triggered during the current run.
CertProfileSet queued_profiles_to_update_;
std::unique_ptr<CertProvisioningCertsWithIdsGetter> certs_with_ids_getter_; std::unique_ptr<CertProvisioningCertsWithIdsGetter> certs_with_ids_getter_;
std::unique_ptr<CertProvisioningCertDeleter> cert_deleter_; std::unique_ptr<CertProvisioningCertDeleter> cert_deleter_;
std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_; std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_;
......
...@@ -262,7 +262,7 @@ TEST_F(CertProvisioningSchedulerTest, Success) { ...@@ -262,7 +262,7 @@ TEST_F(CertProvisioningSchedulerTest, Success) {
// Check one more time that scheduler doesn't create new workers for // Check one more time that scheduler doesn't create new workers for
// finished certificate profiles (the factory will fail on an attempt to // finished certificate profiles (the factory will fail on an attempt to
// do so). // do so).
scheduler.UpdateCerts(); scheduler.UpdateAllCerts();
FastForwardBy(base::TimeDelta::FromSeconds(100)); FastForwardBy(base::TimeDelta::FromSeconds(100));
} }
...@@ -331,7 +331,7 @@ TEST_F(CertProvisioningSchedulerTest, WorkerFailed) { ...@@ -331,7 +331,7 @@ TEST_F(CertProvisioningSchedulerTest, WorkerFailed) {
// Check one more time that scheduler doesn't create new workers for failed // Check one more time that scheduler doesn't create new workers for failed
// certificate profiles (the factory will fail on an attempt to do so). // certificate profiles (the factory will fail on an attempt to do so).
scheduler.UpdateCerts(); scheduler.UpdateAllCerts();
} }
TEST_F(CertProvisioningSchedulerTest, InitialAndDailyUpdates) { TEST_F(CertProvisioningSchedulerTest, InitialAndDailyUpdates) {
...@@ -497,7 +497,7 @@ TEST_F(CertProvisioningSchedulerTest, MultipleWorkers) { ...@@ -497,7 +497,7 @@ TEST_F(CertProvisioningSchedulerTest, MultipleWorkers) {
.WillOnce(base::test::RunOnceCallback<3>(kCertProfileId0, "")); .WillOnce(base::test::RunOnceCallback<3>(kCertProfileId0, ""));
// Make scheduler check workers state. // Make scheduler check workers state.
scheduler.UpdateCerts(); scheduler.UpdateAllCerts();
EXPECT_EQ(scheduler.GetWorkers().size(), 1U); EXPECT_EQ(scheduler.GetWorkers().size(), 1U);
EXPECT_TRUE( EXPECT_TRUE(
...@@ -514,7 +514,7 @@ TEST_F(CertProvisioningSchedulerTest, MultipleWorkers) { ...@@ -514,7 +514,7 @@ TEST_F(CertProvisioningSchedulerTest, MultipleWorkers) {
// Check one more time that scheduler doesn't create new workers for failed // Check one more time that scheduler doesn't create new workers for failed
// certificate profiles (the factory will fail on an attempt to do so). // certificate profiles (the factory will fail on an attempt to do so).
scheduler.UpdateCerts(); scheduler.UpdateAllCerts();
EXPECT_EQ(scheduler.GetWorkers().size(), 1U); EXPECT_EQ(scheduler.GetWorkers().size(), 1U);
} }
...@@ -685,7 +685,7 @@ TEST_F(CertProvisioningSchedulerTest, InconsistentDataErrorHandling) { ...@@ -685,7 +685,7 @@ TEST_F(CertProvisioningSchedulerTest, InconsistentDataErrorHandling) {
// If another update happens, workers with matching policy versions should not // If another update happens, workers with matching policy versions should not
// be deleted. // be deleted.
scheduler.UpdateCerts(); scheduler.UpdateAllCerts();
EXPECT_EQ(scheduler.GetWorkers().size(), 1U); EXPECT_EQ(scheduler.GetWorkers().size(), 1U);
// On policy update if existing profile has changed its policy_version, // On policy update if existing profile has changed its policy_version,
...@@ -872,6 +872,103 @@ TEST_F(CertProvisioningSchedulerTest, DeleteVaKeysOnIdle) { ...@@ -872,6 +872,103 @@ TEST_F(CertProvisioningSchedulerTest, DeleteVaKeysOnIdle) {
} }
} }
TEST_F(CertProvisioningSchedulerTest, UpdateOneCert) {
CertScope cert_scope = CertScope::kUser;
CertProvisioningScheduler scheduler(
cert_scope, GetProfile(), &pref_service_,
prefs::kRequiredClientCertificateForUser, &cloud_policy_client_,
network_state_test_helper_.network_state_handler(),
MakeFakeInvalidationFactory());
const char kCertProfileId[] = "cert_profile_id_1";
const char kCertProfileVersion[] = "cert_profile_version_1";
CertProfile cert_profile{kCertProfileId, kCertProfileVersion};
// From CertProvisioningScheduler::CleanVaKeysIfIdle.
EXPECT_CALL(fake_cryptohome_client_, OnTpmAttestationDeleteKeysByPrefix);
// There is no policies yet, |kCertProfileId| will not be found.
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_TRUE(scheduler.GetWorkers().empty());
MockCertProvisioningWorker* worker =
mock_factory_.ExpectCreateReturnMock(cert_scope, cert_profile);
worker->SetExpectations(/*do_step_times=*/Exactly(1),
/*is_waiting=*/false, cert_profile);
// Add 1 certificate profile to the policy ("cert_profile_id" is the same as
// above). That will trigger creation of a worker.
base::Value config = ParseJson(
R"([{"name": "Certificate Profile 1",
"cert_profile_id":"cert_profile_id_1",
"policy_version":"cert_profile_version_1",
"key_algorithm":"rsa",
"renewal_period_seconds": 365000}])");
pref_service_.Set(prefs::kRequiredClientCertificateForUser, config);
// If worker is waiting, it should be continued.
{
worker->SetExpectations(/*do_step_times=*/Exactly(1),
/*is_waiting=*/true, cert_profile);
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 1U);
}
// If worker is not waiting, it should not be continued.
{
worker->SetExpectations(/*do_step_times=*/Exactly(0),
/*is_waiting=*/false, cert_profile);
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 1U);
}
// If there is no intenet connection, the worker should not be continued
// until it is restored.
{
SetWifiNetworkState(shill::kStateIdle);
worker->SetExpectations(/*do_step_times=*/Exactly(0),
/*is_waiting=*/true, cert_profile);
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 1U);
worker->SetExpectations(/*do_step_times=*/Exactly(1),
/*is_waiting=*/true, cert_profile);
SetWifiNetworkState(shill::kStateOnline);
}
// Emulate callback from the worker.
scheduler.OnProfileFinished(cert_profile,
CertProvisioningWorkerState::kSucceeded);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_TRUE(scheduler.GetWorkers().empty());
certificate_helper_.AddCert();
EXPECT_CALL(
*platform_keys_service_,
GetAttributeForKey(
GetPlatformKeysTokenId(cert_scope),
certificate_helper_.GetPublicKeyForCert(),
platform_keys::KeyAttributeType::CertificateProvisioningId, _))
.Times(1)
.WillOnce(base::test::RunOnceCallback<3>(kCertProfileId, ""));
{
// If a certificate already exists, a new worker should not be created.
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_TRUE(scheduler.GetWorkers().empty());
}
}
} // namespace } // namespace
} // namespace cert_provisioning } // namespace cert_provisioning
} // namespace chromeos } // namespace chromeos
...@@ -213,27 +213,32 @@ CertProvisioningWorkerImpl::~CertProvisioningWorkerImpl() = default; ...@@ -213,27 +213,32 @@ CertProvisioningWorkerImpl::~CertProvisioningWorkerImpl() = default;
bool CertProvisioningWorkerImpl::IsWaiting() const { bool CertProvisioningWorkerImpl::IsWaiting() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return is_waiting_; return is_waiting_;
} }
const CertProfile& CertProvisioningWorkerImpl::GetCertProfile() const { const CertProfile& CertProvisioningWorkerImpl::GetCertProfile() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return cert_profile_; return cert_profile_;
} }
const std::string& CertProvisioningWorkerImpl::GetPublicKey() const { const std::string& CertProvisioningWorkerImpl::GetPublicKey() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return public_key_; return public_key_;
} }
CertProvisioningWorkerState CertProvisioningWorkerImpl::GetState() const { CertProvisioningWorkerState CertProvisioningWorkerImpl::GetState() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return state_; return state_;
} }
CertProvisioningWorkerState CertProvisioningWorkerImpl::GetPreviousState() CertProvisioningWorkerState CertProvisioningWorkerImpl::GetPreviousState()
const { const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return prev_state_; return prev_state_;
} }
...@@ -243,6 +248,7 @@ base::Time CertProvisioningWorkerImpl::GetLastUpdateTime() const { ...@@ -243,6 +248,7 @@ base::Time CertProvisioningWorkerImpl::GetLastUpdateTime() const {
void CertProvisioningWorkerImpl::Stop(CertProvisioningWorkerState state) { void CertProvisioningWorkerImpl::Stop(CertProvisioningWorkerState state) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(IsFinalState(state)); DCHECK(IsFinalState(state));
CancelScheduledTasks(); CancelScheduledTasks();
...@@ -251,6 +257,7 @@ void CertProvisioningWorkerImpl::Stop(CertProvisioningWorkerState state) { ...@@ -251,6 +257,7 @@ void CertProvisioningWorkerImpl::Stop(CertProvisioningWorkerState state) {
void CertProvisioningWorkerImpl::Pause() { void CertProvisioningWorkerImpl::Pause() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CancelScheduledTasks(); CancelScheduledTasks();
is_waiting_ = true; is_waiting_ = true;
} }
...@@ -300,6 +307,7 @@ void CertProvisioningWorkerImpl::DoStep() { ...@@ -300,6 +307,7 @@ void CertProvisioningWorkerImpl::DoStep() {
void CertProvisioningWorkerImpl::UpdateState( void CertProvisioningWorkerImpl::UpdateState(
CertProvisioningWorkerState new_state) { CertProvisioningWorkerState new_state) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(GetStateOrderedIndex(state_) < GetStateOrderedIndex(new_state)); DCHECK(GetStateOrderedIndex(state_) < GetStateOrderedIndex(new_state));
prev_state_ = state_; prev_state_ = state_;
...@@ -482,6 +490,7 @@ void CertProvisioningWorkerImpl::OnBuildVaChallengeResponseDone( ...@@ -482,6 +490,7 @@ void CertProvisioningWorkerImpl::OnBuildVaChallengeResponseDone(
void CertProvisioningWorkerImpl::RegisterKey() { void CertProvisioningWorkerImpl::RegisterKey() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
tpm_challenge_key_subtle_impl_->StartRegisterKeyStep( tpm_challenge_key_subtle_impl_->StartRegisterKeyStep(
base::BindOnce(&CertProvisioningWorkerImpl::OnRegisterKeyDone, base::BindOnce(&CertProvisioningWorkerImpl::OnRegisterKeyDone,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
...@@ -734,6 +743,7 @@ void CertProvisioningWorkerImpl::OnShouldContinue(ContinueReason reason) { ...@@ -734,6 +743,7 @@ void CertProvisioningWorkerImpl::OnShouldContinue(ContinueReason reason) {
void CertProvisioningWorkerImpl::CancelScheduledTasks() { void CertProvisioningWorkerImpl::CancelScheduledTasks() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
weak_factory_.InvalidateWeakPtrs(); weak_factory_.InvalidateWeakPtrs();
} }
...@@ -800,6 +810,7 @@ void CertProvisioningWorkerImpl::OnRemoveKeyDone( ...@@ -800,6 +810,7 @@ void CertProvisioningWorkerImpl::OnRemoveKeyDone(
void CertProvisioningWorkerImpl::OnCleanUpDone() { void CertProvisioningWorkerImpl::OnCleanUpDone() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
RecordResult(cert_scope_, state_, prev_state_); RecordResult(cert_scope_, state_, prev_state_);
std::move(callback_).Run(cert_profile_, state_); std::move(callback_).Run(cert_profile_, state_);
} }
...@@ -853,6 +864,7 @@ void CertProvisioningWorkerImpl::InitAfterDeserialization() { ...@@ -853,6 +864,7 @@ void CertProvisioningWorkerImpl::InitAfterDeserialization() {
void CertProvisioningWorkerImpl::RegisterForInvalidationTopic() { void CertProvisioningWorkerImpl::RegisterForInvalidationTopic() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(invalidator_); DCHECK(invalidator_);
// Can be empty after deserialization if no topic was received yet. Also // Can be empty after deserialization if no topic was received yet. Also
...@@ -873,6 +885,7 @@ void CertProvisioningWorkerImpl::RegisterForInvalidationTopic() { ...@@ -873,6 +885,7 @@ void CertProvisioningWorkerImpl::RegisterForInvalidationTopic() {
void CertProvisioningWorkerImpl::UnregisterFromInvalidationTopic() { void CertProvisioningWorkerImpl::UnregisterFromInvalidationTopic() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(invalidator_); DCHECK(invalidator_);
invalidator_->Unregister(); invalidator_->Unregister();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment