Commit 6a509131 authored by Leonid Baraz's avatar Leonid Baraz Committed by Commit Bot

Refactor encryption to receive public key id.

It will now receive public key id from caller (server in prod mode)
rather than calculating it locally. This was we avoid any discrepancy
between the way it is calculated on the server and on the client.

Bug: b:153651358
Change-Id: If1d43e7c86c142ed8be753b14b49976f9f5659ce
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2510609Reviewed-by: default avatarZach Trudo <zatrudo@google.com>
Commit-Queue: Leonid Baraz <lbaraz@chromium.org>
Cr-Commit-Position: refs/heads/master@{#822761}
parent ecd51275
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "base/containers/span.h" #include "base/containers/span.h"
#include "base/hash/hash.h" #include "base/hash/hash.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "base/rand_util.h"
#include "base/strings/strcat.h" #include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h" #include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
...@@ -123,16 +124,16 @@ Decryptor::~Decryptor() = default; ...@@ -123,16 +124,16 @@ Decryptor::~Decryptor() = default;
void Decryptor::RecordKeyPair(base::StringPiece private_key, void Decryptor::RecordKeyPair(base::StringPiece private_key,
base::StringPiece public_key, base::StringPiece public_key,
base::OnceCallback<void(Status)> cb) { base::OnceCallback<void(StatusOr<int64_t>)> cb) {
// Schedule key recording on the sequenced task runner. // Schedule key recording on the sequenced task runner.
keys_sequenced_task_runner_->PostTask( keys_sequenced_task_runner_->PostTask(
FROM_HERE, FROM_HERE,
base::BindOnce( base::BindOnce(
[](std::string public_key, KeyInfo key_info, [](std::string public_key, KeyInfo key_info,
base::OnceCallback<void(Status)> cb, base::OnceCallback<void(StatusOr<int64_t>)> cb,
scoped_refptr<Decryptor> decryptor) { scoped_refptr<Decryptor> decryptor) {
DCHECK_CALLED_ON_VALID_SEQUENCE(decryptor->keys_sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(decryptor->keys_sequence_checker_);
Status result; StatusOr<int64_t> result;
if (key_info.private_key.size() != X25519_PRIVATE_KEY_LEN) { if (key_info.private_key.size() != X25519_PRIVATE_KEY_LEN) {
result = Status( result = Status(
error::FAILED_PRECONDITION, error::FAILED_PRECONDITION,
...@@ -147,19 +148,27 @@ void Decryptor::RecordKeyPair(base::StringPiece private_key, ...@@ -147,19 +148,27 @@ void Decryptor::RecordKeyPair(base::StringPiece private_key,
{"Public key size mismatch, expected=", {"Public key size mismatch, expected=",
base::NumberToString(X25519_PUBLIC_VALUE_LEN), base::NumberToString(X25519_PUBLIC_VALUE_LEN),
" actual=", base::NumberToString(public_key.size())})); " actual=", base::NumberToString(public_key.size())}));
} else if (!decryptor->keys_ } else {
.emplace(base::PersistentHash(public_key), key_info) // Assign a random number to be public key id for testing purposes
.second) { // only (in production it will be Java Fingerprint2011 which is
result = Status(error::ALREADY_EXISTS, // 'long').
base::StrCat({"Public key='", public_key, int64_t public_key_id;
"' already recorded"})); base::RandBytes(&public_key_id, sizeof(public_key_id));
if (!decryptor->keys_.emplace(public_key_id, key_info).second) {
result = Status(error::ALREADY_EXISTS,
base::StrCat({"Public key='", public_key,
"' already recorded"}));
} else {
result = public_key_id;
}
} }
// Schedule response on a generic thread pool. // Schedule response on a generic thread pool.
base::ThreadPool::PostTask( base::ThreadPool::PostTask(
FROM_HERE, FROM_HERE,
base::BindOnce([](base::OnceCallback<void(Status)> cb, base::BindOnce(
Status result) { std::move(cb).Run(result); }, [](base::OnceCallback<void(StatusOr<int64_t>)> cb,
std::move(cb), result)); StatusOr<int64_t> result) { std::move(cb).Run(result); },
std::move(cb), result));
}, },
std::string(public_key), std::string(public_key),
KeyInfo{.private_key = std::string(private_key), KeyInfo{.private_key = std::string(private_key),
...@@ -168,13 +177,13 @@ void Decryptor::RecordKeyPair(base::StringPiece private_key, ...@@ -168,13 +177,13 @@ void Decryptor::RecordKeyPair(base::StringPiece private_key,
} }
void Decryptor::RetrieveMatchingPrivateKey( void Decryptor::RetrieveMatchingPrivateKey(
uint32_t public_key_id, int64_t public_key_id,
base::OnceCallback<void(StatusOr<std::string>)> cb) { base::OnceCallback<void(StatusOr<std::string>)> cb) {
// Schedule key retrieval on the sequenced task runner. // Schedule key retrieval on the sequenced task runner.
keys_sequenced_task_runner_->PostTask( keys_sequenced_task_runner_->PostTask(
FROM_HERE, FROM_HERE,
base::BindOnce( base::BindOnce(
[](uint32_t public_key_id, [](int64_t public_key_id,
base::OnceCallback<void(StatusOr<std::string>)> cb, base::OnceCallback<void(StatusOr<std::string>)> cb,
scoped_refptr<Decryptor> decryptor) { scoped_refptr<Decryptor> decryptor) {
DCHECK_CALLED_ON_VALID_SEQUENCE(decryptor->keys_sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(decryptor->keys_sequence_checker_);
......
...@@ -78,15 +78,15 @@ class Decryptor : public base::RefCountedThreadSafe<Decryptor> { ...@@ -78,15 +78,15 @@ class Decryptor : public base::RefCountedThreadSafe<Decryptor> {
base::StringPiece peer_public_value); base::StringPiece peer_public_value);
// Records a key pair (stores only private key). // Records a key pair (stores only private key).
// Executes on a sequenced thread, returns with callback. // Executes on a sequenced thread, returns key id or error with callback.
void RecordKeyPair(base::StringPiece private_key, void RecordKeyPair(base::StringPiece private_key,
base::StringPiece public_key, base::StringPiece public_key,
base::OnceCallback<void(Status)> cb); base::OnceCallback<void(StatusOr<int64_t>)> cb);
// Retrieves private key matching the public key hash. // Retrieves private key matching the public key hash.
// Executes on a sequenced thread, returns with callback. // Executes on a sequenced thread, returns with callback.
void RetrieveMatchingPrivateKey( void RetrieveMatchingPrivateKey(
uint32_t public_key_id, int64_t public_key_id,
base::OnceCallback<void(StatusOr<std::string>)> cb); base::OnceCallback<void(StatusOr<std::string>)> cb);
private: private:
...@@ -102,7 +102,7 @@ class Decryptor : public base::RefCountedThreadSafe<Decryptor> { ...@@ -102,7 +102,7 @@ class Decryptor : public base::RefCountedThreadSafe<Decryptor> {
std::string private_key; std::string private_key;
base::Time time_stamp; base::Time time_stamp;
}; };
base::flat_map<uint32_t, KeyInfo> keys_; base::flat_map<int64_t, KeyInfo> keys_;
// Sequential task runner for all keys_ activities: // Sequential task runner for all keys_ activities:
// recording, lookup, purge. // recording, lookup, purge.
......
...@@ -47,7 +47,7 @@ void Encryptor::Handle::CloseRecord( ...@@ -47,7 +47,7 @@ void Encryptor::Handle::CloseRecord(
void Encryptor::Handle::ProduceEncryptedRecord( void Encryptor::Handle::ProduceEncryptedRecord(
base::OnceCallback<void(StatusOr<EncryptedRecord>)> cb, base::OnceCallback<void(StatusOr<EncryptedRecord>)> cb,
StatusOr<std::string> asymmetric_key_result) { StatusOr<std::pair<std::string, int64_t>> asymmetric_key_result) {
// Make sure the record self-destructs when returning from this method. // Make sure the record self-destructs when returning from this method.
const auto self_destruct = base::WrapUnique(this); const auto self_destruct = base::WrapUnique(this);
...@@ -57,12 +57,12 @@ void Encryptor::Handle::ProduceEncryptedRecord( ...@@ -57,12 +57,12 @@ void Encryptor::Handle::ProduceEncryptedRecord(
return; return;
} }
const auto& asymmetric_key = asymmetric_key_result.ValueOrDie(); const auto& asymmetric_key = asymmetric_key_result.ValueOrDie();
if (asymmetric_key.size() != X25519_PUBLIC_VALUE_LEN) { if (asymmetric_key.first.size() != X25519_PUBLIC_VALUE_LEN) {
std::move(cb).Run(Status( std::move(cb).Run(Status(
error::INTERNAL, error::INTERNAL,
base::StrCat({"Asymmetric key size mismatch, expected=", base::StrCat({"Asymmetric key size mismatch, expected=",
base::NumberToString(X25519_PUBLIC_VALUE_LEN), " actual=", base::NumberToString(X25519_PUBLIC_VALUE_LEN), " actual=",
base::NumberToString(asymmetric_key.size())}))); base::NumberToString(asymmetric_key.first.size())})));
return; return;
} }
...@@ -74,7 +74,7 @@ void Encryptor::Handle::ProduceEncryptedRecord( ...@@ -74,7 +74,7 @@ void Encryptor::Handle::ProduceEncryptedRecord(
// Compute shared secret. // Compute shared secret.
uint8_t out_shared_secret[X25519_SHARED_KEY_LEN]; uint8_t out_shared_secret[X25519_SHARED_KEY_LEN];
if (!X25519(out_shared_secret, out_private_key, if (!X25519(out_shared_secret, out_private_key,
reinterpret_cast<const uint8_t*>(asymmetric_key.data()))) { reinterpret_cast<const uint8_t*>(asymmetric_key.first.data()))) {
std::move(cb).Run(Status(error::DATA_LOSS, "Curve25519 encryption failed")); std::move(cb).Run(Status(error::DATA_LOSS, "Curve25519 encryption failed"));
return; return;
} }
...@@ -105,7 +105,7 @@ void Encryptor::Handle::ProduceEncryptedRecord( ...@@ -105,7 +105,7 @@ void Encryptor::Handle::ProduceEncryptedRecord(
// Prepare encrypted record. // Prepare encrypted record.
EncryptedRecord encrypted_record; EncryptedRecord encrypted_record;
encrypted_record.mutable_encryption_info()->set_public_key_id( encrypted_record.mutable_encryption_info()->set_public_key_id(
base::PersistentHash(asymmetric_key)); asymmetric_key.second);
encrypted_record.mutable_encryption_info()->set_encryption_key( encrypted_record.mutable_encryption_info()->set_encryption_key(
reinterpret_cast<const char*>(out_public_value), X25519_PUBLIC_VALUE_LEN); reinterpret_cast<const char*>(out_public_value), X25519_PUBLIC_VALUE_LEN);
...@@ -132,9 +132,10 @@ Encryptor::Encryptor() ...@@ -132,9 +132,10 @@ Encryptor::Encryptor()
Encryptor::~Encryptor() = default; Encryptor::~Encryptor() = default;
void Encryptor::UpdateAsymmetricKey( void Encryptor::UpdateAsymmetricKey(
base::StringPiece new_key, base::StringPiece new_public_key,
int64_t new_public_key_id,
base::OnceCallback<void(Status)> response_cb) { base::OnceCallback<void(Status)> response_cb) {
if (new_key.empty()) { if (new_public_key.empty()) {
std::move(response_cb) std::move(response_cb)
.Run(Status(error::INVALID_ARGUMENT, "Provided key is empty")); .Run(Status(error::INVALID_ARGUMENT, "Provided key is empty"));
return; return;
...@@ -144,10 +145,13 @@ void Encryptor::UpdateAsymmetricKey( ...@@ -144,10 +145,13 @@ void Encryptor::UpdateAsymmetricKey(
asymmetric_key_sequenced_task_runner_->PostTask( asymmetric_key_sequenced_task_runner_->PostTask(
FROM_HERE, FROM_HERE,
base::BindOnce( base::BindOnce(
[](base::StringPiece new_key, scoped_refptr<Encryptor> encryptor) { [](base::StringPiece new_public_key, int64_t new_public_key_id,
encryptor->asymmetric_key_ = std::string(new_key); scoped_refptr<Encryptor> encryptor) {
encryptor->asymmetric_key_ =
std::make_pair(std::string(new_public_key), new_public_key_id);
}, },
std::string(new_key), base::WrapRefCounted(this))); std::string(new_public_key), new_public_key_id,
base::WrapRefCounted(this)));
// Response OK not waiting for the update. // Response OK not waiting for the update.
std::move(response_cb).Run(Status::StatusOK()); std::move(response_cb).Run(Status::StatusOK());
...@@ -158,27 +162,29 @@ void Encryptor::OpenRecord(base::OnceCallback<void(StatusOr<Handle*>)> cb) { ...@@ -158,27 +162,29 @@ void Encryptor::OpenRecord(base::OnceCallback<void(StatusOr<Handle*>)> cb) {
} }
void Encryptor::RetrieveAsymmetricKey( void Encryptor::RetrieveAsymmetricKey(
base::OnceCallback<void(StatusOr<std::string>)> cb) { base::OnceCallback<void(StatusOr<std::pair<std::string, int64_t>>)> cb) {
// Schedule key retrieval on the sequenced task runner. // Schedule key retrieval on the sequenced task runner.
asymmetric_key_sequenced_task_runner_->PostTask( asymmetric_key_sequenced_task_runner_->PostTask(
FROM_HERE, FROM_HERE,
base::BindOnce( base::BindOnce(
[](base::OnceCallback<void(StatusOr<std::string>)> cb, [](base::OnceCallback<void(StatusOr<std::pair<std::string, int64_t>>)>
cb,
scoped_refptr<Encryptor> encryptor) { scoped_refptr<Encryptor> encryptor) {
DCHECK_CALLED_ON_VALID_SEQUENCE( DCHECK_CALLED_ON_VALID_SEQUENCE(
encryptor->asymmetric_key_sequence_checker_); encryptor->asymmetric_key_sequence_checker_);
StatusOr<std::string> response; StatusOr<std::pair<std::string, int64_t>> response;
// Schedule response on regular thread pool. // Schedule response on regular thread pool.
base::ThreadPool::PostTask( base::ThreadPool::PostTask(
FROM_HERE, FROM_HERE,
base::BindOnce( base::BindOnce(
[](base::OnceCallback<void(StatusOr<std::string>)> cb, [](base::OnceCallback<void(
StatusOr<std::string> response) { StatusOr<std::pair<std::string, int64_t>>)> cb,
StatusOr<std::pair<std::string, int64_t>> response) {
std::move(cb).Run(response); std::move(cb).Run(response);
}, },
std::move(cb), std::move(cb),
!encryptor->asymmetric_key_.has_value() !encryptor->asymmetric_key_.has_value()
? StatusOr<std::string>(Status( ? StatusOr<std::pair<std::string, int64_t>>(Status(
error::NOT_FOUND, "Asymmetric key not set")) error::NOT_FOUND, "Asymmetric key not set"))
: encryptor->asymmetric_key_.value())); : encryptor->asymmetric_key_.value()));
}, },
......
...@@ -65,7 +65,7 @@ class Encryptor : public base::RefCountedThreadSafe<Encryptor> { ...@@ -65,7 +65,7 @@ class Encryptor : public base::RefCountedThreadSafe<Encryptor> {
// as a callback after asynchronous retrieval of the asymmetric key. // as a callback after asynchronous retrieval of the asymmetric key.
void ProduceEncryptedRecord( void ProduceEncryptedRecord(
base::OnceCallback<void(StatusOr<EncryptedRecord>)> cb, base::OnceCallback<void(StatusOr<EncryptedRecord>)> cb,
StatusOr<std::string> asymmetric_key_result); StatusOr<std::pair<std::string, int64_t>> asymmetric_key_result);
// Accumulated data to encrypt. // Accumulated data to encrypt.
std::string record_; std::string record_;
...@@ -80,25 +80,26 @@ class Encryptor : public base::RefCountedThreadSafe<Encryptor> { ...@@ -80,25 +80,26 @@ class Encryptor : public base::RefCountedThreadSafe<Encryptor> {
// Hands the Handle raw pointer over to the callback, or error status). // Hands the Handle raw pointer over to the callback, or error status).
void OpenRecord(base::OnceCallback<void(StatusOr<Handle*>)> cb); void OpenRecord(base::OnceCallback<void(StatusOr<Handle*>)> cb);
// Delivers public asymmetric key to the implementation. // Delivers public asymmetric key and its id to the implementation.
// To affect specific record, must happen before Handle::CloseRecord // To affect specific record, must happen before Handle::CloseRecord
// (it is OK to do it after OpenRecord and Handle::AddToRecord). // (it is OK to do it after OpenRecord and Handle::AddToRecord).
// Executes on a sequenced thread, returns with callback. // Executes on a sequenced thread, returns with callback.
void UpdateAsymmetricKey(base::StringPiece new_key, void UpdateAsymmetricKey(base::StringPiece new_public_key,
int64_t new_public_key_id,
base::OnceCallback<void(Status)> response_cb); base::OnceCallback<void(Status)> response_cb);
// Retrieves the current public key. // Retrieves the current public key.
// Executes on a sequenced thread, returns with callback. // Executes on a sequenced thread, returns with callback.
void RetrieveAsymmetricKey( void RetrieveAsymmetricKey(
base::OnceCallback<void(StatusOr<std::string>)> cb); base::OnceCallback<void(StatusOr<std::pair<std::string, int64_t>>)> cb);
private: private:
friend class base::RefCountedThreadSafe<Encryptor>; friend class base::RefCountedThreadSafe<Encryptor>;
Encryptor(); Encryptor();
~Encryptor(); ~Encryptor();
// Public key used for asymmetric encryption of symmetric key. // Public key used for asymmetric encryption of symmetric key and its id.
base::Optional<std::string> asymmetric_key_; base::Optional<std::pair<std::string, int64_t>> asymmetric_key_;
// Sequential task runner for all asymmetric_key_ activities: update, read. // Sequential task runner for all asymmetric_key_ activities: update, read.
scoped_refptr<base::SequencedTaskRunner> scoped_refptr<base::SequencedTaskRunner>
......
...@@ -87,9 +87,11 @@ void EncryptionModule::EncryptRecord( ...@@ -87,9 +87,11 @@ void EncryptionModule::EncryptRecord(
} }
void EncryptionModule::UpdateAsymmetricKey( void EncryptionModule::UpdateAsymmetricKey(
base::StringPiece new_key, base::StringPiece new_public_key,
int64_t new_public_key_id,
base::OnceCallback<void(Status)> response_cb) { base::OnceCallback<void(Status)> response_cb) {
encryptor_->UpdateAsymmetricKey(new_key, std::move(response_cb)); encryptor_->UpdateAsymmetricKey(new_public_key, new_public_key_id,
std::move(response_cb));
} }
} // namespace reporting } // namespace reporting
...@@ -35,7 +35,8 @@ class EncryptionModule : public base::RefCountedThreadSafe<EncryptionModule> { ...@@ -35,7 +35,8 @@ class EncryptionModule : public base::RefCountedThreadSafe<EncryptionModule> {
// Records current public asymmetric key. // Records current public asymmetric key.
virtual void UpdateAsymmetricKey( virtual void UpdateAsymmetricKey(
base::StringPiece new_key, base::StringPiece new_public_key,
int64_t new_public_key_id,
base::OnceCallback<void(Status)> response_cb); base::OnceCallback<void(Status)> response_cb);
protected: protected:
......
...@@ -114,7 +114,7 @@ class EncryptionModuleTest : public ::testing::Test { ...@@ -114,7 +114,7 @@ class EncryptionModuleTest : public ::testing::Test {
return decrypted_string; return decrypted_string;
} }
StatusOr<std::string> DecryptMatchingSecret(uint32_t public_key_id, StatusOr<std::string> DecryptMatchingSecret(int64_t public_key_id,
base::StringPiece encrypted_key) { base::StringPiece encrypted_key) {
// Retrieve private key that matches public key hash. // Retrieve private key that matches public key hash.
TestEvent<StatusOr<std::string>> retrieve_private_key; TestEvent<StatusOr<std::string>> retrieve_private_key;
...@@ -133,19 +133,19 @@ class EncryptionModuleTest : public ::testing::Test { ...@@ -133,19 +133,19 @@ class EncryptionModuleTest : public ::testing::Test {
uint8_t out_private_key[X25519_PRIVATE_KEY_LEN]; uint8_t out_private_key[X25519_PRIVATE_KEY_LEN];
X25519_keypair(out_public_value, out_private_key); X25519_keypair(out_public_value, out_private_key);
TestEvent<Status> record_keys; TestEvent<StatusOr<int64_t>> record_keys;
decryptor_->RecordKeyPair( decryptor_->RecordKeyPair(
std::string(reinterpret_cast<const char*>(out_private_key), std::string(reinterpret_cast<const char*>(out_private_key),
X25519_PRIVATE_KEY_LEN), X25519_PRIVATE_KEY_LEN),
std::string(reinterpret_cast<const char*>(out_public_value), std::string(reinterpret_cast<const char*>(out_public_value),
X25519_PUBLIC_VALUE_LEN), X25519_PUBLIC_VALUE_LEN),
record_keys.cb()); record_keys.cb());
RETURN_IF_ERROR(record_keys.result()); ASSIGN_OR_RETURN(int64_t new_public_key_id, record_keys.result());
TestEvent<Status> set_public_key; TestEvent<Status> set_public_key;
encryption_module_->UpdateAsymmetricKey( encryption_module_->UpdateAsymmetricKey(
std::string(reinterpret_cast<const char*>(out_public_value), std::string(reinterpret_cast<const char*>(out_public_value),
X25519_PUBLIC_VALUE_LEN), X25519_PUBLIC_VALUE_LEN),
set_public_key.cb()); new_public_key_id, set_public_key.cb());
RETURN_IF_ERROR(set_public_key.result()); RETURN_IF_ERROR(set_public_key.result());
return Status::StatusOK(); return Status::StatusOK();
} }
...@@ -275,10 +275,12 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) { ...@@ -275,10 +275,12 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) {
SingleEncryptionContext( SingleEncryptionContext(
base::StringPiece test_string, base::StringPiece test_string,
base::StringPiece public_key, base::StringPiece public_key,
int64_t public_key_id,
scoped_refptr<EncryptionModule> encryption_module, scoped_refptr<EncryptionModule> encryption_module,
base::OnceCallback<void(StatusOr<EncryptedRecord>)> response) base::OnceCallback<void(StatusOr<EncryptedRecord>)> response)
: test_string_(test_string), : test_string_(test_string),
public_key_(public_key), public_key_(public_key),
public_key_id_(public_key_id),
encryption_module_(encryption_module), encryption_module_(encryption_module),
response_(std::move(response)) {} response_(std::move(response)) {}
...@@ -303,7 +305,7 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) { ...@@ -303,7 +305,7 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) {
} }
void SetPublicKey() { void SetPublicKey() {
encryption_module_->UpdateAsymmetricKey( encryption_module_->UpdateAsymmetricKey(
public_key_, public_key_, public_key_id_,
base::BindOnce( base::BindOnce(
[](SingleEncryptionContext* self, Status status) { [](SingleEncryptionContext* self, Status status) {
if (!status.ok()) { if (!status.ok()) {
...@@ -334,6 +336,7 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) { ...@@ -334,6 +336,7 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) {
private: private:
const std::string test_string_; const std::string test_string_;
const std::string public_key_; const std::string public_key_;
const int64_t public_key_id_;
const scoped_refptr<EncryptionModule> encryption_module_; const scoped_refptr<EncryptionModule> encryption_module_;
base::OnceCallback<void(StatusOr<EncryptedRecord>)> response_; base::OnceCallback<void(StatusOr<EncryptedRecord>)> response_;
}; };
...@@ -465,6 +468,7 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) { ...@@ -465,6 +468,7 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) {
// Public and private key pairs in this test are reversed strings. // Public and private key pairs in this test are reversed strings.
std::vector<std::string> private_key_strings; std::vector<std::string> private_key_strings;
std::vector<std::string> public_value_strings; std::vector<std::string> public_value_strings;
std::vector<int64_t> public_value_ids;
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
// Generate new pair of private key and public value. // Generate new pair of private key and public value.
uint8_t out_public_value[X25519_PUBLIC_VALUE_LEN]; uint8_t out_public_value[X25519_PUBLIC_VALUE_LEN];
...@@ -477,27 +481,16 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) { ...@@ -477,27 +481,16 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) {
X25519_PUBLIC_VALUE_LEN); X25519_PUBLIC_VALUE_LEN);
} }
// Encrypt all records in parallel.
std::vector<TestEvent<StatusOr<EncryptedRecord>>> results(
kTestStrings.size());
for (size_t i = 0; i < kTestStrings.size(); ++i) {
// Choose random key pair.
size_t i_key_pair = base::RandInt(0, public_value_strings.size() - 1);
(new SingleEncryptionContext(kTestStrings[i],
public_value_strings[i_key_pair],
encryption_module_, results[i].cb()))
->Start();
}
// Register all key pairs for decryption. // Register all key pairs for decryption.
std::vector<TestEvent<Status>> record_results(public_value_strings.size()); std::vector<TestEvent<StatusOr<int64_t>>> record_results(
public_value_strings.size());
for (size_t i = 0; i < public_value_strings.size(); ++i) { for (size_t i = 0; i < public_value_strings.size(); ++i) {
base::ThreadPool::PostTask( base::ThreadPool::PostTask(
FROM_HERE, base::BindOnce( FROM_HERE, base::BindOnce(
[](base::StringPiece private_key_string, [](base::StringPiece private_key_string,
base::StringPiece public_key_string, base::StringPiece public_key_string,
scoped_refptr<Decryptor> decryptor, scoped_refptr<Decryptor> decryptor,
base::OnceCallback<void(Status)> done_cb) { base::OnceCallback<void(StatusOr<int64_t>)> done_cb) {
decryptor->RecordKeyPair(private_key_string, decryptor->RecordKeyPair(private_key_string,
public_key_string, public_key_string,
std::move(done_cb)); std::move(done_cb));
...@@ -507,7 +500,21 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) { ...@@ -507,7 +500,21 @@ TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) {
} }
// Verify registration success. // Verify registration success.
for (auto& record_result : record_results) { for (auto& record_result : record_results) {
ASSERT_OK(record_result.result()) << record_result.result(); const auto result = record_result.result();
ASSERT_OK(result.status()) << result.status();
public_value_ids.push_back(result.ValueOrDie());
}
// Encrypt all records in parallel.
std::vector<TestEvent<StatusOr<EncryptedRecord>>> results(
kTestStrings.size());
for (size_t i = 0; i < kTestStrings.size(); ++i) {
// Choose random key pair.
size_t i_key_pair = base::RandInt(0, public_value_strings.size() - 1);
(new SingleEncryptionContext(
kTestStrings[i], public_value_strings[i_key_pair],
public_value_ids[i_key_pair], encryption_module_, results[i].cb()))
->Start();
} }
// Decrypt all records in parallel. // Decrypt all records in parallel.
......
...@@ -134,7 +134,7 @@ class EncryptionTest : public ::testing::Test { ...@@ -134,7 +134,7 @@ class EncryptionTest : public ::testing::Test {
return decrypted_string; return decrypted_string;
} }
StatusOr<std::string> DecryptMatchingSecret(uint32_t public_key_id, StatusOr<std::string> DecryptMatchingSecret(int64_t public_key_id,
base::StringPiece encrypted_key) { base::StringPiece encrypted_key) {
// Retrieve private key that matches public key hash. // Retrieve private key that matches public key hash.
TestEvent<StatusOr<std::string>> retrieve_private_key; TestEvent<StatusOr<std::string>> retrieve_private_key;
...@@ -153,19 +153,19 @@ class EncryptionTest : public ::testing::Test { ...@@ -153,19 +153,19 @@ class EncryptionTest : public ::testing::Test {
uint8_t out_private_key[X25519_PRIVATE_KEY_LEN]; uint8_t out_private_key[X25519_PRIVATE_KEY_LEN];
X25519_keypair(out_public_value, out_private_key); X25519_keypair(out_public_value, out_private_key);
TestEvent<Status> record_keys; TestEvent<StatusOr<int64_t>> record_keys;
decryptor_->RecordKeyPair( decryptor_->RecordKeyPair(
std::string(reinterpret_cast<const char*>(out_private_key), std::string(reinterpret_cast<const char*>(out_private_key),
X25519_PRIVATE_KEY_LEN), X25519_PRIVATE_KEY_LEN),
std::string(reinterpret_cast<const char*>(out_public_value), std::string(reinterpret_cast<const char*>(out_public_value),
X25519_PUBLIC_VALUE_LEN), X25519_PUBLIC_VALUE_LEN),
record_keys.cb()); record_keys.cb());
RETURN_IF_ERROR(record_keys.result()); ASSIGN_OR_RETURN(int64_t new_public_key_id, record_keys.result());
TestEvent<Status> set_public_key; TestEvent<Status> set_public_key;
encryptor_->UpdateAsymmetricKey( encryptor_->UpdateAsymmetricKey(
std::string(reinterpret_cast<const char*>(out_public_value), std::string(reinterpret_cast<const char*>(out_public_value),
X25519_PUBLIC_VALUE_LEN), X25519_PUBLIC_VALUE_LEN),
set_public_key.cb()); new_public_key_id, set_public_key.cb());
RETURN_IF_ERROR(set_public_key.result()); RETURN_IF_ERROR(set_public_key.result());
return Status::StatusOK(); return Status::StatusOK();
} }
...@@ -275,10 +275,12 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) { ...@@ -275,10 +275,12 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) {
SingleEncryptionContext( SingleEncryptionContext(
base::StringPiece test_string, base::StringPiece test_string,
base::StringPiece public_key, base::StringPiece public_key,
int64_t public_key_id,
scoped_refptr<Encryptor> encryptor, scoped_refptr<Encryptor> encryptor,
base::OnceCallback<void(StatusOr<EncryptedRecord>)> response) base::OnceCallback<void(StatusOr<EncryptedRecord>)> response)
: test_string_(test_string), : test_string_(test_string),
public_key_(public_key), public_key_(public_key),
public_key_id_(public_key_id),
encryptor_(encryptor), encryptor_(encryptor),
response_(std::move(response)) {} response_(std::move(response)) {}
...@@ -303,7 +305,7 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) { ...@@ -303,7 +305,7 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) {
} }
void SetPublicKey() { void SetPublicKey() {
encryptor_->UpdateAsymmetricKey( encryptor_->UpdateAsymmetricKey(
public_key_, public_key_, public_key_id_,
base::BindOnce( base::BindOnce(
[](SingleEncryptionContext* self, Status status) { [](SingleEncryptionContext* self, Status status) {
if (!status.ok()) { if (!status.ok()) {
...@@ -363,6 +365,7 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) { ...@@ -363,6 +365,7 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) {
private: private:
const std::string test_string_; const std::string test_string_;
const std::string public_key_; const std::string public_key_;
const int64_t public_key_id_;
const scoped_refptr<Encryptor> encryptor_; const scoped_refptr<Encryptor> encryptor_;
base::OnceCallback<void(StatusOr<EncryptedRecord>)> response_; base::OnceCallback<void(StatusOr<EncryptedRecord>)> response_;
}; };
...@@ -494,6 +497,7 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) { ...@@ -494,6 +497,7 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) {
// Public and private key pairs in this test are reversed strings. // Public and private key pairs in this test are reversed strings.
std::vector<std::string> private_key_strings; std::vector<std::string> private_key_strings;
std::vector<std::string> public_value_strings; std::vector<std::string> public_value_strings;
std::vector<int64_t> public_value_ids;
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
// Generate new pair of private key and public value. // Generate new pair of private key and public value.
uint8_t out_public_value[X25519_PUBLIC_VALUE_LEN]; uint8_t out_public_value[X25519_PUBLIC_VALUE_LEN];
...@@ -506,27 +510,16 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) { ...@@ -506,27 +510,16 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) {
X25519_PUBLIC_VALUE_LEN); X25519_PUBLIC_VALUE_LEN);
} }
// Encrypt all records in parallel.
std::vector<TestEvent<StatusOr<EncryptedRecord>>> results(
kTestStrings.size());
for (size_t i = 0; i < kTestStrings.size(); ++i) {
// Choose random key pair.
size_t i_key_pair = base::RandInt(0, public_value_strings.size() - 1);
(new SingleEncryptionContext(kTestStrings[i],
public_value_strings[i_key_pair], encryptor_,
results[i].cb()))
->Start();
}
// Register all key pairs for decryption. // Register all key pairs for decryption.
std::vector<TestEvent<Status>> record_results(public_value_strings.size()); std::vector<TestEvent<StatusOr<int64_t>>> record_results(
public_value_strings.size());
for (size_t i = 0; i < public_value_strings.size(); ++i) { for (size_t i = 0; i < public_value_strings.size(); ++i) {
base::ThreadPool::PostTask( base::ThreadPool::PostTask(
FROM_HERE, base::BindOnce( FROM_HERE, base::BindOnce(
[](base::StringPiece private_key_string, [](base::StringPiece private_key_string,
base::StringPiece public_key_string, base::StringPiece public_key_string,
scoped_refptr<Decryptor> decryptor, scoped_refptr<Decryptor> decryptor,
base::OnceCallback<void(Status)> done_cb) { base::OnceCallback<void(StatusOr<int64_t>)> done_cb) {
decryptor->RecordKeyPair(private_key_string, decryptor->RecordKeyPair(private_key_string,
public_key_string, public_key_string,
std::move(done_cb)); std::move(done_cb));
...@@ -536,7 +529,21 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) { ...@@ -536,7 +529,21 @@ TEST_F(EncryptionTest, EncryptAndDecryptMultipleParallel) {
} }
// Verify registration success. // Verify registration success.
for (auto& record_result : record_results) { for (auto& record_result : record_results) {
ASSERT_OK(record_result.result()) << record_result.result(); const auto result = record_result.result();
ASSERT_OK(result.status()) << result.status();
public_value_ids.push_back(result.ValueOrDie());
}
// Encrypt all records in parallel.
std::vector<TestEvent<StatusOr<EncryptedRecord>>> results(
kTestStrings.size());
for (size_t i = 0; i < kTestStrings.size(); ++i) {
// Choose random key pair.
size_t i_key_pair = base::RandInt(0, public_value_strings.size() - 1);
(new SingleEncryptionContext(
kTestStrings[i], public_value_strings[i_key_pair],
public_value_ids[i_key_pair], encryptor_, results[i].cb()))
->Start();
} }
// Decrypt all records in parallel. // Decrypt all records in parallel.
......
...@@ -27,7 +27,8 @@ TestEncryptionModuleStrict::TestEncryptionModuleStrict() { ...@@ -27,7 +27,8 @@ TestEncryptionModuleStrict::TestEncryptionModuleStrict() {
} }
void TestEncryptionModuleStrict::UpdateAsymmetricKey( void TestEncryptionModuleStrict::UpdateAsymmetricKey(
base::StringPiece new_key, base::StringPiece new_public_key,
int64_t new_public_key_id,
base::OnceCallback<void(Status)> response_cb) { base::OnceCallback<void(Status)> response_cb) {
std::move(response_cb) std::move(response_cb)
.Run(Status(error::UNIMPLEMENTED, .Run(Status(error::UNIMPLEMENTED,
......
...@@ -28,7 +28,8 @@ class TestEncryptionModuleStrict : public EncryptionModule { ...@@ -28,7 +28,8 @@ class TestEncryptionModuleStrict : public EncryptionModule {
(const override)); (const override));
void UpdateAsymmetricKey( void UpdateAsymmetricKey(
base::StringPiece new_key, base::StringPiece new_public_key,
int64_t new_public_key_id,
base::OnceCallback<void(Status)> response_cb) override; base::OnceCallback<void(Status)> response_cb) override;
protected: protected:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment