Commit c18f5642 authored by Leonid Baraz's avatar Leonid Baraz Committed by Chromium LUCI CQ

Add keys delivery back to the client.

Bug: b:170054326
Change-Id: Id1948ddad64e6bec3ed960a83b910073b18bb811
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2582863
Commit-Queue: Leonid Baraz <lbaraz@chromium.org>
Reviewed-by: default avatarZach Trudo <zatrudo@google.com>
Cr-Commit-Position: refs/heads/master@{#835838}
parent 9ee6c6d0
...@@ -414,6 +414,8 @@ void ReportingClient::InitializingContext::CreateUploadClient() { ...@@ -414,6 +414,8 @@ void ReportingClient::InitializingContext::CreateUploadClient() {
std::move(client_config_->cloud_policy_client), std::move(client_config_->cloud_policy_client),
base::BindRepeating(&StorageModule::ReportSuccess, base::BindRepeating(&StorageModule::ReportSuccess,
client_config_->storage), client_config_->storage),
base::BindRepeating(&StorageModule::UpdateEncryptionKey,
client_config_->storage),
base::BindOnce(&InitializingContext::OnUploadClientCreated, base::BindOnce(&InitializingContext::OnUploadClientCreated,
base::Unretained(this))); base::Unretained(this)));
} }
......
...@@ -166,6 +166,14 @@ void Storage::Create( ...@@ -166,6 +166,14 @@ void Storage::Create(
void OnStart() override { void OnStart() override {
CheckOnValidSequence(); CheckOnValidSequence();
// TODO(b/170054326): Locate the latest signed_encryption_key file with
// matching key signature after deserialization. Call
// storage_->encryption_module_->UpdateAsymmetricKey(...) with the key and
// id.
// DCHECK(storage_->encryption_module_->has_encryption_key());
// Construct all queues.
for (const auto& queue_options : queues_options_) { for (const auto& queue_options : queues_options_) {
StorageQueue::Create( StorageQueue::Create(
/*options=*/queue_options.second, /*options=*/queue_options.second,
...@@ -264,6 +272,24 @@ Status Storage::Flush(Priority priority) { ...@@ -264,6 +272,24 @@ Status Storage::Flush(Priority priority) {
return Status::StatusOK(); return Status::StatusOK();
} }
void Storage::UpdateEncryptionKey(SignedEncryptionInfo signed_encryption_key) {
// TODO(b/170054326): Verify received key signature. Bail out if failed.
// TODO(b/170054326): Serialize whole signed_encryption_key to a new file,
// discard the old one.
// Assign the received key to encryption module.
encryption_module_->UpdateAsymmetricKey(
signed_encryption_key.public_asymmetric_key(),
signed_encryption_key.public_key_id(), base::BindOnce([](Status status) {
if (!status.ok()) {
LOG(WARNING) << "Encryption key update failed, status=" << status;
return;
}
// Encryption key updated successfully.
}));
}
bool Storage::has_encryption_key() const { bool Storage::has_encryption_key() const {
return !encryption_module_->has_encryption_key(); return !encryption_module_->has_encryption_key();
} }
......
...@@ -65,6 +65,10 @@ class Storage : public base::RefCountedThreadSafe<Storage> { ...@@ -65,6 +65,10 @@ class Storage : public base::RefCountedThreadSafe<Storage> {
// Returns error if cannot start upload. // Returns error if cannot start upload.
Status Flush(Priority priority); Status Flush(Priority priority);
// If the server attached signed encryption key to the response, it needs to
// be paased here.
void UpdateEncryptionKey(SignedEncryptionInfo signed_encryption_key);
// Returns `false` if encryption key has not been found in the Storage during // Returns `false` if encryption key has not been found in the Storage during
// initialization and not received from the server yet, and `true` otherwise. // initialization and not received from the server yet, and `true` otherwise.
// The result is lazy: the method may return `false` for some time even after // The result is lazy: the method may return `false` for some time even after
......
...@@ -41,6 +41,11 @@ void StorageModule::ReportSuccess( ...@@ -41,6 +41,11 @@ void StorageModule::ReportSuccess(
})); }));
} }
void StorageModule::UpdateEncryptionKey(
SignedEncryptionInfo signed_encryption_key) {
storage_->UpdateEncryptionKey(std::move(signed_encryption_key));
}
// static // static
void StorageModule::Create( void StorageModule::Create(
const StorageOptions& options, const StorageOptions& options,
......
...@@ -44,6 +44,10 @@ class StorageModule : public base::RefCountedThreadSafe<StorageModule> { ...@@ -44,6 +44,10 @@ class StorageModule : public base::RefCountedThreadSafe<StorageModule> {
// can be passed back to the StorageModule here for record deletion. // can be passed back to the StorageModule here for record deletion.
virtual void ReportSuccess(SequencingInformation sequencing_information); virtual void ReportSuccess(SequencingInformation sequencing_information);
// If the server attached signed encryption key to the response, it needs to
// be paased here.
virtual void UpdateEncryptionKey(SignedEncryptionInfo signed_encryption_key);
// Returns `false` if encryption key has not been found in the Storage during // Returns `false` if encryption key has not been found in the Storage during
// initialization and not received from the server yet, and `true` otherwise. // initialization and not received from the server yet, and `true` otherwise.
// The result is lazy: the method may return `false` for some time even after // The result is lazy: the method may return `false` for some time even after
......
...@@ -73,11 +73,13 @@ DmServerUploader::DmServerUploader( ...@@ -73,11 +73,13 @@ DmServerUploader::DmServerUploader(
std::unique_ptr<std::vector<EncryptedRecord>> records, std::unique_ptr<std::vector<EncryptedRecord>> records,
RecordHandler* handler, RecordHandler* handler,
CompletionCallback completion_cb, CompletionCallback completion_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb,
scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner) scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner)
: TaskRunnerContext<CompletionResponse>(std::move(completion_cb), : TaskRunnerContext<CompletionResponse>(std::move(completion_cb),
sequenced_task_runner), sequenced_task_runner),
need_encryption_key_(need_encryption_key), need_encryption_key_(need_encryption_key),
encrypted_records_(std::move(records)), encrypted_records_(std::move(records)),
encryption_key_attached_cb_(encryption_key_attached_cb),
handler_(handler) { handler_(handler) {
DETACH_FROM_SEQUENCE(sequence_checker_); DETACH_FROM_SEQUENCE(sequence_checker_);
} }
...@@ -138,7 +140,8 @@ void DmServerUploader::HandleRecords() { ...@@ -138,7 +140,8 @@ void DmServerUploader::HandleRecords() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
handler_->HandleRecords( handler_->HandleRecords(
need_encryption_key_, std::move(encrypted_records_), need_encryption_key_, std::move(encrypted_records_),
base::BindOnce(&DmServerUploader::Complete, base::Unretained(this))); base::BindOnce(&DmServerUploader::Complete, base::Unretained(this)),
encryption_key_attached_cb_);
} }
void DmServerUploader::Complete(CompletionResponse completion_response) { void DmServerUploader::Complete(CompletionResponse completion_response) {
...@@ -168,6 +171,7 @@ Status DmServerUploader::IsRecordValid( ...@@ -168,6 +171,7 @@ Status DmServerUploader::IsRecordValid(
void DmServerUploadService::Create( void DmServerUploadService::Create(
policy::CloudPolicyClient* client, policy::CloudPolicyClient* client,
ReportSuccessfulUploadCallback report_upload_success_cb, ReportSuccessfulUploadCallback report_upload_success_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb,
base::OnceCallback<void(StatusOr<std::unique_ptr<DmServerUploadService>>)> base::OnceCallback<void(StatusOr<std::unique_ptr<DmServerUploadService>>)>
created_cb) { created_cb) {
if (client == nullptr) { if (client == nullptr) {
...@@ -176,16 +180,18 @@ void DmServerUploadService::Create( ...@@ -176,16 +180,18 @@ void DmServerUploadService::Create(
return; return;
} }
auto uploader = base::WrapUnique( auto uploader = base::WrapUnique(new DmServerUploadService(
new DmServerUploadService(std::move(client), report_upload_success_cb)); std::move(client), report_upload_success_cb, encryption_key_attached_cb));
InitRecordHandler(std::move(uploader), std::move(created_cb)); InitRecordHandler(std::move(uploader), std::move(created_cb));
} }
DmServerUploadService::DmServerUploadService( DmServerUploadService::DmServerUploadService(
policy::CloudPolicyClient* client, policy::CloudPolicyClient* client,
ReportSuccessfulUploadCallback upload_cb) ReportSuccessfulUploadCallback upload_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb)
: client_(std::move(client)), : client_(std::move(client)),
upload_cb_(upload_cb), upload_cb_(upload_cb),
encryption_key_attached_cb_(encryption_key_attached_cb),
sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})) {} sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})) {}
DmServerUploadService::~DmServerUploadService() = default; DmServerUploadService::~DmServerUploadService() = default;
...@@ -197,7 +203,7 @@ Status DmServerUploadService::EnqueueUpload( ...@@ -197,7 +203,7 @@ Status DmServerUploadService::EnqueueUpload(
need_encryption_key, std::move(records), handler_.get(), need_encryption_key, std::move(records), handler_.get(),
base::BindOnce(&DmServerUploadService::UploadCompletion, base::BindOnce(&DmServerUploadService::UploadCompletion,
base::Unretained(this)), base::Unretained(this)),
sequenced_task_runner_); encryption_key_attached_cb_, sequenced_task_runner_);
return Status::StatusOK(); return Status::StatusOK();
} }
......
...@@ -42,6 +42,11 @@ class DmServerUploadService { ...@@ -42,6 +42,11 @@ class DmServerUploadService {
using ReportSuccessfulUploadCallback = using ReportSuccessfulUploadCallback =
base::RepeatingCallback<void(SequencingInformation)>; base::RepeatingCallback<void(SequencingInformation)>;
// ReceivedEncryptionKeyCallback is called if server attached encryption key
// to the response.
using EncryptionKeyAttachedCallback =
base::RepeatingCallback<void(SignedEncryptionInfo)>;
using CompletionResponse = StatusOr<SequencingInformation>; using CompletionResponse = StatusOr<SequencingInformation>;
using CompletionCallback = base::OnceCallback<void(CompletionResponse)>; using CompletionCallback = base::OnceCallback<void(CompletionResponse)>;
...@@ -58,7 +63,8 @@ class DmServerUploadService { ...@@ -58,7 +63,8 @@ class DmServerUploadService {
// the encryption key from the server (either because it does not have it // the encryption key from the server (either because it does not have it
// or because the one it has is old and may be outdated). In that case // or because the one it has is old and may be outdated). In that case
// it is ok for |records| to be empty (otherwise at least one record must // it is ok for |records| to be empty (otherwise at least one record must
// be present). // be present). If response has the key info attached, it is decoded and
// handed over to |encryption_key_attached_cb|.
// Once the server has responded |upload_complete| is called with either the // Once the server has responded |upload_complete| is called with either the
// highest accepted SequencingInformation, or an error detailing the failure // highest accepted SequencingInformation, or an error detailing the failure
// cause. // cause.
...@@ -66,7 +72,9 @@ class DmServerUploadService { ...@@ -66,7 +72,9 @@ class DmServerUploadService {
virtual void HandleRecords( virtual void HandleRecords(
bool need_encryption_key, bool need_encryption_key,
std::unique_ptr<std::vector<EncryptedRecord>> records, std::unique_ptr<std::vector<EncryptedRecord>> records,
DmServerUploadService::CompletionCallback upload_complete) = 0; DmServerUploadService::CompletionCallback upload_complete,
DmServerUploadService::EncryptionKeyAttachedCallback
encryption_key_attached_cb) = 0;
protected: protected:
explicit RecordHandler(policy::CloudPolicyClient* client); explicit RecordHandler(policy::CloudPolicyClient* client);
...@@ -86,6 +94,7 @@ class DmServerUploadService { ...@@ -86,6 +94,7 @@ class DmServerUploadService {
std::unique_ptr<std::vector<EncryptedRecord>> records, std::unique_ptr<std::vector<EncryptedRecord>> records,
RecordHandler* handler, RecordHandler* handler,
CompletionCallback completion_cb, CompletionCallback completion_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb,
scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner); scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner);
private: private:
...@@ -129,6 +138,7 @@ class DmServerUploadService { ...@@ -129,6 +138,7 @@ class DmServerUploadService {
const bool need_encryption_key_; const bool need_encryption_key_;
std::unique_ptr<std::vector<EncryptedRecord>> encrypted_records_; std::unique_ptr<std::vector<EncryptedRecord>> encrypted_records_;
EncryptionKeyAttachedCallback encryption_key_attached_cb_;
RecordHandler* handler_; RecordHandler* handler_;
base::Optional<SequencingInformation> highest_successful_sequence_; base::Optional<SequencingInformation> highest_successful_sequence_;
...@@ -144,9 +154,12 @@ class DmServerUploadService { ...@@ -144,9 +154,12 @@ class DmServerUploadService {
// //
// |report_upload_success_cb| should report back to the holder of the created // |report_upload_success_cb| should report back to the holder of the created
// object whenever a record set is successfully uploaded. // object whenever a record set is successfully uploaded.
// |encryption_key_attached_cb| if called would update the encryption key with
// the one received from the server.
static void Create( static void Create(
policy::CloudPolicyClient* client, policy::CloudPolicyClient* client,
ReportSuccessfulUploadCallback report_upload_success_cb, ReportSuccessfulUploadCallback report_upload_success_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb,
base::OnceCallback<void(StatusOr<std::unique_ptr<DmServerUploadService>>)> base::OnceCallback<void(StatusOr<std::unique_ptr<DmServerUploadService>>)>
created_cb); created_cb);
~DmServerUploadService(); ~DmServerUploadService();
...@@ -155,8 +168,10 @@ class DmServerUploadService { ...@@ -155,8 +168,10 @@ class DmServerUploadService {
std::unique_ptr<std::vector<EncryptedRecord>> record); std::unique_ptr<std::vector<EncryptedRecord>> record);
private: private:
DmServerUploadService(policy::CloudPolicyClient* client, DmServerUploadService(
ReportSuccessfulUploadCallback completion_cb); policy::CloudPolicyClient* client,
ReportSuccessfulUploadCallback completion_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb);
static void InitRecordHandler( static void InitRecordHandler(
std::unique_ptr<DmServerUploadService> uploader, std::unique_ptr<DmServerUploadService> uploader,
...@@ -169,6 +184,7 @@ class DmServerUploadService { ...@@ -169,6 +184,7 @@ class DmServerUploadService {
policy::CloudPolicyClient* client_; policy::CloudPolicyClient* client_;
ReportSuccessfulUploadCallback upload_cb_; ReportSuccessfulUploadCallback upload_cb_;
EncryptionKeyAttachedCallback encryption_key_attached_cb_;
std::unique_ptr<RecordHandler> handler_; std::unique_ptr<RecordHandler> handler_;
scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner_; scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner_;
......
...@@ -25,7 +25,9 @@ namespace { ...@@ -25,7 +25,9 @@ namespace {
using ::testing::_; using ::testing::_;
using ::testing::Invoke; using ::testing::Invoke;
using ::testing::MockFunction;
using ::testing::Return; using ::testing::Return;
using ::testing::StrictMock;
using ::testing::WithArgs; using ::testing::WithArgs;
// Usage (in tests only): // Usage (in tests only):
...@@ -68,7 +70,8 @@ class TestEvent { ...@@ -68,7 +70,8 @@ class TestEvent {
TEST(DmServerUploadServiceTest, DeniesNullptrProfile) { TEST(DmServerUploadServiceTest, DeniesNullptrProfile) {
content::BrowserTaskEnvironment task_envrionment; content::BrowserTaskEnvironment task_envrionment;
TestEvent<StatusOr<std::unique_ptr<DmServerUploadService>>> e; TestEvent<StatusOr<std::unique_ptr<DmServerUploadService>>> e;
DmServerUploadService::Create(/*profile=*/nullptr, base::DoNothing(), e.cb()); DmServerUploadService::Create(/*client=*/nullptr, base::DoNothing(),
base::DoNothing(), e.cb());
StatusOr<std::unique_ptr<DmServerUploadService>> result = e.result(); StatusOr<std::unique_ptr<DmServerUploadService>> result = e.result();
EXPECT_FALSE(result.ok()); EXPECT_FALSE(result.ok());
EXPECT_EQ(result.status().error_code(), error::INVALID_ARGUMENT); EXPECT_EQ(result.status().error_code(), error::INVALID_ARGUMENT);
...@@ -123,21 +126,24 @@ class TestRecordHandler : public DmServerUploadService::RecordHandler { ...@@ -123,21 +126,24 @@ class TestRecordHandler : public DmServerUploadService::RecordHandler {
TestRecordHandler() : RecordHandler(/*client=*/nullptr) {} TestRecordHandler() : RecordHandler(/*client=*/nullptr) {}
~TestRecordHandler() override = default; ~TestRecordHandler() override = default;
void HandleRecords( void HandleRecords(bool need_encryption_key,
bool need_encryption_key,
std::unique_ptr<std::vector<EncryptedRecord>> records, std::unique_ptr<std::vector<EncryptedRecord>> records,
DmServerUploadService::CompletionCallback upload_complete) override { DmServerUploadService::CompletionCallback upload_complete,
HandleRecords_(need_encryption_key, records, upload_complete); DmServerUploadService::EncryptionKeyAttachedCallback
encryption_key_attached_cb) override {
HandleRecords_(need_encryption_key, records, upload_complete,
encryption_key_attached_cb);
} }
MOCK_METHOD(void, MOCK_METHOD(void,
HandleRecords_, HandleRecords_,
(bool, (bool,
std::unique_ptr<std::vector<EncryptedRecord>>&, std::unique_ptr<std::vector<EncryptedRecord>>&,
DmServerUploadService::CompletionCallback&)); DmServerUploadService::CompletionCallback&,
DmServerUploadService::EncryptionKeyAttachedCallback&));
}; };
class DmServerUploaderTest : public testing::Test { class DmServerUploaderTest : public ::testing::TestWithParam<bool> {
public: public:
DmServerUploaderTest() DmServerUploaderTest()
: sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})), : sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})),
...@@ -145,6 +151,8 @@ class DmServerUploaderTest : public testing::Test { ...@@ -145,6 +151,8 @@ class DmServerUploaderTest : public testing::Test {
records_(std::make_unique<std::vector<EncryptedRecord>>()) {} records_(std::make_unique<std::vector<EncryptedRecord>>()) {}
protected: protected:
bool need_encryption_key() const { return GetParam(); }
content::BrowserTaskEnvironment task_envrionment_{ content::BrowserTaskEnvironment task_envrionment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME}; base::test::TaskEnvironment::TimeSource::MOCK_TIME};
...@@ -156,29 +164,44 @@ class DmServerUploaderTest : public testing::Test { ...@@ -156,29 +164,44 @@ class DmServerUploaderTest : public testing::Test {
const base::TimeDelta kMaxDelay_ = base::TimeDelta::FromSeconds(1); const base::TimeDelta kMaxDelay_ = base::TimeDelta::FromSeconds(1);
}; };
TEST_F(DmServerUploaderTest, ProcessesRecord) { using TestEncryptionKeyAttached = MockFunction<void(SignedEncryptionInfo)>;
TEST_P(DmServerUploaderTest, ProcessesRecord) {
// Add an empty record. // Add an empty record.
records_->emplace_back(); records_->emplace_back();
EXPECT_CALL(*handler_, HandleRecords_(_, _, _)) EXPECT_CALL(*handler_, HandleRecords_(_, _, _, _))
.WillOnce(WithArgs<2>( .WillOnce(WithArgs<0, 2, 3>(
Invoke([](DmServerUploadService::CompletionCallback& callback) { Invoke([](bool need_encryption_key,
DmServerUploadService::CompletionCallback& callback,
DmServerUploadService::EncryptionKeyAttachedCallback&
encryption_key_attached_cb) {
if (need_encryption_key) {
encryption_key_attached_cb.Run(SignedEncryptionInfo());
}
std::move(callback).Run(SequencingInformation()); std::move(callback).Run(SequencingInformation());
}))); })));
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_))
.Times(need_encryption_key() ? 1 : 0);
auto encryption_key_attached_cb =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
TestCallbackWaiter callback_waiter; TestCallbackWaiter callback_waiter;
DmServerUploadService::CompletionCallback cb = DmServerUploadService::CompletionCallback cb =
base::BindOnce(&TestCallbackWaiter::CompleteExpectSuccess, base::BindOnce(&TestCallbackWaiter::CompleteExpectSuccess,
base::Unretained(&callback_waiter)); base::Unretained(&callback_waiter));
Start<DmServerUploadService::DmServerUploader>( Start<DmServerUploadService::DmServerUploader>(
/*need_encryption_key=*/false, std::move(records_), handler_.get(), need_encryption_key(), std::move(records_), handler_.get(), std::move(cb),
std::move(cb), sequenced_task_runner_); encryption_key_attached_cb, sequenced_task_runner_);
callback_waiter.Wait(); callback_waiter.Wait();
} }
TEST_F(DmServerUploaderTest, ProcessesRecords) { TEST_P(DmServerUploaderTest, ProcessesRecords) {
uint64_t kNumberOfRecords = 10; uint64_t kNumberOfRecords = 10;
uint64_t kGenerationId = 1234; uint64_t kGenerationId = 1234;
...@@ -194,82 +217,117 @@ TEST_F(DmServerUploaderTest, ProcessesRecords) { ...@@ -194,82 +217,117 @@ TEST_F(DmServerUploaderTest, ProcessesRecords) {
records_->push_back(std::move(encrypted_record)); records_->push_back(std::move(encrypted_record));
} }
EXPECT_CALL(*handler_, HandleRecords_(_, _, _)) EXPECT_CALL(*handler_, HandleRecords_(_, _, _, _))
.WillOnce(WithArgs<2>( .WillOnce(WithArgs<0, 2, 3>(
Invoke([](DmServerUploadService::CompletionCallback& callback) { Invoke([](bool need_encryption_key,
DmServerUploadService::CompletionCallback& callback,
DmServerUploadService::EncryptionKeyAttachedCallback&
encryption_key_attached_cb) {
if (need_encryption_key) {
encryption_key_attached_cb.Run(SignedEncryptionInfo());
}
std::move(callback).Run(SequencingInformation()); std::move(callback).Run(SequencingInformation());
}))); })));
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_))
.Times(need_encryption_key() ? 1 : 0);
auto encryption_key_attached_cb =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
TestCallbackWaiter callback_waiter; TestCallbackWaiter callback_waiter;
DmServerUploadService::CompletionCallback cb = DmServerUploadService::CompletionCallback cb =
base::BindOnce(&TestCallbackWaiter::CompleteExpectSuccess, base::BindOnce(&TestCallbackWaiter::CompleteExpectSuccess,
base::Unretained(&callback_waiter)); base::Unretained(&callback_waiter));
Start<DmServerUploadService::DmServerUploader>( Start<DmServerUploadService::DmServerUploader>(
/*need_encryption_key=*/false, std::move(records_), handler_.get(), need_encryption_key(), std::move(records_), handler_.get(), std::move(cb),
std::move(cb), sequenced_task_runner_); encryption_key_attached_cb, sequenced_task_runner_);
callback_waiter.Wait(); callback_waiter.Wait();
} }
TEST_F(DmServerUploaderTest, ReportsFailureToProcess) { TEST_P(DmServerUploaderTest, ReportsFailureToProcess) {
// Add an empty record. // Add an empty record.
records_->emplace_back(); records_->emplace_back();
EXPECT_CALL(*handler_, HandleRecords_(_, _, _)) EXPECT_CALL(*handler_, HandleRecords_(_, _, _, _))
.WillOnce(WithArgs<2>( .WillOnce(WithArgs<2>(
Invoke([](DmServerUploadService::CompletionCallback& callback) { Invoke([](DmServerUploadService::CompletionCallback& callback) {
std::move(callback).Run( std::move(callback).Run(
Status(error::FAILED_PRECONDITION, "Fail for test")); Status(error::FAILED_PRECONDITION, "Fail for test"));
}))); })));
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_)).Times(0);
auto encryption_key_attached_cb =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
TestCallbackWaiter callback_waiter; TestCallbackWaiter callback_waiter;
DmServerUploadService::CompletionCallback cb = DmServerUploadService::CompletionCallback cb =
base::BindOnce(&TestCallbackWaiter::CompleteExpectFailedPrecondition, base::BindOnce(&TestCallbackWaiter::CompleteExpectFailedPrecondition,
base::Unretained(&callback_waiter)); base::Unretained(&callback_waiter));
Start<DmServerUploadService::DmServerUploader>( Start<DmServerUploadService::DmServerUploader>(
/*need_encryption_key=*/false, std::move(records_), handler_.get(), need_encryption_key(), std::move(records_), handler_.get(), std::move(cb),
std::move(cb), sequenced_task_runner_); encryption_key_attached_cb, sequenced_task_runner_);
callback_waiter.Wait(); callback_waiter.Wait();
} }
TEST_F(DmServerUploaderTest, ReportsFailureToUpload) { TEST_P(DmServerUploaderTest, ReportsFailureToUpload) {
// Add an empty record. // Add an empty record.
records_->emplace_back(); records_->emplace_back();
EXPECT_CALL(*handler_, HandleRecords_(_, _, _)) EXPECT_CALL(*handler_, HandleRecords_(_, _, _, _))
.WillOnce(WithArgs<2>( .WillOnce(WithArgs<2>(
Invoke([](DmServerUploadService::CompletionCallback& callback) { Invoke([](DmServerUploadService::CompletionCallback& callback) {
std::move(callback).Run( std::move(callback).Run(
Status(error::DEADLINE_EXCEEDED, "Fail for test")); Status(error::DEADLINE_EXCEEDED, "Fail for test"));
}))); })));
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_)).Times(0);
auto encryption_key_attached_cb =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
TestCallbackWaiter callback_waiter; TestCallbackWaiter callback_waiter;
DmServerUploadService::CompletionCallback cb = DmServerUploadService::CompletionCallback cb =
base::BindOnce(&TestCallbackWaiter::CompleteExpectDeadlineExceeded, base::BindOnce(&TestCallbackWaiter::CompleteExpectDeadlineExceeded,
base::Unretained(&callback_waiter)); base::Unretained(&callback_waiter));
Start<DmServerUploadService::DmServerUploader>( Start<DmServerUploadService::DmServerUploader>(
/*need_encryption_key=*/false, std::move(records_), handler_.get(), need_encryption_key(), std::move(records_), handler_.get(), std::move(cb),
std::move(cb), sequenced_task_runner_); encryption_key_attached_cb, sequenced_task_runner_);
callback_waiter.Wait(); callback_waiter.Wait();
} }
TEST_F(DmServerUploaderTest, FailWithZeroRecords) { TEST_P(DmServerUploaderTest, FailWithZeroRecords) {
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_)).Times(0);
auto encryption_key_attached_cb =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
TestCallbackWaiter callback_waiter; TestCallbackWaiter callback_waiter;
DmServerUploadService::CompletionCallback cb = DmServerUploadService::CompletionCallback cb =
base::BindOnce(&TestCallbackWaiter::CompleteExpectInvalidArgument, base::BindOnce(&TestCallbackWaiter::CompleteExpectInvalidArgument,
base::Unretained(&callback_waiter)); base::Unretained(&callback_waiter));
Start<DmServerUploadService::DmServerUploader>( Start<DmServerUploadService::DmServerUploader>(
/*need_encryption_key=*/false, std::move(records_), handler_.get(), need_encryption_key(), std::move(records_), handler_.get(), std::move(cb),
std::move(cb), sequenced_task_runner_); base::DoNothing(), sequenced_task_runner_);
callback_waiter.Wait(); callback_waiter.Wait();
} }
INSTANTIATE_TEST_SUITE_P(NeedOrNoNeedKey,
DmServerUploaderTest,
testing::Bool());
} // namespace } // namespace
} // namespace reporting } // namespace reporting
...@@ -43,6 +43,8 @@ class RecordHandlerImpl::ReportUploader ...@@ -43,6 +43,8 @@ class RecordHandlerImpl::ReportUploader
std::unique_ptr<std::vector<EncryptedRecord>> records, std::unique_ptr<std::vector<EncryptedRecord>> records,
policy::CloudPolicyClient* client, policy::CloudPolicyClient* client,
DmServerUploadService::CompletionCallback upload_complete_cb, DmServerUploadService::CompletionCallback upload_complete_cb,
DmServerUploadService::EncryptionKeyAttachedCallback
encryption_key_attached_cb,
scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner); scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner);
private: private:
...@@ -68,6 +70,10 @@ class RecordHandlerImpl::ReportUploader ...@@ -68,6 +70,10 @@ class RecordHandlerImpl::ReportUploader
std::unique_ptr<std::vector<EncryptedRecord>> records_; std::unique_ptr<std::vector<EncryptedRecord>> records_;
policy::CloudPolicyClient* client_; policy::CloudPolicyClient* client_;
// Encryption key delivery callback.
DmServerUploadService::EncryptionKeyAttachedCallback
encryption_key_attached_cb_;
// Last successful response to be processed. // Last successful response to be processed.
// Note: I could not find a way to pass it as a parameter, // Note: I could not find a way to pass it as a parameter,
// so it is a class member variable. |last_response_| must be processed before // so it is a class member variable. |last_response_| must be processed before
...@@ -83,13 +89,16 @@ RecordHandlerImpl::ReportUploader::ReportUploader( ...@@ -83,13 +89,16 @@ RecordHandlerImpl::ReportUploader::ReportUploader(
std::unique_ptr<std::vector<EncryptedRecord>> records, std::unique_ptr<std::vector<EncryptedRecord>> records,
policy::CloudPolicyClient* client, policy::CloudPolicyClient* client,
DmServerUploadService::CompletionCallback client_cb, DmServerUploadService::CompletionCallback client_cb,
DmServerUploadService::EncryptionKeyAttachedCallback
encryption_key_attached_cb,
scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner) scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner)
: TaskRunnerContext<DmServerUploadService::CompletionResponse>( : TaskRunnerContext<DmServerUploadService::CompletionResponse>(
std::move(client_cb), std::move(client_cb),
sequenced_task_runner), sequenced_task_runner),
need_encryption_key_(need_encryption_key), need_encryption_key_(need_encryption_key),
records_(std::move(records)), records_(std::move(records)),
client_(client) {} client_(client),
encryption_key_attached_cb_(encryption_key_attached_cb) {}
RecordHandlerImpl::ReportUploader::~ReportUploader() = default; RecordHandlerImpl::ReportUploader::~ReportUploader() = default;
...@@ -129,7 +138,8 @@ void RecordHandlerImpl::ReportUploader::StartUpload( ...@@ -129,7 +138,8 @@ void RecordHandlerImpl::ReportUploader::StartUpload(
base::BindOnce(&RecordHandlerImpl::ReportUploader::OnUploadComplete, base::BindOnce(&RecordHandlerImpl::ReportUploader::OnUploadComplete,
base::Unretained(this)); base::Unretained(this));
auto request_result = UploadEncryptedReportingRequestBuilder() auto request_result =
UploadEncryptedReportingRequestBuilder(need_encryption_key)
.AddRecord(encrypted_record) .AddRecord(encrypted_record)
.Build(); .Build();
if (!request_result.has_value()) { if (!request_result.has_value()) {
...@@ -187,6 +197,7 @@ void RecordHandlerImpl::ReportUploader::HandleSuccessfulUpload() { ...@@ -187,6 +197,7 @@ void RecordHandlerImpl::ReportUploader::HandleSuccessfulUpload() {
// "failedUploadedRecord": ... // SequencingInformation proto // "failedUploadedRecord": ... // SequencingInformation proto
// "failureStatus": ... // Status proto // "failureStatus": ... // Status proto
// } // }
// "encryptionSettings": ... // EncryptionSettings proto
// } // }
// TODO(b/169883262): Factor out the decoding into a separate class. // TODO(b/169883262): Factor out the decoding into a separate class.
...@@ -214,7 +225,37 @@ void RecordHandlerImpl::ReportUploader::HandleSuccessfulUpload() { ...@@ -214,7 +225,37 @@ void RecordHandlerImpl::ReportUploader::HandleSuccessfulUpload() {
} }
} }
// TODO(b/169883262): Decode and handle failure information. // TODO(b/169883262): Decode and handle failure information.
// TODO(b/170054326): Handle the encryption settings.
// Handle the encryption settings.
// Note: server can attach it to response regardless of whether
// the response indicates success or failure, and whether the client
// set attach_encryption_settings to true in request.
const base::Value* signed_encryption_key_record =
last_response_.FindDictKey("encryptionSettings");
if (signed_encryption_key_record != nullptr) {
const std::string* public_key_str =
signed_encryption_key_record->FindStringKey("publicKey");
const auto public_key_id_result =
signed_encryption_key_record->FindIntKey("publicKeyId");
// TODO(b/170054326): Make signature mandatory too.
// const std::string* public_key_signature_str =
// signed_encryption_key_record->FindStringKey("publicKeySignature");
std::string public_key;
std::string public_key_signature;
if (public_key_str != nullptr &&
base::Base64Decode(*public_key_str, &public_key) &&
// TODO(b/170054326): Make signature mandatory too.
// public_key_signature_str != nullptr
// base::Base64Decode(*public_key_signature_str,
// &public_key_signature) &&
public_key_id_result.has_value()) {
SignedEncryptionInfo signed_encryption_key;
signed_encryption_key.set_public_asymmetric_key(public_key);
signed_encryption_key.set_public_key_id(public_key_id_result.value());
signed_encryption_key.set_signature(public_key_signature);
encryption_key_attached_cb_.Run(signed_encryption_key);
}
}
// Pop the last record that was processed. // Pop the last record that was processed.
records_->pop_back(); records_->pop_back();
...@@ -243,10 +284,13 @@ RecordHandlerImpl::~RecordHandlerImpl() = default; ...@@ -243,10 +284,13 @@ RecordHandlerImpl::~RecordHandlerImpl() = default;
void RecordHandlerImpl::HandleRecords( void RecordHandlerImpl::HandleRecords(
bool need_encryption_key, bool need_encryption_key,
std::unique_ptr<std::vector<EncryptedRecord>> records, std::unique_ptr<std::vector<EncryptedRecord>> records,
DmServerUploadService::CompletionCallback upload_complete_cb) { DmServerUploadService::CompletionCallback upload_complete_cb,
DmServerUploadService::EncryptionKeyAttachedCallback
encryption_key_attached_cb) {
Start<RecordHandlerImpl::ReportUploader>( Start<RecordHandlerImpl::ReportUploader>(
need_encryption_key, std::move(records), GetClient(), need_encryption_key, std::move(records), GetClient(),
std::move(upload_complete_cb), sequenced_task_runner_); std::move(upload_complete_cb), encryption_key_attached_cb,
sequenced_task_runner_);
} }
} // namespace reporting } // namespace reporting
...@@ -37,10 +37,11 @@ class RecordHandlerImpl : public DmServerUploadService::RecordHandler { ...@@ -37,10 +37,11 @@ class RecordHandlerImpl : public DmServerUploadService::RecordHandler {
~RecordHandlerImpl() override; ~RecordHandlerImpl() override;
// Base class RecordHandler method implementation. // Base class RecordHandler method implementation.
void HandleRecords( void HandleRecords(bool need_encryption_key,
bool need_encryption_key,
std::unique_ptr<std::vector<EncryptedRecord>> record, std::unique_ptr<std::vector<EncryptedRecord>> record,
DmServerUploadService::CompletionCallback upload_complete) override; DmServerUploadService::CompletionCallback upload_complete,
DmServerUploadService::EncryptionKeyAttachedCallback
encryption_key_attached_cb) override;
private: private:
// Helper |ReportUploader| class handles enqueuing events on the // Helper |ReportUploader| class handles enqueuing events on the
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "chrome/browser/policy/messaging_layer/upload/record_handler_impl.h" #include "chrome/browser/policy/messaging_layer/upload/record_handler_impl.h"
#include "base/base64.h"
#include "base/json/json_writer.h" #include "base/json/json_writer.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/strings/strcat.h" #include "base/strings/strcat.h"
...@@ -28,7 +29,9 @@ ...@@ -28,7 +29,9 @@
using ::testing::_; using ::testing::_;
using ::testing::Invoke; using ::testing::Invoke;
using ::testing::MockFunction;
using ::testing::Return; using ::testing::Return;
using ::testing::StrictMock;
using ::testing::WithArgs; using ::testing::WithArgs;
namespace reporting { namespace reporting {
...@@ -60,26 +63,24 @@ class TestCallbackWaiter { ...@@ -60,26 +63,24 @@ class TestCallbackWaiter {
class TestCallbackWaiterWithCounter : public TestCallbackWaiter { class TestCallbackWaiterWithCounter : public TestCallbackWaiter {
public: public:
explicit TestCallbackWaiterWithCounter(int counter_limit) explicit TestCallbackWaiterWithCounter(size_t counter_limit)
: counter_limit_(counter_limit) {} : counter_limit_(counter_limit) {}
void Signal() override { void Signal() override {
DCHECK_GT(counter_limit_, 0); DCHECK_GT(counter_limit_, 0u);
if (--counter_limit_ == 0) { if (--counter_limit_ == 0u) {
run_loop_.Quit(); run_loop_.Quit();
} }
} }
private: private:
std::atomic<int> counter_limit_; std::atomic<size_t> counter_limit_;
}; };
class TestCompletionResponder { using TestCompletionResponder =
public: MockFunction<void(DmServerUploadService::CompletionResponse)>;
MOCK_METHOD(void,
RecordsHandled, using TestEncryptionKeyAttached = MockFunction<void(SignedEncryptionInfo)>;
(DmServerUploadService::CompletionResponse));
};
// Helper function composes JSON represented as base::Value from Sequencing // Helper function composes JSON represented as base::Value from Sequencing
// information in request. // information in request.
...@@ -87,6 +88,9 @@ base::Value ValueFromSucceededSequencingInfo( ...@@ -87,6 +88,9 @@ base::Value ValueFromSucceededSequencingInfo(
const base::Optional<base::Value> request) { const base::Optional<base::Value> request) {
EXPECT_TRUE(request.has_value()); EXPECT_TRUE(request.has_value());
EXPECT_TRUE(request.value().is_dict()); EXPECT_TRUE(request.value().is_dict());
base::Value response(base::Value::Type::DICTIONARY);
// Retrieve and process sequencing information
const base::Value* const encrypted_record_list = const base::Value* const encrypted_record_list =
request.value().FindListKey("encryptedRecord"); request.value().FindListKey("encryptedRecord");
EXPECT_TRUE(encrypted_record_list != nullptr); EXPECT_TRUE(encrypted_record_list != nullptr);
...@@ -95,12 +99,30 @@ base::Value ValueFromSucceededSequencingInfo( ...@@ -95,12 +99,30 @@ base::Value ValueFromSucceededSequencingInfo(
encrypted_record_list->GetList().rbegin()->FindDictKey( encrypted_record_list->GetList().rbegin()->FindDictKey(
"sequencingInformation"); "sequencingInformation");
EXPECT_TRUE(seq_info != nullptr); EXPECT_TRUE(seq_info != nullptr);
base::Value response(base::Value::Type::DICTIONARY);
response.SetPath("lastSucceedUploadedRecord", seq_info->Clone()); response.SetPath("lastSucceedUploadedRecord", seq_info->Clone());
// If attach_encryption_settings it true, process that.
const auto attach_encryption_settings =
request.value().FindBoolKey("attachEncryptionSettings");
if (attach_encryption_settings.has_value() &&
attach_encryption_settings.value()) {
base::Value encryption_settings{base::Value::Type::DICTIONARY};
std::string public_key;
base::Base64Encode("PUBLIC KEY", &public_key);
encryption_settings.SetStringKey("publicKey", public_key);
encryption_settings.SetIntKey("publicKeyId", 12345);
std::string public_key_signature;
// TODO(b/170054326): Generate signature.
base::Base64Encode("PUBLIC KEY SIG", &public_key_signature);
encryption_settings.SetStringKey("publicKeySignature",
public_key_signature);
response.SetPath("encryptionSettings", std::move(encryption_settings));
}
return response; return response;
} }
class RecordHandlerImplTest : public testing::Test { class RecordHandlerImplTest : public ::testing::TestWithParam<bool> {
public: public:
RecordHandlerImplTest() RecordHandlerImplTest()
: client_(std::make_unique<policy::MockCloudPolicyClient>()) {} : client_(std::make_unique<policy::MockCloudPolicyClient>()) {}
...@@ -110,6 +132,9 @@ class RecordHandlerImplTest : public testing::Test { ...@@ -110,6 +132,9 @@ class RecordHandlerImplTest : public testing::Test {
client_->SetDMToken( client_->SetDMToken(
policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value()); policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value());
} }
bool need_encryption_key() const { return GetParam(); }
content::BrowserTaskEnvironment task_environment_; content::BrowserTaskEnvironment task_environment_;
std::unique_ptr<policy::MockCloudPolicyClient> client_; std::unique_ptr<policy::MockCloudPolicyClient> client_;
...@@ -135,7 +160,7 @@ std::unique_ptr<std::vector<EncryptedRecord>> BuildTestRecordsVector( ...@@ -135,7 +160,7 @@ std::unique_ptr<std::vector<EncryptedRecord>> BuildTestRecordsVector(
return test_records; return test_records;
} }
TEST_F(RecordHandlerImplTest, ForwardsRecordsToCloudPolicyClient) { TEST_P(RecordHandlerImplTest, ForwardsRecordsToCloudPolicyClient) {
constexpr size_t kNumTestRecords = 10; constexpr size_t kNumTestRecords = 10;
constexpr uint64_t kGenerationId = 1234; constexpr uint64_t kGenerationId = 1234;
auto test_records = BuildTestRecordsVector(kNumTestRecords, kGenerationId); auto test_records = BuildTestRecordsVector(kNumTestRecords, kGenerationId);
...@@ -154,24 +179,34 @@ TEST_F(RecordHandlerImplTest, ForwardsRecordsToCloudPolicyClient) { ...@@ -154,24 +179,34 @@ TEST_F(RecordHandlerImplTest, ForwardsRecordsToCloudPolicyClient) {
RecordHandlerImpl handler(client_.get()); RecordHandlerImpl handler(client_.get());
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
StrictMock<TestCompletionResponder> responder;
TestCallbackWaiter responder_waiter; TestCallbackWaiter responder_waiter;
TestCompletionResponder responder;
::testing::InSequence seq; EXPECT_CALL(encryption_key_attached, Call(_))
EXPECT_CALL(responder, RecordsHandled(ValueEqualsProto( .Times(need_encryption_key() ? 1 : 0);
test_records->back().sequencing_information())))
EXPECT_CALL(
responder,
Call(ValueEqualsProto(test_records->back().sequencing_information())))
.WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); })); .WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); }));
auto responder_callback = base::BindOnce( auto encryption_key_attached_callback =
&TestCompletionResponder::RecordsHandled, base::Unretained(&responder)); base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
auto responder_callback = base::BindOnce(&TestCompletionResponder::Call,
base::Unretained(&responder));
handler.HandleRecords(/*need_encryption_key=*/false, std::move(test_records), handler.HandleRecords(need_encryption_key(), std::move(test_records),
std::move(responder_callback)); std::move(responder_callback),
encryption_key_attached_callback);
client_waiter.Wait(); client_waiter.Wait();
responder_waiter.Wait(); responder_waiter.Wait();
} }
TEST_F(RecordHandlerImplTest, ReportsEarlyFailure) { TEST_P(RecordHandlerImplTest, ReportsEarlyFailure) {
uint64_t kNumSuccessfulUploads = 5; uint64_t kNumSuccessfulUploads = 5;
uint64_t kNumTestRecords = 10; uint64_t kNumTestRecords = 10;
uint64_t kGenerationId = 1234; uint64_t kGenerationId = 1234;
...@@ -180,6 +215,7 @@ TEST_F(RecordHandlerImplTest, ReportsEarlyFailure) { ...@@ -180,6 +215,7 @@ TEST_F(RecordHandlerImplTest, ReportsEarlyFailure) {
// Wait kNumSuccessfulUploads times + 1 for the failure. // Wait kNumSuccessfulUploads times + 1 for the failure.
TestCallbackWaiterWithCounter client_waiter{kNumSuccessfulUploads + 1}; TestCallbackWaiterWithCounter client_waiter{kNumSuccessfulUploads + 1};
{
::testing::InSequence seq; ::testing::InSequence seq;
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _)) EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.Times(kNumSuccessfulUploads) .Times(kNumSuccessfulUploads)
...@@ -193,31 +229,45 @@ TEST_F(RecordHandlerImplTest, ReportsEarlyFailure) { ...@@ -193,31 +229,45 @@ TEST_F(RecordHandlerImplTest, ReportsEarlyFailure) {
}))); })));
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _)) EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.WillOnce(WithArgs<2>(Invoke( .WillOnce(WithArgs<2>(Invoke(
[&client_waiter]( [&client_waiter](base::OnceCallback<void(
base::OnceCallback<void(base::Optional<base::Value>)> callback) { base::Optional<base::Value>)> callback) {
std::move(callback).Run(base::nullopt); std::move(callback).Run(base::nullopt);
client_waiter.Signal(); client_waiter.Signal();
}))); })));
}
RecordHandlerImpl handler(client_.get()); RecordHandlerImpl handler(client_.get());
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
StrictMock<TestCompletionResponder> responder;
TestCallbackWaiter responder_waiter; TestCallbackWaiter responder_waiter;
TestCompletionResponder responder;
EXPECT_CALL(encryption_key_attached, Call(_))
.Times(need_encryption_key() ? 1 : 0);
EXPECT_CALL( EXPECT_CALL(
responder, responder,
RecordsHandled(ValueEqualsProto( Call(ValueEqualsProto(
(*test_records)[kNumSuccessfulUploads - 1].sequencing_information()))) (*test_records)[kNumSuccessfulUploads - 1].sequencing_information())))
.WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); })); .WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); }));
auto responder_callback = base::BindOnce( auto encryption_key_attached_callback =
&TestCompletionResponder::RecordsHandled, base::Unretained(&responder)); base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
handler.HandleRecords(/*need_encryption_key=*/false, std::move(test_records), auto responder_callback = base::BindOnce(&TestCompletionResponder::Call,
std::move(responder_callback)); base::Unretained(&responder));
handler.HandleRecords(need_encryption_key(), std::move(test_records),
std::move(responder_callback),
encryption_key_attached_callback);
client_waiter.Wait(); client_waiter.Wait();
responder_waiter.Wait(); responder_waiter.Wait();
} }
INSTANTIATE_TEST_SUITE_P(NeedOrNoNeedKey,
RecordHandlerImplTest,
testing::Bool());
} // namespace } // namespace
} // namespace reporting } // namespace reporting
...@@ -20,11 +20,13 @@ namespace reporting { ...@@ -20,11 +20,13 @@ namespace reporting {
void UploadClient::Create( void UploadClient::Create(
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
ReportSuccessfulUploadCallback report_upload_success_cb, ReportSuccessfulUploadCallback report_upload_success_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb,
base::OnceCallback<void(StatusOr<std::unique_ptr<UploadClient>>)> base::OnceCallback<void(StatusOr<std::unique_ptr<UploadClient>>)>
created_cb) { created_cb) {
auto upload_client = base::WrapUnique(new UploadClient()); auto upload_client = base::WrapUnique(new UploadClient());
DmServerUploadService::Create( DmServerUploadService::Create(
std::move(cloud_policy_client), report_upload_success_cb, std::move(cloud_policy_client), report_upload_success_cb,
encryption_key_attached_cb,
base::BindOnce( base::BindOnce(
[](std::unique_ptr<UploadClient> upload_client, [](std::unique_ptr<UploadClient> upload_client,
base::OnceCallback<void(StatusOr<std::unique_ptr<UploadClient>>)> base::OnceCallback<void(StatusOr<std::unique_ptr<UploadClient>>)>
...@@ -42,7 +44,7 @@ void UploadClient::Create( ...@@ -42,7 +44,7 @@ void UploadClient::Create(
} }
Status UploadClient::EnqueueUpload( Status UploadClient::EnqueueUpload(
bool need_encryption_key, bool need_encryption_keys,
std::unique_ptr<std::vector<EncryptedRecord>> records) { std::unique_ptr<std::vector<EncryptedRecord>> records) {
DCHECK(records); DCHECK(records);
...@@ -50,7 +52,7 @@ Status UploadClient::EnqueueUpload( ...@@ -50,7 +52,7 @@ Status UploadClient::EnqueueUpload(
return Status::StatusOK(); return Status::StatusOK();
} }
return dm_server_upload_service_->EnqueueUpload(need_encryption_key, return dm_server_upload_service_->EnqueueUpload(need_encryption_keys,
std::move(records)); std::move(records));
} }
......
...@@ -26,9 +26,15 @@ class UploadClient { ...@@ -26,9 +26,15 @@ class UploadClient {
using ReportSuccessfulUploadCallback = using ReportSuccessfulUploadCallback =
base::RepeatingCallback<void(SequencingInformation)>; base::RepeatingCallback<void(SequencingInformation)>;
// ReceivedEncryptionKeyCallback is called if server attached encryption key
// to the response.
using EncryptionKeyAttachedCallback =
base::RepeatingCallback<void(SignedEncryptionInfo)>;
static void Create( static void Create(
policy::CloudPolicyClient* cloud_policy_client, policy::CloudPolicyClient* cloud_policy_client,
ReportSuccessfulUploadCallback report_upload_success_cb, ReportSuccessfulUploadCallback report_upload_success_cb,
EncryptionKeyAttachedCallback encryption_key_attached_cb,
base::OnceCallback<void(StatusOr<std::unique_ptr<UploadClient>>)> base::OnceCallback<void(StatusOr<std::unique_ptr<UploadClient>>)>
created_cb); created_cb);
...@@ -36,7 +42,7 @@ class UploadClient { ...@@ -36,7 +42,7 @@ class UploadClient {
UploadClient(const UploadClient& other) = delete; UploadClient(const UploadClient& other) = delete;
UploadClient& operator=(const UploadClient& other) = delete; UploadClient& operator=(const UploadClient& other) = delete;
Status EnqueueUpload(bool need_encryption_key, Status EnqueueUpload(bool need_encryption_keys,
std::unique_ptr<std::vector<EncryptedRecord>> record); std::unique_ptr<std::vector<EncryptedRecord>> record);
private: private:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "chrome/browser/policy/messaging_layer/upload/upload_client.h" #include "chrome/browser/policy/messaging_layer/upload/upload_client.h"
#include "base/base64.h"
#include "base/bind.h" #include "base/bind.h"
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/json/json_writer.h" #include "base/json/json_writer.h"
...@@ -30,11 +31,13 @@ ...@@ -30,11 +31,13 @@
namespace reporting { namespace reporting {
namespace { namespace {
using policy::MockCloudPolicyClient; using ::policy::MockCloudPolicyClient;
using testing::_; using ::testing::_;
using testing::Invoke; using ::testing::Invoke;
using testing::InvokeArgument; using ::testing::InvokeArgument;
using testing::WithArgs; using ::testing::MockFunction;
using ::testing::StrictMock;
using ::testing::WithArgs;
MATCHER_P(EqualsProto, MATCHER_P(EqualsProto,
message, message,
...@@ -125,20 +128,43 @@ base::Value ValueFromSucceededSequencingInfo( ...@@ -125,20 +128,43 @@ base::Value ValueFromSucceededSequencingInfo(
const base::Optional<base::Value> request) { const base::Optional<base::Value> request) {
EXPECT_TRUE(request.has_value()); EXPECT_TRUE(request.has_value());
EXPECT_TRUE(request.value().is_dict()); EXPECT_TRUE(request.value().is_dict());
base::Value response(base::Value::Type::DICTIONARY);
// Retrieve and process data
const base::Value* const encrypted_record_list = const base::Value* const encrypted_record_list =
request.value().FindListKey("encryptedRecord"); request.value().FindListKey("encryptedRecord");
EXPECT_TRUE(encrypted_record_list != nullptr); EXPECT_TRUE(encrypted_record_list != nullptr);
EXPECT_FALSE(encrypted_record_list->GetList().empty()); EXPECT_FALSE(encrypted_record_list->GetList().empty());
// Retrieve and process sequencing information
const base::Value* seq_info = const base::Value* seq_info =
encrypted_record_list->GetList().rbegin()->FindDictKey( encrypted_record_list->GetList().rbegin()->FindDictKey(
"sequencingInformation"); "sequencingInformation");
EXPECT_TRUE(seq_info != nullptr); EXPECT_TRUE(seq_info != nullptr);
base::Value response(base::Value::Type::DICTIONARY);
response.SetPath("lastSucceedUploadedRecord", seq_info->Clone()); response.SetPath("lastSucceedUploadedRecord", seq_info->Clone());
// If attach_encryption_settings it true, process that.
const auto attach_encryption_settings =
request.value().FindBoolKey("attachEncryptionSettings");
if (attach_encryption_settings.has_value() &&
attach_encryption_settings.value()) {
base::Value encryption_settings{base::Value::Type::DICTIONARY};
std::string public_key;
base::Base64Encode("PUBLIC KEY", &public_key);
encryption_settings.SetStringKey("publicKey", public_key);
encryption_settings.SetIntKey("publicKeyId", 12345);
std::string public_key_signature;
// TODO(b/170054326): Generate signature.
base::Base64Encode("PUBLIC KEY SIG", &public_key_signature);
encryption_settings.SetStringKey("publicKeySignature",
public_key_signature);
response.SetPath("encryptionSettings", std::move(encryption_settings));
}
return response; return response;
} }
class UploadClientTest : public ::testing::Test { class UploadClientTest : public ::testing::TestWithParam<bool> {
public: public:
UploadClientTest() = default; UploadClientTest() = default;
...@@ -169,6 +195,8 @@ class UploadClientTest : public ::testing::Test { ...@@ -169,6 +195,8 @@ class UploadClientTest : public ::testing::Test {
#endif // OS_CHROMEOS #endif // OS_CHROMEOS
} }
bool need_encryption_key() const { return GetParam(); }
content::BrowserTaskEnvironment task_envrionment_; content::BrowserTaskEnvironment task_envrionment_;
#ifdef OS_CHROMEOS #ifdef OS_CHROMEOS
std::unique_ptr<TestingProfile> profile_; std::unique_ptr<TestingProfile> profile_;
...@@ -176,7 +204,9 @@ class UploadClientTest : public ::testing::Test { ...@@ -176,7 +204,9 @@ class UploadClientTest : public ::testing::Test {
#endif // OS_CHROMEOS #endif // OS_CHROMEOS
}; };
TEST_F(UploadClientTest, CreateUploadClientAndUploadRecords) { using TestEncryptionKeyAttached = MockFunction<void(SignedEncryptionInfo)>;
TEST_P(UploadClientTest, CreateUploadClientAndUploadRecords) {
const int kExpectedCallTimes = 10; const int kExpectedCallTimes = 10;
const uint64_t kGenerationId = 1234; const uint64_t kGenerationId = 1234;
...@@ -209,6 +239,13 @@ TEST_F(UploadClientTest, CreateUploadClientAndUploadRecords) { ...@@ -209,6 +239,13 @@ TEST_F(UploadClientTest, CreateUploadClientAndUploadRecords) {
TestCallbackWaiterWithCounter waiter(kExpectedCallTimes); TestCallbackWaiterWithCounter waiter(kExpectedCallTimes);
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_))
.Times(need_encryption_key() ? 1 : 0);
auto encryption_key_attached_cb =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
auto client = std::make_unique<MockCloudPolicyClient>(); auto client = std::make_unique<MockCloudPolicyClient>();
client->SetDMToken( client->SetDMToken(
policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value()); policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value());
...@@ -233,18 +270,21 @@ TEST_F(UploadClientTest, CreateUploadClientAndUploadRecords) { ...@@ -233,18 +270,21 @@ TEST_F(UploadClientTest, CreateUploadClientAndUploadRecords) {
records->back().sequencing_information()); records->back().sequencing_information());
TestEvent<StatusOr<std::unique_ptr<UploadClient>>> e; TestEvent<StatusOr<std::unique_ptr<UploadClient>>> e;
UploadClient::Create(client.get(), completion_cb, e.cb()); UploadClient::Create(client.get(), completion_cb, encryption_key_attached_cb,
e.cb());
StatusOr<std::unique_ptr<UploadClient>> upload_client_result = e.result(); StatusOr<std::unique_ptr<UploadClient>> upload_client_result = e.result();
ASSERT_OK(upload_client_result) << upload_client_result.status(); ASSERT_OK(upload_client_result) << upload_client_result.status();
auto upload_client = std::move(upload_client_result.ValueOrDie()); auto upload_client = std::move(upload_client_result.ValueOrDie());
auto enqueue_result = upload_client->EnqueueUpload( auto enqueue_result =
/*need_encryption_key=*/false, std::move(records)); upload_client->EnqueueUpload(need_encryption_key(), std::move(records));
EXPECT_TRUE(enqueue_result.ok()); EXPECT_TRUE(enqueue_result.ok());
waiter.Wait(); waiter.Wait();
completion_callback_waiter.Wait(); completion_callback_waiter.Wait();
} }
INSTANTIATE_TEST_SUITE_P(NeedOrNoNeedKey, UploadClientTest, testing::Bool());
} // namespace } // namespace
} // namespace reporting } // namespace reporting
...@@ -102,5 +102,23 @@ message EncryptedRecord { ...@@ -102,5 +102,23 @@ message EncryptedRecord {
// TODO(b/153651358): Disable an option to send record not encrypted. // TODO(b/153651358): Disable an option to send record not encrypted.
optional EncryptionInfo encryption_info = 2; optional EncryptionInfo encryption_info = 2;
// Sequencing information (required). Must be present to allow
// tracking and confirmation of the events by server.
optional SequencingInformation sequencing_information = 3; optional SequencingInformation sequencing_information = 3;
} }
// Encryption public key as delivered from the server and stored in Storage.
// Signature ensures the key was actually sent by the server and not manipulated
// afterwards.
message SignedEncryptionInfo {
// Public asymmetric key (required).
optional bytes public_asymmetric_key = 1;
// Public key id (required).
// Identifies private key matching |public_asymmetric_key| for the server.
optional uint64 public_key_id = 2;
// Signature of |public_asymmetric_key| (required).
// Verified by client against a well-known signature.
optional bytes signature = 3;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment