Commit 075d9579 authored by Michael Ershov's avatar Michael Ershov Committed by Commit Bot

Cert Provisioning: Check existing certs in UpdateOneCert

Check existing certs in UpdateOneCert to avoid creating a worker
for certificate profile that already has a provisioned certificate.
Also add a unit test for UpdateOneCert method.

Bug: 1045895
Test: CertProvisioning*
Change-Id: I284bec0c90766673b6ba3a2c1ac69f28303258ce
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2245150
Commit-Queue: Michael Ershov <miersh@google.com>
Reviewed-by: default avatarPavol Marko <pmarko@chromium.org>
Cr-Commit-Position: refs/heads/master@{#781407}
parent 3530c216
......@@ -88,6 +88,14 @@ bool CertProfile::operator!=(const CertProfile& other) const {
return !(*this == other);
}
bool CertProfileComparator::operator()(const CertProfile& a,
const CertProfile& b) const {
static_assert(CertProfile::kVersion == 3, "This function should be updated");
return ((a.profile_id < b.profile_id) ||
(a.policy_version < b.policy_version) ||
(a.is_va_enabled < b.is_va_enabled));
}
//==============================================================================
void RegisterProfilePrefs(PrefRegistrySimple* registry) {
......
......@@ -83,6 +83,10 @@ struct CertProfile {
bool operator!=(const CertProfile& other) const;
};
struct CertProfileComparator {
bool operator()(const CertProfile& a, const CertProfile& b) const;
};
void RegisterProfilePrefs(PrefRegistrySimple* registry);
void RegisterLocalStatePrefs(PrefRegistrySimple* registry);
const char* GetPrefNameForSerialization(CertScope scope);
......
......@@ -3,10 +3,12 @@
// found in the LICENSE file.
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_platform_keys_helpers.h"
#include <memory>
#include "base/bind.h"
#include "base/check.h"
#include "base/containers/flat_set.h"
#include "base/stl_util.h"
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
......@@ -74,7 +76,7 @@ void CertProvisioningCertsWithIdsGetter::OnGetCertificatesDone(
void CertProvisioningCertsWithIdsGetter::CollectOneResult(
scoped_refptr<net::X509Certificate> cert,
const std::string& cert_id,
const CertProfileId& cert_id,
const std::string& error_message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(wait_counter_ > 0);
......@@ -110,7 +112,7 @@ CertProvisioningCertDeleter::~CertProvisioningCertDeleter() = default;
void CertProvisioningCertDeleter::DeleteCerts(
CertScope cert_scope,
platform_keys::PlatformKeysService* platform_keys_service,
const std::set<std::string>& cert_profile_ids_to_keep,
base::flat_set<CertProfileId> cert_profile_ids_to_keep,
DeleteCertsCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(platform_keys_service);
......@@ -119,7 +121,7 @@ void CertProvisioningCertDeleter::DeleteCerts(
cert_scope_ = cert_scope;
platform_keys_service_ = platform_keys_service;
callback_ = std::move(callback);
cert_profile_ids_to_keep_ = cert_profile_ids_to_keep;
cert_profile_ids_to_keep_ = std::move(cert_profile_ids_to_keep);
cert_getter_ = std::make_unique<CertProvisioningCertsWithIdsGetter>();
cert_getter_->GetCertsWithIds(
......@@ -129,7 +131,8 @@ void CertProvisioningCertDeleter::DeleteCerts(
}
void CertProvisioningCertDeleter::OnGetCertsWithIdsDone(
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids,
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids,
const std::string& error_message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
......@@ -146,7 +149,7 @@ void CertProvisioningCertDeleter::OnGetCertsWithIdsDone(
wait_counter_ = certs_with_ids.size();
for (const auto& kv : certs_with_ids) {
const std::string& cert_id = kv.first;
const CertProfileId& cert_id = kv.first;
if (base::Contains(cert_profile_ids_to_keep_, cert_id)) {
AccountOneResult();
continue;
......
......@@ -6,6 +6,8 @@
#define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_PLATFORM_KEYS_HELPERS_H_
#include "base/callback.h"
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/memory/weak_ptr.h"
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h"
#include "net/cert/x509_certificate.h"
......@@ -22,7 +24,8 @@ namespace cert_provisioning {
// ========= CertProvisioningCertsWithIdsGetter ================================
using GetCertsWithIdsCallback = base::OnceCallback<void(
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids,
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids,
const std::string& error_message)>;
// Helper class that retrieves list of all certificates in a given scope with
......@@ -49,14 +52,15 @@ class CertProvisioningCertsWithIdsGetter {
const std::string& error_message);
void CollectOneResult(scoped_refptr<net::X509Certificate> cert,
const std::string& cert_id,
const CertProfileId& cert_id,
const std::string& error_message);
CertScope cert_scope_ = CertScope::kDevice;
platform_keys::PlatformKeysService* platform_keys_service_ = nullptr;
size_t wait_counter_ = 0;
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids_;
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids_;
GetCertsWithIdsCallback callback_;
SEQUENCE_CHECKER(sequence_checker_);
......@@ -81,12 +85,13 @@ class CertProvisioningCertDeleter {
void DeleteCerts(CertScope cert_scope,
platform_keys::PlatformKeysService* platform_keys_service,
const std::set<std::string>& cert_profile_ids_to_keep,
base::flat_set<CertProfileId> cert_profile_ids_to_keep,
DeleteCertsCallback callback);
private:
void OnGetCertsWithIdsDone(
std::map<std::string, scoped_refptr<net::X509Certificate>> certs_with_ids,
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids,
const std::string& error_message);
void OnRemoveCertificateDone(const std::string& error_message);
......@@ -99,7 +104,7 @@ class CertProvisioningCertDeleter {
platform_keys::PlatformKeysService* platform_keys_service_ = nullptr;
size_t wait_counter_ = 0;
std::set<std::string> cert_profile_ids_to_keep_;
base::flat_set<CertProfileId> cert_profile_ids_to_keep_;
DeleteCertsCallback callback_;
std::unique_ptr<CertProvisioningCertsWithIdsGetter> cert_getter_;
......
......@@ -5,9 +5,10 @@
#ifndef CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_SCHEDULER_H_
#define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_CERT_PROVISIONING_SCHEDULER_H_
#include <map>
#include <set>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/time/time.h"
......@@ -39,6 +40,8 @@ class CertProvisioningWorker;
using WorkerMap =
std::map<CertProfileId, std::unique_ptr<CertProvisioningWorker>>;
using CertProfileSet = base::flat_set<CertProfile, CertProfileComparator>;
struct FailedWorkerInfo {
CertProvisioningWorkerState state = CertProvisioningWorkerState::kInitState;
std::string public_key;
......@@ -75,14 +78,15 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
delete;
// Intended to be called when a user presses a button in certificate manager
// UI. Retries provisioning of a specific certificate.
void UpdateOneCert(const std::string& cert_profile_id);
void UpdateCerts();
void UpdateOneCert(const CertProfileId& cert_profile_id);
void UpdateAllCerts();
void OnProfileFinished(const CertProfile& profile,
CertProvisioningWorkerState state);
const WorkerMap& GetWorkers() const;
const std::map<std::string, FailedWorkerInfo>& GetFailedCertProfileIds()
const;
const base::flat_map<CertProfileId, FailedWorkerInfo>&
GetFailedCertProfileIds() const;
private:
void ScheduleInitialUpdate();
......@@ -98,30 +102,33 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
void OnCleanVaKeysIfIdleDone(base::Optional<bool> delete_result);
void RegisterForPrefsChanges();
void UpdateOneCertImpl(const std::string& cert_profile_id);
void UpdateOneCertImpl(const CertProfileId& cert_profile_id);
void UpdateCertList(std::vector<CertProfile> profiles);
void UpdateCertListWithExistingCerts(
std::vector<CertProfile> profiles,
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
existing_certs_with_ids,
const std::string& error_message);
void OnPrefsChange();
void DailyUpdateCerts();
void DeserializeWorkers();
void OnGetCertsWithIdsDone(
std::map<std::string, scoped_refptr<net::X509Certificate>>
existing_certs_with_ids,
const std::string& error_message);
// Creates a new worker for |profile| if there is no at the moment.
// Recreates a worker if existing one has a different version of the profile.
// Continues an existing worker if it is in a waiting state.
void ProcessProfile(const CertProfile& profile);
base::Optional<CertProfile> GetOneCertProfile(
const std::string& cert_profile_id);
const CertProfileId& cert_profile_id);
std::vector<CertProfile> GetCertProfiles();
void CreateCertProvisioningWorker(CertProfile profile);
CertProvisioningWorker* FindWorker(CertProfileId profile_id);
bool CheckInternetConnection();
// Returns true if the process can be continued (if it's not required to
// wait).
bool MaybeWaitForInternetConnection();
void WaitForInternetConnection();
void OnNetworkChange(const NetworkState* network);
// NetworkStateHandlerObserver
......@@ -143,12 +150,15 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
// retried until next |DailyUpdateCerts|. FailedWorkerInfo contains some extra
// information about the failure. Profiles that failed with
// kInconsistentDataError will not be stored into this collection.
std::map<std::string /*cert_profile_id*/, FailedWorkerInfo>
failed_cert_profiles_;
base::flat_map<CertProfileId, FailedWorkerInfo> failed_cert_profiles_;
// Equals true if the last attempt to update certificates failed because there
// was no internet connection.
bool is_waiting_for_online_ = false;
// Contains profiles that should be updated after the current update batch
// run, because an update for them was triggered during the current run.
CertProfileSet queued_profiles_to_update_;
std::unique_ptr<CertProvisioningCertsWithIdsGetter> certs_with_ids_getter_;
std::unique_ptr<CertProvisioningCertDeleter> cert_deleter_;
std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_;
......
......@@ -262,7 +262,7 @@ TEST_F(CertProvisioningSchedulerTest, Success) {
// Check one more time that scheduler doesn't create new workers for
// finished certificate profiles (the factory will fail on an attempt to
// do so).
scheduler.UpdateCerts();
scheduler.UpdateAllCerts();
FastForwardBy(base::TimeDelta::FromSeconds(100));
}
......@@ -331,7 +331,7 @@ TEST_F(CertProvisioningSchedulerTest, WorkerFailed) {
// Check one more time that scheduler doesn't create new workers for failed
// certificate profiles (the factory will fail on an attempt to do so).
scheduler.UpdateCerts();
scheduler.UpdateAllCerts();
}
TEST_F(CertProvisioningSchedulerTest, InitialAndDailyUpdates) {
......@@ -497,7 +497,7 @@ TEST_F(CertProvisioningSchedulerTest, MultipleWorkers) {
.WillOnce(base::test::RunOnceCallback<3>(kCertProfileId0, ""));
// Make scheduler check workers state.
scheduler.UpdateCerts();
scheduler.UpdateAllCerts();
EXPECT_EQ(scheduler.GetWorkers().size(), 1U);
EXPECT_TRUE(
......@@ -514,7 +514,7 @@ TEST_F(CertProvisioningSchedulerTest, MultipleWorkers) {
// Check one more time that scheduler doesn't create new workers for failed
// certificate profiles (the factory will fail on an attempt to do so).
scheduler.UpdateCerts();
scheduler.UpdateAllCerts();
EXPECT_EQ(scheduler.GetWorkers().size(), 1U);
}
......@@ -685,7 +685,7 @@ TEST_F(CertProvisioningSchedulerTest, InconsistentDataErrorHandling) {
// If another update happens, workers with matching policy versions should not
// be deleted.
scheduler.UpdateCerts();
scheduler.UpdateAllCerts();
EXPECT_EQ(scheduler.GetWorkers().size(), 1U);
// On policy update if existing profile has changed its policy_version,
......@@ -872,6 +872,103 @@ TEST_F(CertProvisioningSchedulerTest, DeleteVaKeysOnIdle) {
}
}
TEST_F(CertProvisioningSchedulerTest, UpdateOneCert) {
CertScope cert_scope = CertScope::kUser;
CertProvisioningScheduler scheduler(
cert_scope, GetProfile(), &pref_service_,
prefs::kRequiredClientCertificateForUser, &cloud_policy_client_,
network_state_test_helper_.network_state_handler(),
MakeFakeInvalidationFactory());
const char kCertProfileId[] = "cert_profile_id_1";
const char kCertProfileVersion[] = "cert_profile_version_1";
CertProfile cert_profile{kCertProfileId, kCertProfileVersion};
// From CertProvisioningScheduler::CleanVaKeysIfIdle.
EXPECT_CALL(fake_cryptohome_client_, OnTpmAttestationDeleteKeysByPrefix);
// There is no policies yet, |kCertProfileId| will not be found.
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_TRUE(scheduler.GetWorkers().empty());
MockCertProvisioningWorker* worker =
mock_factory_.ExpectCreateReturnMock(cert_scope, cert_profile);
worker->SetExpectations(/*do_step_times=*/Exactly(1),
/*is_waiting=*/false, cert_profile);
// Add 1 certificate profile to the policy ("cert_profile_id" is the same as
// above). That will trigger creation of a worker.
base::Value config = ParseJson(
R"([{"name": "Certificate Profile 1",
"cert_profile_id":"cert_profile_id_1",
"policy_version":"cert_profile_version_1",
"key_algorithm":"rsa",
"renewal_period_seconds": 365000}])");
pref_service_.Set(prefs::kRequiredClientCertificateForUser, config);
// If worker is waiting, it should be continued.
{
worker->SetExpectations(/*do_step_times=*/Exactly(1),
/*is_waiting=*/true, cert_profile);
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 1U);
}
// If worker is not waiting, it should not be continued.
{
worker->SetExpectations(/*do_step_times=*/Exactly(0),
/*is_waiting=*/false, cert_profile);
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 1U);
}
// If there is no intenet connection, the worker should not be continued
// until it is restored.
{
SetWifiNetworkState(shill::kStateIdle);
worker->SetExpectations(/*do_step_times=*/Exactly(0),
/*is_waiting=*/true, cert_profile);
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 1U);
worker->SetExpectations(/*do_step_times=*/Exactly(1),
/*is_waiting=*/true, cert_profile);
SetWifiNetworkState(shill::kStateOnline);
}
// Emulate callback from the worker.
scheduler.OnProfileFinished(cert_profile,
CertProvisioningWorkerState::kSucceeded);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_TRUE(scheduler.GetWorkers().empty());
certificate_helper_.AddCert();
EXPECT_CALL(
*platform_keys_service_,
GetAttributeForKey(
GetPlatformKeysTokenId(cert_scope),
certificate_helper_.GetPublicKeyForCert(),
platform_keys::KeyAttributeType::CertificateProvisioningId, _))
.Times(1)
.WillOnce(base::test::RunOnceCallback<3>(kCertProfileId, ""));
{
// If a certificate already exists, a new worker should not be created.
scheduler.UpdateOneCert(kCertProfileId);
FastForwardBy(base::TimeDelta::FromSeconds(1));
ASSERT_TRUE(scheduler.GetWorkers().empty());
}
}
} // namespace
} // namespace cert_provisioning
} // namespace chromeos
......@@ -213,27 +213,32 @@ CertProvisioningWorkerImpl::~CertProvisioningWorkerImpl() = default;
bool CertProvisioningWorkerImpl::IsWaiting() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return is_waiting_;
}
const CertProfile& CertProvisioningWorkerImpl::GetCertProfile() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return cert_profile_;
}
const std::string& CertProvisioningWorkerImpl::GetPublicKey() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return public_key_;
}
CertProvisioningWorkerState CertProvisioningWorkerImpl::GetState() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return state_;
}
CertProvisioningWorkerState CertProvisioningWorkerImpl::GetPreviousState()
const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return prev_state_;
}
......@@ -243,6 +248,7 @@ base::Time CertProvisioningWorkerImpl::GetLastUpdateTime() const {
void CertProvisioningWorkerImpl::Stop(CertProvisioningWorkerState state) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(IsFinalState(state));
CancelScheduledTasks();
......@@ -251,6 +257,7 @@ void CertProvisioningWorkerImpl::Stop(CertProvisioningWorkerState state) {
void CertProvisioningWorkerImpl::Pause() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CancelScheduledTasks();
is_waiting_ = true;
}
......@@ -300,6 +307,7 @@ void CertProvisioningWorkerImpl::DoStep() {
void CertProvisioningWorkerImpl::UpdateState(
CertProvisioningWorkerState new_state) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(GetStateOrderedIndex(state_) < GetStateOrderedIndex(new_state));
prev_state_ = state_;
......@@ -482,6 +490,7 @@ void CertProvisioningWorkerImpl::OnBuildVaChallengeResponseDone(
void CertProvisioningWorkerImpl::RegisterKey() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
tpm_challenge_key_subtle_impl_->StartRegisterKeyStep(
base::BindOnce(&CertProvisioningWorkerImpl::OnRegisterKeyDone,
weak_factory_.GetWeakPtr()));
......@@ -734,6 +743,7 @@ void CertProvisioningWorkerImpl::OnShouldContinue(ContinueReason reason) {
void CertProvisioningWorkerImpl::CancelScheduledTasks() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
weak_factory_.InvalidateWeakPtrs();
}
......@@ -800,6 +810,7 @@ void CertProvisioningWorkerImpl::OnRemoveKeyDone(
void CertProvisioningWorkerImpl::OnCleanUpDone() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
RecordResult(cert_scope_, state_, prev_state_);
std::move(callback_).Run(cert_profile_, state_);
}
......@@ -853,6 +864,7 @@ void CertProvisioningWorkerImpl::InitAfterDeserialization() {
void CertProvisioningWorkerImpl::RegisterForInvalidationTopic() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(invalidator_);
// Can be empty after deserialization if no topic was received yet. Also
......@@ -873,6 +885,7 @@ void CertProvisioningWorkerImpl::RegisterForInvalidationTopic() {
void CertProvisioningWorkerImpl::UnregisterFromInvalidationTopic() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(invalidator_);
invalidator_->Unregister();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment