Commit 2a3cc46b authored by Sorin Jianu's avatar Sorin Jianu Committed by Commit Bot

Make CrxDownloader a ref-counted thread safe type.

This avoid retention issues and allows instances of CrxDownloader to
be shared between sequences when interfacing with RPC modules such as
BITS/DO for Windows.

There are some mechanical changes such as removing macros.h and
replacing thread checkers with sequence checkers (but not for the
BITS downloader, which has thread affinity due to COM RPC).

Change-Id: I62e159d4ae8999ec70da94247f91f9cc26a5d278
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2293461Reviewed-by: default avatarJoshua Pawlicki <waffles@chromium.org>
Commit-Queue: Sorin Jianu <sorin@chromium.org>
Cr-Commit-Position: refs/heads/master@{#787857}
parent 365b0274
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include "base/files/file_util.h" #include "base/files/file_util.h"
#include "base/location.h" #include "base/location.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/macros.h"
#include "base/metrics/histogram_macros.h" #include "base/metrics/histogram_macros.h"
#include "base/sequenced_task_runner.h" #include "base/sequenced_task_runner.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
...@@ -396,7 +395,7 @@ void CleanupJob(const ComPtr<IBackgroundCopyJob>& job) { ...@@ -396,7 +395,7 @@ void CleanupJob(const ComPtr<IBackgroundCopyJob>& job) {
} // namespace } // namespace
BackgroundDownloader::BackgroundDownloader( BackgroundDownloader::BackgroundDownloader(
std::unique_ptr<CrxDownloader> successor) scoped_refptr<CrxDownloader> successor)
: CrxDownloader(std::move(successor)), : CrxDownloader(std::move(successor)),
com_task_runner_(base::ThreadPool::CreateCOMSTATaskRunner( com_task_runner_(base::ThreadPool::CreateCOMSTATaskRunner(
kTaskTraitsBackgroundDownloader)), kTaskTraitsBackgroundDownloader)),
...@@ -417,15 +416,14 @@ void BackgroundDownloader::StartTimer() { ...@@ -417,15 +416,14 @@ void BackgroundDownloader::StartTimer() {
void BackgroundDownloader::OnTimer() { void BackgroundDownloader::OnTimer() {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
com_task_runner_->PostTask( com_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&BackgroundDownloader::OnDownloading, FROM_HERE, base::BindOnce(&BackgroundDownloader::OnDownloading, this));
base::Unretained(this)));
} }
void BackgroundDownloader::DoStartDownload(const GURL& url) { void BackgroundDownloader::DoStartDownload(const GURL& url) {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
com_task_runner_->PostTask( com_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&BackgroundDownloader::BeginDownload, FROM_HERE,
base::Unretained(this), url)); base::BindOnce(&BackgroundDownloader::BeginDownload, this, url));
} }
// Called one time when this class is asked to do a download. // Called one time when this class is asked to do a download.
...@@ -444,9 +442,8 @@ void BackgroundDownloader::BeginDownload(const GURL& url) { ...@@ -444,9 +442,8 @@ void BackgroundDownloader::BeginDownload(const GURL& url) {
VLOG(1) << "Starting BITS download for: " << url.spec(); VLOG(1) << "Starting BITS download for: " << url.spec();
ResetInterfacePointers(); ResetInterfacePointers();
main_task_runner()->PostTask(FROM_HERE, main_task_runner()->PostTask(
base::BindOnce(&BackgroundDownloader::StartTimer, FROM_HERE, base::BindOnce(&BackgroundDownloader::StartTimer, this));
base::Unretained(this)));
} }
// Creates or opens an existing BITS job to download the |url|, and handles // Creates or opens an existing BITS job to download the |url|, and handles
...@@ -535,9 +532,8 @@ void BackgroundDownloader::OnDownloading() { ...@@ -535,9 +532,8 @@ void BackgroundDownloader::OnDownloading() {
return; return;
ResetInterfacePointers(); ResetInterfacePointers();
main_task_runner()->PostTask(FROM_HERE, main_task_runner()->PostTask(
base::BindOnce(&BackgroundDownloader::StartTimer, FROM_HERE, base::BindOnce(&BackgroundDownloader::StartTimer, this));
base::Unretained(this)));
} }
// Completes the BITS download, picks up the file path of the response, and // Completes the BITS download, picks up the file path of the response, and
...@@ -582,9 +578,8 @@ void BackgroundDownloader::EndDownload(HRESULT error) { ...@@ -582,9 +578,8 @@ void BackgroundDownloader::EndDownload(HRESULT error) {
if (!result.error) if (!result.error)
result.response = response_; result.response = response_;
main_task_runner()->PostTask( main_task_runner()->PostTask(
FROM_HERE, base::BindOnce(&BackgroundDownloader::OnDownloadComplete, FROM_HERE, base::BindOnce(&BackgroundDownloader::OnDownloadComplete, this,
base::Unretained(this), is_handled, result, is_handled, result, download_metrics));
download_metrics));
// Once the task is posted to the the main thread, this object may be deleted // Once the task is posted to the the main thread, this object may be deleted
// by its owner. It is not safe to access members of this object on this task // by its owner. It is not safe to access members of this object on this task
...@@ -666,9 +661,8 @@ bool BackgroundDownloader::OnStateTransferring() { ...@@ -666,9 +661,8 @@ bool BackgroundDownloader::OnStateTransferring() {
GetJobByteCount(job_, &downloaded_bytes, &total_bytes); GetJobByteCount(job_, &downloaded_bytes, &total_bytes);
main_task_runner()->PostTask( main_task_runner()->PostTask(
FROM_HERE, FROM_HERE, base::BindOnce(&BackgroundDownloader::OnDownloadProgress, this,
base::BindOnce(&BackgroundDownloader::OnDownloadProgress, downloaded_bytes, total_bytes));
base::Unretained(this), downloaded_bytes, total_bytes));
return false; return false;
} }
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <memory> #include <memory>
#include "base/macros.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/strings/string16.h" #include "base/strings/string16.h"
#include "base/threading/thread_checker.h" #include "base/threading/thread_checker.h"
...@@ -39,11 +38,11 @@ namespace update_client { ...@@ -39,11 +38,11 @@ namespace update_client {
// the BITS service. // the BITS service.
class BackgroundDownloader : public CrxDownloader { class BackgroundDownloader : public CrxDownloader {
public: public:
explicit BackgroundDownloader(std::unique_ptr<CrxDownloader> successor); explicit BackgroundDownloader(scoped_refptr<CrxDownloader> successor);
~BackgroundDownloader() override;
private: private:
// Overrides for CrxDownloader. // Overrides for CrxDownloader.
~BackgroundDownloader() override;
void DoStartDownload(const GURL& url) override; void DoStartDownload(const GURL& url) override;
// Called asynchronously on the |com_task_runner_| at different stages during // Called asynchronously on the |com_task_runner_| at different stages during
...@@ -148,8 +147,6 @@ class BackgroundDownloader : public CrxDownloader { ...@@ -148,8 +147,6 @@ class BackgroundDownloader : public CrxDownloader {
// Contains the path of the downloaded file if the download was successful. // Contains the path of the downloaded file if the download was successful.
base::FilePath response_; base::FilePath response_;
DISALLOW_COPY_AND_ASSIGN(BackgroundDownloader);
}; };
} // namespace update_client } // namespace update_client
......
...@@ -759,7 +759,7 @@ void Component::StateDownloadingDiff::DownloadComplete( ...@@ -759,7 +759,7 @@ void Component::StateDownloadingDiff::DownloadComplete(
for (const auto& download_metrics : crx_downloader_->download_metrics()) for (const auto& download_metrics : crx_downloader_->download_metrics())
component.AppendEvent(component.MakeEventDownloadMetrics(download_metrics)); component.AppendEvent(component.MakeEventDownloadMetrics(download_metrics));
crx_downloader_.reset(); crx_downloader_ = nullptr;
if (download_result.error) { if (download_result.error) {
DCHECK(download_result.response.empty()); DCHECK(download_result.response.empty());
...@@ -832,7 +832,7 @@ void Component::StateDownloading::DownloadComplete( ...@@ -832,7 +832,7 @@ void Component::StateDownloading::DownloadComplete(
for (const auto& download_metrics : crx_downloader_->download_metrics()) for (const auto& download_metrics : crx_downloader_->download_metrics())
component.AppendEvent(component.MakeEventDownloadMetrics(download_metrics)); component.AppendEvent(component.MakeEventDownloadMetrics(download_metrics));
crx_downloader_.reset(); crx_downloader_ = nullptr;
if (download_result.error) { if (download_result.error) {
DCHECK(download_result.response.empty()); DCHECK(download_result.response.empty());
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/gtest_prod_util.h" #include "base/gtest_prod_util.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/sequence_checker.h" #include "base/sequence_checker.h"
...@@ -45,6 +44,8 @@ class Component { ...@@ -45,6 +44,8 @@ class Component {
using CallbackHandleComplete = base::OnceCallback<void()>; using CallbackHandleComplete = base::OnceCallback<void()>;
Component(const UpdateContext& update_context, const std::string& id); Component(const UpdateContext& update_context, const std::string& id);
Component(const Component&) = delete;
Component& operator=(const Component&) = delete;
~Component(); ~Component();
// Handles the current state of the component and makes it transition // Handles the current state of the component and makes it transition
...@@ -183,18 +184,20 @@ class Component { ...@@ -183,18 +184,20 @@ class Component {
class StateNew : public State { class StateNew : public State {
public: public:
explicit StateNew(Component* component); explicit StateNew(Component* component);
StateNew(const StateNew&) = delete;
StateNew& operator=(const StateNew&) = delete;
~StateNew() override; ~StateNew() override;
private: private:
// State overrides. // State overrides.
void DoHandle() override; void DoHandle() override;
DISALLOW_COPY_AND_ASSIGN(StateNew);
}; };
class StateChecking : public State { class StateChecking : public State {
public: public:
explicit StateChecking(Component* component); explicit StateChecking(Component* component);
StateChecking(const StateChecking&) = delete;
StateChecking& operator=(const StateChecking&) = delete;
~StateChecking() override; ~StateChecking() override;
private: private:
...@@ -202,50 +205,50 @@ class Component { ...@@ -202,50 +205,50 @@ class Component {
void DoHandle() override; void DoHandle() override;
void UpdateCheckComplete(); void UpdateCheckComplete();
DISALLOW_COPY_AND_ASSIGN(StateChecking);
}; };
class StateUpdateError : public State { class StateUpdateError : public State {
public: public:
explicit StateUpdateError(Component* component); explicit StateUpdateError(Component* component);
StateUpdateError(const StateUpdateError&) = delete;
StateUpdateError& operator=(const StateUpdateError&) = delete;
~StateUpdateError() override; ~StateUpdateError() override;
private: private:
// State overrides. // State overrides.
void DoHandle() override; void DoHandle() override;
DISALLOW_COPY_AND_ASSIGN(StateUpdateError);
}; };
class StateCanUpdate : public State { class StateCanUpdate : public State {
public: public:
explicit StateCanUpdate(Component* component); explicit StateCanUpdate(Component* component);
StateCanUpdate(const StateCanUpdate&) = delete;
StateCanUpdate& operator=(const StateCanUpdate&) = delete;
~StateCanUpdate() override; ~StateCanUpdate() override;
private: private:
// State overrides. // State overrides.
void DoHandle() override; void DoHandle() override;
bool CanTryDiffUpdate() const; bool CanTryDiffUpdate() const;
DISALLOW_COPY_AND_ASSIGN(StateCanUpdate);
}; };
class StateUpToDate : public State { class StateUpToDate : public State {
public: public:
explicit StateUpToDate(Component* component); explicit StateUpToDate(Component* component);
StateUpToDate(const StateUpToDate&) = delete;
StateUpToDate& operator=(const StateUpToDate&) = delete;
~StateUpToDate() override; ~StateUpToDate() override;
private: private:
// State overrides. // State overrides.
void DoHandle() override; void DoHandle() override;
DISALLOW_COPY_AND_ASSIGN(StateUpToDate);
}; };
class StateDownloadingDiff : public State { class StateDownloadingDiff : public State {
public: public:
explicit StateDownloadingDiff(Component* component); explicit StateDownloadingDiff(Component* component);
StateDownloadingDiff(const StateDownloadingDiff&) = delete;
StateDownloadingDiff& operator=(const StateDownloadingDiff&) = delete;
~StateDownloadingDiff() override; ~StateDownloadingDiff() override;
private: private:
...@@ -260,14 +263,14 @@ class Component { ...@@ -260,14 +263,14 @@ class Component {
void DownloadComplete(const CrxDownloader::Result& download_result); void DownloadComplete(const CrxDownloader::Result& download_result);
// Downloads updates for one CRX id only. // Downloads updates for one CRX id only.
std::unique_ptr<CrxDownloader> crx_downloader_; scoped_refptr<CrxDownloader> crx_downloader_;
DISALLOW_COPY_AND_ASSIGN(StateDownloadingDiff);
}; };
class StateDownloading : public State { class StateDownloading : public State {
public: public:
explicit StateDownloading(Component* component); explicit StateDownloading(Component* component);
StateDownloading(const StateDownloading&) = delete;
StateDownloading& operator=(const StateDownloading&) = delete;
~StateDownloading() override; ~StateDownloading() override;
private: private:
...@@ -282,14 +285,14 @@ class Component { ...@@ -282,14 +285,14 @@ class Component {
void DownloadComplete(const CrxDownloader::Result& download_result); void DownloadComplete(const CrxDownloader::Result& download_result);
// Downloads updates for one CRX id only. // Downloads updates for one CRX id only.
std::unique_ptr<CrxDownloader> crx_downloader_; scoped_refptr<CrxDownloader> crx_downloader_;
DISALLOW_COPY_AND_ASSIGN(StateDownloading);
}; };
class StateUpdatingDiff : public State { class StateUpdatingDiff : public State {
public: public:
explicit StateUpdatingDiff(Component* component); explicit StateUpdatingDiff(Component* component);
StateUpdatingDiff(const StateUpdatingDiff&) = delete;
StateUpdatingDiff& operator=(const StateUpdatingDiff&) = delete;
~StateUpdatingDiff() override; ~StateUpdatingDiff() override;
private: private:
...@@ -300,13 +303,13 @@ class Component { ...@@ -300,13 +303,13 @@ class Component {
void InstallComplete(ErrorCategory error_category, void InstallComplete(ErrorCategory error_category,
int error_code, int error_code,
int extra_code1); int extra_code1);
DISALLOW_COPY_AND_ASSIGN(StateUpdatingDiff);
}; };
class StateUpdating : public State { class StateUpdating : public State {
public: public:
explicit StateUpdating(Component* component); explicit StateUpdating(Component* component);
StateUpdating(const StateUpdating&) = delete;
StateUpdating& operator=(const StateUpdating&) = delete;
~StateUpdating() override; ~StateUpdating() override;
private: private:
...@@ -317,37 +320,37 @@ class Component { ...@@ -317,37 +320,37 @@ class Component {
void InstallComplete(ErrorCategory error_category, void InstallComplete(ErrorCategory error_category,
int error_code, int error_code,
int extra_code1); int extra_code1);
DISALLOW_COPY_AND_ASSIGN(StateUpdating);
}; };
class StateUpdated : public State { class StateUpdated : public State {
public: public:
explicit StateUpdated(Component* component); explicit StateUpdated(Component* component);
StateUpdated(const StateUpdated&) = delete;
StateUpdated& operator=(const StateUpdated&) = delete;
~StateUpdated() override; ~StateUpdated() override;
private: private:
// State overrides. // State overrides.
void DoHandle() override; void DoHandle() override;
DISALLOW_COPY_AND_ASSIGN(StateUpdated);
}; };
class StateUninstalled : public State { class StateUninstalled : public State {
public: public:
explicit StateUninstalled(Component* component); explicit StateUninstalled(Component* component);
StateUninstalled(const StateUninstalled&) = delete;
StateUninstalled& operator=(const StateUninstalled&) = delete;
~StateUninstalled() override; ~StateUninstalled() override;
private: private:
// State overrides. // State overrides.
void DoHandle() override; void DoHandle() override;
DISALLOW_COPY_AND_ASSIGN(StateUninstalled);
}; };
class StateRun : public State { class StateRun : public State {
public: public:
explicit StateRun(Component* component); explicit StateRun(Component* component);
StateRun(const StateRun&) = delete;
StateRun& operator=(const StateRun&) = delete;
~StateRun() override; ~StateRun() override;
private: private:
...@@ -359,8 +362,6 @@ class Component { ...@@ -359,8 +362,6 @@ class Component {
// Runs the action referred by the |action_run_| member of the Component // Runs the action referred by the |action_run_| member of the Component
// class. // class.
std::unique_ptr<ActionRunner> action_runner_; std::unique_ptr<ActionRunner> action_runner_;
DISALLOW_COPY_AND_ASSIGN(StateRun);
}; };
// Returns true is the update payload for this component can be downloaded // Returns true is the update payload for this component can be downloaded
...@@ -479,8 +480,6 @@ class Component { ...@@ -479,8 +480,6 @@ class Component {
// True if this component has reached a final state because all its states // True if this component has reached a final state because all its states
// have been handled. // have been handled.
bool is_handled_ = false; bool is_handled_ = false;
DISALLOW_COPY_AND_ASSIGN(Component);
}; };
using IdToComponentPtrMap = std::map<std::string, std::unique_ptr<Component>>; using IdToComponentPtrMap = std::map<std::string, std::unique_ptr<Component>>;
......
...@@ -33,15 +33,16 @@ CrxDownloader::DownloadMetrics::DownloadMetrics() ...@@ -33,15 +33,16 @@ CrxDownloader::DownloadMetrics::DownloadMetrics()
// On Windows, the first downloader in the chain is a background downloader, // On Windows, the first downloader in the chain is a background downloader,
// which uses the BITS service. // which uses the BITS service.
std::unique_ptr<CrxDownloader> CrxDownloader::Create( scoped_refptr<CrxDownloader> CrxDownloader::Create(
bool is_background_download, bool is_background_download,
scoped_refptr<NetworkFetcherFactory> network_fetcher_factory) { scoped_refptr<NetworkFetcherFactory> network_fetcher_factory) {
std::unique_ptr<CrxDownloader> url_fetcher_downloader = scoped_refptr<CrxDownloader> url_fetcher_downloader =
std::make_unique<UrlFetcherDownloader>(nullptr, network_fetcher_factory); base::MakeRefCounted<UrlFetcherDownloader>(nullptr,
network_fetcher_factory);
#if defined(OS_WIN) #if defined(OS_WIN)
if (is_background_download) { if (is_background_download) {
return std::make_unique<BackgroundDownloader>( return base::MakeRefCounted<BackgroundDownloader>(
std::move(url_fetcher_downloader)); std::move(url_fetcher_downloader));
} }
#endif #endif
...@@ -49,7 +50,7 @@ std::unique_ptr<CrxDownloader> CrxDownloader::Create( ...@@ -49,7 +50,7 @@ std::unique_ptr<CrxDownloader> CrxDownloader::Create(
return url_fetcher_downloader; return url_fetcher_downloader;
} }
CrxDownloader::CrxDownloader(std::unique_ptr<CrxDownloader> successor) CrxDownloader::CrxDownloader(scoped_refptr<CrxDownloader> successor)
: main_task_runner_(base::ThreadTaskRunnerHandle::Get()), : main_task_runner_(base::ThreadTaskRunnerHandle::Get()),
successor_(std::move(successor)) {} successor_(std::move(successor)) {}
...@@ -86,7 +87,7 @@ void CrxDownloader::StartDownloadFromUrl(const GURL& url, ...@@ -86,7 +87,7 @@ void CrxDownloader::StartDownloadFromUrl(const GURL& url,
void CrxDownloader::StartDownload(const std::vector<GURL>& urls, void CrxDownloader::StartDownload(const std::vector<GURL>& urls,
const std::string& expected_hash, const std::string& expected_hash,
DownloadCallback download_callback) { DownloadCallback download_callback) {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto error = CrxDownloaderError::NONE; auto error = CrxDownloaderError::NONE;
if (urls.empty()) { if (urls.empty()) {
...@@ -115,23 +116,22 @@ void CrxDownloader::OnDownloadComplete( ...@@ -115,23 +116,22 @@ void CrxDownloader::OnDownloadComplete(
bool is_handled, bool is_handled,
const Result& result, const Result& result,
const DownloadMetrics& download_metrics) { const DownloadMetrics& download_metrics) {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!result.error) if (!result.error)
base::ThreadPool::PostTask( base::ThreadPool::PostTask(
FROM_HERE, kTaskTraits, FROM_HERE, kTaskTraits,
base::BindOnce(&CrxDownloader::VerifyResponse, base::Unretained(this), base::BindOnce(&CrxDownloader::VerifyResponse, this, is_handled, result,
is_handled, result, download_metrics)); download_metrics));
else else
main_task_runner()->PostTask( main_task_runner()->PostTask(
FROM_HERE, base::BindOnce(&CrxDownloader::HandleDownloadError, FROM_HERE, base::BindOnce(&CrxDownloader::HandleDownloadError, this,
base::Unretained(this), is_handled, result, is_handled, result, download_metrics));
download_metrics));
} }
void CrxDownloader::OnDownloadProgress(int64_t downloaded_bytes, void CrxDownloader::OnDownloadProgress(int64_t downloaded_bytes,
int64_t total_bytes) { int64_t total_bytes) {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (progress_callback_.is_null()) if (progress_callback_.is_null())
return; return;
...@@ -164,16 +164,15 @@ void CrxDownloader::VerifyResponse(bool is_handled, ...@@ -164,16 +164,15 @@ void CrxDownloader::VerifyResponse(bool is_handled,
result.response.clear(); result.response.clear();
main_task_runner()->PostTask( main_task_runner()->PostTask(
FROM_HERE, base::BindOnce(&CrxDownloader::HandleDownloadError, FROM_HERE, base::BindOnce(&CrxDownloader::HandleDownloadError, this,
base::Unretained(this), is_handled, result, is_handled, result, download_metrics));
download_metrics));
} }
void CrxDownloader::HandleDownloadError( void CrxDownloader::HandleDownloadError(
bool is_handled, bool is_handled,
const Result& result, const Result& result,
const DownloadMetrics& download_metrics) { const DownloadMetrics& download_metrics) {
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_NE(0, result.error); DCHECK_NE(0, result.error);
DCHECK(result.response.empty()); DCHECK(result.response.empty());
DCHECK_NE(0, download_metrics.error); DCHECK_NE(0, download_metrics.error);
......
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
#include "base/callback.h" #include "base/callback.h"
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/sequence_checker.h"
#include "base/single_thread_task_runner.h" #include "base/single_thread_task_runner.h"
#include "base/threading/thread_checker.h"
#include "url/gurl.h" #include "url/gurl.h"
namespace update_client { namespace update_client {
...@@ -33,7 +32,7 @@ class NetworkFetcherFactory; ...@@ -33,7 +32,7 @@ class NetworkFetcherFactory;
// the order they are provided in the StartDownload function argument. After // the order they are provided in the StartDownload function argument. After
// that, the download request is routed to the next downloader in the chain. // that, the download request is routed to the next downloader in the chain.
// The members of this class expect to be called from the main thread only. // The members of this class expect to be called from the main thread only.
class CrxDownloader { class CrxDownloader : public base::RefCountedThreadSafe<CrxDownloader> {
public: public:
struct DownloadMetrics { struct DownloadMetrics {
enum Downloader { kNone = 0, kUrlFetcher, kBits }; enum Downloader { kNone = 0, kUrlFetcher, kBits };
...@@ -78,18 +77,20 @@ class CrxDownloader { ...@@ -78,18 +77,20 @@ class CrxDownloader {
int64_t total_bytes)>; int64_t total_bytes)>;
using Factory = using Factory =
std::unique_ptr<CrxDownloader> (*)(bool, scoped_refptr<CrxDownloader> (*)(bool,
scoped_refptr<NetworkFetcherFactory>); scoped_refptr<NetworkFetcherFactory>);
CrxDownloader(const CrxDownloader&) = delete;
CrxDownloader& operator=(const CrxDownloader&) = delete;
// Factory method to create an instance of this class and build the // Factory method to create an instance of this class and build the
// chain of responsibility. |is_background_download| specifies that a // chain of responsibility. |is_background_download| specifies that a
// background downloader be used, if the platform supports it. // background downloader be used, if the platform supports it.
// |task_runner| should be a task runner able to run blocking // |task_runner| should be a task runner able to run blocking
// code such as file IO operations. // code such as file IO operations.
static std::unique_ptr<CrxDownloader> Create( static scoped_refptr<CrxDownloader> Create(
bool is_background_download, bool is_background_download,
scoped_refptr<NetworkFetcherFactory> network_fetcher_factory); scoped_refptr<NetworkFetcherFactory> network_fetcher_factory);
virtual ~CrxDownloader();
void set_progress_callback(const ProgressCallback& progress_callback); void set_progress_callback(const ProgressCallback& progress_callback);
...@@ -108,7 +109,8 @@ class CrxDownloader { ...@@ -108,7 +109,8 @@ class CrxDownloader {
const std::vector<DownloadMetrics> download_metrics() const; const std::vector<DownloadMetrics> download_metrics() const;
protected: protected:
explicit CrxDownloader(std::unique_ptr<CrxDownloader> successor); explicit CrxDownloader(scoped_refptr<CrxDownloader> successor);
virtual ~CrxDownloader();
// Handles the fallback in the case of multiple urls and routing of the // Handles the fallback in the case of multiple urls and routing of the
// download to the following successor in the chain. Derived classes must call // download to the following successor in the chain. Derived classes must call
...@@ -133,6 +135,8 @@ class CrxDownloader { ...@@ -133,6 +135,8 @@ class CrxDownloader {
} }
private: private:
friend class base::RefCountedThreadSafe<CrxDownloader>;
virtual void DoStartDownload(const GURL& url) = 0; virtual void DoStartDownload(const GURL& url) = 0;
void VerifyResponse(bool is_handled, void VerifyResponse(bool is_handled,
...@@ -143,7 +147,7 @@ class CrxDownloader { ...@@ -143,7 +147,7 @@ class CrxDownloader {
const Result& result, const Result& result,
const DownloadMetrics& download_metrics); const DownloadMetrics& download_metrics);
base::ThreadChecker thread_checker_; SEQUENCE_CHECKER(sequence_checker_);
// Used to post callbacks to the main thread. // Used to post callbacks to the main thread.
scoped_refptr<base::SingleThreadTaskRunner> main_task_runner_; scoped_refptr<base::SingleThreadTaskRunner> main_task_runner_;
...@@ -152,15 +156,13 @@ class CrxDownloader { ...@@ -152,15 +156,13 @@ class CrxDownloader {
// The SHA256 hash of the download payload in hexadecimal format. // The SHA256 hash of the download payload in hexadecimal format.
std::string expected_hash_; std::string expected_hash_;
std::unique_ptr<CrxDownloader> successor_; scoped_refptr<CrxDownloader> successor_;
DownloadCallback download_callback_; DownloadCallback download_callback_;
ProgressCallback progress_callback_; ProgressCallback progress_callback_;
std::vector<GURL>::iterator current_url_; std::vector<GURL>::iterator current_url_;
std::vector<DownloadMetrics> download_metrics_; std::vector<DownloadMetrics> download_metrics_;
DISALLOW_COPY_AND_ASSIGN(CrxDownloader);
}; };
} // namespace update_client } // namespace update_client
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include "components/update_client/crx_downloader.h" #include "components/update_client/crx_downloader.h"
#include <memory>
#include <utility> #include <utility>
#include "base/bind.h" #include "base/bind.h"
...@@ -69,7 +68,7 @@ class CrxDownloaderTest : public testing::Test { ...@@ -69,7 +68,7 @@ class CrxDownloaderTest : public testing::Test {
int net_error); int net_error);
protected: protected:
std::unique_ptr<CrxDownloader> crx_downloader_; scoped_refptr<CrxDownloader> crx_downloader_;
network::TestURLLoaderFactory test_url_loader_factory_; network::TestURLLoaderFactory test_url_loader_factory_;
...@@ -120,7 +119,7 @@ void CrxDownloaderTest::SetUp() { ...@@ -120,7 +119,7 @@ void CrxDownloaderTest::SetUp() {
} }
void CrxDownloaderTest::TearDown() { void CrxDownloaderTest::TearDown() {
crx_downloader_.reset(); crx_downloader_ = nullptr;
} }
void CrxDownloaderTest::Quit() { void CrxDownloaderTest::Quit() {
...@@ -175,6 +174,7 @@ void CrxDownloaderTest::RunThreads() { ...@@ -175,6 +174,7 @@ void CrxDownloaderTest::RunThreads() {
RunThreadsUntilIdle(); RunThreadsUntilIdle();
} }
// TODO(crbug.com/1104691): rewrite the tests to not use RunUntilIdle().
void CrxDownloaderTest::RunThreadsUntilIdle() { void CrxDownloaderTest::RunThreadsUntilIdle() {
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
......
...@@ -21,23 +21,21 @@ ...@@ -21,23 +21,21 @@
namespace update_client { namespace update_client {
UrlFetcherDownloader::UrlFetcherDownloader( UrlFetcherDownloader::UrlFetcherDownloader(
std::unique_ptr<CrxDownloader> successor, scoped_refptr<CrxDownloader> successor,
scoped_refptr<NetworkFetcherFactory> network_fetcher_factory) scoped_refptr<NetworkFetcherFactory> network_fetcher_factory)
: CrxDownloader(std::move(successor)), : CrxDownloader(std::move(successor)),
network_fetcher_factory_(network_fetcher_factory) {} network_fetcher_factory_(network_fetcher_factory) {}
UrlFetcherDownloader::~UrlFetcherDownloader() { UrlFetcherDownloader::~UrlFetcherDownloader() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
} }
void UrlFetcherDownloader::DoStartDownload(const GURL& url) { void UrlFetcherDownloader::DoStartDownload(const GURL& url) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
base::ThreadPool::PostTaskAndReply( base::ThreadPool::PostTaskAndReply(
FROM_HERE, kTaskTraits, FROM_HERE, kTaskTraits,
base::BindOnce(&UrlFetcherDownloader::CreateDownloadDir, base::BindOnce(&UrlFetcherDownloader::CreateDownloadDir, this),
base::Unretained(this)), base::BindOnce(&UrlFetcherDownloader::StartURLFetch, this, url));
base::BindOnce(&UrlFetcherDownloader::StartURLFetch,
base::Unretained(this), url));
} }
void UrlFetcherDownloader::CreateDownloadDir() { void UrlFetcherDownloader::CreateDownloadDir() {
...@@ -46,7 +44,7 @@ void UrlFetcherDownloader::CreateDownloadDir() { ...@@ -46,7 +44,7 @@ void UrlFetcherDownloader::CreateDownloadDir() {
} }
void UrlFetcherDownloader::StartURLFetch(const GURL& url) { void UrlFetcherDownloader::StartURLFetch(const GURL& url) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (download_dir_.empty()) { if (download_dir_.empty()) {
Result result; Result result;
...@@ -62,8 +60,7 @@ void UrlFetcherDownloader::StartURLFetch(const GURL& url) { ...@@ -62,8 +60,7 @@ void UrlFetcherDownloader::StartURLFetch(const GURL& url) {
main_task_runner()->PostTask( main_task_runner()->PostTask(
FROM_HERE, base::BindOnce(&UrlFetcherDownloader::OnDownloadComplete, FROM_HERE, base::BindOnce(&UrlFetcherDownloader::OnDownloadComplete,
base::Unretained(this), false, result, this, false, result, download_metrics));
download_metrics));
return; return;
} }
...@@ -71,19 +68,16 @@ void UrlFetcherDownloader::StartURLFetch(const GURL& url) { ...@@ -71,19 +68,16 @@ void UrlFetcherDownloader::StartURLFetch(const GURL& url) {
network_fetcher_ = network_fetcher_factory_->Create(); network_fetcher_ = network_fetcher_factory_->Create();
network_fetcher_->DownloadToFile( network_fetcher_->DownloadToFile(
url, file_path_, url, file_path_,
base::BindOnce(&UrlFetcherDownloader::OnResponseStarted, base::BindOnce(&UrlFetcherDownloader::OnResponseStarted, this),
base::Unretained(this)), base::BindRepeating(&UrlFetcherDownloader::OnDownloadProgress, this),
base::BindRepeating(&UrlFetcherDownloader::OnDownloadProgress, base::BindOnce(&UrlFetcherDownloader::OnNetworkFetcherComplete, this));
base::Unretained(this)),
base::BindOnce(&UrlFetcherDownloader::OnNetworkFetcherComplete,
base::Unretained(this)));
download_start_time_ = base::TimeTicks::Now(); download_start_time_ = base::TimeTicks::Now();
} }
void UrlFetcherDownloader::OnNetworkFetcherComplete(int net_error, void UrlFetcherDownloader::OnNetworkFetcherComplete(int net_error,
int64_t content_size) { int64_t content_size) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
const base::TimeTicks download_end_time(base::TimeTicks::Now()); const base::TimeTicks download_end_time(base::TimeTicks::Now());
const base::TimeDelta download_time = const base::TimeDelta download_time =
...@@ -131,15 +125,15 @@ void UrlFetcherDownloader::OnNetworkFetcherComplete(int net_error, ...@@ -131,15 +125,15 @@ void UrlFetcherDownloader::OnNetworkFetcherComplete(int net_error,
} }
main_task_runner()->PostTask( main_task_runner()->PostTask(
FROM_HERE, base::BindOnce(&UrlFetcherDownloader::OnDownloadComplete, FROM_HERE, base::BindOnce(&UrlFetcherDownloader::OnDownloadComplete, this,
base::Unretained(this), is_handled, result, is_handled, result, download_metrics));
download_metrics)); network_fetcher_ = nullptr;
} }
// This callback is used to indicate that a download has been started. // This callback is used to indicate that a download has been started.
void UrlFetcherDownloader::OnResponseStarted(int response_code, void UrlFetcherDownloader::OnResponseStarted(int response_code,
int64_t content_length) { int64_t content_length) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
VLOG(1) << "url fetcher response started for: " << url().spec(); VLOG(1) << "url fetcher response started for: " << url().spec();
...@@ -148,7 +142,7 @@ void UrlFetcherDownloader::OnResponseStarted(int response_code, ...@@ -148,7 +142,7 @@ void UrlFetcherDownloader::OnResponseStarted(int response_code,
} }
void UrlFetcherDownloader::OnDownloadProgress(int64_t current) { void UrlFetcherDownloader::OnDownloadProgress(int64_t current) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CrxDownloader::OnDownloadProgress(current, total_bytes_); CrxDownloader::OnDownloadProgress(current, total_bytes_);
} }
......
...@@ -10,9 +10,8 @@ ...@@ -10,9 +10,8 @@
#include <memory> #include <memory>
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/threading/thread_checker.h" #include "base/sequence_checker.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "components/update_client/crx_downloader.h" #include "components/update_client/crx_downloader.h"
...@@ -25,12 +24,14 @@ class NetworkFetcherFactory; ...@@ -25,12 +24,14 @@ class NetworkFetcherFactory;
class UrlFetcherDownloader : public CrxDownloader { class UrlFetcherDownloader : public CrxDownloader {
public: public:
UrlFetcherDownloader( UrlFetcherDownloader(
std::unique_ptr<CrxDownloader> successor, scoped_refptr<CrxDownloader> successor,
scoped_refptr<NetworkFetcherFactory> network_fetcher_factory); scoped_refptr<NetworkFetcherFactory> network_fetcher_factory);
~UrlFetcherDownloader() override; UrlFetcherDownloader(const UrlFetcherDownloader&) = delete;
UrlFetcherDownloader& operator=(const UrlFetcherDownloader&) = delete;
private: private:
// Overrides for CrxDownloader. // Overrides for CrxDownloader.
~UrlFetcherDownloader() override;
void DoStartDownload(const GURL& url) override; void DoStartDownload(const GURL& url) override;
void CreateDownloadDir(); void CreateDownloadDir();
...@@ -39,7 +40,7 @@ class UrlFetcherDownloader : public CrxDownloader { ...@@ -39,7 +40,7 @@ class UrlFetcherDownloader : public CrxDownloader {
void OnResponseStarted(int response_code, int64_t content_length); void OnResponseStarted(int response_code, int64_t content_length);
void OnDownloadProgress(int64_t content_length); void OnDownloadProgress(int64_t content_length);
THREAD_CHECKER(thread_checker_); SEQUENCE_CHECKER(sequence_checker_);
scoped_refptr<NetworkFetcherFactory> network_fetcher_factory_; scoped_refptr<NetworkFetcherFactory> network_fetcher_factory_;
std::unique_ptr<NetworkFetcher> network_fetcher_; std::unique_ptr<NetworkFetcher> network_fetcher_;
...@@ -54,8 +55,6 @@ class UrlFetcherDownloader : public CrxDownloader { ...@@ -54,8 +55,6 @@ class UrlFetcherDownloader : public CrxDownloader {
int response_code_ = -1; int response_code_ = -1;
int64_t total_bytes_ = -1; int64_t total_bytes_ = -1;
DISALLOW_COPY_AND_ASSIGN(UrlFetcherDownloader);
}; };
} // namespace update_client } // namespace update_client
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment