Commit 59afe0e2 authored by Zach Trudo's avatar Zach Trudo Committed by Commit Bot

Build CloudPolicyClient asynchronously

CloudPolicyClient must be built on the main UI thread, this means that
it needs to be built asynchronously.

Bug: chromium:1078512
Change-Id: I7aca6e9085091a8ac4d2c6209e10260f1295c450
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2368318
Commit-Queue: Zach Trudo <zatrudo@google.com>
Reviewed-by: default avatarLeonid Baraz <lbaraz@chromium.org>
Cr-Commit-Position: refs/heads/master@{#802846}
parent 9bd5253b
...@@ -37,43 +37,13 @@ namespace reporting { ...@@ -37,43 +37,13 @@ namespace reporting {
// } // }
class ReportingClient { class ReportingClient {
public: public:
// Uploader is passed to Storage in order to upload messages using the
// UploadClient.
class Uploader : public Storage::UploaderInterface {
public:
using UploadCallback = base::OnceCallback<Status(
std::unique_ptr<std::vector<EncryptedRecord>>)>;
static StatusOr<std::unique_ptr<Uploader>> Create(
UploadCallback upload_callback);
~Uploader() override;
Uploader(const Uploader& other) = delete;
Uploader& operator=(const Uploader& other) = delete;
void ProcessRecord(StatusOr<EncryptedRecord> data,
base::OnceCallback<void(bool)> processed_cb) override;
void Completed(Status final_status) override;
private:
explicit Uploader(UploadCallback upload_callback_);
void RunUpload();
UploadCallback upload_callback_;
bool completed_;
std::unique_ptr<std::vector<EncryptedRecord>> encrypted_records_;
scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner_;
};
struct Configuration { struct Configuration {
Configuration(); Configuration();
~Configuration(); ~Configuration();
scoped_refptr<StorageModule> storage_; std::unique_ptr<policy::CloudPolicyClient> cloud_policy_client;
scoped_refptr<EncryptionModule> encryption_; scoped_refptr<StorageModule> storage;
scoped_refptr<EncryptionModule> encryption;
}; };
using CreateReportQueueResponse = StatusOr<std::unique_ptr<ReportQueue>>; using CreateReportQueueResponse = StatusOr<std::unique_ptr<ReportQueue>>;
...@@ -82,8 +52,11 @@ class ReportingClient { ...@@ -82,8 +52,11 @@ class ReportingClient {
base::OnceCallback<void(CreateReportQueueResponse)>; base::OnceCallback<void(CreateReportQueueResponse)>;
using UpdateConfigurationCallback = using UpdateConfigurationCallback =
base::OnceCallback<void(const Configuration&, base::OnceCallback<void(std::unique_ptr<Configuration>,
base::OnceCallback<void(Status)>)>; base::OnceCallback<void(Status)>)>;
using BuildCloudPolicyClientCallback = base::OnceCallback<void(
base::OnceCallback<void(
StatusOr<std::unique_ptr<policy::CloudPolicyClient>>)>)>;
using InitCompleteCallback = base::OnceCallback<void(Status)>; using InitCompleteCallback = base::OnceCallback<void(Status)>;
...@@ -137,6 +110,7 @@ class ReportingClient { ...@@ -137,6 +110,7 @@ class ReportingClient {
class InitializingContext : public TaskRunnerContext<Status> { class InitializingContext : public TaskRunnerContext<Status> {
public: public:
InitializingContext( InitializingContext(
BuildCloudPolicyClientCallback build_client_cb,
Storage::StartUploadCb start_upload_cb, Storage::StartUploadCb start_upload_cb,
UpdateConfigurationCallback update_config_cb, UpdateConfigurationCallback update_config_cb,
InitCompleteCallback init_complete_cb, InitCompleteCallback init_complete_cb,
...@@ -152,6 +126,10 @@ class ReportingClient { ...@@ -152,6 +126,10 @@ class ReportingClient {
StatusOr<InitializationStateTracker::ReleaseLeaderCallback> StatusOr<InitializationStateTracker::ReleaseLeaderCallback>
promo_result); promo_result);
void ConfigureCloudPolicyClient();
void OnCloudPolicyClientConfigured(
StatusOr<std::unique_ptr<policy::CloudPolicyClient>> client_result);
// ConfigureStorageModule will build a StorageModule and add it to the // ConfigureStorageModule will build a StorageModule and add it to the
// |client_config_|. // |client_config_|.
void ConfigureStorageModule(); void ConfigureStorageModule();
...@@ -169,12 +147,13 @@ class ReportingClient { ...@@ -169,12 +147,13 @@ class ReportingClient {
// Complete calls response with |client_config_| // Complete calls response with |client_config_|
void Complete(Status status); void Complete(Status status);
BuildCloudPolicyClientCallback build_client_cb_;
Storage::StartUploadCb start_upload_cb_; Storage::StartUploadCb start_upload_cb_;
UpdateConfigurationCallback update_config_cb_; UpdateConfigurationCallback update_config_cb_;
scoped_refptr<InitializationStateTracker> init_state_tracker_; scoped_refptr<InitializationStateTracker> init_state_tracker_;
InitializationStateTracker::ReleaseLeaderCallback release_leader_cb_; InitializationStateTracker::ReleaseLeaderCallback release_leader_cb_;
Configuration client_config_; std::unique_ptr<Configuration> client_config_;
}; };
~ReportingClient(); ~ReportingClient();
...@@ -194,11 +173,46 @@ class ReportingClient { ...@@ -194,11 +173,46 @@ class ReportingClient {
std::unique_ptr<ReportQueueConfiguration> config, std::unique_ptr<ReportQueueConfiguration> config,
CreateReportQueueCallback create_cb); CreateReportQueueCallback create_cb);
// Sets up the ReportingClient for testing with a specified CloudPolicyClient.
static void Setup_test(std::unique_ptr<policy::CloudPolicyClient> client);
// Resets the singleton object. Should only be used in tests when the current // Resets the singleton object. Should only be used in tests when the current
// TaskEnvironment will be invalidated. // TaskEnvironment will be invalidated.
static void Reset_test(); static void Reset_test();
private: private:
// Uploader is passed to Storage in order to upload messages using the
// UploadClient.
class Uploader : public Storage::UploaderInterface {
public:
using UploadCallback = base::OnceCallback<Status(
std::unique_ptr<std::vector<EncryptedRecord>>)>;
static StatusOr<std::unique_ptr<Uploader>> Create(
UploadCallback upload_callback);
~Uploader() override;
Uploader(const Uploader& other) = delete;
Uploader& operator=(const Uploader& other) = delete;
void ProcessRecord(StatusOr<EncryptedRecord> data,
base::OnceCallback<void(bool)> processed_cb) override;
void Completed(Status final_status) override;
private:
explicit Uploader(UploadCallback upload_callback_);
static void RunUpload(
UploadCallback upload_callback,
std::unique_ptr<std::vector<EncryptedRecord>> encrypted_records);
UploadCallback upload_callback_;
bool completed_{false};
std::unique_ptr<std::vector<EncryptedRecord>> encrypted_records_;
scoped_refptr<base::SequencedTaskRunner> sequenced_task_runner_;
};
// Holds the creation request for a ReportQueue. // Holds the creation request for a ReportQueue.
class CreateReportQueueRequest { class CreateReportQueueRequest {
public: public:
...@@ -222,7 +236,7 @@ class ReportingClient { ...@@ -222,7 +236,7 @@ class ReportingClient {
void OnPushComplete(); void OnPushComplete();
void OnInitState(bool reporting_client_configured); void OnInitState(bool reporting_client_configured);
void OnConfigResult(const Configuration& config, void OnConfigResult(std::unique_ptr<Configuration> config,
base::OnceCallback<void(Status)> continue_init_cb); base::OnceCallback<void(Status)> continue_init_cb);
void OnInitializationComplete(Status init_status); void OnInitializationComplete(Status init_status);
...@@ -236,13 +250,11 @@ class ReportingClient { ...@@ -236,13 +250,11 @@ class ReportingClient {
// Queue for storing creation requests while the ReportingClient is // Queue for storing creation requests while the ReportingClient is
// initializing. // initializing.
scoped_refptr<SharedQueue<CreateReportQueueRequest>> create_request_queue_; scoped_refptr<SharedQueue<CreateReportQueueRequest>> create_request_queue_;
scoped_refptr<InitializationStateTracker> init_state_tracker_; scoped_refptr<InitializationStateTracker> init_state_tracker_;
BuildCloudPolicyClientCallback build_cloud_policy_client_cb_;
scoped_refptr<StorageModule> storage_;
scoped_refptr<EncryptionModule> encryption_;
std::unique_ptr<UploadClient> upload_client_; std::unique_ptr<UploadClient> upload_client_;
Configuration config_; std::unique_ptr<Configuration> config_;
}; };
} // namespace reporting } // namespace reporting
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "base/memory/singleton.h" #include "base/memory/singleton.h"
#include "base/synchronization/waitable_event.h" #include "base/synchronization/waitable_event.h"
#include "base/task/post_task.h"
#include "base/test/task_environment.h" #include "base/test/task_environment.h"
#include "chrome/browser/policy/messaging_layer/public/report_queue.h" #include "chrome/browser/policy/messaging_layer/public/report_queue.h"
#include "chrome/browser/policy/messaging_layer/public/report_queue_configuration.h" #include "chrome/browser/policy/messaging_layer/public/report_queue_configuration.h"
...@@ -13,9 +14,20 @@ ...@@ -13,9 +14,20 @@
#include "chrome/browser/policy/messaging_layer/util/status_macros.h" #include "chrome/browser/policy/messaging_layer/util/status_macros.h"
#include "chrome/browser/policy/messaging_layer/util/statusor.h" #include "chrome/browser/policy/messaging_layer/util/statusor.h"
#include "components/policy/core/common/cloud/dm_token.h" #include "components/policy/core/common/cloud/dm_token.h"
#include "components/policy/core/common/cloud/mock_cloud_policy_client.h"
#include "components/policy/proto/record_constants.pb.h" #include "components/policy/proto/record_constants.pb.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
// #if defined(OS_CHROMEOS)
// #include "chrome/browser/chromeos/settings/device_settings_service.h"
// #include "components/policy/proto/chrome_device_policy.pb.h"
// #else
// #include "chrome/browser/policy/chrome_browser_policy_connector.h"
// #endif
namespace reporting { namespace reporting {
namespace { namespace {
...@@ -25,29 +37,33 @@ using reporting::Priority; ...@@ -25,29 +37,33 @@ using reporting::Priority;
class TestCallbackWaiter { class TestCallbackWaiter {
public: public:
TestCallbackWaiter() TestCallbackWaiter() : run_loop_(std::make_unique<base::RunLoop>()) {}
: completed_(base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED) {}
virtual void Signal() { virtual void Signal() { run_loop_->Quit(); }
DCHECK(!completed_.IsSignaled());
completed_.Signal();
}
void Wait() { completed_.Wait(); } void Wait() { run_loop_->Run(); }
void Reset() { completed_.Reset(); } void Reset() {
run_loop_.reset();
run_loop_ = std::make_unique<base::RunLoop>();
}
protected: protected:
base::WaitableEvent completed_; std::unique_ptr<base::RunLoop> run_loop_;
}; };
class ReportingClientTest : public testing::Test { class ReportingClientTest : public testing::Test {
public: public:
void SetUp() override {
auto client = std::make_unique<policy::MockCloudPolicyClient>();
client->SetDMToken(
policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value());
ReportingClient::Setup_test(std::move(client));
}
void TearDown() override { ReportingClient::Reset_test(); } void TearDown() override { ReportingClient::Reset_test(); }
protected: protected:
base::test::TaskEnvironment task_envrionment_{ content::BrowserTaskEnvironment task_envrionment_;
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
const DMToken dm_token_ = DMToken::CreateValidTokenForTesting("TOKEN"); const DMToken dm_token_ = DMToken::CreateValidTokenForTesting("TOKEN");
const Destination destination_ = Destination::UPLOAD_EVENTS; const Destination destination_ = Destination::UPLOAD_EVENTS;
const Priority priority_ = Priority::IMMEDIATE; const Priority priority_ = Priority::IMMEDIATE;
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "chrome/browser/policy/messaging_layer/util/statusor.h" #include "chrome/browser/policy/messaging_layer/util/statusor.h"
#include "components/policy/core/common/cloud/dm_token.h" #include "components/policy/core/common/cloud/dm_token.h"
#include "components/policy/proto/record_constants.pb.h" #include "components/policy/proto/record_constants.pb.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -144,8 +145,7 @@ class ReportQueueTest : public testing::Test { ...@@ -144,8 +145,7 @@ class ReportQueueTest : public testing::Test {
MOCK_METHOD(Status, MockedPolicyCheck, (), ()); MOCK_METHOD(Status, MockedPolicyCheck, (), ());
base::test::TaskEnvironment task_envrionment_{ content::BrowserTaskEnvironment task_envrionment_;
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
const Priority priority_; const Priority priority_;
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "components/policy/core/common/cloud/cloud_policy_client.h" #include "components/policy/core/common/cloud/cloud_policy_client.h"
#include "components/policy/proto/record.pb.h" #include "components/policy/proto/record.pb.h"
#include "components/policy/proto/record_constants.pb.h" #include "components/policy/proto/record_constants.pb.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
namespace reporting { namespace reporting {
...@@ -111,7 +113,14 @@ void AppInstallReportUploader::OnPopResult(StatusOr<base::Value> pop_result) { ...@@ -111,7 +113,14 @@ void AppInstallReportUploader::OnPopResult(StatusOr<base::Value> pop_result) {
void AppInstallReportUploader::StartUpload(base::Value record) { void AppInstallReportUploader::StartUpload(base::Value record) {
ClientCallback cb = base::BindOnce( ClientCallback cb = base::BindOnce(
&AppInstallReportUploader::OnUploadComplete, base::Unretained(this)); &AppInstallReportUploader::OnUploadComplete, base::Unretained(this));
client_->UploadAppInstallReport(std::move(record), std::move(cb)); base::PostTask(FROM_HERE, {content::BrowserThread::UI},
base::BindOnce(
[](policy::CloudPolicyClient* client, base::Value record,
ClientCallback cb) {
client->UploadExtensionInstallReport(std::move(record),
std::move(cb));
},
client_, std::move(record), std::move(cb)));
} }
void AppInstallReportUploader::OnUploadComplete(bool success) { void AppInstallReportUploader::OnUploadComplete(bool success) {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "components/policy/core/common/cloud/mock_cloud_policy_client.h" #include "components/policy/core/common/cloud/mock_cloud_policy_client.h"
#include "components/policy/proto/record.pb.h" #include "components/policy/proto/record.pb.h"
#include "components/policy/proto/record_constants.pb.h" #include "components/policy/proto/record_constants.pb.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace reporting { namespace reporting {
...@@ -51,19 +52,14 @@ MATCHER_P(MatchValue, expected, "matches base::Value") { ...@@ -51,19 +52,14 @@ MATCHER_P(MatchValue, expected, "matches base::Value") {
class TestCallbackWaiter { class TestCallbackWaiter {
public: public:
TestCallbackWaiter() TestCallbackWaiter() : run_loop_(std::make_unique<base::RunLoop>()) {}
: completed_(base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED) {}
virtual void Signal() { virtual void Signal() { run_loop_->Quit(); }
DCHECK(!completed_.IsSignaled());
completed_.Signal();
}
void Wait() { completed_.Wait(); } void Wait() { run_loop_->Run(); }
protected: protected:
base::WaitableEvent completed_; std::unique_ptr<base::RunLoop> run_loop_;
}; };
class AppInstallReportHandlerTest : public testing::Test { class AppInstallReportHandlerTest : public testing::Test {
...@@ -76,8 +72,7 @@ class AppInstallReportHandlerTest : public testing::Test { ...@@ -76,8 +72,7 @@ class AppInstallReportHandlerTest : public testing::Test {
} }
protected: protected:
base::test::TaskEnvironment task_envrionment_{ content::BrowserTaskEnvironment task_envrionment_;
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
policy::MockCloudPolicyClient client_; policy::MockCloudPolicyClient client_;
}; };
...@@ -104,7 +99,7 @@ TEST_F(AppInstallReportHandlerTest, AcceptsValidRecord) { ...@@ -104,7 +99,7 @@ TEST_F(AppInstallReportHandlerTest, AcceptsValidRecord) {
TestCallbackWaiter waiter; TestCallbackWaiter waiter;
TestRecord test_record; TestRecord test_record;
EXPECT_CALL(client_, EXPECT_CALL(client_,
UploadAppInstallReport_(MatchValue(test_record.data()), _)) UploadExtensionInstallReport_(MatchValue(test_record.data()), _))
.WillOnce(WithArgs<1>( .WillOnce(WithArgs<1>(
Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) { Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) {
std::move(callback).Run(true); std::move(callback).Run(true);
...@@ -118,7 +113,7 @@ TEST_F(AppInstallReportHandlerTest, AcceptsValidRecord) { ...@@ -118,7 +113,7 @@ TEST_F(AppInstallReportHandlerTest, AcceptsValidRecord) {
} }
TEST_F(AppInstallReportHandlerTest, DeniesInvalidDestination) { TEST_F(AppInstallReportHandlerTest, DeniesInvalidDestination) {
EXPECT_CALL(client_, UploadAppInstallReport_(_, _)).Times(0); EXPECT_CALL(client_, UploadExtensionInstallReport_(_, _)).Times(0);
AppInstallReportHandler handler(&client_); AppInstallReportHandler handler(&client_);
TestRecord test_record; TestRecord test_record;
...@@ -130,7 +125,7 @@ TEST_F(AppInstallReportHandlerTest, DeniesInvalidDestination) { ...@@ -130,7 +125,7 @@ TEST_F(AppInstallReportHandlerTest, DeniesInvalidDestination) {
} }
TEST_F(AppInstallReportHandlerTest, DeniesInvalidData) { TEST_F(AppInstallReportHandlerTest, DeniesInvalidData) {
EXPECT_CALL(client_, UploadAppInstallReport_(_, _)).Times(0); EXPECT_CALL(client_, UploadExtensionInstallReport_(_, _)).Times(0);
AppInstallReportHandler handler(&client_); AppInstallReportHandler handler(&client_);
TestRecord test_record; TestRecord test_record;
...@@ -145,7 +140,7 @@ TEST_F(AppInstallReportHandlerTest, ReportsUnsuccessfulCall) { ...@@ -145,7 +140,7 @@ TEST_F(AppInstallReportHandlerTest, ReportsUnsuccessfulCall) {
TestRecord test_record; TestRecord test_record;
EXPECT_CALL(client_, EXPECT_CALL(client_,
UploadAppInstallReport_(MatchValue(test_record.data()), _)) UploadExtensionInstallReport_(MatchValue(test_record.data()), _))
.WillOnce(WithArgs<1>( .WillOnce(WithArgs<1>(
Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) { Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) {
std::move(callback).Run(false); std::move(callback).Run(false);
...@@ -164,13 +159,10 @@ class TestCallbackWaiterWithCounter : public TestCallbackWaiter { ...@@ -164,13 +159,10 @@ class TestCallbackWaiterWithCounter : public TestCallbackWaiter {
: counter_limit_(counter_limit) {} : counter_limit_(counter_limit) {}
void Signal() override { void Signal() override {
DCHECK(!completed_.IsSignaled()); DCHECK_GT(counter_limit_, 0);
const int new_counter = --counter_limit_; if (--counter_limit_ == 0) {
DCHECK_GE(new_counter, 0); run_loop_->Quit();
if (new_counter > 0) {
return;
} }
completed_.Signal();
} }
private: private:
...@@ -183,7 +175,7 @@ TEST_F(AppInstallReportHandlerTest, AcceptsMultipleValidRecords) { ...@@ -183,7 +175,7 @@ TEST_F(AppInstallReportHandlerTest, AcceptsMultipleValidRecords) {
TestRecord test_record; TestRecord test_record;
EXPECT_CALL(client_, EXPECT_CALL(client_,
UploadAppInstallReport_(MatchValue(test_record.data()), _)) UploadExtensionInstallReport_(MatchValue(test_record.data()), _))
.WillRepeatedly(WithArgs<1>( .WillRepeatedly(WithArgs<1>(
Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) { Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) {
std::move(callback).Run(true); std::move(callback).Run(true);
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "components/policy/core/common/cloud/user_cloud_policy_manager.h" #include "components/policy/core/common/cloud/user_cloud_policy_manager.h"
#include "components/policy/proto/record.pb.h" #include "components/policy/proto/record.pb.h"
#include "components/policy/proto/record_constants.pb.h" #include "components/policy/proto/record_constants.pb.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
#if defined(OS_CHROMEOS) #if defined(OS_CHROMEOS)
#include "chrome/browser/chromeos/policy/user_cloud_policy_manager_chromeos.h" #include "chrome/browser/chromeos/policy/user_cloud_policy_manager_chromeos.h"
...@@ -218,14 +220,20 @@ DmServerUploadService::DmServerUploadService( ...@@ -218,14 +220,20 @@ DmServerUploadService::DmServerUploadService(
upload_cb_(upload_cb), upload_cb_(upload_cb),
sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})) {} sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})) {}
DmServerUploadService::~DmServerUploadService() = default; DmServerUploadService::~DmServerUploadService() {
if (client_) {
base::PostTask(
FROM_HERE, {content::BrowserThread::UI},
base::BindOnce(
[](std::unique_ptr<policy::CloudPolicyClient> cloud_policy_client) {
cloud_policy_client.reset();
},
std::move(client_)));
}
}
Status DmServerUploadService::EnqueueUpload( Status DmServerUploadService::EnqueueUpload(
std::unique_ptr<std::vector<EncryptedRecord>> records) { std::unique_ptr<std::vector<EncryptedRecord>> records) {
if (!GetClient()->is_registered()) {
return Status(error::UNAVAILABLE, "DmServer is currently unavailable.");
}
Start<DmServerUploader>( Start<DmServerUploader>(
std::move(records), &record_handlers_, std::move(records), &record_handlers_,
base::BindOnce(&DmServerUploadService::UploadCompletion, base::BindOnce(&DmServerUploadService::UploadCompletion,
......
...@@ -3,19 +3,20 @@ ...@@ -3,19 +3,20 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "chrome/browser/policy/messaging_layer/upload/upload_client.h" #include "chrome/browser/policy/messaging_layer/upload/upload_client.h"
#include "base/bind.h"
#include "base/files/file_path.h"
#include "base/json/json_writer.h" #include "base/json/json_writer.h"
#include "base/test/task_environment.h" #include "base/test/task_environment.h"
#include "base/test/test_mock_time_task_runner.h"
#include "base/values.h" #include "base/values.h"
#include "chrome/browser/policy/messaging_layer/upload/app_install_report_handler.h" #include "chrome/browser/policy/messaging_layer/upload/app_install_report_handler.h"
#include "components/account_id/account_id.h"
#include "components/policy/core/common/cloud/dm_token.h" #include "components/policy/core/common/cloud/dm_token.h"
#include "components/policy/core/common/cloud/mock_cloud_policy_client.h" #include "components/policy/core/common/cloud/mock_cloud_policy_client.h"
#include "components/policy/proto/record.pb.h" #include "components/policy/proto/record.pb.h"
#include "components/policy/proto/record_constants.pb.h" #include "components/policy/proto/record_constants.pb.h"
#include "content/public/test/browser_task_environment.h"
#include "base/bind.h"
#include "base/files/file_path.h"
#include "base/test/test_mock_time_task_runner.h"
#include "components/account_id/account_id.h"
#include "services/network/test/test_network_connection_tracker.h" #include "services/network/test/test_network_connection_tracker.h"
namespace reporting { namespace reporting {
...@@ -29,19 +30,14 @@ using testing::WithArgs; ...@@ -29,19 +30,14 @@ using testing::WithArgs;
class TestCallbackWaiter { class TestCallbackWaiter {
public: public:
TestCallbackWaiter() TestCallbackWaiter() : run_loop_(std::make_unique<base::RunLoop>()) {}
: completed_(base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED) {}
virtual void Signal() { virtual void Signal() { run_loop_->Quit(); }
DCHECK(!completed_.IsSignaled());
completed_.Signal();
}
void Wait() { completed_.Wait(); } void Wait() { run_loop_->Run(); }
protected: protected:
base::WaitableEvent completed_; std::unique_ptr<base::RunLoop> run_loop_;
}; };
class TestCallbackWaiterWithCounter : public TestCallbackWaiter { class TestCallbackWaiterWithCounter : public TestCallbackWaiter {
...@@ -50,10 +46,9 @@ class TestCallbackWaiterWithCounter : public TestCallbackWaiter { ...@@ -50,10 +46,9 @@ class TestCallbackWaiterWithCounter : public TestCallbackWaiter {
: counter_limit_(counter_limit) {} : counter_limit_(counter_limit) {}
void Signal() override { void Signal() override {
DCHECK(!completed_.IsSignaled());
DCHECK_GT(counter_limit_, 0); DCHECK_GT(counter_limit_, 0);
if (--counter_limit_ == 0) { if (--counter_limit_ == 0) {
completed_.Signal(); run_loop_->Quit();
} }
} }
...@@ -62,8 +57,7 @@ class TestCallbackWaiterWithCounter : public TestCallbackWaiter { ...@@ -62,8 +57,7 @@ class TestCallbackWaiterWithCounter : public TestCallbackWaiter {
}; };
TEST(UploadClientTest, CreateUploadClient) { TEST(UploadClientTest, CreateUploadClient) {
base::test::TaskEnvironment task_envrionment{ content::BrowserTaskEnvironment task_envrionment_;
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
const int kExpectedCallTimes = 10; const int kExpectedCallTimes = 10;
const uint64_t kGenerationId = 1234; const uint64_t kGenerationId = 1234;
...@@ -74,7 +68,7 @@ TEST(UploadClientTest, CreateUploadClient) { ...@@ -74,7 +68,7 @@ TEST(UploadClientTest, CreateUploadClient) {
client->SetDMToken( client->SetDMToken(
policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value()); policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value());
EXPECT_CALL(*client, UploadAppInstallReport_(_, _)) EXPECT_CALL(*client, UploadExtensionInstallReport_(_, _))
.WillRepeatedly(WithArgs<1>( .WillRepeatedly(WithArgs<1>(
Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) { Invoke([&waiter](AppInstallReportHandler::ClientCallback& callback) {
std::move(callback).Run(true); std::move(callback).Run(true);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment