Commit 76e743dc authored by Ivan Sandrk's avatar Ivan Sandrk Committed by Commit Bot

[Remote Commands] Add tests for CloudPolicyClient and RemoteCommandsService

Bug: 891222
Change-Id: I6675c208550affd8f8bb1316d4935c02f0e07a5e
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1783439Reviewed-by: default avatarRoman Sorokin [CET] <rsorokin@chromium.org>
Commit-Queue: Ivan Šandrk <isandrk@chromium.org>
Cr-Commit-Position: refs/heads/master@{#697198}
parent ef95fb21
...@@ -98,11 +98,20 @@ void RemoteCommandsService::VerifyAndEnqueueSignedCommand( ...@@ -98,11 +98,20 @@ void RemoteCommandsService::VerifyAndEnqueueSignedCommand(
signed_command.signature(), signed_command.signature(),
CloudPolicyValidatorBase::SignatureType::SHA1); CloudPolicyValidatorBase::SignatureType::SHA1);
auto ignore_result = base::BindOnce(
[](std::vector<em::RemoteCommandResult>* unsent_results,
const char* error_msg) {
SYSLOG(ERROR) << error_msg;
em::RemoteCommandResult result;
result.set_result(em::RemoteCommandResult_ResultType_RESULT_IGNORED);
result.set_command_id(-1);
unsent_results->push_back(result);
},
&unsent_results_);
if (!valid_signature) { if (!valid_signature) {
SYSLOG(ERROR) << "Secure remote command signature verification failed"; std::move(ignore_result)
em::RemoteCommandResult result; .Run("Secure remote command signature verification failed");
result.set_result(em::RemoteCommandResult_ResultType_RESULT_IGNORED);
unsent_results_.push_back(result);
return; return;
} }
...@@ -111,24 +120,25 @@ void RemoteCommandsService::VerifyAndEnqueueSignedCommand( ...@@ -111,24 +120,25 @@ void RemoteCommandsService::VerifyAndEnqueueSignedCommand(
!policy_data.has_policy_type() || !policy_data.has_policy_type() ||
policy_data.policy_type() != policy_data.policy_type() !=
dm_protocol::kChromeRemoteCommandPolicyType) { dm_protocol::kChromeRemoteCommandPolicyType) {
SYSLOG(ERROR) << "Secure remote command with wrong PolicyData type"; std::move(ignore_result)
em::RemoteCommandResult result; .Run("Secure remote command with wrong PolicyData type");
result.set_result(em::RemoteCommandResult_ResultType_RESULT_IGNORED);
unsent_results_.push_back(result);
return; return;
} }
em::RemoteCommand command; em::RemoteCommand command;
if (!policy_data.has_policy_value() || if (!policy_data.has_policy_value() ||
!command.ParseFromString(policy_data.policy_value())) { !command.ParseFromString(policy_data.policy_value())) {
SYSLOG(ERROR) << "Secure remote command invalid RemoteCommand data"; std::move(ignore_result)
em::RemoteCommandResult result; .Run("Secure remote command invalid RemoteCommand data");
result.set_result(em::RemoteCommandResult_ResultType_RESULT_IGNORED);
unsent_results_.push_back(result);
return; return;
} }
// TODO(isandrk): Also make sure that target_device_id matches and add tests! const em::PolicyData* const policy = store_->policy();
if (!policy || policy->device_id() != command.target_device_id()) {
std::move(ignore_result)
.Run("Secure remote command wrong target device id");
return;
}
// Signature verification passed. // Signature verification passed.
EnqueueCommand(command, &signed_command); EnqueueCommand(command, &signed_command);
......
...@@ -4,33 +4,61 @@ ...@@ -4,33 +4,61 @@
#include "components/policy/core/common/remote_commands/testing_remote_commands_server.h" #include "components/policy/core/common/remote_commands/testing_remote_commands_server.h"
#include <iterator>
#include <utility> #include <utility>
#include "base/bind.h" #include "base/bind.h"
#include "base/callback.h" #include "base/callback.h"
#include "base/hash/sha1.h"
#include "base/location.h" #include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/optional.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/threading/thread_task_runner_handle.h" #include "base/threading/thread_task_runner_handle.h"
#include "base/time/default_tick_clock.h" #include "base/time/default_tick_clock.h"
#include "base/time/tick_clock.h" #include "base/time/tick_clock.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "components/policy/core/common/cloud/policy_builder.h"
#include "crypto/signature_creator.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace em = enterprise_management; namespace em = enterprise_management;
namespace policy { namespace policy {
namespace {
std::string SignDataWithTestKey(const std::string& data) {
std::unique_ptr<crypto::RSAPrivateKey> private_key =
PolicyBuilder::CreateTestSigningKey();
std::string sha1 = base::SHA1HashString(data);
std::vector<uint8_t> digest(sha1.begin(), sha1.end());
std::vector<uint8_t> result;
CHECK(crypto::SignatureCreator::Sign(private_key.get(),
crypto::SignatureCreator::SHA1,
digest.data(), digest.size(), &result));
return std::string(result.begin(), result.end());
}
} // namespace
struct TestingRemoteCommandsServer::RemoteCommandWithCallback { struct TestingRemoteCommandsServer::RemoteCommandWithCallback {
RemoteCommandWithCallback(const em::RemoteCommand& command_proto, RemoteCommandWithCallback(em::RemoteCommand command_proto,
base::Optional<em::SignedData> signed_command_proto,
base::TimeTicks issued_time, base::TimeTicks issued_time,
const ResultReportedCallback& reported_callback) ResultReportedCallback reported_callback)
: command_proto(command_proto), : command_proto(command_proto),
signed_command_proto(signed_command_proto),
issued_time(issued_time), issued_time(issued_time),
reported_callback(reported_callback) {} reported_callback(std::move(reported_callback)) {}
RemoteCommandWithCallback(RemoteCommandWithCallback&& other) = default;
RemoteCommandWithCallback& operator=(RemoteCommandWithCallback&& other) =
default;
~RemoteCommandWithCallback() {} ~RemoteCommandWithCallback() {}
em::RemoteCommand command_proto; em::RemoteCommand command_proto;
base::Optional<em::SignedData> signed_command_proto;
base::TimeTicks issued_time; base::TimeTicks issued_time;
ResultReportedCallback reported_callback; ResultReportedCallback reported_callback;
}; };
...@@ -53,10 +81,9 @@ TestingRemoteCommandsServer::~TestingRemoteCommandsServer() { ...@@ -53,10 +81,9 @@ TestingRemoteCommandsServer::~TestingRemoteCommandsServer() {
void TestingRemoteCommandsServer::IssueCommand( void TestingRemoteCommandsServer::IssueCommand(
em::RemoteCommand_Type type, em::RemoteCommand_Type type,
const std::string& payload, const std::string& payload,
const ResultReportedCallback& reported_callback, ResultReportedCallback reported_callback,
bool skip_next_fetch) { bool skip_next_fetch) {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
base::AutoLock auto_lock(lock_); base::AutoLock auto_lock(lock_);
em::RemoteCommand command; em::RemoteCommand command;
...@@ -65,18 +92,59 @@ void TestingRemoteCommandsServer::IssueCommand( ...@@ -65,18 +92,59 @@ void TestingRemoteCommandsServer::IssueCommand(
if (!payload.empty()) if (!payload.empty())
command.set_payload(payload); command.set_payload(payload);
const RemoteCommandWithCallback command_with_callback( RemoteCommandWithCallback command_with_callback(
command, clock_->NowTicks(), reported_callback); command, base::nullopt, clock_->NowTicks(), std::move(reported_callback));
if (skip_next_fetch) if (skip_next_fetch)
commands_issued_after_next_fetch_.push_back(command_with_callback); commands_issued_after_next_fetch_.push_back(
std::move(command_with_callback));
else else
commands_.push_back(command_with_callback); commands_.push_back(std::move(command_with_callback));
}
void TestingRemoteCommandsServer::IssueSignedCommand(
ResultReportedCallback reported_callback,
em::RemoteCommand* command_in,
em::PolicyData* policy_data_in,
em::SignedData* signed_data_in) {
DCHECK(thread_checker_.CalledOnValidThread());
base::AutoLock auto_lock(lock_);
em::RemoteCommand command;
em::PolicyData policy_data;
em::SignedData signed_data;
if (command_in) {
command = *command_in;
} else {
command.set_target_device_id("acme-device");
command.set_type(em::RemoteCommand_Type_COMMAND_ECHO_TEST);
command.set_command_id(++last_generated_unique_id_);
}
if (policy_data_in) {
policy_data = *policy_data_in;
} else {
policy_data.set_policy_type("google/chromeos/remotecommand");
EXPECT_TRUE(command.SerializeToString(policy_data.mutable_policy_value()));
}
if (signed_data_in) {
signed_data = *signed_data_in;
} else {
EXPECT_TRUE(policy_data.SerializeToString(signed_data.mutable_data()));
signed_data.set_signature(SignDataWithTestKey(signed_data.data()));
}
RemoteCommandWithCallback command_with_callback(
command, signed_data, clock_->NowTicks(), std::move(reported_callback));
commands_.push_back(std::move(command_with_callback));
} }
TestingRemoteCommandsServer::RemoteCommands void TestingRemoteCommandsServer::FetchCommands(
TestingRemoteCommandsServer::FetchCommands(
std::unique_ptr<RemoteCommandJob::UniqueIDType> last_command_id, std::unique_ptr<RemoteCommandJob::UniqueIDType> last_command_id,
const RemoteCommandResults& previous_job_results) { const RemoteCommandResults& previous_job_results,
std::vector<em::RemoteCommand>* fetched_commands,
std::vector<em::SignedData>* signed_commands) {
base::AutoLock auto_lock(lock_); base::AutoLock auto_lock(lock_);
for (const auto& job_result : previous_job_results) { for (const auto& job_result : previous_job_results) {
...@@ -86,6 +154,15 @@ TestingRemoteCommandsServer::FetchCommands( ...@@ -86,6 +154,15 @@ TestingRemoteCommandsServer::FetchCommands(
bool found_command = false; bool found_command = false;
ResultReportedCallback reported_callback; ResultReportedCallback reported_callback;
if (job_result.command_id() == -1) {
// The result can have command_id equal to -1 in case a signed command was
// rejected at the validation stage before it could be unpacked.
CHECK(commands_.size() == 1);
found_command = true;
reported_callback = std::move(commands_[0].reported_callback);
commands_.clear();
}
if (last_command_id) { if (last_command_id) {
// This relies on us generating commands with increasing IDs. // This relies on us generating commands with increasing IDs.
EXPECT_GE(*last_command_id, job_result.command_id()); EXPECT_GE(*last_command_id, job_result.command_id());
...@@ -93,7 +170,7 @@ TestingRemoteCommandsServer::FetchCommands( ...@@ -93,7 +170,7 @@ TestingRemoteCommandsServer::FetchCommands(
for (auto it = commands_.begin(); it != commands_.end(); ++it) { for (auto it = commands_.begin(); it != commands_.end(); ++it) {
if (it->command_proto.command_id() == job_result.command_id()) { if (it->command_proto.command_id() == job_result.command_id()) {
reported_callback = it->reported_callback; reported_callback = std::move(it->reported_callback);
commands_.erase(it); commands_.erase(it);
found_command = true; found_command = true;
break; break;
...@@ -103,35 +180,40 @@ TestingRemoteCommandsServer::FetchCommands( ...@@ -103,35 +180,40 @@ TestingRemoteCommandsServer::FetchCommands(
// Verify that the command result is for an existing command actually // Verify that the command result is for an existing command actually
// expecting a result. // expecting a result.
EXPECT_TRUE(found_command); EXPECT_TRUE(found_command);
EXPECT_FALSE(reported_callback.is_null());
if (reported_callback.is_null()) { if (!reported_callback.is_null()) {
// Post task to the original thread which will report the result. // Post task to the original thread which will report the result.
task_runner_->PostTask( task_runner_->PostTask(
FROM_HERE, FROM_HERE,
base::BindOnce(&TestingRemoteCommandsServer::ReportJobResult, base::BindOnce(&TestingRemoteCommandsServer::ReportJobResult,
weak_ptr_to_this_, reported_callback, job_result)); weak_ptr_to_this_, std::move(reported_callback),
job_result));
} }
} }
RemoteCommands fetched_commands;
for (const auto& command_with_callback : commands_) { for (const auto& command_with_callback : commands_) {
if (!last_command_id || if (command_with_callback.signed_command_proto) {
command_with_callback.command_proto.command_id() > *last_command_id) { // Signed commands.
fetched_commands.push_back(command_with_callback.command_proto); signed_commands->push_back(
command_with_callback.signed_command_proto.value());
} else if (!last_command_id ||
command_with_callback.command_proto.command_id() >
*last_command_id) {
// Old style, unsigned commands.
fetched_commands->push_back(command_with_callback.command_proto);
// Simulate the age of commands calculation on the server side. // Simulate the age of commands calculation on the server side.
fetched_commands.back().set_age_of_command( fetched_commands->back().set_age_of_command(
(clock_->NowTicks() - command_with_callback.issued_time) (clock_->NowTicks() - command_with_callback.issued_time)
.InMilliseconds()); .InMilliseconds());
} }
} }
// Push delayed commands into the main queue. // Push delayed commands into the main queue.
commands_.insert(commands_.end(), commands_issued_after_next_fetch_.begin(), commands_.insert(
commands_issued_after_next_fetch_.end()); commands_.end(),
std::make_move_iterator(commands_issued_after_next_fetch_.begin()),
std::make_move_iterator(commands_issued_after_next_fetch_.end()));
commands_issued_after_next_fetch_.clear(); commands_issued_after_next_fetch_.clear();
return fetched_commands;
} }
void TestingRemoteCommandsServer::SetClock(const base::TickClock* clock) { void TestingRemoteCommandsServer::SetClock(const base::TickClock* clock) {
...@@ -145,10 +227,10 @@ size_t TestingRemoteCommandsServer::NumberOfCommandsPendingResult() const { ...@@ -145,10 +227,10 @@ size_t TestingRemoteCommandsServer::NumberOfCommandsPendingResult() const {
} }
void TestingRemoteCommandsServer::ReportJobResult( void TestingRemoteCommandsServer::ReportJobResult(
const ResultReportedCallback& reported_callback, ResultReportedCallback reported_callback,
const em::RemoteCommandResult& job_result) const { const em::RemoteCommandResult& job_result) const {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
reported_callback.Run(job_result); std::move(reported_callback).Run(job_result);
} }
} // namespace policy } // namespace policy
...@@ -27,6 +27,10 @@ class SingleThreadTaskRunner; ...@@ -27,6 +27,10 @@ class SingleThreadTaskRunner;
namespace policy { namespace policy {
// Callback called when a command's result is reported back to the server.
using ResultReportedCallback =
base::OnceCallback<void(const enterprise_management::RemoteCommandResult&)>;
// This class implements server-side logic for remote commands service tests. It // This class implements server-side logic for remote commands service tests. It
// acts just like a queue, and there are mainly two exposed methods for this // acts just like a queue, and there are mainly two exposed methods for this
// purpose. Test authors are expected to call IssueCommand() to add commands to // purpose. Test authors are expected to call IssueCommand() to add commands to
...@@ -45,14 +49,9 @@ class TestingRemoteCommandsServer { ...@@ -45,14 +49,9 @@ class TestingRemoteCommandsServer {
TestingRemoteCommandsServer(); TestingRemoteCommandsServer();
virtual ~TestingRemoteCommandsServer(); virtual ~TestingRemoteCommandsServer();
using RemoteCommands = std::vector<enterprise_management::RemoteCommand>;
using RemoteCommandResults = using RemoteCommandResults =
std::vector<enterprise_management::RemoteCommandResult>; std::vector<enterprise_management::RemoteCommandResult>;
// Callback called when a command's result is reported back to the server.
using ResultReportedCallback =
base::Callback<void(const enterprise_management::RemoteCommandResult&)>;
// Create and add a command with |type| as command type and |payload| as // Create and add a command with |type| as command type and |payload| as
// command payload if it's not empty. |clock_| will be used to get the // command payload if it's not empty. |clock_| will be used to get the
// command issue time. |reported_callback| will be called from the same // command issue time. |reported_callback| will be called from the same
...@@ -63,8 +62,12 @@ class TestingRemoteCommandsServer { ...@@ -63,8 +62,12 @@ class TestingRemoteCommandsServer {
// the server and |reported_callback| itself will be called at that time. // the server and |reported_callback| itself will be called at that time.
void IssueCommand(enterprise_management::RemoteCommand_Type type, void IssueCommand(enterprise_management::RemoteCommand_Type type,
const std::string& payload, const std::string& payload,
const ResultReportedCallback& reported_callback, ResultReportedCallback reported_callback,
bool skip_next_fetch); bool skip_next_fetch);
void IssueSignedCommand(ResultReportedCallback reported_callback,
enterprise_management::RemoteCommand* command_in,
enterprise_management::PolicyData* policy_data_in,
enterprise_management::SignedData* signed_data_in);
// Fetch commands, acknowledging all commands up to and including // Fetch commands, acknowledging all commands up to and including
// |last_command_id|, and provide |previous_job_results| as results for // |last_command_id|, and provide |previous_job_results| as results for
...@@ -74,9 +77,11 @@ class TestingRemoteCommandsServer { ...@@ -74,9 +77,11 @@ class TestingRemoteCommandsServer {
// and client for remote command fetching. // and client for remote command fetching.
// Unlike every other methods in the class, this method can be called from // Unlike every other methods in the class, this method can be called from
// any thread. // any thread.
RemoteCommands FetchCommands( void FetchCommands(
std::unique_ptr<RemoteCommandJob::UniqueIDType> last_command_id, std::unique_ptr<RemoteCommandJob::UniqueIDType> last_command_id,
const RemoteCommandResults& previous_job_results); const RemoteCommandResults& previous_job_results,
std::vector<enterprise_management::RemoteCommand>* fetched_commands,
std::vector<enterprise_management::SignedData>* signed_commands);
// Set alternative clock for obtaining the command issue time. The default // Set alternative clock for obtaining the command issue time. The default
// clock uses the system clock. // clock uses the system clock.
...@@ -90,7 +95,7 @@ class TestingRemoteCommandsServer { ...@@ -90,7 +95,7 @@ class TestingRemoteCommandsServer {
struct RemoteCommandWithCallback; struct RemoteCommandWithCallback;
void ReportJobResult( void ReportJobResult(
const ResultReportedCallback& reported_callback, ResultReportedCallback reported_callback,
const enterprise_management::RemoteCommandResult& job_result) const; const enterprise_management::RemoteCommandResult& job_result) const;
// The main command queue. // The main command queue.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment