Commit 322f39a9 authored by Pavol Marko's avatar Pavol Marko Committed by Commit Bot

Cert Provisioning: Report state changes from worker/scheduler

Add CertProvisioningSchedulerObserver which can be used to get
notifications about state changes regarding the set of workers
/ the set of failed cert profile ids / states of individual workers.

The CertProvisioningWorker gets a StateChangeCallback to be
able to report state changes back to the owner (the
CertProvisioningScheduler).

Bug: 1081396
Test: unit_tests
Change-Id: Ib41f2913edbefdffb250aba417dca5ace9e85dbd
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2339997
Commit-Queue: Pavol Marko <pmarko@chromium.org>
Reviewed-by: default avatarMichael Ershov <miersh@google.com>
Cr-Commit-Position: refs/heads/master@{#804873}
parent 34be7d92
......@@ -14,6 +14,8 @@
#include "base/containers/flat_set.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/observer_list.h"
#include "base/observer_list_types.h"
#include "base/optional.h"
#include "base/stl_util.h"
#include "base/time/time.h"
......@@ -334,6 +336,9 @@ void CertProvisioningSchedulerImpl::DeserializeWorkers() {
CertProvisioningWorkerFactory::Get()->Deserialize(
cert_scope_, profile_, pref_service_, saved_worker,
cloud_policy_client_, invalidator_factory_->Create(),
base::BindRepeating(
&CertProvisioningSchedulerImpl::OnVisibleStateChanged,
weak_factory_.GetWeakPtr()),
base::BindOnce(&CertProvisioningSchedulerImpl::OnProfileFinished,
weak_factory_.GetWeakPtr()));
if (!worker) {
......@@ -341,7 +346,7 @@ void CertProvisioningSchedulerImpl::DeserializeWorkers() {
continue;
}
workers_[worker->GetCertProfile().profile_id] = std::move(worker);
AddWorkerToMap(std::move(worker));
}
}
......@@ -501,10 +506,12 @@ void CertProvisioningSchedulerImpl::CreateCertProvisioningWorker(
CertProvisioningWorkerFactory::Get()->Create(
cert_scope_, profile_, pref_service_, cert_profile,
cloud_policy_client_, invalidator_factory_->Create(),
base::BindRepeating(
&CertProvisioningSchedulerImpl::OnVisibleStateChanged,
weak_factory_.GetWeakPtr()),
base::BindOnce(&CertProvisioningSchedulerImpl::OnProfileFinished,
weak_factory_.GetWeakPtr()));
CertProvisioningWorker* worker_unowned = worker.get();
workers_[cert_profile.profile_id] = std::move(worker);
CertProvisioningWorker* worker_unowned = AddWorkerToMap(std::move(worker));
worker_unowned->DoStep();
}
......@@ -539,7 +546,7 @@ void CertProvisioningSchedulerImpl::OnProfileFinished(
break;
}
workers_.erase(worker_iter);
RemoveWorkerFromMap(worker_iter);
}
CertProvisioningWorker* CertProvisioningSchedulerImpl::FindWorker(
......@@ -554,6 +561,20 @@ CertProvisioningWorker* CertProvisioningSchedulerImpl::FindWorker(
return iter->second.get();
}
CertProvisioningWorker* CertProvisioningSchedulerImpl::AddWorkerToMap(
std::unique_ptr<CertProvisioningWorker> worker) {
CertProvisioningWorker* worker_unowned = worker.get();
workers_[worker_unowned->GetCertProfile().profile_id] = std::move(worker);
OnVisibleStateChanged();
return worker_unowned;
}
void CertProvisioningSchedulerImpl::RemoveWorkerFromMap(
WorkerMap::iterator worker_iter) {
workers_.erase(worker_iter);
OnVisibleStateChanged();
}
base::Optional<CertProfile> CertProvisioningSchedulerImpl::GetOneCertProfile(
const CertProfileId& cert_profile_id) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
......@@ -612,6 +633,16 @@ CertProvisioningSchedulerImpl::GetFailedCertProfileIds() const {
return failed_cert_profiles_;
}
void CertProvisioningSchedulerImpl::AddObserver(
CertProvisioningSchedulerObserver* observer) {
observers_.AddObserver(observer);
}
void CertProvisioningSchedulerImpl::RemoveObserver(
CertProvisioningSchedulerObserver* observer) {
observers_.RemoveObserver(observer);
}
bool CertProvisioningSchedulerImpl::MaybeWaitForInternetConnection() {
const NetworkState* network = network_state_handler_->DefaultNetwork();
bool is_online = network && network->IsOnline();
......@@ -729,5 +760,26 @@ void CertProvisioningSchedulerImpl::CancelWorkersWithoutPolicy(
}
}
void CertProvisioningSchedulerImpl::OnVisibleStateChanged() {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
if (notify_observers_pending_) {
return;
}
notify_observers_pending_ = true;
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::Bind(
&CertProvisioningSchedulerImpl::NotifyObserversVisibleStateChanged,
weak_factory_.GetWeakPtr()));
}
void CertProvisioningSchedulerImpl::NotifyObserversVisibleStateChanged() {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
notify_observers_pending_ = false;
for (auto& observer : observers_) {
observer.OnVisibleStateChanged();
}
}
} // namespace cert_provisioning
} // namespace chromeos
......@@ -55,6 +55,18 @@ struct FailedWorkerInfo {
base::Time last_update_time;
};
// An observer that gets notified about state changes of the
// CertProvisioningScheduler.
class CertProvisioningSchedulerObserver : public base::CheckedObserver {
public:
// Called when the "visible state" of the observerd CertProvisioningScheduler
// has changed, i.e. when:
// (*) the list of active workers changed,
// (*) the list of recently failed workers changed,
// (*) the state of a worker changed.
virtual void OnVisibleStateChanged() = 0;
};
// Interface for the scheduler for client certificate provisioning using device
// management.
class CertProvisioningScheduler {
......@@ -73,6 +85,12 @@ class CertProvisioningScheduler {
// failed and have not been restarted (yet).
virtual const base::flat_map<CertProfileId, FailedWorkerInfo>&
GetFailedCertProfileIds() const = 0;
// Adds |observer| which will observer this CertProvisioningScheduler.
virtual void AddObserver(CertProvisioningSchedulerObserver* observer) = 0;
// Removes a previously added |observer|.
virtual void RemoveObserver(CertProvisioningSchedulerObserver* observer) = 0;
};
// This class is a part of certificate provisioning feature. It tracks updates
......@@ -113,10 +131,19 @@ class CertProvisioningSchedulerImpl
const WorkerMap& GetWorkers() const override;
const base::flat_map<CertProfileId, FailedWorkerInfo>&
GetFailedCertProfileIds() const override;
void AddObserver(CertProvisioningSchedulerObserver* observer) override;
void RemoveObserver(CertProvisioningSchedulerObserver* observer) override;
// Invoked when the CertProvisioningWorker corresponding to |profile| reached
// its final state.
// Public so it can be called from tests.
void OnProfileFinished(const CertProfile& profile,
CertProvisioningWorkerState state);
// Called when any state visible from the outside has changed.
// Public so it can be called from tests.
void OnVisibleStateChanged();
private:
void ScheduleInitialUpdate();
void ScheduleDailyUpdate();
......@@ -156,6 +183,13 @@ class CertProvisioningSchedulerImpl
void CreateCertProvisioningWorker(CertProfile profile);
CertProvisioningWorker* FindWorker(CertProfileId profile_id);
// Adds |worker| to |workers_| and returns an unowned pointer to |worker|.
// Triggers a state change notification.
CertProvisioningWorker* AddWorkerToMap(
std::unique_ptr<CertProvisioningWorker> worker);
// Removes the element referenced by |worker_iter| from |workers_|.
// Triggers a state change notification.
void RemoveWorkerFromMap(WorkerMap::iterator worker_iter);
// Returns true if the process can be continued (if it's not required to
// wait).
......@@ -171,6 +205,9 @@ class CertProvisioningSchedulerImpl
// PlatformKeysServiceObserver
void OnPlatformKeysServiceShutDown() override;
// Notifies each observer from |observers_| that the state has changed.
void NotifyObserversVisibleStateChanged();
CertScope cert_scope_ = CertScope::kUser;
// |profile_| can be nullptr for the device-wide instance of
// CertProvisioningScheduler.
......@@ -204,6 +241,12 @@ class CertProvisioningSchedulerImpl
CertDeleter cert_deleter_;
std::unique_ptr<CertProvisioningInvalidatorFactory> invalidator_factory_;
// Observers that are observing this CertProvisioningSchedulerImpl.
base::ObserverList<CertProvisioningSchedulerObserver> observers_;
// True when a task for notifying observers about a state change has been
// scheduled but not executed yet.
bool notify_observers_pending_ = false;
ScopedObserver<platform_keys::PlatformKeysService,
platform_keys::PlatformKeysServiceObserver>
scoped_platform_keys_service_observer_{this};
......
......@@ -45,7 +45,36 @@ constexpr char kCertProfileId[] = "cert_profile_id_1";
constexpr char kCertProfileVersion[] = "cert_profile_version_1";
constexpr TimeDelta kCertProfileRenewalPeriod = TimeDelta::FromSeconds(0);
//================ CertProvisioningSchedulerTest ===============================
//=============== TestCertProvisioningSchedulerObserver ========================
class TestCertProvisioningSchedulerObserver
: public CertProvisioningSchedulerObserver {
public:
TestCertProvisioningSchedulerObserver() = default;
~TestCertProvisioningSchedulerObserver() override = default;
TestCertProvisioningSchedulerObserver(
const TestCertProvisioningSchedulerObserver& other) = delete;
TestCertProvisioningSchedulerObserver& operator=(
const TestCertProvisioningSchedulerObserver& other) = delete;
// CertProvisioningSchedulerObserver:
void OnVisibleStateChanged() override { run_loop_->Quit(); }
// Waits for one call to happen (since construction or since the previous
// WaitForOneCall has returned).
void WaitForOneCall() {
run_loop_->Run();
// Create a new RunLoop so it can already be terminated when the next
// OnVisibleStateChanged() call comes in.
run_loop_ = std::make_unique<base::RunLoop>();
}
private:
std::unique_ptr<base::RunLoop> run_loop_ = std::make_unique<base::RunLoop>();
};
//=================== CertProvisioningSchedulerTest ============================
class CertProvisioningSchedulerTest : public testing::Test {
public:
......@@ -952,6 +981,95 @@ TEST_F(CertProvisioningSchedulerTest, PlatformKeysServiceShutDown) {
scheduler.UpdateAllCerts();
}
TEST_F(CertProvisioningSchedulerTest, StateChangeNotifications) {
const CertScope kCertScope = CertScope::kDevice;
CertProvisioningSchedulerImpl scheduler(
kCertScope, GetProfile(), &pref_service_, &cloud_policy_client_,
&platform_keys_service_,
network_state_test_helper_.network_state_handler(),
MakeFakeInvalidationFactory());
TestCertProvisioningSchedulerObserver observer;
scheduler.AddObserver(&observer);
// From CertProvisioningSchedulerImpl::CleanVaKeysIfIdle.
EXPECT_CALL(fake_cryptohome_client_,
OnTpmAttestationDeleteKeysByPrefix(
attestation::AttestationKeyType::KEY_DEVICE, kKeyNamePrefix))
.Times(1);
// The policy is empty, so no workers should be created yet.
FastForwardBy(TimeDelta::FromSeconds(1));
ASSERT_EQ(scheduler.GetWorkers().size(), 0U);
// Two new workers will be created on prefs update.
// Expect a state change notification for this.
const char kCertProfileId0[] = "cert_profile_id_0";
const char kCertProfileVersion0[] = "cert_profile_version_0";
CertProfile cert_profile0(kCertProfileId0, kCertProfileVersion0,
/*is_va_enabled=*/true, kCertProfileRenewalPeriod);
const char kCertProfileId1[] = "cert_profile_id_1";
const char kCertProfileVersion1[] = "cert_profile_version_1";
CertProfile cert_profile1(kCertProfileId1, kCertProfileVersion1,
/*is_va_enabled=*/true, kCertProfileRenewalPeriod);
MockCertProvisioningWorker* worker0 =
mock_factory_.ExpectCreateReturnMock(kCertScope, cert_profile0);
worker0->SetExpectations(/*do_step_times=*/AtLeast(1), /*is_waiting=*/false,
cert_profile0);
MockCertProvisioningWorker* worker1 =
mock_factory_.ExpectCreateReturnMock(kCertScope, cert_profile1);
worker1->SetExpectations(/*do_step_times=*/AtLeast(1), /*is_waiting=*/false,
cert_profile1);
// Add 2 certificate profiles to the policy (the values are the same as
// in |cert_profile|-s)
base::Value config = ParseJson(
R"([{
"name": "Certificate Profile 0",
"cert_profile_id":"cert_profile_id_0",
"policy_version":"cert_profile_version_0",
"key_algorithm":"rsa"
},
{
"name": "Certificate Profile 1",
"cert_profile_id":"cert_profile_id_1",
"policy_version":"cert_profile_version_1",
"key_algorithm":"rsa"
}])");
pref_service_.Set(GetPrefNameForCertProfiles(kCertScope), config);
observer.WaitForOneCall();
// Now one worker for each profile should be created.
ASSERT_EQ(scheduler.GetWorkers().size(), 2U);
// Emulate a worker reporting a state change.
// A state change event should be fired by the scheduler for that.
scheduler.OnVisibleStateChanged();
observer.WaitForOneCall();
// Emulate a worker reporting a state changeand successfully finishing.
// Should be just deleted, and state change event should be
// fired for that.
scheduler.OnVisibleStateChanged();
scheduler.OnProfileFinished(cert_profile0,
CertProvisioningWorkerState::kSucceeded);
observer.WaitForOneCall();
// worker1 failed. Should be deleted and the profile id should be saved, and a
// state change event should be fired for that.
scheduler.OnProfileFinished(cert_profile1,
CertProvisioningWorkerState::kFailed);
observer.WaitForOneCall();
EXPECT_EQ(scheduler.GetWorkers().size(), 0U);
EXPECT_TRUE(
base::Contains(scheduler.GetFailedCertProfileIds(), kCertProfileId1));
scheduler.RemoveObserver(&observer);
}
} // namespace
} // namespace cert_provisioning
} // namespace chromeos
......@@ -7,6 +7,7 @@
#include "base/base64.h"
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/callback_forward.h"
#include "base/no_destructor.h"
#include "base/optional.h"
#include "base/time/time.h"
......@@ -135,11 +136,13 @@ std::unique_ptr<CertProvisioningWorker> CertProvisioningWorkerFactory::Create(
const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback) {
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback) {
RecordEvent(cert_scope, CertProvisioningEvent::kWorkerCreated);
return std::make_unique<CertProvisioningWorkerImpl>(
cert_scope, profile, pref_service, cert_profile, cloud_policy_client,
std::move(invalidator), std::move(callback));
std::move(invalidator), std::move(state_change_callback),
std::move(result_callback));
}
std::unique_ptr<CertProvisioningWorker>
......@@ -150,10 +153,12 @@ CertProvisioningWorkerFactory::Deserialize(
const base::Value& saved_worker,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback) {
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback) {
auto worker = std::make_unique<CertProvisioningWorkerImpl>(
cert_scope, profile, pref_service, CertProfile(), cloud_policy_client,
std::move(invalidator), std::move(callback));
std::move(invalidator), std::move(state_change_callback),
std::move(result_callback));
if (!CertProvisioningSerializer::DeserializeWorker(saved_worker,
worker.get())) {
RecordEvent(cert_scope,
......@@ -179,12 +184,14 @@ CertProvisioningWorkerImpl::CertProvisioningWorkerImpl(
const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback)
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback)
: cert_scope_(cert_scope),
profile_(profile),
pref_service_(pref_service),
cert_profile_(cert_profile),
callback_(std::move(callback)),
state_change_callback_(std::move(state_change_callback)),
result_callback_(std::move(result_callback)),
request_backoff_(&kBackoffPolicy),
cloud_policy_client_(cloud_policy_client),
invalidator_(std::move(invalidator)) {
......@@ -311,6 +318,7 @@ void CertProvisioningWorkerImpl::UpdateState(
HandleSerialization();
state_change_callback_.Run();
if (IsFinalState(state_)) {
CleanUpAndRunCallback();
}
......@@ -805,7 +813,7 @@ void CertProvisioningWorkerImpl::OnCleanUpDone() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
RecordResult(cert_scope_, state_, prev_state_);
std::move(callback_).Run(cert_profile_, state_);
std::move(result_callback_).Run(cert_profile_, state_);
}
void CertProvisioningWorkerImpl::HandleSerialization() {
......
......@@ -7,6 +7,7 @@
#include <memory>
#include "base/callback_forward.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/time/time.h"
......@@ -28,6 +29,8 @@ namespace cert_provisioning {
class CertProvisioningInvalidator;
// A OnceCallback that is invoked when the CertProvisioningWorker is done and
// has a result (which could be success or failure).
using CertProvisioningWorkerCallback =
base::OnceCallback<void(const CertProfile& profile,
CertProvisioningWorkerState state)>;
......@@ -47,7 +50,8 @@ class CertProvisioningWorkerFactory {
const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback);
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback);
virtual std::unique_ptr<CertProvisioningWorker> Deserialize(
CertScope cert_scope,
......@@ -56,7 +60,8 @@ class CertProvisioningWorkerFactory {
const base::Value& saved_worker,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback);
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback);
// Doesn't take ownership.
static void SetFactoryForTesting(CertProvisioningWorkerFactory* test_factory);
......@@ -108,7 +113,8 @@ class CertProvisioningWorkerImpl : public CertProvisioningWorker {
const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
CertProvisioningWorkerCallback callback);
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback result_callback);
~CertProvisioningWorkerImpl() override;
// CertProvisioningWorker
......@@ -219,7 +225,8 @@ class CertProvisioningWorkerImpl : public CertProvisioningWorker {
Profile* profile_ = nullptr;
PrefService* pref_service_ = nullptr;
CertProfile cert_profile_;
CertProvisioningWorkerCallback callback_;
base::RepeatingClosure state_change_callback_;
CertProvisioningWorkerCallback result_callback_;
// This field should be updated only via |UpdateState| function. It will
// trigger update of the serialized data.
......
......@@ -25,11 +25,18 @@ class MockCertProvisioningScheduler : public CertProvisioningScheduler {
(override));
MOCK_METHOD(void, UpdateAllCerts, (), (override));
MOCK_METHOD(const WorkerMap&, GetWorkers, (), (const override));
MOCK_METHOD((const base::flat_map<CertProfileId, FailedWorkerInfo>&),
GetFailedCertProfileIds,
(),
(const override));
MOCK_METHOD(void,
AddObserver,
(CertProvisioningSchedulerObserver*),
(override));
MOCK_METHOD(void,
RemoveObserver,
(CertProvisioningSchedulerObserver*),
(override));
};
} // namespace cert_provisioning
......
......@@ -28,7 +28,7 @@ MockCertProvisioningWorkerFactory::ExpectCreateReturnMock(
auto mock_worker = std::make_unique<MockCertProvisioningWorker>();
MockCertProvisioningWorker* pointer = mock_worker.get();
EXPECT_CALL(*this, Create(cert_scope, _, _, cert_profile, _, _, _))
EXPECT_CALL(*this, Create(cert_scope, _, _, cert_profile, _, _, _, _))
.Times(1)
.WillOnce(Return(testing::ByMove(std::move(mock_worker))));
......@@ -43,7 +43,7 @@ MockCertProvisioningWorkerFactory::ExpectDeserializeReturnMock(
MockCertProvisioningWorker* pointer = mock_worker.get();
EXPECT_CALL(*this,
Deserialize(cert_scope, _, _, IsJson(saved_worker), _, _, _))
Deserialize(cert_scope, _, _, IsJson(saved_worker), _, _, _, _))
.Times(1)
.WillOnce(Return(testing::ByMove(std::move(mock_worker))));
......
......@@ -5,6 +5,7 @@
#ifndef CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_MOCK_CERT_PROVISIONING_WORKER_H_
#define CHROME_BROWSER_CHROMEOS_CERT_PROVISIONING_MOCK_CERT_PROVISIONING_WORKER_H_
#include "base/callback_forward.h"
#include "base/containers/queue.h"
#include "chrome/browser/chromeos/cert_provisioning/cert_provisioning_worker.h"
#include "chrome/browser/chromeos/cert_provisioning/mock_cert_provisioning_invalidator.h"
......@@ -33,6 +34,7 @@ class MockCertProvisioningWorkerFactory : public CertProvisioningWorkerFactory {
const CertProfile& cert_profile,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback callback),
(override));
......@@ -44,6 +46,7 @@ class MockCertProvisioningWorkerFactory : public CertProvisioningWorkerFactory {
const base::Value& saved_worker,
policy::CloudPolicyClient* cloud_policy_client,
std::unique_ptr<CertProvisioningInvalidator> invalidator,
base::RepeatingClosure state_change_callback,
CertProvisioningWorkerCallback callback),
(override));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment