Commit 322f39a9 authored by Pavol Marko's avatar Pavol Marko Committed by Commit Bot

Cert Provisioning: Report state changes from worker/scheduler

Add CertProvisioningSchedulerObserver which can be used to get
notifications about state changes regarding the set of workers
/ the set of failed cert profile ids / states of individual workers.

The CertProvisioningWorker gets a StateChangeCallback to be
able to report state changes back to the owner (the
CertProvisioningScheduler).

Bug: 1081396
Test: unit_tests
Change-Id: Ib41f2913edbefdffb250aba417dca5ace9e85dbd
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2339997
Commit-Queue: Pavol Marko <pmarko@chromium.org>
Reviewed-by: default avatarMichael Ershov <miersh@google.com>
Cr-Commit-Position: refs/heads/master@{#804873}
parent 34be7d92
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "base/containers/flat_set.h" #include "base/containers/flat_set.h"
#include "base/location.h" #include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/observer_list.h"
#include "base/observer_list_types.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/stl_util.h" #include "base/stl_util.h"
#include "base/time/time.h" #include "base/time/time.h"
...@@ -334,6 +336,9 @@ void CertProvisioningSchedulerImpl::DeserializeWorkers() { ...@@ -334,6 +336,9 @@ void CertProvisioningSchedulerImpl::DeserializeWorkers() {
CertProvisioningWorkerFactory::Get()->Deserialize( CertProvisioningWorkerFactory::Get()->Deserialize(
cert_scope_, profile_, pref_service_, saved_worker, cert_scope_, profile_, pref_service_, saved_worker,
cloud_policy_client_, invalidator_factory_->Create(), cloud_policy_client_, invalidator_factory_->Create(),
base::BindRepeating(
&CertProvisioningSchedulerImpl::OnVisibleStateChanged,
weak_factory_.GetWeakPtr()),
base::BindOnce(&CertProvisioningSchedulerImpl::OnProfileFinished, base::BindOnce(&CertProvisioningSchedulerImpl::OnProfileFinished,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
if (!worker) { if (!worker) {
...@@ -341,7 +346,7 @@ void CertProvisioningSchedulerImpl::DeserializeWorkers() { ...@@ -341,7 +346,7 @@ void CertProvisioningSchedulerImpl::DeserializeWorkers() {
continue; continue;
} }
workers_[worker->GetCertProfile().profile_id] = std::move(worker); AddWorkerToMap(std::move(worker));
} }
} }
...@@ -501,10 +506,12 @@ void CertProvisioningSchedulerImpl::CreateCertProvisioningWorker( ...@@ -501,10 +506,12 @@ void CertProvisioningSchedulerImpl::CreateCertProvisioningWorker(
CertProvisioningWorkerFactory::Get()->Create( CertProvisioningWorkerFactory::Get()->Create(
cert_scope_, profile_, pref_service_, cert_profile, cert_scope_, profile_, pref_service_, cert_profile,
cloud_policy_client_, invalidator_factory_->Create(), cloud_policy_client_, invalidator_factory_->Create(),
base::BindRepeating(
&CertProvisioningSchedulerImpl::OnVisibleStateChanged,
weak_factory_.GetWeakPtr()),
base::BindOnce(&CertProvisioningSchedulerImpl::OnProfileFinished, base::BindOnce(&CertProvisioningSchedulerImpl::OnProfileFinished,
weak_factory_.GetWeakPtr())); weak_factory_.GetWeakPtr()));
CertProvisioningWorker* worker_unowned = worker.get(); CertProvisioningWorker* worker_unowned = AddWorkerToMap(std::move(worker));
workers_[cert_profile.profile_id] = std::move(worker);
worker_unowned->DoStep(); worker_unowned->DoStep();
} }
...@@ -539,7 +546,7 @@ void CertProvisioningSchedulerImpl::OnProfileFinished( ...@@ -539,7 +546,7 @@ void CertProvisioningSchedulerImpl::OnProfileFinished(
break; break;
} }
workers_.erase(worker_iter); RemoveWorkerFromMap(worker_iter);
} }
CertProvisioningWorker* CertProvisioningSchedulerImpl::FindWorker( CertProvisioningWorker* CertProvisioningSchedulerImpl::FindWorker(
...@@ -554,6 +561,20 @@ CertProvisioningWorker* CertProvisioningSchedulerImpl::FindWorker( ...@@ -554,6 +561,20 @@ CertProvisioningWorker* CertProvisioningSchedulerImpl::FindWorker(
return iter->second.get(); return iter->second.get();
} }
CertProvisioningWorker* CertProvisioningSchedulerImpl::AddWorkerToMap(
std::unique_ptr<CertProvisioningWorker> worker) {
CertProvisioningWorker* worker_unowned = worker.get();
workers_[worker_unowned->GetCertProfile().profile_id] = std::move(worker);
OnVisibleStateChanged();
return worker_unowned;
}
void CertProvisioningSchedulerImpl::RemoveWorkerFromMap(
WorkerMap::iterator worker_iter) {
workers_.erase(worker_iter);
OnVisibleStateChanged();
}
base::Optional<CertProfile> CertProvisioningSchedulerImpl::GetOneCertProfile( base::Optional<CertProfile> CertProvisioningSchedulerImpl::GetOneCertProfile(
const CertProfileId& cert_profile_id) { const CertProfileId& cert_profile_id) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
...@@ -612,6 +633,16 @@ CertProvisioningSchedulerImpl::GetFailedCertProfileIds() const { ...@@ -612,6 +633,16 @@ CertProvisioningSchedulerImpl::GetFailedCertProfileIds() const {
return failed_cert_profiles_; return failed_cert_profiles_;
} }
void CertProvisioningSchedulerImpl::AddObserver(
CertProvisioningSchedulerObserver* observer) {
observers_.AddObserver(observer);
}
void CertProvisioningSchedulerImpl::RemoveObserver(
CertProvisioningSchedulerObserver* observer) {
observers_.RemoveObserver(observer);
}
bool CertProvisioningSchedulerImpl::MaybeWaitForInternetConnection() { bool CertProvisioningSchedulerImpl::MaybeWaitForInternetConnection() {
const NetworkState* network = network_state_handler_->DefaultNetwork(); const NetworkState* network = network_state_handler_->DefaultNetwork();
bool is_online = network && network->IsOnline(); bool is_online = network && network->IsOnline();
...@@ -729,5 +760,26 @@ void CertProvisioningSchedulerImpl::CancelWorkersWithoutPolicy( ...@@ -729,5 +760,26 @@ void CertProvisioningSchedulerImpl::CancelWorkersWithoutPolicy(
} }
} }
void CertProvisioningSchedulerImpl::OnVisibleStateChanged() {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
if (notify_observers_pending_) {
return;
}
notify_observers_pending_ = true;
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::Bind(
&CertProvisioningSchedulerImpl::NotifyObserversVisibleStateChanged,
weak_factory_.GetWeakPtr()));
}
void CertProvisioningSchedulerImpl::NotifyObserversVisibleStateChanged() {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
notify_observers_pending_ = false;
for (auto& observer : observers_) {
observer.OnVisibleStateChanged();
}
}
} // namespace cert_provisioning } // namespace cert_provisioning
} // namespace chromeos } // namespace chromeos
...@@ -55,6 +55,18 @@ struct FailedWorkerInfo { ...@@ -55,6 +55,18 @@ struct FailedWorkerInfo {
base::Time last_update_time; base::Time last_update_time;
}; };
// An observer that gets notified about state changes of the
// CertProvisioningScheduler.
class CertProvisioningSchedulerObserver : public base::CheckedObserver {
public:
// Called when the "visible state" of the observerd CertProvisioningScheduler
// has changed, i.e. when:
// (*) the list of active workers changed,
// (*) the list of recently failed workers changed,
// (*) the state of a worker changed.
virtual void OnVisibleStateChanged() = 0;
};
// Interface for the scheduler for client certificate provisioning using device // Interface for the scheduler for client certificate provisioning using device
// management. // management.
class CertProvisioningScheduler { class CertProvisioningScheduler {
...@@ -73,6 +85,12 @@ class CertProvisioningScheduler { ...@@ -73,6 +85,12 @@ class CertProvisioningScheduler {
// failed and have not been restarted (yet). // failed and have not been restarted (yet).
virtual const base::flat_map<CertProfileId, FailedWorkerInfo>& virtual const base::flat_map<CertProfileId, FailedWorkerInfo>&
GetFailedCertProfileIds() const = 0; GetFailedCertProfileIds() const = 0;
// Adds |observer| which will observer this CertProvisioningScheduler.
virtual void AddObserver(CertProvisioningSchedulerObserver* observer) = 0;
// Removes a previously added |observer|.
virtual void RemoveObserver(CertProvisioningSchedulerObserver* observer) = 0;
}; };
// This class is a part of certificate provisioning feature. It tracks updates // This class is a part of certificate provisioning feature. It tracks updates
...@@ -113,10 +131,19 @@ class CertProvisioningSchedulerImpl ...@@ -113,10 +131,19 @@ class CertProvisioningSchedulerImpl
const WorkerMap& GetWorkers() const override; const WorkerMap& GetWorkers() const override;
const base::flat_map<CertProfileId, FailedWorkerInfo>& const base::flat_map<CertProfileId, FailedWorkerInfo>&
GetFailedCertProfileIds() const override; GetFailedCertProfileIds() const override;
void AddObserver(CertProvisioningSchedulerObserver* observer) override;
void RemoveObserver(CertProvisioningSchedulerObserver* observer) override;
// Invoked when the CertProvisioningWorker corresponding to |profile| reached
// its final state.
// Public so it can be called from tests.
void OnProfileFinished(const CertProfile& profile, void OnProfileFinished(const CertProfile& profile,
CertProvisioningWorkerState state); CertProvisioningWorkerState state);
// Called when any state visible from the outside has changed.
// Public so it can be called from tests.
void OnVisibleStateChanged();
private: private:
void ScheduleInitialUpdate(); void ScheduleInitialUpdate();
void ScheduleDailyUpdate(); void ScheduleDailyUpdate();
...@@ -156,6 +183,13 @@ class CertProvisioningSchedulerImpl ...@@ -156,6 +183,13 @@ class CertProvisioningSchedulerImpl
void CreateCertProvisioningWorker(CertProfile profile); void CreateCertProvisioningWorker(CertProfile profile);
CertProvisioningWorker* FindWorker(CertProfileId profile_id); CertProvisioningWorker* FindWorker(CertProfileId profile_id);
// Adds |worker| to |workers_| and returns an unowned pointer to |worker|.
// Triggers a state change notification.
CertProvisioningWorker* AddWorkerToMap(
std::unique_ptr<CertProvisioningWorker> worker);
// Removes the element referenced by |worker_iter| from |workers_|.
// Triggers a state change notification.
void RemoveWorkerFromMap(WorkerMap::iterator worker_iter);
// Returns true if the process can be continued (if it's not required to // Returns true if the process can be continued (if it's not required to
// wait). // wait).
...@@ -171,6 +205,9 @@ class CertProvisioningSchedulerImpl ...@@ -171,6 +205,9 @@ class CertProvisioningSchedulerImpl
// PlatformKeysServiceObserver // PlatformKeysServiceObserver
void OnPlatformKeysServiceShutDown() override; void OnPlatformKeysServiceShutDown() override;
// Notifies each observer from |observers_| that the state has changed.
void NotifyObserversVisibleStateChanged();
CertScope cert_scope_ = CertScope::kUser; CertScope cert_scope_ = CertScope::kUser;
// |profile_| can be nullptr for the device-wide instance of // |profile_| can be nullptr for the device-wide instance of
// CertProvisioningScheduler. // CertProvisioningScheduler.
...@@ -204,6 +241,12 @@ class CertProvisioningSchedulerImpl ...@@ -204,6 +241,12 @@ class CertProvisioningSchedulerImpl
CertDeleter cert_deleter_; CertDeleter cert_deleter_;
std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_; std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_;
// Observers that are observing this CertProvisioningSchedulerImpl.
base::ObserverList<CertProvisioningSchedulerObserver> observers_;
// True when a task for notifying observers about a state change has been
// scheduled but not executed yet.
bool notify_observers_pending_ = false;
ScopedObserver<platform_keys::PlatformKeysService, ScopedObserver<platform_keys::PlatformKeysService,
platform_keys::PlatformKeysServiceObserver> platform_keys::PlatformKeysServiceObserver>
scoped_platform_keys_service_observer_{this}; scoped_platform_keys_service_observer_{this};
......
...@@ -45,7 +45,36 @@ constexpr char kCertProfileId[] = "cert_profile_id_1"; ...@@ -45,7 +45,36 @@ constexpr char kCertProfileId[] = "cert_profile_id_1";
constexpr char kCertProfileVersion[] = "cert_profile_version_1"; constexpr char kCertProfileVersion[] = "cert_profile_version_1";
constexpr TimeDelta kCertProfileRenewalPeriod = TimeDelta::FromSeconds(0); constexpr TimeDelta kCertProfileRenewalPeriod = TimeDelta::FromSeconds(0);
//================ CertProvisioningSchedulerTest =============================== //=============== TestCertProvisioningSchedulerObserver ========================
class TestCertProvisioningSchedulerObserver
: public CertProvisioningSchedulerObserver {
public:
TestCertProvisioningSchedulerObserver() = default;
~TestCertProvisioningSchedulerObserver() override = default;
TestCertProvisioningSchedulerObserver(
const TestCertProvisioningSchedulerObserver& other) = delete;
TestCertProvisioningSchedulerObserver& operator=(
const TestCertProvisioningSchedulerObserver& other) = delete;
// CertProvisioningSchedulerObserver:
void OnVisibleStateChanged() override { run_loop_->Quit(); }
// Waits for one call to happen (since construction or since the previous
// WaitForOneCall has returned).
void WaitForOneCall() {
run_loop_->Run();
// Create a new RunLoop so it can already be terminated when the next
// OnVisibleStateChanged() call comes in.
run_loop_ = std::make_unique<base::RunLoop>();
}
private:
std::unique_ptr<base::RunLoop> run_loop_ = std::make_unique<base::RunLoop>();
};
//=================== CertProvisioningSchedulerTest ============================
class CertProvisioningSchedulerTest : public testing::Test { class CertProvisioningSchedulerTest : public testing::Test {
public: public:
...@@ -952,6 +981,95 @@ TEST_F(CertProvisioningSchedulerTest, PlatformKeysServiceShutDown) { ...@@ -952,6 +981,95 @@ TEST_F(CertProvisioningSchedulerTest, PlatformKeysServiceShutDown) {
scheduler.UpdateAllCerts(); scheduler.UpdateAllCerts();
} }
TEST_F(CertProvisioningSchedulerTest, StateChangeNotifications) {
const CertScope kCertScope = CertScope::kDevice;
CertProvisioningSchedulerImpl scheduler(
kCertScope, GetProfile(), &pref_service_, &cloud_policy_client_,
&platform_keys_service_,
network_state_test_helper_.network_state_handler(),
MakeFakeInvalidationFactory());
TestCertProvisioningSchedulerObserver observer;
scheduler.AddObserver(&observer);
// From CertProvisioningSchedulerImpl::CleanVaKeysIfIdle.
EXPECT_CALL(fake_cryptohome_client_,
OnTpmAttestationDeleteKeysByPrefix(
attestation::AttestationKeyType::KEY_DEVICE, kKeyNamePrefix))
.Times(1);
// The policy is empty, so no workers should be created yet.
FastForwardBy(TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 0U);
// Two new workers will be created on prefs update.
// Expect a state change notification for this.
const char kCertProfileId0[] = "cert_profile_id_0";
const char kCertProfileVersion0[] = "cert_profile_version_0";
CertProfile cert_profile0(kCertProfileId0, kCertProfileVersion0,
/*is_va_enabled=*/true, kCertProfileRenewalPeriod);
const char kCertProfileId1[] = "cert_profile_id_1";
const char kCertProfileVersion1[] = "cert_profile_version_1";
CertProfile cert_profile1(kCertProfileId1, kCertProfileVersion1,
/*is_va_enabled=*/true, kCertProfileRenewalPeriod);
MockCertProvisioningWorker* worker0 =
mock_factory_.ExpectCreateReturnMock(kCertScope, cert_profile0);
worker0->SetExpectations(/*do_step_times=*/AtLeast(1), /*is_waiting=*/false,
cert_profile0);
MockCertProvisioningWorker* worker1 =
mock_factory_.ExpectCreateReturnMock(kCertScope, cert_profile1);
worker1->SetExpectations(/*do_step_times=*/AtLeast(1), /*is_waiting=*/false,
cert_profile1);
// Add 2 certificate profiles to the policy (the values are the same as
// in |cert_profile|-s)
base::Value config = ParseJson(
R"([{
"name": "Certificate Profile 0",
"cert_profile_id":"cert_profile_id_0",
"policy_version":"cert_profile_version_0",
"key_algorithm":"rsa"
},
{
"name": "Certificate Profile 1",
"cert_profile_id":"cert_profile_id_1",
"policy_version":"cert_profile_version_1",
"key_algorithm":"rsa"
}])");
pref_service_.Set(GetPrefNameForCertProfiles(kCertScope), config);
observer.WaitForOneCall();
// Now one worker for each profile should be created.
ASSERT_EQ(scheduler.GetWorkers().size(), 2U);
// Emulate a worker reporting a state change.
// A state change event should be fired by the scheduler for that.
scheduler.OnVisibleStateChanged();
observer.WaitForOneCall();
// Emulate a worker reporting a state changeand successfully finishing.
// Should be just deleted, and state change event should be
// fired for that.
scheduler.OnVisibleStateChanged();
scheduler.OnProfileFinished(cert_profile0,
CertProvisioningWorkerState::kSucceeded);
observer.WaitForOneCall();
// worker1 failed. Should be deleted and the profile id should be saved, and a
// state change event should be fired for that.
scheduler.OnProfileFinished(cert_profile1,
CertProvisioningWorkerState::kFailed);
observer.WaitForOneCall();
EXPECT_EQ(scheduler.GetWorkers().size(), 0U);
EXPECT_TRUE(
base::Contains(scheduler.GetFailedCertProfileIds(), kCertProfileId1));
scheduler.RemoveObserver(&observer);
}
} // namespace } // namespace
} // namespace cert_provisioning } // namespace cert_provisioning
} // namespace chromeos } // namespace chromeos
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "base/base64.h" #include "base/base64.h"
#include "base/bind.h" #include "base/bind.h"
#include "base/bind_helpers.h" #include "base/bind_helpers.h"
#include "base/callback_forward.h"
#include "base/no_destructor.h" #include "base/no_destructor.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/time/time.h" #include "base/time/time.h"
...@@ -135,11 +136,13 @@ std::unique_ptr<CertProvisioningWorker> CertProvisioningWorkerFactory::Create( ...@@ -135,11 +136,13 @@ std::unique_ptr<CertProvisioningWorker> CertProvisioningWorkerFactory::Create(
const CertProfile& cert_profile, const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback) { base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback) {
RecordEvent(cert_scope, CertProvisioningEvent::kWorkerCreated); RecordEvent(cert_scope, CertProvisioningEvent::kWorkerCreated);
return std::make_unique<CertProvisioningWorkerImpl>( return std::make_unique<CertProvisioningWorkerImpl>(
cert_scope, profile, pref_service, cert_profile, cloud_policy_client, cert_scope, profile, pref_service, cert_profile, cloud_policy_client,
std::move(invalidator), std::move(callback)); std::move(invalidator), std::move(state_change_callback),
std::move(result_callback));
} }
std::unique_ptr<CertProvisioningWorker> std::unique_ptr<CertProvisioningWorker>
...@@ -150,10 +153,12 @@ CertProvisioningWorkerFactory::Deserialize( ...@@ -150,10 +153,12 @@ CertProvisioningWorkerFactory::Deserialize(
const base::Value& saved_worker, const base::Value& saved_worker,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback) { base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback) {
auto worker = std::make_unique<CertProvisioningWorkerImpl>( auto worker = std::make_unique<CertProvisioningWorkerImpl>(
cert_scope, profile, pref_service, CertProfile(), cloud_policy_client, cert_scope, profile, pref_service, CertProfile(), cloud_policy_client,
std::move(invalidator), std::move(callback)); std::move(invalidator), std::move(state_change_callback),
std::move(result_callback));
if (!CertProvisioningSerializer::DeserializeWorker(saved_worker, if (!CertProvisioningSerializer::DeserializeWorker(saved_worker,
worker.get())) { worker.get())) {
RecordEvent(cert_scope, RecordEvent(cert_scope,
...@@ -179,12 +184,14 @@ CertProvisioningWorkerImpl::CertProvisioningWorkerImpl( ...@@ -179,12 +184,14 @@ CertProvisioningWorkerImpl::CertProvisioningWorkerImpl(
const CertProfile& cert_profile, const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback) base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback)
: cert_scope_(cert_scope), : cert_scope_(cert_scope),
profile_(profile), profile_(profile),
pref_service_(pref_service), pref_service_(pref_service),
cert_profile_(cert_profile), cert_profile_(cert_profile),
callback_(std::move(callback)), state_change_callback_(std::move(state_change_callback)),
result_callback_(std::move(result_callback)),
request_backoff_(&kBackoffPolicy), request_backoff_(&kBackoffPolicy),
cloud_policy_client_(cloud_policy_client), cloud_policy_client_(cloud_policy_client),
invalidator_(std::move(invalidator)) { invalidator_(std::move(invalidator)) {
...@@ -311,6 +318,7 @@ void CertProvisioningWorkerImpl::UpdateState( ...@@ -311,6 +318,7 @@ void CertProvisioningWorkerImpl::UpdateState(
HandleSerialization(); HandleSerialization();
state_change_callback_.Run();
if (IsFinalState(state_)) { if (IsFinalState(state_)) {
CleanUpAndRunCallback(); CleanUpAndRunCallback();
} }
...@@ -805,7 +813,7 @@ void CertProvisioningWorkerImpl::OnCleanUpDone() { ...@@ -805,7 +813,7 @@ void CertProvisioningWorkerImpl::OnCleanUpDone() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
RecordResult(cert_scope_, state_, prev_state_); RecordResult(cert_scope_, state_, prev_state_);
std::move(callback_).Run(cert_profile_, state_); std::move(result_callback_).Run(cert_profile_, state_);
} }
void CertProvisioningWorkerImpl::HandleSerialization() { void CertProvisioningWorkerImpl::HandleSerialization() {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <memory> #include <memory>
#include "base/callback_forward.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h" #include "base/sequence_checker.h"
#include "base/time/time.h" #include "base/time/time.h"
...@@ -28,6 +29,8 @@ namespace cert_provisioning { ...@@ -28,6 +29,8 @@ namespace cert_provisioning {
class CertProvisioningInvalidator; class CertProvisioningInvalidator;
// A OnceCallback that is invoked when the CertProvisioningWorker is done and
// has a result (which could be success or failure).
using CertProvisioningWorkerCallback = using CertProvisioningWorkerCallback =
base::OnceCallback<void(const CertProfile& profile, base::OnceCallback<void(const CertProfile& profile,
CertProvisioningWorkerState state)>; CertProvisioningWorkerState state)>;
...@@ -47,7 +50,8 @@ class CertProvisioningWorkerFactory { ...@@ -47,7 +50,8 @@ class CertProvisioningWorkerFactory {
const CertProfile& cert_profile, const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback); base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback);
virtual std::unique_ptr<CertProvisioningWorker> Deserialize( virtual std::unique_ptr<CertProvisioningWorker> Deserialize(
CertScope cert_scope, CertScope cert_scope,
...@@ -56,7 +60,8 @@ class CertProvisioningWorkerFactory { ...@@ -56,7 +60,8 @@ class CertProvisioningWorkerFactory {
const base::Value& saved_worker, const base::Value& saved_worker,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback); base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback);
// Doesn't take ownership. // Doesn't take ownership.
static void SetFactoryForTesting(CertProvisioningWorkerFactory* test_factory); static void SetFactoryForTesting(CertProvisioningWorkerFactory* test_factory);
...@@ -108,7 +113,8 @@ class CertProvisioningWorkerImpl : public CertProvisioningWorker { ...@@ -108,7 +113,8 @@ class CertProvisioningWorkerImpl : public CertProvisioningWorker {
const CertProfile& cert_profile, const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback); base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback);
~CertProvisioningWorkerImpl() override; ~CertProvisioningWorkerImpl() override;
// CertProvisioningWorker // CertProvisioningWorker
...@@ -219,7 +225,8 @@ class CertProvisioningWorkerImpl : public CertProvisioningWorker { ...@@ -219,7 +225,8 @@ class CertProvisioningWorkerImpl : public CertProvisioningWorker {
Profile* profile_ = nullptr; Profile* profile_ = nullptr;
PrefService* pref_service_ = nullptr; PrefService* pref_service_ = nullptr;
CertProfile cert_profile_; CertProfile cert_profile_;
CertProvisioningWorkerCallback callback_; base::RepeatingClosure state_change_callback_;
CertProvisioningWorkerCallback result_callback_;
// This field should be updated only via |UpdateState| function. It will // This field should be updated only via |UpdateState| function. It will
// trigger update of the serialized data. // trigger update of the serialized data.
......
...@@ -38,6 +38,7 @@ using base::test::ParseJson; ...@@ -38,6 +38,7 @@ using base::test::ParseJson;
using base::test::RunOnceCallback; using base::test::RunOnceCallback;
using chromeos::attestation::MockTpmChallengeKeySubtle; using chromeos::attestation::MockTpmChallengeKeySubtle;
using testing::_; using testing::_;
using testing::AtLeast;
using testing::Mock; using testing::Mock;
using testing::StrictMock; using testing::StrictMock;
...@@ -307,6 +308,7 @@ const std::string& GetPublicKey() { ...@@ -307,6 +308,7 @@ const std::string& GetPublicKey() {
.WillOnce(RunOnceCallback<2>(platform_keys::Status::kSuccess)); \ .WillOnce(RunOnceCallback<2>(platform_keys::Status::kSuccess)); \
} }
// A mock for observing the result callback of the worker.
class CallbackObserver { class CallbackObserver {
public: public:
MOCK_METHOD(void, MOCK_METHOD(void,
...@@ -314,6 +316,12 @@ class CallbackObserver { ...@@ -314,6 +316,12 @@ class CallbackObserver {
(const CertProfile& profile, CertProvisioningWorkerState state)); (const CertProfile& profile, CertProvisioningWorkerState state));
}; };
// A mock for observing the state change callback of the worker.
class StateChangeCallbackObserver {
public:
MOCK_METHOD(void, StateChangeCallback, ());
};
class CertProvisioningWorkerTest : public ::testing::Test { class CertProvisioningWorkerTest : public ::testing::Test {
public: public:
CertProvisioningWorkerTest() { Init(); } CertProvisioningWorkerTest() { Init(); }
...@@ -376,7 +384,13 @@ class CertProvisioningWorkerTest : public ::testing::Test { ...@@ -376,7 +384,13 @@ class CertProvisioningWorkerTest : public ::testing::Test {
return tpm_challenge_key_impl; return tpm_challenge_key_impl;
} }
CertProvisioningWorkerCallback GetCallback() { base::RepeatingClosure GetStateChangeCallback() {
return base::BindRepeating(
&StateChangeCallbackObserver ::StateChangeCallback,
base::Unretained(&state_change_callback_observer_));
}
CertProvisioningWorkerCallback GetResultCallback() {
return base::BindOnce(&CallbackObserver::Callback, return base::BindOnce(&CallbackObserver::Callback,
base::Unretained(&callback_observer_)); base::Unretained(&callback_observer_));
} }
...@@ -397,6 +411,7 @@ class CertProvisioningWorkerTest : public ::testing::Test { ...@@ -397,6 +411,7 @@ class CertProvisioningWorkerTest : public ::testing::Test {
content::BrowserTaskEnvironment task_environment_{ content::BrowserTaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME}; base::test::TaskEnvironment::TimeSource::MOCK_TIME};
StrictMock<StateChangeCallbackObserver> state_change_callback_observer_;
StrictMock<CallbackObserver> callback_observer_; StrictMock<CallbackObserver> callback_observer_;
StrictMock<SpyingFakeCryptohomeClient> fake_cryptohome_client_; StrictMock<SpyingFakeCryptohomeClient> fake_cryptohome_client_;
ProfileHelperForTesting profile_helper_for_testing_; ProfileHelperForTesting profile_helper_for_testing_;
...@@ -418,7 +433,8 @@ TEST_F(CertProvisioningWorkerTest, Success) { ...@@ -418,7 +433,8 @@ TEST_F(CertProvisioningWorkerTest, Success) {
MockCertProvisioningInvalidator* mock_invalidator = nullptr; MockCertProvisioningInvalidator* mock_invalidator = nullptr;
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(&mock_invalidator), GetCallback()); &cloud_policy_client_, MakeInvalidator(&mock_invalidator),
GetStateChangeCallback(), GetResultCallback());
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -430,31 +446,38 @@ TEST_F(CertProvisioningWorkerTest, Success) { ...@@ -430,31 +446,38 @@ TEST_F(CertProvisioningWorkerTest, Success) {
GetKeyName(kCertProfileId), GetKeyName(kCertProfileId),
/*profile=*/_, /*profile=*/_,
/*callback=*/_)); /*callback=*/_));
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_START_CSR_OK(ClientCertProvisioningStartCsr( EXPECT_START_CSR_OK(ClientCertProvisioningStartCsr(
kCertScopeStrUser, kCertProfileId, kCertProfileVersion, GetPublicKey(), kCertScopeStrUser, kCertProfileId, kCertProfileVersion, GetPublicKey(),
/*callback=*/_)); /*callback=*/_));
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_CALL(*mock_invalidator, Register(kInvalidationTopic, _)).Times(1); EXPECT_CALL(*mock_invalidator, Register(kInvalidationTopic, _)).Times(1);
EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key, EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
StartSignChallengeStep(kChallenge, StartSignChallengeStep(kChallenge,
/*callback=*/_)); /*callback=*/_));
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep); EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey( EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey(
platform_keys::TokenId::kUser, GetPublicKey(), platform_keys::TokenId::kUser, GetPublicKey(),
platform_keys::KeyAttributeType::kCertificateProvisioningId, platform_keys::KeyAttributeType::kCertificateProvisioningId,
kCertProfileId, _)); kCertProfileId, _));
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_SIGN_RSAPKC1_DIGEST_OK(SignRSAPKCS1Digest( EXPECT_SIGN_RSAPKC1_DIGEST_OK(SignRSAPKCS1Digest(
::testing::Optional(platform_keys::TokenId::kUser), kDataToSign, ::testing::Optional(platform_keys::TokenId::kUser), kDataToSign,
GetPublicKey(), kPkHashAlgo, /*callback=*/_)); GetPublicKey(), kPkHashAlgo, /*callback=*/_));
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr( EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr(
kCertScopeStrUser, kCertProfileId, kCertProfileVersion, GetPublicKey(), kCertScopeStrUser, kCertProfileId, kCertProfileVersion, GetPublicKey(),
kChallengeResponse, kSignature, /*callback=*/_)); kChallengeResponse, kSignature, /*callback=*/_));
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_DOWNLOAD_CERT_OK(ClientCertProvisioningDownloadCert( EXPECT_DOWNLOAD_CERT_OK(ClientCertProvisioningDownloadCert(
kCertScopeStrUser, kCertProfileId, kCertProfileVersion, GetPublicKey(), kCertScopeStrUser, kCertProfileId, kCertProfileVersion, GetPublicKey(),
...@@ -462,6 +485,7 @@ TEST_F(CertProvisioningWorkerTest, Success) { ...@@ -462,6 +485,7 @@ TEST_F(CertProvisioningWorkerTest, Success) {
EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate( EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate(
platform_keys::TokenId::kUser, /*certificate=*/_, /*callback=*/_)); platform_keys::TokenId::kUser, /*certificate=*/_, /*callback=*/_));
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
EXPECT_CALL(*mock_invalidator, Unregister()).Times(1); EXPECT_CALL(*mock_invalidator, Unregister()).Times(1);
...@@ -493,8 +517,11 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) { ...@@ -493,8 +517,11 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) {
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -546,9 +573,12 @@ TEST_F(CertProvisioningWorkerTest, TryLaterManualRetry) { ...@@ -546,9 +573,12 @@ TEST_F(CertProvisioningWorkerTest, TryLaterManualRetry) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kDevice, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kDevice, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
const TimeDelta delay = TimeDelta::FromSeconds(30); const TimeDelta delay = TimeDelta::FromSeconds(30);
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -649,7 +679,8 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) { ...@@ -649,7 +679,8 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
const TimeDelta start_csr_delay = TimeDelta::FromSeconds(30); const TimeDelta start_csr_delay = TimeDelta::FromSeconds(30);
const TimeDelta finish_csr_delay = TimeDelta::FromSeconds(30); const TimeDelta finish_csr_delay = TimeDelta::FromSeconds(30);
...@@ -657,6 +688,8 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) { ...@@ -657,6 +688,8 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) {
const TimeDelta download_cert_real_delay = TimeDelta::FromSeconds(10); const TimeDelta download_cert_real_delay = TimeDelta::FromSeconds(10);
const TimeDelta small_delay = TimeDelta::FromMilliseconds(500); const TimeDelta small_delay = TimeDelta::FromMilliseconds(500);
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -760,8 +793,11 @@ TEST_F(CertProvisioningWorkerTest, StatusErrorHandling) { ...@@ -760,8 +793,11 @@ TEST_F(CertProvisioningWorkerTest, StatusErrorHandling) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -803,8 +839,11 @@ TEST_F(CertProvisioningWorkerTest, ResponseErrorHandling) { ...@@ -803,8 +839,11 @@ TEST_F(CertProvisioningWorkerTest, ResponseErrorHandling) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
auto worker = CertProvisioningWorkerFactory::Get()->Create( auto worker = CertProvisioningWorkerFactory::Get()->Create(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -847,8 +886,11 @@ TEST_F(CertProvisioningWorkerTest, InconsistentDataErrorHandling) { ...@@ -847,8 +886,11 @@ TEST_F(CertProvisioningWorkerTest, InconsistentDataErrorHandling) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
auto worker = CertProvisioningWorkerFactory::Get()->Create( auto worker = CertProvisioningWorkerFactory::Get()->Create(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -887,11 +929,14 @@ TEST_F(CertProvisioningWorkerTest, BackoffStrategy) { ...@@ -887,11 +929,14 @@ TEST_F(CertProvisioningWorkerTest, BackoffStrategy) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
TimeDelta next_delay = TimeDelta::FromSeconds(30); TimeDelta next_delay = TimeDelta::FromSeconds(30);
const TimeDelta small_delay = TimeDelta::FromMilliseconds(500); const TimeDelta small_delay = TimeDelta::FromMilliseconds(500);
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -951,8 +996,11 @@ TEST_F(CertProvisioningWorkerTest, RemoveRegisteredKey) { ...@@ -951,8 +996,11 @@ TEST_F(CertProvisioningWorkerTest, RemoveRegisteredKey) {
MockCertProvisioningInvalidator* mock_invalidator = nullptr; MockCertProvisioningInvalidator* mock_invalidator = nullptr;
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(&mock_invalidator), GetCallback()); &cloud_policy_client_, MakeInvalidator(&mock_invalidator),
GetStateChangeCallback(), GetResultCallback());
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -1045,12 +1093,16 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) { ...@@ -1045,12 +1093,16 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) {
std::unique_ptr<CertProvisioningWorker> worker = std::unique_ptr<CertProvisioningWorker> worker =
CertProvisioningWorkerFactory::Get()->Create( CertProvisioningWorkerFactory::Get()->Create(
kCertScope, GetProfile(), &testing_pref_service_, cert_profile, kCertScope, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
StrictMock<PrefServiceObserver> pref_observer( StrictMock<PrefServiceObserver> pref_observer(
&testing_pref_service_, GetPrefNameForSerialization(CertScope::kUser)); &testing_pref_service_, GetPrefNameForSerialization(CertScope::kUser));
base::Value pref_val; base::Value pref_val;
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
// Prepare key, send start csr request. // Prepare key, send start csr request.
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -1105,7 +1157,7 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) { ...@@ -1105,7 +1157,7 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) {
kCertScope, GetProfile(), &testing_pref_service_, kCertScope, GetProfile(), &testing_pref_service_,
*pref_val.FindKeyOfType(kCertProfileId, base::Value::Type::DICTIONARY), *pref_val.FindKeyOfType(kCertProfileId, base::Value::Type::DICTIONARY),
&cloud_policy_client_, MakeInvalidator(&mock_invalidator), &cloud_policy_client_, MakeInvalidator(&mock_invalidator),
GetCallback()); GetStateChangeCallback(), GetResultCallback());
} }
// Retry start csr request, receive response, try sign challenge. // Retry start csr request, receive response, try sign challenge.
...@@ -1183,7 +1235,8 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) { ...@@ -1183,7 +1235,8 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) {
worker = CertProvisioningWorkerFactory::Get()->Deserialize( worker = CertProvisioningWorkerFactory::Get()->Deserialize(
kCertScope, GetProfile(), &testing_pref_service_, kCertScope, GetProfile(), &testing_pref_service_,
*pref_val.FindKeyOfType(kCertProfileId, base::Value::Type::DICTIONARY), *pref_val.FindKeyOfType(kCertProfileId, base::Value::Type::DICTIONARY),
&cloud_policy_client_, std::move(mock_invalidator_obj), GetCallback()); &cloud_policy_client_, std::move(mock_invalidator_obj),
GetStateChangeCallback(), GetResultCallback());
} }
// Retry download cert request, receive response, try import certificate. // Retry download cert request, receive response, try import certificate.
...@@ -1216,12 +1269,15 @@ TEST_F(CertProvisioningWorkerTest, SerializationOnFailure) { ...@@ -1216,12 +1269,15 @@ TEST_F(CertProvisioningWorkerTest, SerializationOnFailure) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
auto worker = CertProvisioningWorkerFactory::Get()->Create( auto worker = CertProvisioningWorkerFactory::Get()->Create(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
PrefServiceObserver pref_observer( PrefServiceObserver pref_observer(
&testing_pref_service_, GetPrefNameForSerialization(CertScope::kUser)); &testing_pref_service_, GetPrefNameForSerialization(CertScope::kUser));
base::Value pref_val; base::Value pref_val;
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -1277,8 +1333,11 @@ TEST_F(CertProvisioningWorkerTest, InformationalGetters) { ...@@ -1277,8 +1333,11 @@ TEST_F(CertProvisioningWorkerTest, InformationalGetters) {
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
CertProvisioningWorkerImpl worker( CertProvisioningWorkerImpl worker(
CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile, CertScope::kUser, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
{ {
testing::InSequence seq; testing::InSequence seq;
...@@ -1329,10 +1388,13 @@ TEST_F(CertProvisioningWorkerTest, CancelDeviceWorker) { ...@@ -1329,10 +1388,13 @@ TEST_F(CertProvisioningWorkerTest, CancelDeviceWorker) {
CertProfile cert_profile(kCertProfileId, kCertProfileVersion, CertProfile cert_profile(kCertProfileId, kCertProfileVersion,
/*is_va_enabled=*/true, kCertProfileRenewalPeriod); /*is_va_enabled=*/true, kCertProfileRenewalPeriod);
EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
.Times(AtLeast(1));
MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey(); MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
auto worker = CertProvisioningWorkerFactory::Get()->Create( auto worker = CertProvisioningWorkerFactory::Get()->Create(
kCertScope, GetProfile(), &testing_pref_service_, cert_profile, kCertScope, GetProfile(), &testing_pref_service_, cert_profile,
&cloud_policy_client_, MakeInvalidator(), GetCallback()); &cloud_policy_client_, MakeInvalidator(), GetStateChangeCallback(),
GetResultCallback());
EXPECT_CALL(callback_observer_, Callback).Times(0); EXPECT_CALL(callback_observer_, Callback).Times(0);
......
...@@ -25,11 +25,18 @@ class MockCertProvisioningScheduler : public CertProvisioningScheduler { ...@@ -25,11 +25,18 @@ class MockCertProvisioningScheduler : public CertProvisioningScheduler {
(override)); (override));
MOCK_METHOD(void, UpdateAllCerts, (), (override)); MOCK_METHOD(void, UpdateAllCerts, (), (override));
MOCK_METHOD(const WorkerMap&, GetWorkers, (), (const override)); MOCK_METHOD(const WorkerMap&, GetWorkers, (), (const override));
MOCK_METHOD((const base::flat_map<CertProfileId, FailedWorkerInfo>&), MOCK_METHOD((const base::flat_map<CertProfileId, FailedWorkerInfo>&),
GetFailedCertProfileIds, GetFailedCertProfileIds,
(), (),
(const override)); (const override));
MOCK_METHOD(void,
AddObserver,
(CertProvisioningSchedulerObserver*),
(override));
MOCK_METHOD(void,
RemoveObserver,
(CertProvisioningSchedulerObserver*),
(override));
}; };
} // namespace cert_provisioning } // namespace cert_provisioning
......
...@@ -28,7 +28,7 @@ MockCertProvisioningWorkerFactory::ExpectCreateReturnMock( ...@@ -28,7 +28,7 @@ MockCertProvisioningWorkerFactory::ExpectCreateReturnMock(
auto mock_worker = std::make_unique<MockCertProvisioningWorker>(); auto mock_worker = std::make_unique<MockCertProvisioningWorker>();
MockCertProvisioningWorker* pointer = mock_worker.get(); MockCertProvisioningWorker* pointer = mock_worker.get();
EXPECT_CALL(*this, Create(cert_scope, _, _, cert_profile, _, _, _)) EXPECT_CALL(*this, Create(cert_scope, _, _, cert_profile, _, _, _, _))
.Times(1) .Times(1)
.WillOnce(Return(testing::ByMove(std::move(mock_worker)))); .WillOnce(Return(testing::ByMove(std::move(mock_worker))));
...@@ -43,7 +43,7 @@ MockCertProvisioningWorkerFactory::ExpectDeserializeReturnMock( ...@@ -43,7 +43,7 @@ MockCertProvisioningWorkerFactory::ExpectDeserializeReturnMock(
MockCertProvisioningWorker* pointer = mock_worker.get(); MockCertProvisioningWorker* pointer = mock_worker.get();
EXPECT_CALL(*this, EXPECT_CALL(*this,
Deserialize(cert_scope, _, _, IsJson(saved_worker), _, _, _)) Deserialize(cert_scope, _, _, IsJson(saved_worker), _, _, _, _))
.Times(1) .Times(1)
.WillOnce(Return(testing::ByMove(std::move(mock_worker)))); .WillOnce(Return(testing::ByMove(std::move(mock_worker))));
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifndef CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_MOCK_CERT_PROVISIONING_WORKER_H_ #ifndef CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_MOCK_CERT_PROVISIONING_WORKER_H_
#define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_MOCK_CERT_PROVISIONING_WORKER_H_ #define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_MOCK_CERT_PROVISIONING_WORKER_H_
#include "base/callback_forward.h"
#include "base/containers/queue.h" #include "base/containers/queue.h"
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_worker.h" #include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_worker.h"
#include "chrome/browser/chromeos/cert_provisioning/mock_cert_provisioning_invalidator.h" #include "chrome/browser/chromeos/cert_provisioning/mock_cert_provisioning_invalidator.h"
...@@ -33,6 +34,7 @@ class MockCertProvisioningWorkerFactory : public CertProvisioningWorkerFactory { ...@@ -33,6 +34,7 @@ class MockCertProvisioningWorkerFactory : public CertProvisioningWorkerFactory {
const CertProfile& cert_profile, const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback callback), CertProvisioningWorkerCallback callback),
(override)); (override));
...@@ -44,6 +46,7 @@ class MockCertProvisioningWorkerFactory : public CertProvisioningWorkerFactory { ...@@ -44,6 +46,7 @@ class MockCertProvisioningWorkerFactory : public CertProvisioningWorkerFactory {
const base::Value& saved_worker, const base::Value& saved_worker,
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator, std::unique_ptr<CertProvisioningInvalidator> invalidator,
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback callback), CertProvisioningWorkerCallback callback),
(override)); (override));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment