Commit dbdbe9c4 authored by Michael Ershov's avatar Michael Ershov Committed by Commit Bot

Cert Provisioning: Certificate renewal

Start using renewal_period_seconds field from cert provisioning
policies. Track expiration date in CertProvisioningScheduler,
provision a new certificate when necessary (according to policy),
delete renewed certificates.

Bug: 1045895
Change-Id: I756fab39c96cd8de80a5e89802533d9e3c38508d
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2283145Reviewed-by: default avatarPavol Marko <pmarko@chromium.org>
Commit-Queue: Michael Ershov <miersh@google.com>
Cr-Commit-Position: refs/heads/master@{#788589}
parent 968f7168
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "base/bind_helpers.h" #include "base/bind_helpers.h"
#include "base/notreached.h" #include "base/notreached.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/time/time.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h" #include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
#include "chrome/browser/chromeos/profiles/profile_helper.h" #include "chrome/browser/chromeos/profiles/profile_helper.h"
#include "chrome/common/pref_names.h" #include "chrome/common/pref_names.h"
...@@ -53,35 +54,47 @@ bool IsFinalState(CertProvisioningWorkerState state) { ...@@ -53,35 +54,47 @@ bool IsFinalState(CertProvisioningWorkerState state) {
//===================== CertProfile ============================================ //===================== CertProfile ============================================
CertProfile::CertProfile(CertProfileId profile_id,
std::string policy_version,
bool is_va_enabled,
base::TimeDelta renewal_period)
: profile_id(profile_id),
policy_version(policy_version),
is_va_enabled(is_va_enabled),
renewal_period(renewal_period) {}
base::Optional<CertProfile> CertProfile::MakeFromValue( base::Optional<CertProfile> CertProfile::MakeFromValue(
const base::Value& value) { const base::Value& value) {
static_assert(kVersion == 3, "This function should be updated"); static_assert(kVersion == 4, "This function should be updated");
const std::string* id = value.FindStringKey(kCertProfileIdKey); const std::string* id = value.FindStringKey(kCertProfileIdKey);
const std::string* policy_version = const std::string* policy_version =
value.FindStringKey(kCertProfilePolicyVersionKey); value.FindStringKey(kCertProfilePolicyVersionKey);
base::Optional<bool> is_va_enabled = base::Optional<bool> is_va_enabled =
value.FindBoolKey(kCertProfileIsVaEnabledKey); value.FindBoolKey(kCertProfileIsVaEnabledKey);
base::Optional<int> renewal_period_sec =
value.FindIntKey(kCertProfileRenewalPeroidSec);
if (!id || !policy_version) { if (!id || !policy_version) {
return base::nullopt; return base::nullopt;
} }
if (!is_va_enabled) {
is_va_enabled = true;
}
CertProfile result; CertProfile result;
result.profile_id = *id; result.profile_id = *id;
result.policy_version = *policy_version; result.policy_version = *policy_version;
result.is_va_enabled = *is_va_enabled; result.is_va_enabled = is_va_enabled.value_or(true);
result.renewal_period =
base::TimeDelta::FromSeconds(renewal_period_sec.value_or(0));
return result; return result;
} }
bool CertProfile::operator==(const CertProfile& other) const { bool CertProfile::operator==(const CertProfile& other) const {
static_assert(kVersion == 3, "This function should be updated"); static_assert(kVersion == 4, "This function should be updated");
return ((profile_id == other.profile_id) && return ((profile_id == other.profile_id) &&
(policy_version == other.policy_version) && (policy_version == other.policy_version) &&
(is_va_enabled == other.is_va_enabled)); (is_va_enabled == other.is_va_enabled) &&
(renewal_period == other.renewal_period));
} }
bool CertProfile::operator!=(const CertProfile& other) const { bool CertProfile::operator!=(const CertProfile& other) const {
...@@ -90,10 +103,11 @@ bool CertProfile::operator!=(const CertProfile& other) const { ...@@ -90,10 +103,11 @@ bool CertProfile::operator!=(const CertProfile& other) const {
bool CertProfileComparator::operator()(const CertProfile& a, bool CertProfileComparator::operator()(const CertProfile& a,
const CertProfile& b) const { const CertProfile& b) const {
static_assert(CertProfile::kVersion == 3, "This function should be updated"); static_assert(CertProfile::kVersion == 4, "This function should be updated");
return ((a.profile_id < b.profile_id) || return ((a.profile_id < b.profile_id) ||
(a.policy_version < b.policy_version) || (a.policy_version < b.policy_version) ||
(a.is_va_enabled < b.is_va_enabled)); (a.is_va_enabled < b.is_va_enabled) ||
(a.renewal_period < b.renewal_period));
} }
//============================================================================== //==============================================================================
...@@ -109,6 +123,15 @@ void RegisterLocalStatePrefs(PrefRegistrySimple* registry) { ...@@ -109,6 +123,15 @@ void RegisterLocalStatePrefs(PrefRegistrySimple* registry) {
prefs::kCertificateProvisioningStateForDevice); prefs::kCertificateProvisioningStateForDevice);
} }
const char* GetPrefNameForCertProfiles(CertScope scope) {
switch (scope) {
case CertScope::kUser:
return prefs::kRequiredClientCertificateForUser;
case CertScope::kDevice:
return prefs::kRequiredClientCertificateForDevice;
}
}
const char* GetPrefNameForSerialization(CertScope scope) { const char* GetPrefNameForSerialization(CertScope scope) {
switch (scope) { switch (scope) {
case CertScope::kUser: case CertScope::kUser:
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/callback_forward.h" #include "base/callback_forward.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/time/time.h"
#include "base/values.h" #include "base/values.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h" #include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
#include "chromeos/dbus/constants/attestation_constants.h" #include "chromeos/dbus/constants/attestation_constants.h"
...@@ -66,20 +67,32 @@ using CertProfileId = std::string; ...@@ -66,20 +67,32 @@ using CertProfileId = std::string;
// with definitions of RequiredClientCertificateForDevice and // with definitions of RequiredClientCertificateForDevice and
// RequiredClientCertificateForUser policies in policy_templates.json file. // RequiredClientCertificateForUser policies in policy_templates.json file.
const char kCertProfileIdKey[] = "cert_profile_id"; const char kCertProfileIdKey[] = "cert_profile_id";
const char kCertProfileRenewalPeroidSec[] = "renewal_period_seconds";
const char kCertProfilePolicyVersionKey[] = "policy_version"; const char kCertProfilePolicyVersionKey[] = "policy_version";
const char kCertProfileIsVaEnabledKey[] = "enable_remote_attestation_check"; const char kCertProfileIsVaEnabledKey[] = "enable_remote_attestation_check";
struct CertProfile { struct CertProfile {
static base::Optional<CertProfile> MakeFromValue(const base::Value& value);
CertProfile() = default;
// For tests.
CertProfile(CertProfileId profile_id,
std::string policy_version,
bool is_va_enabled,
base::TimeDelta renewal_period);
CertProfileId profile_id; CertProfileId profile_id;
std::string policy_version; std::string policy_version;
bool is_va_enabled = true; bool is_va_enabled = true;
// Default renewal period 0 means that a certificate will be renewed only
// after the previous one has expired (0 seconds before it is expires).
base::TimeDelta renewal_period = base::TimeDelta::FromSeconds(0);
// IMPORTANT: // IMPORTANT:
// Increment this when you add/change any member in CertProfile (and update // Increment this when you add/change any member in CertProfile (and update
// all functions that fail to compile because of it). // all functions that fail to compile because of it).
static constexpr int kVersion = 3; static constexpr int kVersion = 4;
static base::Optional<CertProfile> MakeFromValue(const base::Value& value);
bool operator==(const CertProfile& other) const; bool operator==(const CertProfile& other) const;
bool operator!=(const CertProfile& other) const; bool operator!=(const CertProfile& other) const;
}; };
...@@ -90,6 +103,7 @@ struct CertProfileComparator { ...@@ -90,6 +103,7 @@ struct CertProfileComparator {
void RegisterProfilePrefs(PrefRegistrySimple* registry); void RegisterProfilePrefs(PrefRegistrySimple* registry);
void RegisterLocalStatePrefs(PrefRegistrySimple* registry); void RegisterLocalStatePrefs(PrefRegistrySimple* registry);
const char* GetPrefNameForCertProfiles(CertScope scope);
const char* GetPrefNameForSerialization(CertScope scope); const char* GetPrefNameForSerialization(CertScope scope);
// Returns the nickname (CKA_LABEL) for keys created for the |profile_id|. // Returns the nickname (CKA_LABEL) for keys created for the |profile_id|.
......
...@@ -22,96 +22,150 @@ class PlatformKeysService; ...@@ -22,96 +22,150 @@ class PlatformKeysService;
namespace chromeos { namespace chromeos {
namespace cert_provisioning { namespace cert_provisioning {
// ========= CertProvisioningCertsWithIdsGetter ================================ // ========= CertIterator ======================================================
using GetCertsWithIdsCallback = base::OnceCallback<void( using CertIteratorForEachCallback =
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>> base::RepeatingCallback<void(scoped_refptr<net::X509Certificate> cert,
certs_with_ids, const CertProfileId& cert_profile_id,
const std::string& error_message)>; const std::string& error_message)>;
using CertIteratorOnFinishedCallback =
base::OnceCallback<void(const std::string& error_message)>;
// Helper class that retrieves list of all certificates in a given scope with // Iterates over all existing certificates of a given |cert_scope| and combines
// their certificate profile ids. Certificates without the id are ignored. // them with their certificate provisioning ids when possible. Runs |callback|
class CertProvisioningCertsWithIdsGetter { // on every (cert, cert_profile_id) pair that had a present and non-empty
// |cert_profile_id|. If |error_message| is not empty, then the pair is not
// valid.
class CertIterator {
public: public:
CertProvisioningCertsWithIdsGetter(); CertIterator(CertScope cert_scope,
CertProvisioningCertsWithIdsGetter( platform_keys::PlatformKeysService* platform_keys_service);
const CertProvisioningCertsWithIdsGetter&) = delete; CertIterator(const CertIterator&) = delete;
CertProvisioningCertsWithIdsGetter& operator=( CertIterator& operator=(const CertIterator&) = delete;
const CertProvisioningCertsWithIdsGetter&) = delete; ~CertIterator();
~CertProvisioningCertsWithIdsGetter();
// Can be called more than once. If previous iteration is not finished, it
bool IsRunning() const; // will be canceled.
void IterateAll(CertIteratorForEachCallback for_each_callback,
void GetCertsWithIds( CertIteratorOnFinishedCallback on_finished_callback);
CertScope cert_scope, void Cancel();
platform_keys::PlatformKeysService* platform_keys_service,
GetCertsWithIdsCallback callback);
private: private:
void OnGetCertificatesDone( void OnGetCertificatesDone(
std::unique_ptr<net::CertificateList> existing_certs, std::unique_ptr<net::CertificateList> existing_certs,
const std::string& error_message); const std::string& error_message);
void OnGetAttributeForKeyDone(scoped_refptr<net::X509Certificate> cert,
void CollectOneResult(scoped_refptr<net::X509Certificate> cert, const base::Optional<std::string>& attr_value,
const base::Optional<CertProfileId>& cert_id,
const std::string& error_message); const std::string& error_message);
void StopIteration(const std::string& error_message);
CertScope cert_scope_ = CertScope::kDevice; const CertScope cert_scope_ = CertScope::kDevice;
platform_keys::PlatformKeysService* platform_keys_service_ = nullptr; platform_keys::PlatformKeysService* const platform_keys_service_ = nullptr;
size_t wait_counter_ = 0; size_t wait_counter_ = 0;
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>> CertIteratorForEachCallback for_each_callback_;
certs_with_ids_; CertIteratorOnFinishedCallback on_finished_callback_;
GetCertsWithIdsCallback callback_;
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<CertProvisioningCertsWithIdsGetter> weak_factory_{this}; base::WeakPtrFactory<CertIterator> weak_factory_{this};
}; };
// ========= CertProvisioningCertDeleter ======================================= // ========= LatestCertsWithIdsGetter ==========================================
using DeleteCertsCallback = using LatestCertsWithIdsGetterCallback = base::OnceCallback<void(
base::OnceCallback<void(const std::string& error_message)>; base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids,
const std::string& error_message)>;
// Helper class that deletes all certificates in a given scope with certificate // Collects map of certificates with their certificate provisioning ids and
// profile ids that are not specified to be kept. Certificates without the id // returns it via |callback|. If there are several certificates for the same id,
// are ignored. // only the newest one will be stored in the map. Only one call to
class CertProvisioningCertDeleter { // GetCertsWithIds() for one instance is allowed.
class LatestCertsWithIdsGetter {
public: public:
CertProvisioningCertDeleter(); LatestCertsWithIdsGetter(
CertProvisioningCertDeleter(const CertProvisioningCertDeleter&) = delete; CertScope cert_scope,
CertProvisioningCertDeleter& operator=(const CertProvisioningCertDeleter&) = platform_keys::PlatformKeysService* platform_keys_service);
delete; LatestCertsWithIdsGetter(const LatestCertsWithIdsGetter&) = delete;
~CertProvisioningCertDeleter(); LatestCertsWithIdsGetter& operator=(const LatestCertsWithIdsGetter&) = delete;
~LatestCertsWithIdsGetter();
void DeleteCerts(CertScope cert_scope,
platform_keys::PlatformKeysService* platform_keys_service, // Can be called more than once. If previous task is not finished, it will be
base::flat_set<CertProfileId> cert_profile_ids_to_keep, // canceled.
DeleteCertsCallback callback); void GetCertsWithIds(LatestCertsWithIdsGetterCallback callback);
bool IsRunning() const;
void Cancel();
private: private:
void OnGetCertsWithIdsDone( void ProcessOneCert(scoped_refptr<net::X509Certificate> new_cert,
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>> const CertProfileId& cert_profile_id,
certs_with_ids,
const std::string& error_message); const std::string& error_message);
void OnIterationFinished(const std::string& error_message);
void OnRemoveCertificateDone(const std::string& error_message); CertIterator iterator_;
// Accumulates results that will be returned at the end via |callback_|.
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids_;
LatestCertsWithIdsGetterCallback callback_;
// Keeps track of how many certificates are already processed. Calls the SEQUENCE_CHECKER(sequence_checker_);
// |callback_| when all work is done. base::WeakPtrFactory<LatestCertsWithIdsGetter> weak_factory_{this};
void AccountOneResult(); };
CertScope cert_scope_ = CertScope::kDevice; // ========= CertDeleter =======================================================
platform_keys::PlatformKeysService* platform_keys_service_ = nullptr;
size_t wait_counter_ = 0; using CertDeleterCallback =
base::OnceCallback<void(const std::string& error_message)>;
// Finds and deletes certificates that 1) have ids that are not in
// |cert_profile_ids_to_keep| set or 2) have another certificate for the same
// id with later expiration date. Only one call to DeleteCerts() for one
// instance is allowed.
class CertDeleter {
public:
CertDeleter(CertScope cert_scope,
platform_keys::PlatformKeysService* platform_keys_service);
CertDeleter(const CertDeleter&) = delete;
CertDeleter& operator=(const CertDeleter&) = delete;
~CertDeleter();
void Cancel();
// Can be called more than once. If previous task is not finished, it will be
// canceled.
void DeleteCerts(base::flat_set<CertProfileId> cert_profile_ids_to_keep,
CertDeleterCallback callback);
private:
void ProcessOneCert(scoped_refptr<net::X509Certificate> cert,
const CertProfileId& cert_profile_id,
const std::string& error_message);
void RememberOrDelete(scoped_refptr<net::X509Certificate> new_cert,
const CertProfileId& cert_profile_id);
void DeleteCert(scoped_refptr<net::X509Certificate> cert);
void OnDeleteCertDone(const std::string& error_message);
void OnIterationFinished(const std::string& error_message);
void CheckStateAndMaybeFinish();
void ReturnStatus(const std::string& error_message);
const CertScope cert_scope_ = CertScope::kDevice;
platform_keys::PlatformKeysService* const platform_keys_service_ = nullptr;
CertIterator iterator_;
bool iteration_finished_ = false;
size_t pending_delete_tasks_counter_ = 0;
CertDeleterCallback callback_;
// Contains list of currently existing certificate profile ids. Certificates
// with ids outside of this set can be deleted.
base::flat_set<CertProfileId> cert_profile_ids_to_keep_; base::flat_set<CertProfileId> cert_profile_ids_to_keep_;
DeleteCertsCallback callback_;
std::unique_ptr<CertProvisioningCertsWithIdsGetter> cert_getter_; // Stores previously seen certificates that allows to find duplicates.
base::flat_map<CertProfileId, scoped_refptr<net::X509Certificate>>
certs_with_ids_;
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<CertProvisioningCertDeleter> weak_factory_{this}; base::WeakPtrFactory<CertDeleter> weak_factory_{this};
}; };
} // namespace cert_provisioning } // namespace cert_provisioning
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "base/location.h" #include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/stl_util.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "chrome/browser/browser_process.h" #include "chrome/browser/browser_process.h"
#include "chrome/browser/browser_process_platform_part.h" #include "chrome/browser/browser_process_platform_part.h"
...@@ -103,6 +104,8 @@ CertProvisioningScheduler::CreateUserCertProvisioningScheduler( ...@@ -103,6 +104,8 @@ CertProvisioningScheduler::CreateUserCertProvisioningScheduler(
PrefService* pref_service = profile->GetPrefs(); PrefService* pref_service = profile->GetPrefs();
policy::CloudPolicyClient* cloud_policy_client = policy::CloudPolicyClient* cloud_policy_client =
GetCloudPolicyClientForUser(profile); GetCloudPolicyClientForUser(profile);
platform_keys::PlatformKeysService* platform_keys_service =
platform_keys::PlatformKeysServiceFactory::GetForBrowserContext(profile);
NetworkStateHandler* network_state_handler = GetNetworkStateHandler(); NetworkStateHandler* network_state_handler = GetNetworkStateHandler();
if (!profile || !pref_service || !cloud_policy_client || if (!profile || !pref_service || !cloud_policy_client ||
...@@ -112,9 +115,8 @@ CertProvisioningScheduler::CreateUserCertProvisioningScheduler( ...@@ -112,9 +115,8 @@ CertProvisioningScheduler::CreateUserCertProvisioningScheduler(
} }
return std::make_unique<CertProvisioningScheduler>( return std::make_unique<CertProvisioningScheduler>(
CertScope::kUser, profile, pref_service, CertScope::kUser, profile, pref_service, cloud_policy_client,
prefs::kRequiredClientCertificateForUser, cloud_policy_client, platform_keys_service, network_state_handler,
network_state_handler,
std::make_unique<CertProvisioningUserInvalidatorFactory>(profile)); std::make_unique<CertProvisioningUserInvalidatorFactory>(profile));
} }
...@@ -127,18 +129,19 @@ CertProvisioningScheduler::CreateDeviceCertProvisioningScheduler( ...@@ -127,18 +129,19 @@ CertProvisioningScheduler::CreateDeviceCertProvisioningScheduler(
PrefService* pref_service = g_browser_process->local_state(); PrefService* pref_service = g_browser_process->local_state();
policy::CloudPolicyClient* cloud_policy_client = policy::CloudPolicyClient* cloud_policy_client =
GetCloudPolicyClientForDevice(); GetCloudPolicyClientForDevice();
platform_keys::PlatformKeysService* platform_keys_service =
platform_keys::PlatformKeysServiceFactory::GetForBrowserContext(profile);
NetworkStateHandler* network_state_handler = GetNetworkStateHandler(); NetworkStateHandler* network_state_handler = GetNetworkStateHandler();
if (!profile || !pref_service || !cloud_policy_client || if (!profile || !pref_service || !cloud_policy_client ||
!network_state_handler) { !network_state_handler || !platform_keys_service) {
LOG(ERROR) << "Failed to create device certificate provisioning scheduler"; LOG(ERROR) << "Failed to create device certificate provisioning scheduler";
return nullptr; return nullptr;
} }
return std::make_unique<CertProvisioningScheduler>( return std::make_unique<CertProvisioningScheduler>(
CertScope::kDevice, profile, pref_service, CertScope::kDevice, profile, pref_service, cloud_policy_client,
prefs::kRequiredClientCertificateForDevice, cloud_policy_client, platform_keys_service, network_state_handler,
network_state_handler,
std::make_unique<CertProvisioningDeviceInvalidatorFactory>( std::make_unique<CertProvisioningDeviceInvalidatorFactory>(
invalidation_service_provider)); invalidation_service_provider));
} }
...@@ -147,26 +150,28 @@ CertProvisioningScheduler::CertProvisioningScheduler( ...@@ -147,26 +150,28 @@ CertProvisioningScheduler::CertProvisioningScheduler(
CertScope cert_scope, CertScope cert_scope,
Profile* profile, Profile* profile,
PrefService* pref_service, PrefService* pref_service,
const char* pref_name,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
platform_keys::PlatformKeysService* platform_keys_service,
NetworkStateHandler* network_state_handler, NetworkStateHandler* network_state_handler,
std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory) std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory)
: cert_scope_(cert_scope), : cert_scope_(cert_scope),
profile_(profile), profile_(profile),
pref_service_(pref_service), pref_service_(pref_service),
pref_name_(pref_name),
cloud_policy_client_(cloud_policy_client), cloud_policy_client_(cloud_policy_client),
platform_keys_service_(platform_keys_service),
network_state_handler_(network_state_handler), network_state_handler_(network_state_handler),
certs_with_ids_getter_(cert_scope, platform_keys_service),
cert_deleter_(cert_scope, platform_keys_service),
invalidator_factory_(std::move(invalidator_factory)) { invalidator_factory_(std::move(invalidator_factory)) {
CHECK(profile);
CHECK(pref_service_); CHECK(pref_service_);
CHECK(pref_name_);
CHECK(cloud_policy_client_); CHECK(cloud_policy_client_);
CHECK(profile); CHECK(platform_keys_service_);
CHECK(network_state_handler);
CHECK(invalidator_factory_); CHECK(invalidator_factory_);
platform_keys_service_ = pref_name_ = GetPrefNameForCertProfiles(cert_scope);
platform_keys::PlatformKeysServiceFactory::GetForBrowserContext(profile); CHECK(pref_name_);
CHECK(platform_keys_service_);
network_state_handler_->AddObserver(this, FROM_HERE); network_state_handler_->AddObserver(this, FROM_HERE);
...@@ -198,16 +203,31 @@ void CertProvisioningScheduler::ScheduleDailyUpdate() { ...@@ -198,16 +203,31 @@ void CertProvisioningScheduler::ScheduleDailyUpdate() {
base::TimeDelta::FromDays(1)); base::TimeDelta::FromDays(1));
} }
void CertProvisioningScheduler::ScheduleRetry(const CertProfile& profile) { void CertProvisioningScheduler::ScheduleRetry(const CertProfileId& profile_id) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask( base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE, FROM_HERE,
base::Bind(&CertProvisioningScheduler::UpdateOneCertImpl, base::Bind(&CertProvisioningScheduler::UpdateOneCertImpl,
weak_factory_.GetWeakPtr(), profile.profile_id), weak_factory_.GetWeakPtr(), profile_id),
kInconsistentDataErrorRetryDelay); kInconsistentDataErrorRetryDelay);
} }
void CertProvisioningScheduler::ScheduleRenewal(const CertProfileId& profile_id,
base::TimeDelta delay) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
if (base::Contains(scheduled_renewals_, profile_id)) {
return;
}
base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
base::Bind(&CertProvisioningScheduler::InitiateRenewal,
weak_factory_.GetWeakPtr(), profile_id),
delay);
}
void CertProvisioningScheduler::InitialUpdateCerts() { void CertProvisioningScheduler::InitialUpdateCerts() {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
...@@ -227,9 +247,8 @@ void CertProvisioningScheduler::DeleteCertsWithoutPolicy() { ...@@ -227,9 +247,8 @@ void CertProvisioningScheduler::DeleteCertsWithoutPolicy() {
cert_profile_ids_to_keep = base::flat_set<CertProfileId>(std::move(ids)); cert_profile_ids_to_keep = base::flat_set<CertProfileId>(std::move(ids));
} }
cert_deleter_ = std::make_unique<CertProvisioningCertDeleter>(); cert_deleter_.DeleteCerts(
cert_deleter_->DeleteCerts( cert_profile_ids_to_keep,
cert_scope_, platform_keys_service_, cert_profile_ids_to_keep,
base::BindOnce(&CertProvisioningScheduler::OnDeleteCertsWithoutPolicyDone, base::BindOnce(&CertProvisioningScheduler::OnDeleteCertsWithoutPolicyDone,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
} }
...@@ -238,15 +257,12 @@ void CertProvisioningScheduler::OnDeleteCertsWithoutPolicyDone( ...@@ -238,15 +257,12 @@ void CertProvisioningScheduler::OnDeleteCertsWithoutPolicyDone(
const std::string& error_message) { const std::string& error_message) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
cert_deleter_.reset();
if (!error_message.empty()) { if (!error_message.empty()) {
LOG(ERROR) << "Failed to delete certificates without policies: " LOG(ERROR) << "Failed to delete certificates without policies: "
<< error_message; << error_message;
} }
DeserializeWorkers(); DeserializeWorkers();
CleanVaKeysIfIdle(); CleanVaKeysIfIdle();
} }
...@@ -326,6 +342,12 @@ void CertProvisioningScheduler::OnPrefsChange() { ...@@ -326,6 +342,12 @@ void CertProvisioningScheduler::OnPrefsChange() {
UpdateAllCerts(); UpdateAllCerts();
} }
void CertProvisioningScheduler::InitiateRenewal(
const CertProfileId& cert_profile_id) {
scheduled_renewals_.erase(cert_profile_id);
UpdateOneCertImpl(cert_profile_id);
}
void CertProvisioningScheduler::UpdateOneCert( void CertProvisioningScheduler::UpdateOneCert(
const CertProfileId& cert_profile_id) { const CertProfileId& cert_profile_id) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
...@@ -336,6 +358,8 @@ void CertProvisioningScheduler::UpdateOneCert( ...@@ -336,6 +358,8 @@ void CertProvisioningScheduler::UpdateOneCert(
void CertProvisioningScheduler::UpdateOneCertImpl( void CertProvisioningScheduler::UpdateOneCertImpl(
const CertProfileId& cert_profile_id) { const CertProfileId& cert_profile_id) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
EraseByKey(failed_cert_profiles_, cert_profile_id); EraseByKey(failed_cert_profiles_, cert_profile_id);
base::Optional<CertProfile> cert_profile = GetOneCertProfile(cert_profile_id); base::Optional<CertProfile> cert_profile = GetOneCertProfile(cert_profile_id);
...@@ -351,6 +375,11 @@ void CertProvisioningScheduler::UpdateAllCerts() { ...@@ -351,6 +375,11 @@ void CertProvisioningScheduler::UpdateAllCerts() {
std::vector<CertProfile> profiles = GetCertProfiles(); std::vector<CertProfile> profiles = GetCertProfiles();
CancelWorkersWithoutPolicy(profiles); CancelWorkersWithoutPolicy(profiles);
if (profiles.empty()) {
return;
}
UpdateCertList(std::move(profiles)); UpdateCertList(std::move(profiles));
} }
...@@ -362,17 +391,13 @@ void CertProvisioningScheduler::UpdateCertList( ...@@ -362,17 +391,13 @@ void CertProvisioningScheduler::UpdateCertList(
return; return;
} }
if (certs_with_ids_getter_ && certs_with_ids_getter_->IsRunning()) { if (certs_with_ids_getter_.IsRunning()) {
queued_profiles_to_update_.insert(std::make_move_iterator(profiles.begin()), queued_profiles_to_update_.insert(std::make_move_iterator(profiles.begin()),
std::make_move_iterator(profiles.end())); std::make_move_iterator(profiles.end()));
return; return;
} }
certs_with_ids_getter_ = certs_with_ids_getter_.GetCertsWithIds(base::BindOnce(
std::make_unique<CertProvisioningCertsWithIdsGetter>();
certs_with_ids_getter_->GetCertsWithIds(
cert_scope_, platform_keys_service_,
base::BindOnce(
&CertProvisioningScheduler::UpdateCertListWithExistingCerts, &CertProvisioningScheduler::UpdateCertListWithExistingCerts,
weak_factory_.GetWeakPtr(), std::move(profiles))); weak_factory_.GetWeakPtr(), std::move(profiles)));
} }
...@@ -384,20 +409,38 @@ void CertProvisioningScheduler::UpdateCertListWithExistingCerts( ...@@ -384,20 +409,38 @@ void CertProvisioningScheduler::UpdateCertListWithExistingCerts(
const std::string& error_message) { const std::string& error_message) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
certs_with_ids_getter_.reset();
if (!error_message.empty()) { if (!error_message.empty()) {
LOG(ERROR) << "Failed to get existing cert ids: " << error_message; LOG(ERROR) << "Failed to get existing cert ids: " << error_message;
return; return;
} }
for (const auto& profile : profiles) { for (const auto& profile : profiles) {
if (base::Contains(existing_certs_with_ids, profile.profile_id) || if (base::Contains(failed_cert_profiles_, profile.profile_id)) {
base::Contains(failed_cert_profiles_, profile.profile_id)) { continue;
}
auto cert_iter = existing_certs_with_ids.find(profile.profile_id);
if (cert_iter == existing_certs_with_ids.end()) {
// The certificate does not exists and should be provisioned.
ProcessProfile(profile);
continue; continue;
} }
const auto& cert = cert_iter->second;
base::Time now = base::Time::Now();
if ((now + profile.renewal_period) >= cert->valid_expiry()) {
// The certificate should be renewed immediately.
ProcessProfile(profile); ProcessProfile(profile);
continue;
}
if ((now + base::TimeDelta::FromDays(1) + profile.renewal_period) >=
cert->valid_expiry()) {
// The certificate should be renewed within 1 day.
base::Time target_time = cert->valid_expiry() - profile.renewal_period;
ScheduleRenewal(profile.profile_id, /*delay=*/target_time - now);
continue;
}
} }
if (!queued_profiles_to_update_.empty()) { if (!queued_profiles_to_update_.empty()) {
...@@ -470,7 +513,7 @@ void CertProvisioningScheduler::OnProfileFinished( ...@@ -470,7 +513,7 @@ void CertProvisioningScheduler::OnProfileFinished(
case CertProvisioningWorkerState::kInconsistentDataError: case CertProvisioningWorkerState::kInconsistentDataError:
LOG(WARNING) << "Inconsistent data error for certificate profile: " LOG(WARNING) << "Inconsistent data error for certificate profile: "
<< profile.profile_id; << profile.profile_id;
ScheduleRetry(profile); ScheduleRetry(profile.profile_id);
break; break;
case CertProvisioningWorkerState::kCanceled: case CertProvisioningWorkerState::kCanceled:
break; break;
......
...@@ -75,8 +75,8 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -75,8 +75,8 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
CertScope cert_scope, CertScope cert_scope,
Profile* profile, Profile* profile,
PrefService* pref_service, PrefService* pref_service,
const char* pref_name,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
platform_keys::PlatformKeysService* platform_keys_service,
NetworkStateHandler* network_state_handler, NetworkStateHandler* network_state_handler,
std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory); std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory);
~CertProvisioningScheduler() override; ~CertProvisioningScheduler() override;
...@@ -99,8 +99,9 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -99,8 +99,9 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
private: private:
void ScheduleInitialUpdate(); void ScheduleInitialUpdate();
void ScheduleDailyUpdate(); void ScheduleDailyUpdate();
// Posts delayed task to call ProcessProfile. // Posts delayed task to call UpdateOneCertImpl.
void ScheduleRetry(const CertProfile& profile); void ScheduleRetry(const CertProfileId& profile_id);
void ScheduleRenewal(const CertProfileId& profile_id, base::TimeDelta delay);
void InitialUpdateCerts(); void InitialUpdateCerts();
void DeleteCertsWithoutPolicy(); void DeleteCertsWithoutPolicy();
...@@ -110,6 +111,7 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -110,6 +111,7 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
void OnCleanVaKeysIfIdleDone(base::Optional<bool> delete_result); void OnCleanVaKeysIfIdleDone(base::Optional<bool> delete_result);
void RegisterForPrefsChanges(); void RegisterForPrefsChanges();
void InitiateRenewal(const CertProfileId& cert_profile_id);
void UpdateOneCertImpl(const CertProfileId& cert_profile_id); void UpdateOneCertImpl(const CertProfileId& cert_profile_id);
void UpdateCertList(std::vector<CertProfile> profiles); void UpdateCertList(std::vector<CertProfile> profiles);
void UpdateCertListWithExistingCerts( void UpdateCertListWithExistingCerts(
...@@ -150,10 +152,14 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -150,10 +152,14 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
PrefService* pref_service_ = nullptr; PrefService* pref_service_ = nullptr;
const char* pref_name_ = nullptr; const char* pref_name_ = nullptr;
policy::CloudPolicyClient* cloud_policy_client_ = nullptr; policy::CloudPolicyClient* cloud_policy_client_ = nullptr;
NetworkStateHandler* network_state_handler_ = nullptr;
platform_keys::PlatformKeysService* platform_keys_service_ = nullptr; platform_keys::PlatformKeysService* platform_keys_service_ = nullptr;
NetworkStateHandler* network_state_handler_ = nullptr;
PrefChangeRegistrar pref_change_registrar_; PrefChangeRegistrar pref_change_registrar_;
WorkerMap workers_; WorkerMap workers_;
// Contains cert profile ids that will be renewed before next daily update.
// Helps to prevent creation of more than one delayed task for renewal. When
// the renewal starts for a profile id, it is removed from the set.
base::flat_set<CertProfileId> scheduled_renewals_;
// Collection of cert profile ids that failed recently. They will not be // Collection of cert profile ids that failed recently. They will not be
// retried until next |DailyUpdateCerts|. FailedWorkerInfo contains some extra // retried until next |DailyUpdateCerts|. FailedWorkerInfo contains some extra
// information about the failure. Profiles that failed with // information about the failure. Profiles that failed with
...@@ -167,8 +173,8 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver { ...@@ -167,8 +173,8 @@ class CertProvisioningScheduler : public NetworkStateHandlerObserver {
// run, because an update for them was triggered during the current run. // run, because an update for them was triggered during the current run.
CertProfileSet queued_profiles_to_update_; CertProfileSet queued_profiles_to_update_;
std::unique_ptr<CertProvisioningCertsWithIdsGetter> certs_with_ids_getter_; LatestCertsWithIdsGetter certs_with_ids_getter_;
std::unique_ptr<CertProvisioningCertDeleter> cert_deleter_; CertDeleter cert_deleter_;
std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_; std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_;
base::WeakPtrFactory<CertProvisioningScheduler> weak_factory_{this}; base::WeakPtrFactory<CertProvisioningScheduler> weak_factory_{this};
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "base/base64.h" #include "base/base64.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/time/time.h"
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h" #include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_common.h"
#include "components/prefs/pref_service.h" #include "components/prefs/pref_service.h"
#include "components/prefs/scoped_user_pref_update.h" #include "components/prefs/scoped_user_pref_update.h"
...@@ -16,15 +17,16 @@ namespace cert_provisioning { ...@@ -16,15 +17,16 @@ namespace cert_provisioning {
namespace { namespace {
const char kKeyNameCertScope[] = "cert_scope"; constexpr char kKeyNameCertScope[] = "cert_scope";
const char kKeyNameCertProfile[] = "cert_profile"; constexpr char kKeyNameCertProfile[] = "cert_profile";
const char kKeyNameState[] = "state"; constexpr char kKeyNameState[] = "state";
const char kKeyNamePublicKey[] = "public_key"; constexpr char kKeyNamePublicKey[] = "public_key";
const char kKeyNameInvalidationTopic[] = "invalidation_topic"; constexpr char kKeyNameInvalidationTopic[] = "invalidation_topic";
const char kKeyNameCertProfileId[] = "profile_id"; constexpr char kKeyNameCertProfileId[] = "profile_id";
const char kKeyNameCertProfileVersion[] = "policy_version"; constexpr char kKeyNameCertProfileVersion[] = "policy_version";
const char kKeyNameCertProfileVaEnabled[] = "va_enabled"; constexpr char kKeyNameCertProfileVaEnabled[] = "va_enabled";
constexpr char kKeyNameCertProfileRenewalPeriod[] = "renewal_period";
template <typename T> template <typename T>
bool ConvertToEnum(int value, T* dst) { bool ConvertToEnum(int value, T* dst) {
...@@ -71,21 +73,34 @@ bool DeserializeBoolValue(const base::Value& parent_value, ...@@ -71,21 +73,34 @@ bool DeserializeBoolValue(const base::Value& parent_value,
return true; return true;
} }
bool DeserializeRenewalPeriod(const base::Value& parent_value,
const char* value_name,
base::TimeDelta* dst) {
base::Optional<int> serialized_time = parent_value.FindIntKey(value_name);
*dst = base::TimeDelta::FromSeconds(serialized_time.value_or(0));
return true;
}
base::Value SerializeCertProfile(const CertProfile& profile) { base::Value SerializeCertProfile(const CertProfile& profile) {
static_assert(CertProfile::kVersion == 3, "This function should be updated"); static_assert(CertProfile::kVersion == 4, "This function should be updated");
base::Value result(base::Value::Type::DICTIONARY); base::Value result(base::Value::Type::DICTIONARY);
result.SetStringKey(kKeyNameCertProfileId, profile.profile_id); result.SetStringKey(kKeyNameCertProfileId, profile.profile_id);
result.SetStringKey(kKeyNameCertProfileVersion, profile.policy_version); result.SetStringKey(kKeyNameCertProfileVersion, profile.policy_version);
result.SetBoolKey(kKeyNameCertProfileVaEnabled, profile.is_va_enabled); result.SetBoolKey(kKeyNameCertProfileVaEnabled, profile.is_va_enabled);
if (!profile.renewal_period.is_zero()) {
result.SetIntKey(kKeyNameCertProfileRenewalPeriod,
profile.renewal_period.InSeconds());
}
return result; return result;
} }
bool DeserializeCertProfile(const base::Value& parent_value, bool DeserializeCertProfile(const base::Value& parent_value,
const char* value_name, const char* value_name,
CertProfile* dst) { CertProfile* dst) {
static_assert(CertProfile::kVersion == 3, "This function should be updated"); static_assert(CertProfile::kVersion == 4, "This function should be updated");
const base::Value* serialized_profile = const base::Value* serialized_profile =
parent_value.FindKeyOfType(value_name, base::Value::Type::DICTIONARY); parent_value.FindKeyOfType(value_name, base::Value::Type::DICTIONARY);
...@@ -104,6 +119,9 @@ bool DeserializeCertProfile(const base::Value& parent_value, ...@@ -104,6 +119,9 @@ bool DeserializeCertProfile(const base::Value& parent_value,
is_ok = is_ok && DeserializeBoolValue(*serialized_profile, is_ok = is_ok && DeserializeBoolValue(*serialized_profile,
kKeyNameCertProfileVaEnabled, kKeyNameCertProfileVaEnabled,
&(dst->is_va_enabled)); &(dst->is_va_enabled));
is_ok = is_ok && DeserializeRenewalPeriod(*serialized_profile,
kKeyNameCertProfileRenewalPeriod,
&(dst->renewal_period));
return is_ok; return is_ok;
} }
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "base/optional.h" #include "base/optional.h"
#include "base/test/gmock_callback_support.h" #include "base/test/gmock_callback_support.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h" #include "base/time/time.h"
#include "chrome/browser/chromeos/profiles/profile_helper.h" #include "chrome/browser/chromeos/profiles/profile_helper.h"
#include "chrome/test/base/testing_browser_process.h" #include "chrome/test/base/testing_browser_process.h"
#include "net/test/cert_builder.h" #include "net/test/cert_builder.h"
...@@ -63,18 +63,15 @@ void CertificateHelperForTesting::GetCertificates( ...@@ -63,18 +63,15 @@ void CertificateHelperForTesting::GetCertificates(
std::move(callback).Run(std::move(result), ""); std::move(callback).Run(std::move(result), "");
} }
void CertificateHelperForTesting::AddCert( scoped_refptr<net::X509Certificate> CertificateHelperForTesting::AddCert(
CertScope cert_scope,
const base::Optional<CertProfileId>& cert_profile_id) {
AddCert(cert_scope, cert_profile_id, /*error_message=*/"");
}
void CertificateHelperForTesting::AddCert(
CertScope cert_scope, CertScope cert_scope,
const base::Optional<CertProfileId>& cert_profile_id, const base::Optional<CertProfileId>& cert_profile_id,
const std::string& error_message) { const std::string& error_message,
base::Time not_valid_before,
base::Time not_valid_after) {
net::CertBuilder cert_builder(template_cert_->cert_buffer(), net::CertBuilder cert_builder(template_cert_->cert_buffer(),
/*issuer=*/nullptr); /*issuer=*/nullptr);
cert_builder.SetValidity(not_valid_before, not_valid_after);
auto cert = cert_builder.GetX509Certificate(); auto cert = cert_builder.GetX509Certificate();
EXPECT_CALL( EXPECT_CALL(
...@@ -86,6 +83,30 @@ void CertificateHelperForTesting::AddCert( ...@@ -86,6 +83,30 @@ void CertificateHelperForTesting::AddCert(
.WillRepeatedly(RunOnceCallback<3>(cert_profile_id, error_message)); .WillRepeatedly(RunOnceCallback<3>(cert_profile_id, error_message));
cert_list_.push_back(cert); cert_list_.push_back(cert);
return cert;
}
scoped_refptr<net::X509Certificate> CertificateHelperForTesting::AddCert(
CertScope cert_scope,
const base::Optional<CertProfileId>& cert_profile_id) {
base::Time not_valid_before =
base::Time::Now() - base::TimeDelta::FromDays(1);
base::Time not_valid_after =
base::Time::Now() + base::TimeDelta::FromDays(365);
return AddCert(cert_scope, cert_profile_id, /*error_message=*/"",
not_valid_before, not_valid_after);
}
scoped_refptr<net::X509Certificate> CertificateHelperForTesting::AddCert(
CertScope cert_scope,
const base::Optional<CertProfileId>& cert_profile_id,
const std::string& error_message) {
base::Time not_valid_before =
base::Time::Now() - base::TimeDelta::FromDays(1);
base::Time not_valid_after =
base::Time::Now() + base::TimeDelta::FromDays(365);
return AddCert(cert_scope, cert_profile_id, error_message, not_valid_before,
not_valid_after);
} }
void CertificateHelperForTesting::ClearCerts() { void CertificateHelperForTesting::ClearCerts() {
......
...@@ -19,19 +19,41 @@ namespace cert_provisioning { ...@@ -19,19 +19,41 @@ namespace cert_provisioning {
//================ CertificateHelperForTesting ================================= //================ CertificateHelperForTesting =================================
// Redirects PlatformKeysService::GetCertificate calls to itself. Allows to add // Allows to add certificate to a fake storage with assigned CertProfileId-s.
// certificate to a fake storage with assigned CertProfileId-s. // Redirects PlatformKeysService::GetCertificate calls to itself and return all
// stored certificates as a result.
struct CertificateHelperForTesting { struct CertificateHelperForTesting {
public: public:
explicit CertificateHelperForTesting( explicit CertificateHelperForTesting(
platform_keys::MockPlatformKeysService* platform_keys_service); platform_keys::MockPlatformKeysService* platform_keys_service);
~CertificateHelperForTesting(); ~CertificateHelperForTesting();
void AddCert(CertScope cert_scope, // Generates and adds a certificate to internal fake certificate storage.
// Returns refpointer to the generated certificate. If |error_message| is not
// empty, an attempt to retrieve |cert_profile_id| via
// PlatformKeysService::GetAttributeForKey() will fail with |error_message|.
// |not_valid_before|, |not_valid_after| configure validity period of the
// certificate.
scoped_refptr<net::X509Certificate> AddCert(
CertScope cert_scope,
const base::Optional<CertProfileId>& cert_profile_id,
const std::string& error_message,
base::Time not_valid_before,
base::Time not_valid_after);
// Simplified version of AddCert(). The certificate is not expired and has
// |cert_profile_id|.
scoped_refptr<net::X509Certificate> AddCert(
CertScope cert_scope,
const base::Optional<CertProfileId>& cert_profile_id); const base::Optional<CertProfileId>& cert_profile_id);
void AddCert(CertScope cert_scope,
// Simplified version of AddCert(). The certificate is not expired, but fails
// to retrieve |cert_profile_id|.
scoped_refptr<net::X509Certificate> AddCert(
CertScope cert_scope,
const base::Optional<CertProfileId>& cert_profile_id, const base::Optional<CertProfileId>& cert_profile_id,
const std::string& error_message); const std::string& error_message);
void ClearCerts(); void ClearCerts();
const net::CertificateList& GetCerts() const; const net::CertificateList& GetCerts() const;
......
...@@ -109,19 +109,6 @@ int GetStateOrderedIndex(CertProvisioningWorkerState state) { ...@@ -109,19 +109,6 @@ int GetStateOrderedIndex(CertProvisioningWorkerState state) {
return res; return res;
} }
bool CheckPublicKeyInCertificate(
const scoped_refptr<net::X509Certificate>& cert,
const std::string& public_key) {
base::StringPiece spki_from_cert;
if (!net::asn1::ExtractSPKIFromDERCert(
net::x509_util::CryptoBufferAsStringPiece(cert->cert_buffer()),
&spki_from_cert)) {
return false;
}
return (public_key == spki_from_cert);
}
} // namespace } // namespace
// ============= CertProvisioningWorkerFactory ================================= // ============= CertProvisioningWorkerFactory =================================
...@@ -633,7 +620,9 @@ void CertProvisioningWorkerImpl::ImportCert( ...@@ -633,7 +620,9 @@ void CertProvisioningWorkerImpl::ImportCert(
return; return;
} }
if (!CheckPublicKeyInCertificate(cert, public_key_)) { std::string public_key_from_cert =
platform_keys::GetSubjectPublicKeyInfo(cert);
if (public_key_from_cert != public_key_) {
LOG(ERROR) << "Downloaded certificate does not match the expected key pair"; LOG(ERROR) << "Downloaded certificate does not match the expected key pair";
UpdateState(CertProvisioningWorkerState::kFailed); UpdateState(CertProvisioningWorkerState::kFailed);
return; return;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment