Commit 30df8d48 authored by Omar Morsi's avatar Omar Morsi Committed by Commit Bot

Refactor platform keys service token ID

This CL changes how platform keys service expects token IDs from the
service consumers and makes it clear what to expect in special cases
like empty/unprovided token IDs.

Before this CL, platform keys service was accepting token IDs as strings
which is not as clear as enums especially for the cases of empty
strings.

Bug: 1073512
Change-Id: I4cbdfdc8f22b23ce0297314915e52804b2e495f3
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2270069
Commit-Queue: Omar Morsi <omorsi@google.com>
Reviewed-by: default avatarPavol Marko <pmarko@chromium.org>
Reviewed-by: default avatarMaksim Ivanov <emaxx@chromium.org>
Reviewed-by: default avatarMichael Ershov <miersh@google.com>
Cr-Commit-Position: refs/heads/master@{#784448}
parent 6df7df3b
...@@ -149,12 +149,12 @@ std::string GetVaKeyNameForSpkac(CertScope scope, CertProfileId profile_id) { ...@@ -149,12 +149,12 @@ std::string GetVaKeyNameForSpkac(CertScope scope, CertProfileId profile_id) {
} }
} }
const char* GetPlatformKeysTokenId(CertScope scope) { platform_keys::TokenId GetPlatformKeysTokenId(CertScope scope) {
switch (scope) { switch (scope) {
case CertScope::kUser: case CertScope::kUser:
return platform_keys::kTokenIdUser; return platform_keys::TokenId::kUser;
case CertScope::kDevice: case CertScope::kDevice:
return platform_keys::kTokenIdSystem; return platform_keys::TokenId::kSystem;
} }
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "base/callback_forward.h" #include "base/callback_forward.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/values.h" #include "base/values.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
#include "chromeos/dbus/constants/attestation_constants.h" #include "chromeos/dbus/constants/attestation_constants.h"
#include "components/policy/proto/device_management_backend.pb.h" #include "components/policy/proto/device_management_backend.pb.h"
#include "net/cert/x509_certificate.h" #include "net/cert/x509_certificate.h"
...@@ -95,7 +96,7 @@ const char* GetPrefNameForSerialization(CertScope scope); ...@@ -95,7 +96,7 @@ const char* GetPrefNameForSerialization(CertScope scope);
std::string GetKeyName(CertProfileId profile_id); std::string GetKeyName(CertProfileId profile_id);
// Returns the key type for VA API calls for |scope|. // Returns the key type for VA API calls for |scope|.
attestation::AttestationKeyType GetVaKeyType(CertScope scope); attestation::AttestationKeyType GetVaKeyType(CertScope scope);
const char* GetPlatformKeysTokenId(CertScope scope); platform_keys::TokenId GetPlatformKeysTokenId(CertScope scope);
// The Verified Access APIs are used to generate key pairs. For user-specific // The Verified Access APIs are used to generate key pairs. For user-specific
// key pairs, it is possible to reuse the key pair that is used for Verified // key pairs, it is possible to reuse the key pair that is used for Verified
......
...@@ -73,7 +73,7 @@ struct CertificateTestHelper { ...@@ -73,7 +73,7 @@ struct CertificateTestHelper {
DCHECK(cert); DCHECK(cert);
} }
void GetCertificates(const std::string& token_id, void GetCertificates(chromeos::platform_keys::TokenId token_id,
const platform_keys::GetCertificatesCallback& callback) { const platform_keys::GetCertificatesCallback& callback) {
auto result = std::make_unique<net::CertificateList>(); auto result = std::make_unique<net::CertificateList>();
*result = cert_list; *result = cert_list;
......
...@@ -436,12 +436,12 @@ TEST_F(CertProvisioningWorkerTest, Success) { ...@@ -436,12 +436,12 @@ TEST_F(CertProvisioningWorkerTest, Success) {
EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep); EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey( EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey(
platform_keys::kTokenIdUser, GetPublicKey(), platform_keys::TokenId::kUser, GetPublicKey(),
platform_keys::KeyAttributeType::CertificateProvisioningId, platform_keys::KeyAttributeType::CertificateProvisioningId,
kCertProfileId, _)); kCertProfileId, _));
EXPECT_SIGN_RSAPKC1_DIGEST_OK( EXPECT_SIGN_RSAPKC1_DIGEST_OK(SignRSAPKCS1Digest(
SignRSAPKCS1Digest(platform_keys::kTokenIdUser, kDataToSign, ::testing::Optional(platform_keys::TokenId::kUser), kDataToSign,
GetPublicKey(), kPkHashAlgo, /*callback=*/_)); GetPublicKey(), kPkHashAlgo, /*callback=*/_));
EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr( EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr(
...@@ -453,7 +453,7 @@ TEST_F(CertProvisioningWorkerTest, Success) { ...@@ -453,7 +453,7 @@ TEST_F(CertProvisioningWorkerTest, Success) {
/*callback=*/_)); /*callback=*/_));
EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate( EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate(
platform_keys::kTokenIdUser, /*certificate=*/_, /*callback=*/_)); platform_keys::TokenId::kUser, /*certificate=*/_, /*callback=*/_));
EXPECT_CALL(*mock_invalidator, Unregister()).Times(1); EXPECT_CALL(*mock_invalidator, Unregister()).Times(1);
...@@ -490,9 +490,9 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) { ...@@ -490,9 +490,9 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) {
{ {
testing::InSequence seq; testing::InSequence seq;
EXPECT_CALL( EXPECT_CALL(*platform_keys_service_,
*platform_keys_service_, GenerateRSAKey(platform_keys::TokenId::kUser,
GenerateRSAKey("user", kNonVaKeyModulusLengthBits, /*callback=*/_)) kNonVaKeyModulusLengthBits, /*callback=*/_))
.Times(1) .Times(1)
.WillOnce(RunOnceCallback<2>(GetPublicKey(), "")); .WillOnce(RunOnceCallback<2>(GetPublicKey(), ""));
...@@ -501,12 +501,12 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) { ...@@ -501,12 +501,12 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) {
/*callback=*/_)); /*callback=*/_));
EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey( EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey(
platform_keys::kTokenIdUser, GetPublicKey(), platform_keys::TokenId::kUser, GetPublicKey(),
platform_keys::KeyAttributeType::CertificateProvisioningId, platform_keys::KeyAttributeType::CertificateProvisioningId,
kCertProfileId, _)); kCertProfileId, _));
EXPECT_SIGN_RSAPKC1_DIGEST_OK( EXPECT_SIGN_RSAPKC1_DIGEST_OK(SignRSAPKCS1Digest(
SignRSAPKCS1Digest(platform_keys::kTokenIdUser, kDataToSign, ::testing::Optional(platform_keys::TokenId::kUser), kDataToSign,
GetPublicKey(), kPkHashAlgo, /*callback=*/_)); GetPublicKey(), kPkHashAlgo, /*callback=*/_));
EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr( EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr(
...@@ -518,7 +518,7 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) { ...@@ -518,7 +518,7 @@ TEST_F(CertProvisioningWorkerTest, NoVaSuccess) {
/*callback=*/_)); /*callback=*/_));
EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate( EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate(
platform_keys::kTokenIdUser, /*certificate=*/_, /*callback=*/_)); platform_keys::TokenId::kUser, /*certificate=*/_, /*callback=*/_));
EXPECT_CALL(callback_observer_, EXPECT_CALL(callback_observer_,
Callback(cert_profile, CertProvisioningWorkerState::kSucceeded)) Callback(cert_profile, CertProvisioningWorkerState::kSucceeded))
...@@ -576,7 +576,7 @@ TEST_F(CertProvisioningWorkerTest, TryLaterManualRetry) { ...@@ -576,7 +576,7 @@ TEST_F(CertProvisioningWorkerTest, TryLaterManualRetry) {
EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep); EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey( EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey(
platform_keys::kTokenIdSystem, GetPublicKey(), platform_keys::TokenId::kSystem, GetPublicKey(),
platform_keys::KeyAttributeType::CertificateProvisioningId, platform_keys::KeyAttributeType::CertificateProvisioningId,
kCertProfileId, _)); kCertProfileId, _));
...@@ -619,7 +619,7 @@ TEST_F(CertProvisioningWorkerTest, TryLaterManualRetry) { ...@@ -619,7 +619,7 @@ TEST_F(CertProvisioningWorkerTest, TryLaterManualRetry) {
/*callback=*/_)); /*callback=*/_));
EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate( EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate(
platform_keys::kTokenIdSystem, /*certificate=*/_, /*callback=*/_)); platform_keys::TokenId::kSystem, /*certificate=*/_, /*callback=*/_));
EXPECT_CALL(callback_observer_, EXPECT_CALL(callback_observer_,
Callback(cert_profile, CertProvisioningWorkerState::kSucceeded)) Callback(cert_profile, CertProvisioningWorkerState::kSucceeded))
...@@ -682,12 +682,12 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) { ...@@ -682,12 +682,12 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) {
EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep); EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey( EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey(
platform_keys::kTokenIdUser, GetPublicKey(), platform_keys::TokenId::kUser, GetPublicKey(),
platform_keys::KeyAttributeType::CertificateProvisioningId, platform_keys::KeyAttributeType::CertificateProvisioningId,
kCertProfileId, _)); kCertProfileId, _));
EXPECT_SIGN_RSAPKC1_DIGEST_OK( EXPECT_SIGN_RSAPKC1_DIGEST_OK(SignRSAPKCS1Digest(
SignRSAPKCS1Digest(platform_keys::kTokenIdUser, kDataToSign, ::testing::Optional(platform_keys::TokenId::kUser), kDataToSign,
GetPublicKey(), kPkHashAlgo, /*callback=*/_)); GetPublicKey(), kPkHashAlgo, /*callback=*/_));
EXPECT_FINISH_CSR_TRY_LATER( EXPECT_FINISH_CSR_TRY_LATER(
...@@ -724,7 +724,7 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) { ...@@ -724,7 +724,7 @@ TEST_F(CertProvisioningWorkerTest, TryLaterWait) {
EXPECT_DOWNLOAD_CERT_OK(ClientCertProvisioningDownloadCert); EXPECT_DOWNLOAD_CERT_OK(ClientCertProvisioningDownloadCert);
EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate( EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate(
platform_keys::kTokenIdUser, /*certificate=*/_, /*callback=*/_)); platform_keys::TokenId::kUser, /*certificate=*/_, /*callback=*/_));
FastForwardBy(small_delay); FastForwardBy(small_delay);
// Check that minimum wait time is not too small even if the server // Check that minimum wait time is not too small even if the server
...@@ -957,7 +957,7 @@ TEST_F(CertProvisioningWorkerTest, RemoveRegisteredKey) { ...@@ -957,7 +957,7 @@ TEST_F(CertProvisioningWorkerTest, RemoveRegisteredKey) {
EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep); EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
EXPECT_SET_ATTRIBUTE_FOR_KEY_FAIL(SetAttributeForKey( EXPECT_SET_ATTRIBUTE_FOR_KEY_FAIL(SetAttributeForKey(
platform_keys::kTokenIdUser, GetPublicKey(), platform_keys::TokenId::kUser, GetPublicKey(),
platform_keys::KeyAttributeType::CertificateProvisioningId, platform_keys::KeyAttributeType::CertificateProvisioningId,
kCertProfileId, _)); kCertProfileId, _));
...@@ -965,7 +965,7 @@ TEST_F(CertProvisioningWorkerTest, RemoveRegisteredKey) { ...@@ -965,7 +965,7 @@ TEST_F(CertProvisioningWorkerTest, RemoveRegisteredKey) {
EXPECT_CALL( EXPECT_CALL(
*platform_keys_service_, *platform_keys_service_,
RemoveKey(platform_keys::kTokenIdUser, RemoveKey(platform_keys::TokenId::kUser,
/*public_key_spki_der=*/GetPublicKey(), /*callback=*/_)) /*public_key_spki_der=*/GetPublicKey(), /*callback=*/_))
.Times(1) .Times(1)
.WillOnce(RunOnceCallback<2>(/*error_message=*/"")); .WillOnce(RunOnceCallback<2>(/*error_message=*/""));
...@@ -1104,12 +1104,12 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) { ...@@ -1104,12 +1104,12 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) {
EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep); EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey( EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SetAttributeForKey(
platform_keys::kTokenIdUser, GetPublicKey(), platform_keys::TokenId::kUser, GetPublicKey(),
platform_keys::KeyAttributeType::CertificateProvisioningId, platform_keys::KeyAttributeType::CertificateProvisioningId,
kCertProfileId, _)); kCertProfileId, _));
EXPECT_SIGN_RSAPKC1_DIGEST_OK( EXPECT_SIGN_RSAPKC1_DIGEST_OK(SignRSAPKCS1Digest(
SignRSAPKCS1Digest(platform_keys::kTokenIdUser, kDataToSign, ::testing::Optional(platform_keys::TokenId::kUser), kDataToSign,
GetPublicKey(), kPkHashAlgo, /*callback=*/_)); GetPublicKey(), kPkHashAlgo, /*callback=*/_));
EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr( EXPECT_FINISH_CSR_OK(ClientCertProvisioningFinishCsr(
...@@ -1169,7 +1169,7 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) { ...@@ -1169,7 +1169,7 @@ TEST_F(CertProvisioningWorkerTest, SerializationSuccess) {
/*callback=*/_)); /*callback=*/_));
EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate( EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate(
platform_keys::kTokenIdUser, /*certificate=*/_, /*callback=*/_)); platform_keys::TokenId::kUser, /*certificate=*/_, /*callback=*/_));
pref_val = ParseJson("{}"); pref_val = ParseJson("{}");
EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1); EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);
......
...@@ -58,18 +58,19 @@ bool IsExtensionAllowlisted(const extensions::Extension* extension) { ...@@ -58,18 +58,19 @@ bool IsExtensionAllowlisted(const extensions::Extension* extension) {
#endif // defined(OS_CHROMEOS) #endif // defined(OS_CHROMEOS)
// Converts |token_ids| (string-based token identifiers used in the // Converts |token_ids| (string-based token identifiers used in the
// platformKeys API) to a vector of KeyPermissions::KeyLocation. Currently only // platformKeys API) to a vector of KeyPermissions::KeyLocation.
// accepts |kTokenIdUser| and |kTokenIdSystem| as |token_ids| elements.
std::vector<KeyPermissions::KeyLocation> TokenIdsToKeyLocations( std::vector<KeyPermissions::KeyLocation> TokenIdsToKeyLocations(
const std::vector<std::string>& token_ids) { const std::vector<platform_keys::TokenId>& token_ids) {
std::vector<KeyPermissions::KeyLocation> key_locations; std::vector<KeyPermissions::KeyLocation> key_locations;
for (const auto& token_id : token_ids) { for (const auto& token_id : token_ids) {
if (token_id == platform_keys::kTokenIdUser) switch (token_id) {
case platform_keys::TokenId::kUser:
key_locations.push_back(KeyPermissions::KeyLocation::kUserSlot); key_locations.push_back(KeyPermissions::KeyLocation::kUserSlot);
else if (token_id == platform_keys::kTokenIdSystem) break;
case platform_keys::TokenId::kSystem:
key_locations.push_back(KeyPermissions::KeyLocation::kSystemSlot); key_locations.push_back(KeyPermissions::KeyLocation::kSystemSlot);
else break;
NOTREACHED() << "Unknown platformKeys API token id " << token_id; }
} }
return key_locations; return key_locations;
} }
...@@ -96,7 +97,7 @@ class ExtensionPlatformKeysService::GenerateKeyTask : public Task { ...@@ -96,7 +97,7 @@ class ExtensionPlatformKeysService::GenerateKeyTask : public Task {
DONE, DONE,
}; };
GenerateKeyTask(const std::string& token_id, GenerateKeyTask(platform_keys::TokenId token_id,
const std::string& extension_id, const std::string& extension_id,
const GenerateKeyCallback& callback, const GenerateKeyCallback& callback,
KeyPermissions* key_permissions, KeyPermissions* key_permissions,
...@@ -119,7 +120,7 @@ class ExtensionPlatformKeysService::GenerateKeyTask : public Task { ...@@ -119,7 +120,7 @@ class ExtensionPlatformKeysService::GenerateKeyTask : public Task {
protected: protected:
virtual void GenerateKey(GenerateKeyCallback callback) = 0; virtual void GenerateKey(GenerateKeyCallback callback) = 0;
const std::string token_id_; platform_keys::TokenId token_id_;
std::string public_key_spki_der_; std::string public_key_spki_der_;
const std::string extension_id_; const std::string extension_id_;
GenerateKeyCallback callback_; GenerateKeyCallback callback_;
...@@ -201,7 +202,7 @@ class ExtensionPlatformKeysService::GenerateRSAKeyTask ...@@ -201,7 +202,7 @@ class ExtensionPlatformKeysService::GenerateRSAKeyTask
// This key task generates an RSA key with the parameters |token_id| and // This key task generates an RSA key with the parameters |token_id| and
// |modulus_length| and registers it for the extension with id |extension_id|. // |modulus_length| and registers it for the extension with id |extension_id|.
// The generated key will be passed to |callback|. // The generated key will be passed to |callback|.
GenerateRSAKeyTask(const std::string& token_id, GenerateRSAKeyTask(platform_keys::TokenId token_id,
unsigned int modulus_length, unsigned int modulus_length,
const std::string& extension_id, const std::string& extension_id,
const GenerateKeyCallback& callback, const GenerateKeyCallback& callback,
...@@ -231,7 +232,7 @@ class ExtensionPlatformKeysService::GenerateECKeyTask : public GenerateKeyTask { ...@@ -231,7 +232,7 @@ class ExtensionPlatformKeysService::GenerateECKeyTask : public GenerateKeyTask {
// This Task generates an EC key with the parameters |token_id| and // This Task generates an EC key with the parameters |token_id| and
// |named_curve| and registers it for the extension with id |extension_id|. // |named_curve| and registers it for the extension with id |extension_id|.
// The generated key will be passed to |callback|. // The generated key will be passed to |callback|.
GenerateECKeyTask(const std::string& token_id, GenerateECKeyTask(platform_keys::TokenId token_id,
const std::string& named_curve, const std::string& named_curve,
const std::string& extension_id, const std::string& extension_id,
const GenerateKeyCallback& callback, const GenerateKeyCallback& callback,
...@@ -272,7 +273,7 @@ class ExtensionPlatformKeysService::SignTask : public Task { ...@@ -272,7 +273,7 @@ class ExtensionPlatformKeysService::SignTask : public Task {
// multiple times, also updates the permission to prevent any future signing // multiple times, also updates the permission to prevent any future signing
// operation of that extension using that same key. If an error occurs, an // operation of that extension using that same key. If an error occurs, an
// error message is passed to |callback| instead. // error message is passed to |callback| instead.
SignTask(const std::string& token_id, SignTask(base::Optional<platform_keys::TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
bool raw_pkcs1, bool raw_pkcs1,
...@@ -358,7 +359,7 @@ class ExtensionPlatformKeysService::SignTask : public Task { ...@@ -358,7 +359,7 @@ class ExtensionPlatformKeysService::SignTask : public Task {
base::BindRepeating(&SignTask::GotKeyLocation, base::Unretained(this))); base::BindRepeating(&SignTask::GotKeyLocation, base::Unretained(this)));
} }
void GotKeyLocation(const std::vector<std::string>& token_ids, void GotKeyLocation(const std::vector<platform_keys::TokenId>& token_ids,
const std::string& error_message) { const std::string& error_message) {
if (!error_message.empty()) { if (!error_message.empty()) {
next_step_ = Step::DONE; next_step_ = Step::DONE;
...@@ -408,7 +409,7 @@ class ExtensionPlatformKeysService::SignTask : public Task { ...@@ -408,7 +409,7 @@ class ExtensionPlatformKeysService::SignTask : public Task {
Step next_step_ = Step::GET_EXTENSION_PERMISSIONS; Step next_step_ = Step::GET_EXTENSION_PERMISSIONS;
const std::string token_id_; base::Optional<platform_keys::TokenId> token_id_;
const std::string data_; const std::string data_;
const std::string public_key_spki_der_; const std::string public_key_spki_der_;
...@@ -602,7 +603,7 @@ class ExtensionPlatformKeysService::SelectTask : public Task { ...@@ -602,7 +603,7 @@ class ExtensionPlatformKeysService::SelectTask : public Task {
} }
void GotKeyLocations(const scoped_refptr<net::X509Certificate>& certificate, void GotKeyLocations(const scoped_refptr<net::X509Certificate>& certificate,
const std::vector<std::string>& token_ids, const std::vector<platform_keys::TokenId>& token_ids,
const std::string& error_message) { const std::string& error_message) {
if (!error_message.empty()) { if (!error_message.empty()) {
next_step_ = Step::DONE; next_step_ = Step::DONE;
...@@ -779,7 +780,7 @@ void ExtensionPlatformKeysService::SetSelectDelegate( ...@@ -779,7 +780,7 @@ void ExtensionPlatformKeysService::SetSelectDelegate(
} }
void ExtensionPlatformKeysService::GenerateRSAKey( void ExtensionPlatformKeysService::GenerateRSAKey(
const std::string& token_id, platform_keys::TokenId token_id,
unsigned int modulus_length, unsigned int modulus_length,
const std::string& extension_id, const std::string& extension_id,
const GenerateKeyCallback& callback) { const GenerateKeyCallback& callback) {
...@@ -790,7 +791,7 @@ void ExtensionPlatformKeysService::GenerateRSAKey( ...@@ -790,7 +791,7 @@ void ExtensionPlatformKeysService::GenerateRSAKey(
} }
void ExtensionPlatformKeysService::GenerateECKey( void ExtensionPlatformKeysService::GenerateECKey(
const std::string& token_id, platform_keys::TokenId token_id,
const std::string& named_curve, const std::string& named_curve,
const std::string& extension_id, const std::string& extension_id,
const GenerateKeyCallback& callback) { const GenerateKeyCallback& callback) {
...@@ -805,7 +806,7 @@ bool ExtensionPlatformKeysService::IsUsingSigninProfile() { ...@@ -805,7 +806,7 @@ bool ExtensionPlatformKeysService::IsUsingSigninProfile() {
} }
void ExtensionPlatformKeysService::SignDigest( void ExtensionPlatformKeysService::SignDigest(
const std::string& token_id, base::Optional<platform_keys::TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
platform_keys::KeyType key_type, platform_keys::KeyType key_type,
...@@ -820,7 +821,7 @@ void ExtensionPlatformKeysService::SignDigest( ...@@ -820,7 +821,7 @@ void ExtensionPlatformKeysService::SignDigest(
} }
void ExtensionPlatformKeysService::SignRSAPKCS1Raw( void ExtensionPlatformKeysService::SignRSAPKCS1Raw(
const std::string& token_id, base::Optional<platform_keys::TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
const std::string& extension_id, const std::string& extension_id,
......
...@@ -102,7 +102,7 @@ class ExtensionPlatformKeysService : public KeyedService { ...@@ -102,7 +102,7 @@ class ExtensionPlatformKeysService : public KeyedService {
// specifies the token to store the key pair on. |callback| will be invoked // specifies the token to store the key pair on. |callback| will be invoked
// with the resulting public key or an error. Will only call back during the // with the resulting public key or an error. Will only call back during the
// lifetime of this object. // lifetime of this object.
void GenerateRSAKey(const std::string& token_id, void GenerateRSAKey(platform_keys::TokenId token_id,
unsigned int modulus_length_bits, unsigned int modulus_length_bits,
const std::string& extension_id, const std::string& extension_id,
const GenerateKeyCallback& callback); const GenerateKeyCallback& callback);
...@@ -112,7 +112,7 @@ class ExtensionPlatformKeysService : public KeyedService { ...@@ -112,7 +112,7 @@ class ExtensionPlatformKeysService : public KeyedService {
// token to store the key pair on. |callback| will be invoked with the // token to store the key pair on. |callback| will be invoked with the
// resulting public key or an error. Will only call back during the lifetime // resulting public key or an error. Will only call back during the lifetime
// of this object. // of this object.
void GenerateECKey(const std::string& token_id, void GenerateECKey(platform_keys::TokenId token_id,
const std::string& named_curve, const std::string& named_curve,
const std::string& extension_id, const std::string& extension_id,
const GenerateKeyCallback& callback); const GenerateKeyCallback& callback);
...@@ -130,14 +130,16 @@ class ExtensionPlatformKeysService : public KeyedService { ...@@ -130,14 +130,16 @@ class ExtensionPlatformKeysService : public KeyedService {
// Digests |data|, applies PKCS1 padding if specified by |hash_algorithm| and // Digests |data|, applies PKCS1 padding if specified by |hash_algorithm| and
// chooses the signature algorithm according to |key_type| and signs the data // chooses the signature algorithm according to |key_type| and signs the data
// with the private key matching |public_key_spki_der|. If a non empty token // with the private key matching |public_key_spki_der|. If a |token_id|
// id is provided and the key is not found in that token, the operation // is provided and the key is not found in that token, the operation aborts.
// aborts. If the extension does not have permissions for signing with this // If |token_id| is not provided (nullopt), all tokens available to the caller
// key, the operation aborts. In case of a one time permission (granted after // will be considered while searching for the key.
// If the extension does not have permissions for signing with this key, the
// operation aborts. In case of a one time permission (granted after
// generating the key), this function also removes the permission to prevent // generating the key), this function also removes the permission to prevent
// future signing attempts. |callback| will be invoked with the signature or // future signing attempts. |callback| will be invoked with the signature or
// an error message. Will only call back during the lifetime of this object. // an error message. Will only call back during the lifetime of this object.
void SignDigest(const std::string& token_id, void SignDigest(base::Optional<platform_keys::TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
platform_keys::KeyType key_type, platform_keys::KeyType key_type,
...@@ -146,18 +148,17 @@ class ExtensionPlatformKeysService : public KeyedService { ...@@ -146,18 +148,17 @@ class ExtensionPlatformKeysService : public KeyedService {
const SignCallback& callback); const SignCallback& callback);
// Applies PKCS1 padding and afterwards signs the data with the private key // Applies PKCS1 padding and afterwards signs the data with the private key
// matching |public_key_spki_der|. |data| is not digested. If a non empty // matching |public_key_spki_der|. |data| is not digested. If a |token_id|
// token id is provided and the key is not found in that token, the operation // is provided and the key is not found in that token, the operation aborts.
// aborts. // If |token_id| is not provided (nullopt), all available tokens to the caller
// The size of |data| (number of octets) must be smaller than k - 11, where k // will be considered while searching for the key. The size of |data| (number
// is the key size in octets. // of octets) must be smaller than k - 11, where k is the key size in octets.
// If the extension does not have permissions for signing with this key, the // If the extension does not have permissions for signing with this key, the
// operation aborts. In case of a one time permission (granted after // operation aborts. In case of a one time permission (granted after
// generating the key), this function also removes the permission to prevent // generating the key), this function also removes the permission to prevent
// future signing attempts. // future signing attempts. |callback| will be invoked with the signature or
// |callback| will be invoked with the signature or an error message. // an error message. Will only call back during the lifetime of this object.
// Will only call back during the lifetime of this object. void SignRSAPKCS1Raw(base::Optional<platform_keys::TokenId> token_id,
void SignRSAPKCS1Raw(const std::string& token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
const std::string& extension_id, const std::string& extension_id,
......
...@@ -23,21 +23,21 @@ class MockPlatformKeysService : public PlatformKeysService { ...@@ -23,21 +23,21 @@ class MockPlatformKeysService : public PlatformKeysService {
MOCK_METHOD(void, MOCK_METHOD(void,
GenerateRSAKey, GenerateRSAKey,
(const std::string& token_id, (TokenId token_id,
unsigned int modulus_length_bits, unsigned int modulus_length_bits,
const GenerateKeyCallback& callback), const GenerateKeyCallback& callback),
(override)); (override));
MOCK_METHOD(void, MOCK_METHOD(void,
GenerateECKey, GenerateECKey,
(const std::string& token_id, (TokenId token_id,
const std::string& named_curve, const std::string& named_curve,
const GenerateKeyCallback& callback), const GenerateKeyCallback& callback),
(override)); (override));
MOCK_METHOD(void, MOCK_METHOD(void,
SignRSAPKCS1Digest, SignRSAPKCS1Digest,
(const std::string& token_id, (base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
...@@ -46,7 +46,7 @@ class MockPlatformKeysService : public PlatformKeysService { ...@@ -46,7 +46,7 @@ class MockPlatformKeysService : public PlatformKeysService {
MOCK_METHOD(void, MOCK_METHOD(void,
SignRSAPKCS1Raw, SignRSAPKCS1Raw,
(const std::string& token_id, (base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
const SignCallback& callback), const SignCallback& callback),
...@@ -54,7 +54,7 @@ class MockPlatformKeysService : public PlatformKeysService { ...@@ -54,7 +54,7 @@ class MockPlatformKeysService : public PlatformKeysService {
MOCK_METHOD(void, MOCK_METHOD(void,
SignECDSADigest, SignECDSADigest,
(const std::string& token_id, (base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
...@@ -69,32 +69,31 @@ class MockPlatformKeysService : public PlatformKeysService { ...@@ -69,32 +69,31 @@ class MockPlatformKeysService : public PlatformKeysService {
MOCK_METHOD(void, MOCK_METHOD(void,
GetCertificates, GetCertificates,
(const std::string& token_id, (TokenId token_id, const GetCertificatesCallback& callback),
const GetCertificatesCallback& callback),
(override)); (override));
MOCK_METHOD(void, MOCK_METHOD(void,
GetAllKeys, GetAllKeys,
(const std::string& token_id, GetAllKeysCallback callback), (TokenId token_id, GetAllKeysCallback callback),
(override)); (override));
MOCK_METHOD(void, MOCK_METHOD(void,
ImportCertificate, ImportCertificate,
(const std::string& token_id, (TokenId token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const ImportCertificateCallback& callback), const ImportCertificateCallback& callback),
(override)); (override));
MOCK_METHOD(void, MOCK_METHOD(void,
RemoveCertificate, RemoveCertificate,
(const std::string& token_id, (TokenId token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const RemoveCertificateCallback& callback), const RemoveCertificateCallback& callback),
(override)); (override));
MOCK_METHOD(void, MOCK_METHOD(void,
RemoveKey, RemoveKey,
(const std::string& token_id, (TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
RemoveKeyCallback callback), RemoveKeyCallback callback),
(override)); (override));
...@@ -109,7 +108,7 @@ class MockPlatformKeysService : public PlatformKeysService { ...@@ -109,7 +108,7 @@ class MockPlatformKeysService : public PlatformKeysService {
MOCK_METHOD(void, MOCK_METHOD(void,
SetAttributeForKey, SetAttributeForKey,
(const std::string& token_id, (TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
const std::string& attribute_value, const std::string& attribute_value,
...@@ -118,7 +117,7 @@ class MockPlatformKeysService : public PlatformKeysService { ...@@ -118,7 +117,7 @@ class MockPlatformKeysService : public PlatformKeysService {
MOCK_METHOD(void, MOCK_METHOD(void,
GetAttributeForKey, GetAttributeForKey,
(const std::string& token_id, (TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
GetAttributeForKeyCallback callback), GetAttributeForKeyCallback callback),
......
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
namespace chromeos { namespace chromeos {
namespace platform_keys { namespace platform_keys {
const char kTokenIdUser[] = "user";
const char kTokenIdSystem[] = "system";
namespace { namespace {
void IntersectOnWorkerThread(const net::CertificateList& certs1, void IntersectOnWorkerThread(const net::CertificateList& certs1,
......
...@@ -24,15 +24,6 @@ class BrowserContext; ...@@ -24,15 +24,6 @@ class BrowserContext;
namespace chromeos { namespace chromeos {
namespace platform_keys { namespace platform_keys {
// A token is a store for keys or certs and can provide cryptographic
// operations.
// ChromeOS provides itself a user token and conditionally a system wide token,
// thus these tokens use static identifiers. The platform keys API is designed
// to support arbitrary other tokens in the future, which could then use
// run-time generated IDs.
extern const char kTokenIdUser[];
extern const char kTokenIdSystem[];
// Supported key types. // Supported key types.
enum class KeyType { kRsassaPkcs1V15, kEcdsa }; enum class KeyType { kRsassaPkcs1V15, kEcdsa };
...@@ -48,6 +39,12 @@ enum HashAlgorithm { ...@@ -48,6 +39,12 @@ enum HashAlgorithm {
HASH_ALGORITHM_SHA512 HASH_ALGORITHM_SHA512
}; };
// Supported token IDs.
// A token is a store for keys or certs and can provide cryptographic
// operations.
// ChromeOS provides itself a user token and conditionally a system wide token.
enum class TokenId { kUser, kSystem };
// Returns the DER encoding of the X.509 Subject Public Key Info of the public // Returns the DER encoding of the X.509 Subject Public Key Info of the public
// key in |certificate|. // key in |certificate|.
std::string GetSubjectPublicKeyInfo( std::string GetSubjectPublicKeyInfo(
...@@ -148,14 +145,13 @@ using RemoveKeyCallback = ...@@ -148,14 +145,13 @@ using RemoveKeyCallback =
// will contain the token ids. If an error occurs, |token_ids| will be nullptr // will contain the token ids. If an error occurs, |token_ids| will be nullptr
// and |error_message| will be set to an error message. // and |error_message| will be set to an error message.
using GetTokensCallback = using GetTokensCallback =
base::Callback<void(std::unique_ptr<std::vector<std::string>> token_ids, base::Callback<void(std::unique_ptr<std::vector<TokenId>> token_ids,
const std::string& error_message)>; const std::string& error_message)>;
// If token ids have been successfully retrieved, |error_message| will be empty. // If token ids have been successfully retrieved, |error_message| will be empty.
// Two cases are possible then: // Two cases are possible then:
// If |token_ids| is not empty, |token_ids| has been filled with the identifiers // If |token_ids| is not empty, |token_ids| has been filled with the identifiers
// of the tokens the private key was found on and the user has access to. // of the tokens the private key was found on and the user has access to.
// Currently, valid token identifiers are |kTokenIdUser| and |kTokenIdSystem|.
// If |token_ids| is empty, the private key has not been found on any token the // If |token_ids| is empty, the private key has not been found on any token the
// user has access to. Note that this is also the case if the key exists on the // user has access to. Note that this is also the case if the key exists on the
// system token, but the current user does not have access to the system token. // system token, but the current user does not have access to the system token.
...@@ -164,7 +160,7 @@ using GetTokensCallback = ...@@ -164,7 +160,7 @@ using GetTokensCallback =
// TODO(pmarko): This is currently a RepeatingCallback because of // TODO(pmarko): This is currently a RepeatingCallback because of
// GetNSSCertDatabaseForResourceContext semantics. // GetNSSCertDatabaseForResourceContext semantics.
using GetKeyLocationsCallback = using GetKeyLocationsCallback =
base::RepeatingCallback<void(const std::vector<std::string>& token_ids, base::RepeatingCallback<void(const std::vector<TokenId>& token_ids,
const std::string& error_message)>; const std::string& error_message)>;
// If the attribute value has been successfully set, |error_message| will be // If the attribute value has been successfully set, |error_message| will be
...@@ -193,47 +189,48 @@ class PlatformKeysService : public KeyedService { ...@@ -193,47 +189,48 @@ class PlatformKeysService : public KeyedService {
~PlatformKeysService() override = default; ~PlatformKeysService() override = default;
// Generates a RSA key pair with |modulus_length_bits|. |token_id| specifies // Generates a RSA key pair with |modulus_length_bits|. |token_id| specifies
// the token to store the key pair on and can currently be |kTokenIdUser| or // the token to store the key pair on. |callback| will be invoked with the
// |kTokenIdSystem|. |callback| will be invoked with the resulting public key // resulting public key
// or an error. // or an error.
virtual void GenerateRSAKey(const std::string& token_id, virtual void GenerateRSAKey(TokenId token_id,
unsigned int modulus_length_bits, unsigned int modulus_length_bits,
const GenerateKeyCallback& callback) = 0; const GenerateKeyCallback& callback) = 0;
// Generates a EC key pair with |named_curve|. |token_id| specifies the token // Generates a EC key pair with |named_curve|. |token_id| specifies the token
// to store the key pair on and can currently be |kTokenIdUser| or // to store the key pair on. |callback| will be invoked with the resulting
// |kTokenIdSystem|. |callback| will be invoked with the resulting public key // public key or an error.
// or an error. virtual void GenerateECKey(TokenId token_id,
virtual void GenerateECKey(const std::string& token_id,
const std::string& named_curve, const std::string& named_curve,
const GenerateKeyCallback& callback) = 0; const GenerateKeyCallback& callback) = 0;
// Digests |data|, applies PKCS1 padding and afterwards signs the data with // Digests |data|, applies PKCS1 padding and afterwards signs the data with
// the private key matching |public_key_spki_der|. If a non empty token id is // the private key matching |public_key_spki_der|. If the key is not found in
// provided and the key is not found in that token, the operation aborts. // that |token_id| (or in none of the available tokens if |token_id| is not
// |callback| will be invoked with the signature or an error message. // specified), the operation aborts. |callback| will be invoked with the
virtual void SignRSAPKCS1Digest(const std::string& token_id, // signature or an error message.
virtual void SignRSAPKCS1Digest(base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
const SignCallback& callback) = 0; const SignCallback& callback) = 0;
// Applies PKCS1 padding and afterwards signs the data with the private key // Applies PKCS1 padding and afterwards signs the data with the private key
// matching |public_key_spki_der|. |data| is not digested. If a non empty // matching |public_key_spki_der|. |data| is not digested. If the key is not
// token id is provided and the key is not found in that token, the operation // found in that |token_id| (or in none of the available tokens if |token_id|
// aborts. The size of |data| (number of octets) must be smaller than k - 11, // is not specified), the operation aborts. The size of |data| (number of
// where k is the key size in octets. |callback| will be invoked with the // octets) must be smaller than k - 11, where k is the key size in octets.
// signature or an error message. // |callback| will be invoked with the signature or an error message.
virtual void SignRSAPKCS1Raw(const std::string& token_id, virtual void SignRSAPKCS1Raw(base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
const SignCallback& callback) = 0; const SignCallback& callback) = 0;
// Digests |data| and afterwards signs the data with the private key matching // Digests |data| and afterwards signs the data with the private key matching
// |public_key_spki_der|. If a non empty token id is provided and the key is // |public_key_spki_der|. If the key is not found in that |token_id| (or in
// not found in that token, the operation aborts. |callback| will be invoked // none of the available tokens if |token_id| is not specified), the operation
// with the ECDSA signature or an error message. // aborts. |callback| will be invoked with the ECDSA signature or an error
virtual void SignECDSADigest(const std::string& token_id, // message.
virtual void SignECDSADigest(base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
...@@ -248,45 +245,41 @@ class PlatformKeysService : public KeyedService { ...@@ -248,45 +245,41 @@ class PlatformKeysService : public KeyedService {
const SelectCertificatesCallback& callback) = 0; const SelectCertificatesCallback& callback) = 0;
// Returns the list of all certificates with stored private key available from // Returns the list of all certificates with stored private key available from
// the given token. If an empty |token_id| is provided, all certificates the // the given token. Only certificates from the specified |token_id| are
// user associated with |browser_context| has access to are listed. Otherwise, // listed. |callback| will be invoked with the list of available certificates
// only certificates from the specified token are listed. |callback| will be // or an error message.
// invoked with the list of available certificates or an error message. virtual void GetCertificates(TokenId token_id,
virtual void GetCertificates(const std::string& token_id,
const GetCertificatesCallback& callback) = 0; const GetCertificatesCallback& callback) = 0;
// Returns the list of all keys available from the given |token_id| as a list // Returns the list of all keys available from the given |token_id| as a list
// of der-encoded SubjectPublicKeyInfo strings. |callback| will be invoked on // of der-encoded SubjectPublicKeyInfo strings. |callback| will be invoked on
// the UI thread with the list of available public keys, possibly with an // the UI thread with the list of available public keys, possibly with an
// error message. // error message.
virtual void GetAllKeys(const std::string& token_id, virtual void GetAllKeys(TokenId token_id, GetAllKeysCallback callback) = 0;
GetAllKeysCallback callback) = 0;
// Imports |certificate| to the given token if the certified key is already // Imports |certificate| to the given token if the certified key is already
// stored in this token. Any intermediate of |certificate| will be ignored. // stored in this token. Any intermediate of |certificate| will be ignored.
// |token_id| specifies the token to store the certificate on and can // |token_id| specifies the token to store the certificate on. The private key
// currently be |kTokenIdUser| or |kTokenIdSystem|. The private key must be // must be stored on the same token. |callback| will be invoked when the
// stored on the same token. |callback| will be invoked when the import is // import is finished, possibly with an error message.
// finished, possibly with an error message.
virtual void ImportCertificate( virtual void ImportCertificate(
const std::string& token_id, TokenId token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const ImportCertificateCallback& callback) = 0; const ImportCertificateCallback& callback) = 0;
// Removes |certificate| from the given token if present. Any intermediate of // Removes |certificate| from the given token. Any intermediate of
// |certificate| will be ignored. |token_id| specifies the token to remove the // |certificate| will be ignored. |token_id| specifies the token to remove the
// certificate from and can currently be empty (any token), |kTokenIdUser| or // certificate from. |callback| will be invoked when the removal is finished,
// |kTokenIdSystem|. |callback| will be invoked when the removal is finished,
// possibly with an error message. // possibly with an error message.
virtual void RemoveCertificate( virtual void RemoveCertificate(
const std::string& token_id, TokenId token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const RemoveCertificateCallback& callback) = 0; const RemoveCertificateCallback& callback) = 0;
// Removes the key pair if no matching certificates exist. Only keys in the // Removes the key pair if no matching certificates exist. Only keys in the
// given |token_id| are considered. |callback| will be invoked on the UI // given |token_id| are considered. |callback| will be invoked on the UI
// thread when the removal is finished, possibly with an error message. // thread when the removal is finished, possibly with an error message.
virtual void RemoveKey(const std::string& token_id, virtual void RemoveKey(TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
RemoveKeyCallback callback) = 0; RemoveKeyCallback callback) = 0;
...@@ -306,7 +299,7 @@ class PlatformKeysService : public KeyedService { ...@@ -306,7 +299,7 @@ class PlatformKeysService : public KeyedService {
// |public_key_spki_der| to |attribute_value| only if the key is in // |public_key_spki_der| to |attribute_value| only if the key is in
// |token_id|. |callback| will be invoked on the UI thread when setting the // |token_id|. |callback| will be invoked on the UI thread when setting the
// attribute is done, possibly with an error message. // attribute is done, possibly with an error message.
virtual void SetAttributeForKey(const std::string& token_id, virtual void SetAttributeForKey(TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
const std::string& attribute_value, const std::string& attribute_value,
...@@ -316,7 +309,7 @@ class PlatformKeysService : public KeyedService { ...@@ -316,7 +309,7 @@ class PlatformKeysService : public KeyedService {
// |public_key_spki_der| only if the key is in |token_id|. // |public_key_spki_der| only if the key is in |token_id|.
// |callback| will be invoked on the UI thread when getting the attribute // |callback| will be invoked on the UI thread when getting the attribute
// is done, possibly with an error message. // is done, possibly with an error message.
virtual void GetAttributeForKey(const std::string& token_id, virtual void GetAttributeForKey(TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
GetAttributeForKeyCallback callback) = 0; GetAttributeForKeyCallback callback) = 0;
...@@ -339,22 +332,22 @@ class PlatformKeysServiceImpl final : public PlatformKeysService { ...@@ -339,22 +332,22 @@ class PlatformKeysServiceImpl final : public PlatformKeysService {
~PlatformKeysServiceImpl() override; ~PlatformKeysServiceImpl() override;
// PlatformKeysService // PlatformKeysService
void GenerateRSAKey(const std::string& token_id, void GenerateRSAKey(TokenId token_id,
unsigned int modulus_length_bits, unsigned int modulus_length_bits,
const GenerateKeyCallback& callback) override; const GenerateKeyCallback& callback) override;
void GenerateECKey(const std::string& token_id, void GenerateECKey(TokenId token_id,
const std::string& named_curve, const std::string& named_curve,
const GenerateKeyCallback& callback) override; const GenerateKeyCallback& callback) override;
void SignRSAPKCS1Digest(const std::string& token_id, void SignRSAPKCS1Digest(base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
const SignCallback& callback) override; const SignCallback& callback) override;
void SignRSAPKCS1Raw(const std::string& token_id, void SignRSAPKCS1Raw(base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
const SignCallback& callback) override; const SignCallback& callback) override;
void SignECDSADigest(const std::string& token_id, void SignECDSADigest(base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
...@@ -362,28 +355,27 @@ class PlatformKeysServiceImpl final : public PlatformKeysService { ...@@ -362,28 +355,27 @@ class PlatformKeysServiceImpl final : public PlatformKeysService {
void SelectClientCertificates( void SelectClientCertificates(
const std::vector<std::string>& certificate_authorities, const std::vector<std::string>& certificate_authorities,
const SelectCertificatesCallback& callback) override; const SelectCertificatesCallback& callback) override;
void GetCertificates(const std::string& token_id, void GetCertificates(TokenId token_id,
const GetCertificatesCallback& callback) override; const GetCertificatesCallback& callback) override;
void GetAllKeys(const std::string& token_id, void GetAllKeys(TokenId token_id, GetAllKeysCallback callback) override;
GetAllKeysCallback callback) override; void ImportCertificate(TokenId token_id,
void ImportCertificate(const std::string& token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const ImportCertificateCallback& callback) override; const ImportCertificateCallback& callback) override;
void RemoveCertificate(const std::string& token_id, void RemoveCertificate(TokenId token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const RemoveCertificateCallback& callback) override; const RemoveCertificateCallback& callback) override;
void RemoveKey(const std::string& token_id, void RemoveKey(TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
RemoveKeyCallback callback) override; RemoveKeyCallback callback) override;
void GetTokens(const GetTokensCallback& callback) override; void GetTokens(const GetTokensCallback& callback) override;
void GetKeyLocations(const std::string& public_key_spki_der, void GetKeyLocations(const std::string& public_key_spki_der,
const GetKeyLocationsCallback& callback) override; const GetKeyLocationsCallback& callback) override;
void SetAttributeForKey(const std::string& token_id, void SetAttributeForKey(TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
const std::string& attribute_value, const std::string& attribute_value,
SetAttributeForKeyCallback callback) override; SetAttributeForKeyCallback callback) override;
void GetAttributeForKey(const std::string& token_id, void GetAttributeForKey(TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
GetAttributeForKeyCallback callback) override; GetAttributeForKeyCallback callback) override;
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "base/location.h" #include "base/location.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/run_loop.h" #include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/task/post_task.h" #include "base/task/post_task.h"
#include "base/threading/thread_restrictions.h" #include "base/threading/thread_restrictions.h"
#include "chrome/browser/chromeos/login/test/device_state_mixin.h" #include "chrome/browser/chromeos/login/test/device_state_mixin.h"
...@@ -58,8 +60,6 @@ namespace { ...@@ -58,8 +60,6 @@ namespace {
constexpr char kTestUserEmail[] = "test@example.com"; constexpr char kTestUserEmail[] = "test@example.com";
constexpr char kTestAffiliationId[] = "test_affiliation_id"; constexpr char kTestAffiliationId[] = "test_affiliation_id";
constexpr char kSystemToken[] = "system";
constexpr char kUserToken[] = "user";
enum class ProfileToUse { enum class ProfileToUse {
// A Profile that belongs to a user that is not affiliated with the device (no // A Profile that belongs to a user that is not affiliated with the device (no
...@@ -80,7 +80,7 @@ struct TestConfig { ...@@ -80,7 +80,7 @@ struct TestConfig {
// The token IDs that are expected to be available. This will be checked by // The token IDs that are expected to be available. This will be checked by
// the GetTokens test, and operation for these tokens will be performed by the // the GetTokens test, and operation for these tokens will be performed by the
// other tests. // other tests.
std::vector<std::string> token_ids; std::vector<TokenId> token_ids;
}; };
// Softoken NSS PKCS11 module (used for testing) allows only predefined key // Softoken NSS PKCS11 module (used for testing) allows only predefined key
...@@ -177,12 +177,12 @@ class ExecutionWaiter { ...@@ -177,12 +177,12 @@ class ExecutionWaiter {
// Supports waiting for the result of PlatformKeysService::GetTokens. // Supports waiting for the result of PlatformKeysService::GetTokens.
class GetTokensExecutionWaiter class GetTokensExecutionWaiter
: public ExecutionWaiter<std::unique_ptr<std::vector<std::string>>> { : public ExecutionWaiter<std::unique_ptr<std::vector<TokenId>>> {
public: public:
GetTokensExecutionWaiter() = default; GetTokensExecutionWaiter() = default;
~GetTokensExecutionWaiter() = default; ~GetTokensExecutionWaiter() = default;
const std::unique_ptr<std::vector<std::string>>& token_ids() const { const std::unique_ptr<std::vector<TokenId>>& token_ids() const {
return std::get<0>(result_callback_args()); return std::get<0>(result_callback_args());
} }
}; };
...@@ -330,17 +330,18 @@ class PlatformKeysServiceBrowserTest ...@@ -330,17 +330,18 @@ class PlatformKeysServiceBrowserTest
} }
// Returns the slot to be used depending on |token_id|. // Returns the slot to be used depending on |token_id|.
PK11SlotInfo* GetSlot(const std::string& token_id) { PK11SlotInfo* GetSlot(TokenId token_id) {
if (token_id == kSystemToken) { switch (token_id) {
case TokenId::kSystem:
return system_nss_key_slot_mixin_.slot(); return system_nss_key_slot_mixin_.slot();
} case TokenId::kUser:
DCHECK_EQ(token_id, kUserToken);
return user_slot_.get(); return user_slot_.get();
} }
}
// Generates a key pair in the given |token_id| using platform keys service // Generates a key pair in the given |token_id| using platform keys service
// and returns the SubjectPublicKeyInfo string encoded in DER format. // and returns the SubjectPublicKeyInfo string encoded in DER format.
std::string GenerateKeyPair(const std::string& token_id) { std::string GenerateKeyPair(TokenId token_id) {
const unsigned int kKeySize = 2048; const unsigned int kKeySize = 2048;
GenerateKeyExecutionWaiter generate_key_waiter; GenerateKeyExecutionWaiter generate_key_waiter;
...@@ -404,7 +405,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GenerateRsaAndSign) { ...@@ -404,7 +405,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GenerateRsaAndSign) {
const crypto::SignatureVerifier::SignatureAlgorithm signature_algorithm = const crypto::SignatureVerifier::SignatureAlgorithm signature_algorithm =
crypto::SignatureVerifier::RSA_PKCS1_SHA256; crypto::SignatureVerifier::RSA_PKCS1_SHA256;
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
GenerateKeyExecutionWaiter generate_key_waiter; GenerateKeyExecutionWaiter generate_key_waiter;
platform_keys_service()->GenerateRSAKey(token_id, kKeySize, platform_keys_service()->GenerateRSAKey(token_id, kKeySize,
generate_key_waiter.GetCallback()); generate_key_waiter.GetCallback());
...@@ -438,8 +439,10 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, SetAndGetKeyAttribute) { ...@@ -438,8 +439,10 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, SetAndGetKeyAttribute) {
const KeyAttributeType kAttributeType = const KeyAttributeType kAttributeType =
KeyAttributeType::CertificateProvisioningId; KeyAttributeType::CertificateProvisioningId;
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
const std::string kAttributeValue = "test" + token_id; const int token_id_as_int = static_cast<int>(token_id);
const std::string attribute_value =
base::StringPrintf("test%d", token_id_as_int);
// Generate key pair. // Generate key pair.
const std::string public_key_spki_der = GenerateKeyPair(token_id); const std::string public_key_spki_der = GenerateKeyPair(token_id);
...@@ -448,7 +451,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, SetAndGetKeyAttribute) { ...@@ -448,7 +451,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, SetAndGetKeyAttribute) {
// Set key attribute. // Set key attribute.
SetAttributeForKeyExecutionWaiter set_attribute_for_key_execution_waiter; SetAttributeForKeyExecutionWaiter set_attribute_for_key_execution_waiter;
platform_keys_service()->SetAttributeForKey( platform_keys_service()->SetAttributeForKey(
token_id, public_key_spki_der, kAttributeType, kAttributeValue, token_id, public_key_spki_der, kAttributeType, attribute_value,
set_attribute_for_key_execution_waiter.GetCallback()); set_attribute_for_key_execution_waiter.GetCallback());
set_attribute_for_key_execution_waiter.Wait(); set_attribute_for_key_execution_waiter.Wait();
...@@ -461,7 +464,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, SetAndGetKeyAttribute) { ...@@ -461,7 +464,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, SetAndGetKeyAttribute) {
EXPECT_TRUE(get_attribute_for_key_execution_waiter.error_message().empty()); EXPECT_TRUE(get_attribute_for_key_execution_waiter.error_message().empty());
EXPECT_EQ(get_attribute_for_key_execution_waiter.attribute_value(), EXPECT_EQ(get_attribute_for_key_execution_waiter.attribute_value(),
kAttributeValue); attribute_value);
} }
} }
...@@ -469,7 +472,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GetUnsetKeyAttribute) { ...@@ -469,7 +472,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GetUnsetKeyAttribute) {
const KeyAttributeType kAttributeType = const KeyAttributeType kAttributeType =
KeyAttributeType::CertificateProvisioningId; KeyAttributeType::CertificateProvisioningId;
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
// Generate key pair. // Generate key pair.
const std::string public_key_spki_der = GenerateKeyPair(token_id); const std::string public_key_spki_der = GenerateKeyPair(token_id);
ASSERT_FALSE(public_key_spki_der.empty()); ASSERT_FALSE(public_key_spki_der.empty());
...@@ -493,7 +496,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, ...@@ -493,7 +496,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
KeyAttributeType::CertificateProvisioningId; KeyAttributeType::CertificateProvisioningId;
const std::string kPublicKey = "Non Existing public key"; const std::string kPublicKey = "Non Existing public key";
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
// Get key attribute. // Get key attribute.
GetAttributeForKeyExecutionWaiter get_attribute_for_key_execution_waiter; GetAttributeForKeyExecutionWaiter get_attribute_for_key_execution_waiter;
platform_keys_service()->GetAttributeForKey( platform_keys_service()->GetAttributeForKey(
...@@ -513,7 +516,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, ...@@ -513,7 +516,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
const std::string kAttributeValue = "test"; const std::string kAttributeValue = "test";
const std::string kPublicKey = "Non Existing public key"; const std::string kPublicKey = "Non Existing public key";
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
// Set key attribute. // Set key attribute.
SetAttributeForKeyExecutionWaiter set_attribute_for_key_execution_waiter; SetAttributeForKeyExecutionWaiter set_attribute_for_key_execution_waiter;
platform_keys_service()->SetAttributeForKey( platform_keys_service()->SetAttributeForKey(
...@@ -528,7 +531,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, ...@@ -528,7 +531,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
RemoveKeyWithNoMatchingCertificates) { RemoveKeyWithNoMatchingCertificates) {
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
// Generate first key pair. // Generate first key pair.
const std::string public_key_1 = GenerateKeyPair(token_id); const std::string public_key_1 = GenerateKeyPair(token_id);
ASSERT_FALSE(public_key_1.empty()); ASSERT_FALSE(public_key_1.empty());
...@@ -555,7 +558,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, ...@@ -555,7 +558,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
RemoveKeyWithMatchingCertificate) { RemoveKeyWithMatchingCertificate) {
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
PK11SlotInfo* const slot = GetSlot(token_id); PK11SlotInfo* const slot = GetSlot(token_id);
// Assert that there are no certificates before importing. // Assert that there are no certificates before importing.
...@@ -616,15 +619,15 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, ...@@ -616,15 +619,15 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
// retrieves them. // retrieves them.
IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GetAllKeys) { IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GetAllKeys) {
// Generate key pair in every token. // Generate key pair in every token.
std::map<std::string, std::string> token_key_map; std::map<TokenId, std::string> token_key_map;
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
const std::string public_key_spki_der = GenerateKeyPair(token_id); const std::string public_key_spki_der = GenerateKeyPair(token_id);
ASSERT_FALSE(public_key_spki_der.empty()); ASSERT_FALSE(public_key_spki_der.empty());
token_key_map[token_id] = public_key_spki_der; token_key_map[token_id] = public_key_spki_der;
} }
// Only keys in the requested token should be retrieved. // Only keys in the requested token should be retrieved.
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
GetAllKeysExecutionWaiter get_all_keys_waiter; GetAllKeysExecutionWaiter get_all_keys_waiter;
platform_keys_service()->GetAllKeys(token_id, platform_keys_service()->GetAllKeys(token_id,
get_all_keys_waiter.GetCallback()); get_all_keys_waiter.GetCallback());
...@@ -639,7 +642,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GetAllKeys) { ...@@ -639,7 +642,7 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, GetAllKeys) {
IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
GetAllKeysWhenNoKeysGenerated) { GetAllKeysWhenNoKeysGenerated) {
for (const std::string& token_id : GetParam().token_ids) { for (TokenId token_id : GetParam().token_ids) {
GetAllKeysExecutionWaiter get_all_keys_waiter; GetAllKeysExecutionWaiter get_all_keys_waiter;
platform_keys_service()->GetAllKeys(token_id, platform_keys_service()->GetAllKeys(token_id,
get_all_keys_waiter.GetCallback()); get_all_keys_waiter.GetCallback());
...@@ -654,11 +657,11 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest, ...@@ -654,11 +657,11 @@ IN_PROC_BROWSER_TEST_P(PlatformKeysServiceBrowserTest,
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
AllSupportedProfileTypes, AllSupportedProfileTypes,
PlatformKeysServiceBrowserTest, PlatformKeysServiceBrowserTest,
::testing::Values(TestConfig{ProfileToUse::kSigninProfile, {kSystemToken}}, ::testing::Values(
TestConfig{ProfileToUse::kUnaffiliatedUserProfile, TestConfig{ProfileToUse::kSigninProfile, {TokenId::kSystem}},
{kUserToken}}, TestConfig{ProfileToUse::kUnaffiliatedUserProfile, {TokenId::kUser}},
TestConfig{ProfileToUse::kAffiliatedUserProfile, TestConfig{ProfileToUse::kAffiliatedUserProfile,
{kSystemToken, kUserToken}})); {TokenId::kSystem, TokenId::kUser}}));
} // namespace platform_keys } // namespace platform_keys
} // namespace chromeos } // namespace chromeos
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include "base/optional.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h" #include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
#include <cert.h> #include <cert.h>
...@@ -118,9 +119,9 @@ using GetCertDBCallback = base::Callback<void(net::NSSCertDatabase* cert_db)>; ...@@ -118,9 +119,9 @@ using GetCertDBCallback = base::Callback<void(net::NSSCertDatabase* cert_db)>;
// Used by GetCertDatabaseOnIoThread and called back with the requested // Used by GetCertDatabaseOnIoThread and called back with the requested
// NSSCertDatabase. // NSSCertDatabase.
// If |token_id| is not empty, sets |slot_| of |state| accordingly and calls // If |token_id| is provided, sets |slot_| of |state| accordingly and calls
// |callback| if the database was successfully retrieved. // |callback| if the database was successfully retrieved.
void DidGetCertDbOnIoThread(const std::string& token_id, void DidGetCertDbOnIoThread(base::Optional<TokenId> token_id,
const GetCertDBCallback& callback, const GetCertDBCallback& callback,
NSSOperationState* state, NSSOperationState* state,
net::NSSCertDatabase* cert_db) { net::NSSCertDatabase* cert_db) {
...@@ -131,14 +132,19 @@ void DidGetCertDbOnIoThread(const std::string& token_id, ...@@ -131,14 +132,19 @@ void DidGetCertDbOnIoThread(const std::string& token_id,
return; return;
} }
if (!token_id.empty()) { if (token_id) {
if (token_id == kTokenIdUser) switch (token_id.value()) {
case TokenId::kUser:
state->slot_ = cert_db->GetPrivateSlot(); state->slot_ = cert_db->GetPrivateSlot();
else if (token_id == kTokenIdSystem) break;
case TokenId::kSystem:
state->slot_ = cert_db->GetSystemSlot(); state->slot_ = cert_db->GetSystemSlot();
break;
}
if (!state->slot_) { if (!state->slot_) {
LOG(ERROR) << "Slot for token id '" << token_id << "' not available."; LOG(ERROR) << "Slot for token id '" << static_cast<int>(token_id.value())
<< "' not available.";
state->OnError(FROM_HERE, kErrorInternal); state->OnError(FROM_HERE, kErrorInternal);
return; return;
} }
...@@ -147,10 +153,10 @@ void DidGetCertDbOnIoThread(const std::string& token_id, ...@@ -147,10 +153,10 @@ void DidGetCertDbOnIoThread(const std::string& token_id,
callback.Run(cert_db); callback.Run(cert_db);
} }
// Retrieves the NSSCertDatabase from |context| and, if |token_id| is not empty, // Retrieves the NSSCertDatabase from |context| and, if |token_id| is provided,
// the slot for |token_id|. // the slot for |token_id|.
// Must be called on the IO thread. // Must be called on the IO thread.
void GetCertDatabaseOnIoThread(const std::string& token_id, void GetCertDatabaseOnIoThread(base::Optional<TokenId> token_id,
const GetCertDBCallback& callback, const GetCertDBCallback& callback,
content::ResourceContext* context, content::ResourceContext* context,
NSSOperationState* state) { NSSOperationState* state) {
...@@ -164,7 +170,7 @@ void GetCertDatabaseOnIoThread(const std::string& token_id, ...@@ -164,7 +170,7 @@ void GetCertDatabaseOnIoThread(const std::string& token_id,
// Called by SystemTokenCertDBInitializer on the UI thread with the system token // Called by SystemTokenCertDBInitializer on the UI thread with the system token
// certificate database when it is initialized. // certificate database when it is initialized.
void DidGetSystemTokenCertDbOnUiThread(const std::string& token_id, void DidGetSystemTokenCertDbOnUiThread(base::Optional<TokenId> token_id,
const GetCertDBCallback& callback, const GetCertDBCallback& callback,
NSSOperationState* state, NSSOperationState* state,
net::NSSCertDatabase* cert_db) { net::NSSCertDatabase* cert_db) {
...@@ -178,11 +184,11 @@ void DidGetSystemTokenCertDbOnUiThread(const std::string& token_id, ...@@ -178,11 +184,11 @@ void DidGetSystemTokenCertDbOnUiThread(const std::string& token_id,
} }
// Asynchronously fetches the NSSCertDatabase for |browser_context| and, if // Asynchronously fetches the NSSCertDatabase for |browser_context| and, if
// |token_id| is not empty, the slot for |token_id|. Stores the slot in |state| // |token_id| is provided, the slot for |token_id|. Stores the slot in |state|
// and passes the database to |callback|. Will run |callback| on the IO thread. // and passes the database to |callback|. Will run |callback| on the IO thread.
// TODO(omorsi): Introduce timeout for retrieving certificate database in // TODO(omorsi): Introduce timeout for retrieving certificate database in
// platform keys. // platform keys.
void GetCertDatabase(const std::string& token_id, void GetCertDatabase(base::Optional<TokenId> token_id,
const GetCertDBCallback& callback, const GetCertDBCallback& callback,
BrowserContext* browser_context, BrowserContext* browser_context,
NSSOperationState* state) { NSSOperationState* state) {
...@@ -489,13 +495,12 @@ class GetTokensState : public NSSOperationState { ...@@ -489,13 +495,12 @@ class GetTokensState : public NSSOperationState {
void OnError(const base::Location& from, void OnError(const base::Location& from,
const std::string& error_message) override { const std::string& error_message) override {
CallBack(from, CallBack(from, std::unique_ptr<std::vector<TokenId>>() /* no token ids */,
std::unique_ptr<std::vector<std::string>>() /* no token ids */,
error_message); error_message);
} }
void CallBack(const base::Location& from, void CallBack(const base::Location& from,
std::unique_ptr<std::vector<std::string>> token_ids, std::unique_ptr<std::vector<TokenId>> token_ids,
const std::string& error_message) { const std::string& error_message) {
auto bound_callback = auto bound_callback =
base::BindOnce(callback_, std::move(token_ids), error_message); base::BindOnce(callback_, std::move(token_ids), error_message);
...@@ -518,11 +523,11 @@ class GetKeyLocationsState : public NSSOperationState { ...@@ -518,11 +523,11 @@ class GetKeyLocationsState : public NSSOperationState {
void OnError(const base::Location& from, void OnError(const base::Location& from,
const std::string& error_message) override { const std::string& error_message) override {
CallBack(from, std::vector<std::string>(), error_message); CallBack(from, std::vector<TokenId>() /* no token ids */, error_message);
} }
void CallBack(const base::Location& from, void CallBack(const base::Location& from,
const std::vector<std::string>& token_ids, const std::vector<TokenId>& token_ids,
const std::string& error_message) { const std::string& error_message) {
auto bound_callback = base::BindOnce(callback_, token_ids, error_message); auto bound_callback = base::BindOnce(callback_, token_ids, error_message);
origin_task_runner_->PostTask( origin_task_runner_->PostTask(
...@@ -1247,16 +1252,15 @@ void RemoveKeyWithDb(std::unique_ptr<RemoveKeyState> state, ...@@ -1247,16 +1252,15 @@ void RemoveKeyWithDb(std::unique_ptr<RemoveKeyState> state,
void GetTokensWithDB(std::unique_ptr<GetTokensState> state, void GetTokensWithDB(std::unique_ptr<GetTokensState> state,
net::NSSCertDatabase* cert_db) { net::NSSCertDatabase* cert_db) {
DCHECK_CURRENTLY_ON(BrowserThread::IO); DCHECK_CURRENTLY_ON(BrowserThread::IO);
std::unique_ptr<std::vector<std::string>> token_ids( auto token_ids = std::make_unique<std::vector<TokenId>>();
new std::vector<std::string>);
// The user token will be unavailable in case of no logged in user in this // The user token will be unavailable in case of no logged in user in this
// profile. // profile.
if (cert_db->GetPrivateSlot()) if (cert_db->GetPrivateSlot())
token_ids->push_back(kTokenIdUser); token_ids->push_back(TokenId::kUser);
if (cert_db->GetSystemSlot()) if (cert_db->GetSystemSlot())
token_ids->push_back(kTokenIdSystem); token_ids->push_back(TokenId::kSystem);
DCHECK(!token_ids->empty()); DCHECK(!token_ids->empty());
...@@ -1269,7 +1273,7 @@ void GetKeyLocationsWithDB(std::unique_ptr<GetKeyLocationsState> state, ...@@ -1269,7 +1273,7 @@ void GetKeyLocationsWithDB(std::unique_ptr<GetKeyLocationsState> state,
net::NSSCertDatabase* cert_db) { net::NSSCertDatabase* cert_db) {
DCHECK_CURRENTLY_ON(BrowserThread::IO); DCHECK_CURRENTLY_ON(BrowserThread::IO);
std::vector<std::string> token_ids; std::vector<TokenId> token_ids;
const uint8_t* public_key_uint8 = const uint8_t* public_key_uint8 =
reinterpret_cast<const uint8_t*>(state->public_key_spki_der_.data()); reinterpret_cast<const uint8_t*>(state->public_key_spki_der_.data());
...@@ -1281,14 +1285,14 @@ void GetKeyLocationsWithDB(std::unique_ptr<GetKeyLocationsState> state, ...@@ -1281,14 +1285,14 @@ void GetKeyLocationsWithDB(std::unique_ptr<GetKeyLocationsState> state,
crypto::FindNSSKeyFromPublicKeyInfoInSlot( crypto::FindNSSKeyFromPublicKeyInfoInSlot(
public_key_vector, cert_db->GetPrivateSlot().get()); public_key_vector, cert_db->GetPrivateSlot().get());
if (rsa_key) if (rsa_key)
token_ids.push_back(kTokenIdUser); token_ids.push_back(TokenId::kUser);
} }
if (token_ids.empty() && cert_db->GetPublicSlot().get()) { if (token_ids.empty() && cert_db->GetPublicSlot().get()) {
crypto::ScopedSECKEYPrivateKey rsa_key = crypto::ScopedSECKEYPrivateKey rsa_key =
crypto::FindNSSKeyFromPublicKeyInfoInSlot( crypto::FindNSSKeyFromPublicKeyInfoInSlot(
public_key_vector, cert_db->GetPublicSlot().get()); public_key_vector, cert_db->GetPublicSlot().get());
if (rsa_key) if (rsa_key)
token_ids.push_back(kTokenIdUser); token_ids.push_back(TokenId::kUser);
} }
if (cert_db->GetSystemSlot().get()) { if (cert_db->GetSystemSlot().get()) {
...@@ -1296,7 +1300,7 @@ void GetKeyLocationsWithDB(std::unique_ptr<GetKeyLocationsState> state, ...@@ -1296,7 +1300,7 @@ void GetKeyLocationsWithDB(std::unique_ptr<GetKeyLocationsState> state,
crypto::FindNSSKeyFromPublicKeyInfoInSlot( crypto::FindNSSKeyFromPublicKeyInfoInSlot(
public_key_vector, cert_db->GetSystemSlot().get()); public_key_vector, cert_db->GetSystemSlot().get());
if (rsa_key) if (rsa_key)
token_ids.push_back(kTokenIdSystem); token_ids.push_back(TokenId::kSystem);
} }
state->CallBack(FROM_HERE, std::move(token_ids), state->CallBack(FROM_HERE, std::move(token_ids),
...@@ -1398,7 +1402,7 @@ void GetAttributeForKeyWithDb(std::unique_ptr<GetAttributeForKeyState> state, ...@@ -1398,7 +1402,7 @@ void GetAttributeForKeyWithDb(std::unique_ptr<GetAttributeForKeyState> state,
} // namespace } // namespace
void PlatformKeysServiceImpl::GenerateRSAKey( void PlatformKeysServiceImpl::GenerateRSAKey(
const std::string& token_id, TokenId token_id,
unsigned int modulus_length_bits, unsigned int modulus_length_bits,
const GenerateKeyCallback& callback) { const GenerateKeyCallback& callback) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
...@@ -1418,7 +1422,7 @@ void PlatformKeysServiceImpl::GenerateRSAKey( ...@@ -1418,7 +1422,7 @@ void PlatformKeysServiceImpl::GenerateRSAKey(
} }
void PlatformKeysServiceImpl::GenerateECKey( void PlatformKeysServiceImpl::GenerateECKey(
const std::string& token_id, TokenId token_id,
const std::string& named_curve, const std::string& named_curve,
const GenerateKeyCallback& callback) { const GenerateKeyCallback& callback) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
...@@ -1433,7 +1437,7 @@ void PlatformKeysServiceImpl::GenerateECKey( ...@@ -1433,7 +1437,7 @@ void PlatformKeysServiceImpl::GenerateECKey(
} }
void PlatformKeysServiceImpl::SignRSAPKCS1Digest( void PlatformKeysServiceImpl::SignRSAPKCS1Digest(
const std::string& token_id, base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
...@@ -1455,7 +1459,7 @@ void PlatformKeysServiceImpl::SignRSAPKCS1Digest( ...@@ -1455,7 +1459,7 @@ void PlatformKeysServiceImpl::SignRSAPKCS1Digest(
} }
void PlatformKeysServiceImpl::SignRSAPKCS1Raw( void PlatformKeysServiceImpl::SignRSAPKCS1Raw(
const std::string& token_id, base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
const SignCallback& callback) { const SignCallback& callback) {
...@@ -1476,7 +1480,7 @@ void PlatformKeysServiceImpl::SignRSAPKCS1Raw( ...@@ -1476,7 +1480,7 @@ void PlatformKeysServiceImpl::SignRSAPKCS1Raw(
} }
void PlatformKeysServiceImpl::SignECDSADigest( void PlatformKeysServiceImpl::SignECDSADigest(
const std::string& token_id, base::Optional<TokenId> token_id,
const std::string& data, const std::string& data,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
HashAlgorithm hash_algorithm, HashAlgorithm hash_algorithm,
...@@ -1660,7 +1664,7 @@ bool GetPublicKeyBySpki(const std::string& spki, ...@@ -1660,7 +1664,7 @@ bool GetPublicKeyBySpki(const std::string& spki,
} }
void PlatformKeysServiceImpl::GetCertificates( void PlatformKeysServiceImpl::GetCertificates(
const std::string& token_id, TokenId token_id,
const GetCertificatesCallback& callback) { const GetCertificatesCallback& callback) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
auto state = std::make_unique<GetCertificatesState>( auto state = std::make_unique<GetCertificatesState>(
...@@ -1672,7 +1676,7 @@ void PlatformKeysServiceImpl::GetCertificates( ...@@ -1672,7 +1676,7 @@ void PlatformKeysServiceImpl::GetCertificates(
browser_context_, state_ptr); browser_context_, state_ptr);
} }
void PlatformKeysServiceImpl::GetAllKeys(const std::string& token_id, void PlatformKeysServiceImpl::GetAllKeys(TokenId token_id,
GetAllKeysCallback callback) { GetAllKeysCallback callback) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
...@@ -1686,7 +1690,7 @@ void PlatformKeysServiceImpl::GetAllKeys(const std::string& token_id, ...@@ -1686,7 +1690,7 @@ void PlatformKeysServiceImpl::GetAllKeys(const std::string& token_id,
} }
void PlatformKeysServiceImpl::ImportCertificate( void PlatformKeysServiceImpl::ImportCertificate(
const std::string& token_id, TokenId token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const ImportCertificateCallback& callback) { const ImportCertificateCallback& callback) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
...@@ -1704,7 +1708,7 @@ void PlatformKeysServiceImpl::ImportCertificate( ...@@ -1704,7 +1708,7 @@ void PlatformKeysServiceImpl::ImportCertificate(
} }
void PlatformKeysServiceImpl::RemoveCertificate( void PlatformKeysServiceImpl::RemoveCertificate(
const std::string& token_id, TokenId token_id,
const scoped_refptr<net::X509Certificate>& certificate, const scoped_refptr<net::X509Certificate>& certificate,
const RemoveCertificateCallback& callback) { const RemoveCertificateCallback& callback) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
...@@ -1720,7 +1724,7 @@ void PlatformKeysServiceImpl::RemoveCertificate( ...@@ -1720,7 +1724,7 @@ void PlatformKeysServiceImpl::RemoveCertificate(
browser_context_, state_ptr); browser_context_, state_ptr);
} }
void PlatformKeysServiceImpl::RemoveKey(const std::string& token_id, void PlatformKeysServiceImpl::RemoveKey(TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
RemoveKeyCallback callback) { RemoveKeyCallback callback) {
DCHECK_CURRENTLY_ON(BrowserThread::UI); DCHECK_CURRENTLY_ON(BrowserThread::UI);
...@@ -1743,7 +1747,7 @@ void PlatformKeysServiceImpl::GetTokens(const GetTokensCallback& callback) { ...@@ -1743,7 +1747,7 @@ void PlatformKeysServiceImpl::GetTokens(const GetTokensCallback& callback) {
std::make_unique<GetTokensState>(weak_factory_.GetWeakPtr(), callback); std::make_unique<GetTokensState>(weak_factory_.GetWeakPtr(), callback);
// Get the pointer to |state| before base::Passed releases |state|. // Get the pointer to |state| before base::Passed releases |state|.
NSSOperationState* state_ptr = state.get(); NSSOperationState* state_ptr = state.get();
GetCertDatabase(std::string() /* don't get any specific slot */, GetCertDatabase(/*token_id=*/base::nullopt /* don't get any specific slot */,
base::Bind(&GetTokensWithDB, base::Passed(&state)), base::Bind(&GetTokensWithDB, base::Passed(&state)),
browser_context_, state_ptr); browser_context_, state_ptr);
} }
...@@ -1757,13 +1761,13 @@ void PlatformKeysServiceImpl::GetKeyLocations( ...@@ -1757,13 +1761,13 @@ void PlatformKeysServiceImpl::GetKeyLocations(
NSSOperationState* state_ptr = state.get(); NSSOperationState* state_ptr = state.get();
GetCertDatabase( GetCertDatabase(
std::string() /* don't get any specific slot - we need all slots */, /*token_id=*/base::nullopt /* don't get any specific slot */,
base::BindRepeating(&GetKeyLocationsWithDB, base::Passed(&state)), base::BindRepeating(&GetKeyLocationsWithDB, base::Passed(&state)),
browser_context_, state_ptr); browser_context_, state_ptr);
} }
void PlatformKeysServiceImpl::SetAttributeForKey( void PlatformKeysServiceImpl::SetAttributeForKey(
const std::string& token_id, TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
const std::string& attribute_value, const std::string& attribute_value,
...@@ -1790,7 +1794,7 @@ void PlatformKeysServiceImpl::SetAttributeForKey( ...@@ -1790,7 +1794,7 @@ void PlatformKeysServiceImpl::SetAttributeForKey(
} }
void PlatformKeysServiceImpl::GetAttributeForKey( void PlatformKeysServiceImpl::GetAttributeForKey(
const std::string& token_id, TokenId token_id,
const std::string& public_key_spki_der, const std::string& public_key_spki_der,
KeyAttributeType attribute_type, KeyAttributeType attribute_type,
GetAttributeForKeyCallback callback) { GetAttributeForKeyCallback callback) {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <utility> #include <utility>
#include "base/bind.h" #include "base/bind.h"
#include "base/optional.h"
#include "base/values.h" #include "base/values.h"
#include "chrome/browser/chromeos/platform_keys/extension_platform_keys_service.h" #include "chrome/browser/chromeos/platform_keys/extension_platform_keys_service.h"
#include "chrome/browser/chromeos/platform_keys/extension_platform_keys_service_factory.h" #include "chrome/browser/chromeos/platform_keys/extension_platform_keys_service_factory.h"
...@@ -53,8 +54,9 @@ EnterprisePlatformKeysInternalGenerateKeyFunction::Run() { ...@@ -53,8 +54,9 @@ EnterprisePlatformKeysInternalGenerateKeyFunction::Run() {
api_epki::GenerateKey::Params::Create(*args_)); api_epki::GenerateKey::Params::Create(*args_));
EXTENSION_FUNCTION_VALIDATE(params); EXTENSION_FUNCTION_VALIDATE(params);
std::string platform_keys_token_id; base::Optional<chromeos::platform_keys::TokenId> platform_keys_token_id =
if (!platform_keys::ValidateToken(params->token_id, &platform_keys_token_id)) platform_keys::ApiIdToPlatformKeysTokenId(params->token_id);
if (!platform_keys_token_id)
return RespondNow(Error(platform_keys::kErrorInvalidToken)); return RespondNow(Error(platform_keys::kErrorInvalidToken));
chromeos::ExtensionPlatformKeysService* service = chromeos::ExtensionPlatformKeysService* service =
...@@ -67,7 +69,7 @@ EnterprisePlatformKeysInternalGenerateKeyFunction::Run() { ...@@ -67,7 +69,7 @@ EnterprisePlatformKeysInternalGenerateKeyFunction::Run() {
EXTENSION_FUNCTION_VALIDATE(params->algorithm.modulus_length && EXTENSION_FUNCTION_VALIDATE(params->algorithm.modulus_length &&
*(params->algorithm.modulus_length) >= 0); *(params->algorithm.modulus_length) >= 0);
service->GenerateRSAKey( service->GenerateRSAKey(
platform_keys_token_id, *(params->algorithm.modulus_length), platform_keys_token_id.value(), *(params->algorithm.modulus_length),
extension_id(), extension_id(),
base::Bind( base::Bind(
&EnterprisePlatformKeysInternalGenerateKeyFunction::OnGeneratedKey, &EnterprisePlatformKeysInternalGenerateKeyFunction::OnGeneratedKey,
...@@ -75,7 +77,7 @@ EnterprisePlatformKeysInternalGenerateKeyFunction::Run() { ...@@ -75,7 +77,7 @@ EnterprisePlatformKeysInternalGenerateKeyFunction::Run() {
} else if (params->algorithm.name == "ECDSA") { } else if (params->algorithm.name == "ECDSA") {
EXTENSION_FUNCTION_VALIDATE(params->algorithm.named_curve); EXTENSION_FUNCTION_VALIDATE(params->algorithm.named_curve);
service->GenerateECKey( service->GenerateECKey(
platform_keys_token_id, *(params->algorithm.named_curve), platform_keys_token_id.value(), *(params->algorithm.named_curve),
extension_id(), extension_id(),
base::Bind( base::Bind(
&EnterprisePlatformKeysInternalGenerateKeyFunction::OnGeneratedKey, &EnterprisePlatformKeysInternalGenerateKeyFunction::OnGeneratedKey,
...@@ -107,15 +109,16 @@ EnterprisePlatformKeysGetCertificatesFunction::Run() { ...@@ -107,15 +109,16 @@ EnterprisePlatformKeysGetCertificatesFunction::Run() {
std::unique_ptr<api_epk::GetCertificates::Params> params( std::unique_ptr<api_epk::GetCertificates::Params> params(
api_epk::GetCertificates::Params::Create(*args_)); api_epk::GetCertificates::Params::Create(*args_));
EXTENSION_FUNCTION_VALIDATE(params); EXTENSION_FUNCTION_VALIDATE(params);
std::string platform_keys_token_id; base::Optional<chromeos::platform_keys::TokenId> platform_keys_token_id =
if (!platform_keys::ValidateToken(params->token_id, &platform_keys_token_id)) platform_keys::ApiIdToPlatformKeysTokenId(params->token_id);
if (!platform_keys_token_id)
return RespondNow(Error(platform_keys::kErrorInvalidToken)); return RespondNow(Error(platform_keys::kErrorInvalidToken));
chromeos::platform_keys::PlatformKeysService* platform_keys_service = chromeos::platform_keys::PlatformKeysService* platform_keys_service =
chromeos::platform_keys::PlatformKeysServiceFactory::GetForBrowserContext( chromeos::platform_keys::PlatformKeysServiceFactory::GetForBrowserContext(
browser_context()); browser_context());
platform_keys_service->GetCertificates( platform_keys_service->GetCertificates(
platform_keys_token_id, platform_keys_token_id.value(),
base::Bind( base::Bind(
&EnterprisePlatformKeysGetCertificatesFunction::OnGotCertificates, &EnterprisePlatformKeysGetCertificatesFunction::OnGotCertificates,
this)); this));
...@@ -153,8 +156,9 @@ EnterprisePlatformKeysImportCertificateFunction::Run() { ...@@ -153,8 +156,9 @@ EnterprisePlatformKeysImportCertificateFunction::Run() {
std::unique_ptr<api_epk::ImportCertificate::Params> params( std::unique_ptr<api_epk::ImportCertificate::Params> params(
api_epk::ImportCertificate::Params::Create(*args_)); api_epk::ImportCertificate::Params::Create(*args_));
EXTENSION_FUNCTION_VALIDATE(params); EXTENSION_FUNCTION_VALIDATE(params);
std::string platform_keys_token_id; base::Optional<chromeos::platform_keys::TokenId> platform_keys_token_id =
if (!platform_keys::ValidateToken(params->token_id, &platform_keys_token_id)) platform_keys::ApiIdToPlatformKeysTokenId(params->token_id);
if (!platform_keys_token_id)
return RespondNow(Error(platform_keys::kErrorInvalidToken)); return RespondNow(Error(platform_keys::kErrorInvalidToken));
const std::vector<uint8_t>& cert_der = params->certificate; const std::vector<uint8_t>& cert_der = params->certificate;
...@@ -175,7 +179,7 @@ EnterprisePlatformKeysImportCertificateFunction::Run() { ...@@ -175,7 +179,7 @@ EnterprisePlatformKeysImportCertificateFunction::Run() {
CHECK(platform_keys_service); CHECK(platform_keys_service);
platform_keys_service->ImportCertificate( platform_keys_service->ImportCertificate(
platform_keys_token_id, cert_x509, platform_keys_token_id.value(), cert_x509,
base::Bind(&EnterprisePlatformKeysImportCertificateFunction:: base::Bind(&EnterprisePlatformKeysImportCertificateFunction::
OnImportedCertificate, OnImportedCertificate,
this)); this));
...@@ -199,8 +203,9 @@ EnterprisePlatformKeysRemoveCertificateFunction::Run() { ...@@ -199,8 +203,9 @@ EnterprisePlatformKeysRemoveCertificateFunction::Run() {
std::unique_ptr<api_epk::RemoveCertificate::Params> params( std::unique_ptr<api_epk::RemoveCertificate::Params> params(
api_epk::RemoveCertificate::Params::Create(*args_)); api_epk::RemoveCertificate::Params::Create(*args_));
EXTENSION_FUNCTION_VALIDATE(params); EXTENSION_FUNCTION_VALIDATE(params);
std::string platform_keys_token_id; base::Optional<chromeos::platform_keys::TokenId> platform_keys_token_id =
if (!platform_keys::ValidateToken(params->token_id, &platform_keys_token_id)) platform_keys::ApiIdToPlatformKeysTokenId(params->token_id);
if (!platform_keys_token_id)
return RespondNow(Error(platform_keys::kErrorInvalidToken)); return RespondNow(Error(platform_keys::kErrorInvalidToken));
const std::vector<uint8_t>& cert_der = params->certificate; const std::vector<uint8_t>& cert_der = params->certificate;
...@@ -221,7 +226,7 @@ EnterprisePlatformKeysRemoveCertificateFunction::Run() { ...@@ -221,7 +226,7 @@ EnterprisePlatformKeysRemoveCertificateFunction::Run() {
CHECK(platform_keys_service); CHECK(platform_keys_service);
platform_keys_service->RemoveCertificate( platform_keys_service->RemoveCertificate(
platform_keys_token_id, cert_x509, platform_keys_token_id.value(), cert_x509,
base::Bind(&EnterprisePlatformKeysRemoveCertificateFunction:: base::Bind(&EnterprisePlatformKeysRemoveCertificateFunction::
OnRemovedCertificate, OnRemovedCertificate,
this)); this));
...@@ -255,7 +260,8 @@ EnterprisePlatformKeysInternalGetTokensFunction::Run() { ...@@ -255,7 +260,8 @@ EnterprisePlatformKeysInternalGetTokensFunction::Run() {
} }
void EnterprisePlatformKeysInternalGetTokensFunction::OnGotTokens( void EnterprisePlatformKeysInternalGetTokensFunction::OnGotTokens(
std::unique_ptr<std::vector<std::string>> platform_keys_token_ids, std::unique_ptr<std::vector<chromeos::platform_keys::TokenId>>
platform_keys_token_ids,
const std::string& error_message) { const std::string& error_message) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI); DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
if (!error_message.empty()) { if (!error_message.empty()) {
...@@ -264,15 +270,14 @@ void EnterprisePlatformKeysInternalGetTokensFunction::OnGotTokens( ...@@ -264,15 +270,14 @@ void EnterprisePlatformKeysInternalGetTokensFunction::OnGotTokens(
} }
std::vector<std::string> token_ids; std::vector<std::string> token_ids;
for (std::vector<std::string>::const_iterator it = for (auto token_id : *platform_keys_token_ids) {
platform_keys_token_ids->begin(); std::string api_token_id =
it != platform_keys_token_ids->end(); ++it) { platform_keys::PlatformKeysTokenIdToApiId(token_id);
std::string token_id = platform_keys::PlatformKeysTokenIdToApiId(*it); if (api_token_id.empty()) {
if (token_id.empty()) {
Respond(Error(kEnterprisePlatformErrorInternal)); Respond(Error(kEnterprisePlatformErrorInternal));
return; return;
} }
token_ids.push_back(token_id); token_ids.push_back(api_token_id);
} }
Respond(ArgumentList(api_epki::GetTokens::Results::Create(token_ids))); Respond(ArgumentList(api_epki::GetTokens::Results::Create(token_ids)));
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <vector> #include <vector>
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
#include "chrome/browser/extensions/api/enterprise_platform_keys_private/enterprise_platform_keys_private_api.h" #include "chrome/browser/extensions/api/enterprise_platform_keys_private/enterprise_platform_keys_private_api.h"
#include "extensions/browser/extension_function.h" #include "extensions/browser/extension_function.h"
...@@ -85,7 +86,8 @@ class EnterprisePlatformKeysInternalGetTokensFunction ...@@ -85,7 +86,8 @@ class EnterprisePlatformKeysInternalGetTokensFunction
// Called when the list of tokens was determined. If an error occurred, // Called when the list of tokens was determined. If an error occurred,
// |token_ids| will be NULL and instead |error_message| be set. // |token_ids| will be NULL and instead |error_message| be set.
void OnGotTokens(std::unique_ptr<std::vector<std::string>> token_ids, void OnGotTokens(
std::unique_ptr<std::vector<chromeos::platform_keys::TokenId>> token_ids,
const std::string& error_message); const std::string& error_message);
DECLARE_EXTENSION_FUNCTION("enterprise.platformKeysInternal.getTokens", DECLARE_EXTENSION_FUNCTION("enterprise.platformKeysInternal.getTokens",
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/optional.h"
#include "base/stl_util.h" #include "base/stl_util.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
#include "base/values.h" #include "base/values.h"
...@@ -103,29 +104,25 @@ const char kErrorInvalidX509Cert[] = ...@@ -103,29 +104,25 @@ const char kErrorInvalidX509Cert[] =
const char kTokenIdUser[] = "user"; const char kTokenIdUser[] = "user";
const char kTokenIdSystem[] = "system"; const char kTokenIdSystem[] = "system";
// Returns whether |token_id| references a known Token. base::Optional<chromeos::platform_keys::TokenId> ApiIdToPlatformKeysTokenId(
bool ValidateToken(const std::string& token_id, const std::string& token_id) {
std::string* platform_keys_token_id) { if (token_id == kTokenIdUser)
platform_keys_token_id->clear(); return chromeos::platform_keys::TokenId::kUser;
if (token_id == kTokenIdUser) {
*platform_keys_token_id = chromeos::platform_keys::kTokenIdUser; if (token_id == kTokenIdSystem)
return true; return chromeos::platform_keys::TokenId::kSystem;
}
if (token_id == kTokenIdSystem) { return base::nullopt;
*platform_keys_token_id = chromeos::platform_keys::kTokenIdSystem;
return true;
}
return false;
} }
std::string PlatformKeysTokenIdToApiId( std::string PlatformKeysTokenIdToApiId(
const std::string& platform_keys_token_id) { chromeos::platform_keys::TokenId platform_keys_token_id) {
if (platform_keys_token_id == chromeos::platform_keys::kTokenIdUser) switch (platform_keys_token_id) {
case chromeos::platform_keys::TokenId::kUser:
return kTokenIdUser; return kTokenIdUser;
if (platform_keys_token_id == chromeos::platform_keys::kTokenIdSystem) case chromeos::platform_keys::TokenId::kSystem:
return kTokenIdSystem; return kTokenIdSystem;
}
return std::string();
} }
} // namespace platform_keys } // namespace platform_keys
...@@ -364,12 +361,17 @@ ExtensionFunction::ResponseAction PlatformKeysInternalSignFunction::Run() { ...@@ -364,12 +361,17 @@ ExtensionFunction::ResponseAction PlatformKeysInternalSignFunction::Run() {
std::unique_ptr<api_pki::Sign::Params> params( std::unique_ptr<api_pki::Sign::Params> params(
api_pki::Sign::Params::Create(*args_)); api_pki::Sign::Params::Create(*args_));
EXTENSION_FUNCTION_VALIDATE(params); EXTENSION_FUNCTION_VALIDATE(params);
std::string platform_keys_token_id;
if (!params->token_id.empty() && base::Optional<chromeos::platform_keys::TokenId> platform_keys_token_id;
!platform_keys::ValidateToken(params->token_id, // If |params->token_id| is not specified (empty string), the key will be
&platform_keys_token_id)) { // searched for in all available tokens.
if (!params->token_id.empty()) {
platform_keys_token_id =
platform_keys::ApiIdToPlatformKeysTokenId(params->token_id);
if (!platform_keys_token_id) {
return RespondNow(Error(platform_keys::kErrorInvalidToken)); return RespondNow(Error(platform_keys::kErrorInvalidToken));
} }
}
chromeos::ExtensionPlatformKeysService* service = chromeos::ExtensionPlatformKeysService* service =
chromeos::ExtensionPlatformKeysServiceFactory::GetForBrowserContext( chromeos::ExtensionPlatformKeysServiceFactory::GetForBrowserContext(
......
...@@ -8,12 +8,14 @@ ...@@ -8,12 +8,14 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "base/optional.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys_service.h"
#include "extensions/browser/extension_function.h" #include "extensions/browser/extension_function.h"
namespace net { namespace net {
class X509Certificate; class X509Certificate;
typedef std::vector<scoped_refptr<X509Certificate>> CertificateList; typedef std::vector<scoped_refptr<X509Certificate>> CertificateList;
} // net } // namespace net
namespace extensions { namespace extensions {
namespace platform_keys { namespace platform_keys {
...@@ -21,14 +23,15 @@ namespace platform_keys { ...@@ -21,14 +23,15 @@ namespace platform_keys {
extern const char kErrorInvalidToken[]; extern const char kErrorInvalidToken[];
extern const char kErrorInvalidX509Cert[]; extern const char kErrorInvalidX509Cert[];
// Returns whether |token_id| references a known Token. // Returns a known token if |token_id| is valid and returns nullopt for both
bool ValidateToken(const std::string& token_id, // empty or unknown |token_id|.
std::string* platform_keys_token_id); base::Optional<chromeos::platform_keys::TokenId> ApiIdToPlatformKeysTokenId(
const std::string& token_id);
// Converts a token id from ::chromeos::platform_keys to the platformKeys API // Converts a token id from ::chromeos::platform_keys to the platformKeys API
// token id. // token id.
std::string PlatformKeysTokenIdToApiId( std::string PlatformKeysTokenIdToApiId(
const std::string& platform_keys_token_id); chromeos::platform_keys::TokenId platform_keys_token_id);
} // namespace platform_keys } // namespace platform_keys
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment