Commit 6d12f07b authored by Nate Chapin's avatar Nate Chapin Committed by Commit Bot

Make ThreadPoolTask oilpan-friendly and merge it in to Task

Change-Id: I94ade605404051bf0d735b936923c2084bb805bc
Reviewed-on: https://chromium-review.googlesource.com/c/1316436
Commit-Queue: Nate Chapin <japhet@chromium.org>
Reviewed-by: default avatarHiroki Nakagawa <nhiroki@chromium.org>
Cr-Commit-Position: refs/heads/master@{#605457}
parent 2a90bffa
...@@ -18,41 +18,40 @@ class SerializedScriptValue; ...@@ -18,41 +18,40 @@ class SerializedScriptValue;
// Scans |arguments| for Task objects, and registers those as dependencies, // Scans |arguments| for Task objects, and registers those as dependencies,
// passing the result of those tasks in place of the Task arguments. // passing the result of those tasks in place of the Task arguments.
// All public functions are main-thread-only. // All public functions are main-thread-only.
// ThreadPoolTask keeps itself alive via a self scoped_refptr until the // Task keeps itself alive via a SelfKeepAlive until the
// the task completes and reports itself done on the main thread via // the task completes and reports itself done on the main thread via
// TaskCompleted(). Other users (e.g. Task below) can keep the task // TaskCompleted().
// alive after completion. class Task final : public ScriptWrappable {
class ThreadPoolTask final : public RefCounted<ThreadPoolTask> { DEFINE_WRAPPERTYPEINFO();
public: public:
// Called on main thread // Called on main thread
ThreadPoolTask(ThreadPoolThreadProvider*, Task(ThreadPoolThreadProvider*,
ScriptState*, ScriptState*,
const ScriptValue& function, const ScriptValue& function,
const Vector<ScriptValue>& arguments, const Vector<ScriptValue>& arguments,
TaskType); TaskType);
ThreadPoolTask(ThreadPoolThreadProvider*, Task(ThreadPoolThreadProvider*,
ScriptState*, ScriptState*,
const String& function_name, const String& function_name,
const Vector<ScriptValue>& arguments, const Vector<ScriptValue>& arguments,
TaskType); TaskType);
~ThreadPoolTask(); ~Task() override;
// Returns a promise that will be resolved with the result when it completes. // Returns a promise that will be resolved with the result when it completes.
ScriptPromise GetResult(); ScriptPromise result();
void Cancel() LOCKS_EXCLUDED(mutex_); void cancel() LOCKS_EXCLUDED(mutex_);
base::WeakPtr<ThreadPoolTask> GetWeakPtr() { void Trace(Visitor*) override;
return weak_factory_.GetWeakPtr();
}
private: private:
enum class State { kPending, kStarted, kCancelPending, kCompleted, kFailed }; enum class State { kPending, kStarted, kCancelPending, kCompleted, kFailed };
ThreadPoolTask(ThreadPoolThreadProvider*, Task(ThreadPoolThreadProvider*,
ScriptState*, ScriptState*,
const ScriptValue& function, const ScriptValue& function,
const String& function_name, const String& function_name,
const Vector<ScriptValue>& arguments, const Vector<ScriptValue>& arguments,
TaskType); TaskType);
class AsyncFunctionCompleted; class AsyncFunctionCompleted;
...@@ -74,10 +73,10 @@ class ThreadPoolTask final : public RefCounted<ThreadPoolTask> { ...@@ -74,10 +73,10 @@ class ThreadPoolTask final : public RefCounted<ThreadPoolTask> {
// Called on main thread // Called on main thread
static ThreadPoolThread* SelectThread( static ThreadPoolThread* SelectThread(
const Vector<ThreadPoolTask*>& prerequisites, const HeapVector<Member<Task>>& prerequisites,
ThreadPoolThreadProvider*); ThreadPoolThreadProvider*);
ThreadPoolThread* GetScheduledThread() LOCKS_EXCLUDED(mutex_); ThreadPoolThread* GetScheduledThread() LOCKS_EXCLUDED(mutex_);
void RegisterDependencies(const Vector<ThreadPoolTask*>& prerequisites, void RegisterDependencies(const HeapVector<Member<Task>>& prerequisites,
const Vector<size_t>& prerequisite_indices) const Vector<size_t>& prerequisite_indices)
LOCKS_EXCLUDED(mutex_); LOCKS_EXCLUDED(mutex_);
void TaskCompleted(); void TaskCompleted();
...@@ -87,8 +86,8 @@ class ThreadPoolTask final : public RefCounted<ThreadPoolTask> { ...@@ -87,8 +86,8 @@ class ThreadPoolTask final : public RefCounted<ThreadPoolTask> {
const TaskType task_type_; const TaskType task_type_;
// Main thread only // Main thread only
scoped_refptr<ThreadPoolTask> self_keep_alive_; SelfKeepAlive<Task> self_keep_alive_;
Persistent<ScriptPromiseResolver> resolver_; Member<ScriptPromiseResolver> resolver_;
// Created in constructor on the main thread, consumed and cleared on // Created in constructor on the main thread, consumed and cleared on
// worker_thread_. Those steps can't overlap, so no mutex_ required. // worker_thread_. Those steps can't overlap, so no mutex_ required.
...@@ -120,38 +119,20 @@ class ThreadPoolTask final : public RefCounted<ThreadPoolTask> { ...@@ -120,38 +119,20 @@ class ThreadPoolTask final : public RefCounted<ThreadPoolTask> {
size_t prerequisites_remaining_ GUARDED_BY(mutex_) = 0u; size_t prerequisites_remaining_ GUARDED_BY(mutex_) = 0u;
// Elements added from main thread. Cleared on completion on worker_thread_. // Elements added from main thread. Cleared on completion on worker_thread_.
// Each element in dependents_ is not yet in the kCompleted state and // Each element in dependents_ is not yet in the kCompleted state.
// therefore is guaranteed to be alive. struct Dependent final : public GarbageCollected<Dependent> {
struct Dependent {
public: public:
Dependent(ThreadPoolTask* task, size_t index) : task(task), index(index) {} Dependent(Task* task, size_t index) : task(task), index(index) {
ThreadPoolTask* task; DCHECK(IsMainThread());
}
void Trace(Visitor* visitor) { visitor->Trace(task); }
Member<Task> task;
// The index in the dependent's argument array where this result should go.
size_t index; size_t index;
}; };
HashSet<std::unique_ptr<Dependent>> dependents_ GUARDED_BY(mutex_); Vector<CrossThreadPersistent<Dependent>> dependents_ GUARDED_BY(mutex_);
Mutex mutex_; Mutex mutex_;
base::WeakPtrFactory<ThreadPoolTask> weak_factory_;
};
// This is a thin, v8-exposed wrapper around ThreadPoolTask that allows
// ThreadPoolTask to avoid being GarbageCollected.
class Task : public ScriptWrappable {
DEFINE_WRAPPERTYPEINFO();
public:
explicit Task(ThreadPoolTask* thread_pool_task)
: thread_pool_task_(thread_pool_task) {}
~Task() override = default;
ScriptPromise result() { return thread_pool_task_->GetResult(); }
void cancel() { thread_pool_task_->Cancel(); }
ThreadPoolTask* GetThreadPoolTask() const { return thread_pool_task_.get(); }
private:
scoped_refptr<ThreadPoolTask> thread_pool_task_;
}; };
} // namespace blink } // namespace blink
......
...@@ -55,17 +55,14 @@ Task* TaskWorklet::postTask(ScriptState* script_state, ...@@ -55,17 +55,14 @@ Task* TaskWorklet::postTask(ScriptState* script_state,
// TODO(japhet): Here and below: it's unclear what task type should be used, // TODO(japhet): Here and below: it's unclear what task type should be used,
// and whether the API should allow it to be configured. Using kIdleTask as a // and whether the API should allow it to be configured. Using kIdleTask as a
// placeholder for now. // placeholder for now.
ThreadPoolTask* thread_pool_task = new ThreadPoolTask( return new Task(this, script_state, function, arguments, TaskType::kIdleTask);
this, script_state, function, arguments, TaskType::kIdleTask);
return new Task(thread_pool_task);
} }
Task* TaskWorklet::postTask(ScriptState* script_state, Task* TaskWorklet::postTask(ScriptState* script_state,
const String& function_name, const String& function_name,
const Vector<ScriptValue>& arguments) { const Vector<ScriptValue>& arguments) {
ThreadPoolTask* thread_pool_task = new ThreadPoolTask( return new Task(this, script_state, function_name, arguments,
this, script_state, function_name, arguments, TaskType::kIdleTask); TaskType::kIdleTask);
return new Task(thread_pool_task);
} }
ThreadPoolThread* TaskWorklet::GetLeastBusyThread() { ThreadPoolThread* TaskWorklet::GetLeastBusyThread() {
......
...@@ -39,19 +39,17 @@ WorkerTaskQueue::WorkerTaskQueue(Document* document, TaskType task_type) ...@@ -39,19 +39,17 @@ WorkerTaskQueue::WorkerTaskQueue(Document* document, TaskType task_type)
ScriptPromise WorkerTaskQueue::postFunction( ScriptPromise WorkerTaskQueue::postFunction(
ScriptState* script_state, ScriptState* script_state,
const ScriptValue& task, const ScriptValue& function,
AbortSignal* signal, AbortSignal* signal,
const Vector<ScriptValue>& arguments) { const Vector<ScriptValue>& arguments) {
DCHECK(document_->IsContextThread()); DCHECK(document_->IsContextThread());
DCHECK(task.IsFunction()); DCHECK(function.IsFunction());
ThreadPoolTask* thread_pool_task = new ThreadPoolTask( Task* task = new Task(ThreadPool::From(*document_), script_state, function,
ThreadPool::From(*document_), script_state, task, arguments, task_type_); arguments, task_type_);
if (signal) { if (signal)
signal->AddAlgorithm( signal->AddAlgorithm(WTF::Bind(&Task::cancel, WrapWeakPersistent(task)));
WTF::Bind(&ThreadPoolTask::Cancel, thread_pool_task->GetWeakPtr())); return task->result();
}
return thread_pool_task->GetResult();
} }
Task* WorkerTaskQueue::postTask(ScriptState* script_state, Task* WorkerTaskQueue::postTask(ScriptState* script_state,
...@@ -59,11 +57,8 @@ Task* WorkerTaskQueue::postTask(ScriptState* script_state, ...@@ -59,11 +57,8 @@ Task* WorkerTaskQueue::postTask(ScriptState* script_state,
const Vector<ScriptValue>& arguments) { const Vector<ScriptValue>& arguments) {
DCHECK(document_->IsContextThread()); DCHECK(document_->IsContextThread());
DCHECK(function.IsFunction()); DCHECK(function.IsFunction());
return new Task(ThreadPool::From(*document_), script_state, function,
ThreadPoolTask* thread_pool_task = arguments, task_type_);
new ThreadPoolTask(ThreadPool::From(*document_), script_state, function,
arguments, task_type_);
return new Task(thread_pool_task);
} }
void WorkerTaskQueue::Trace(blink::Visitor* visitor) { void WorkerTaskQueue::Trace(blink::Visitor* visitor) {
......
...@@ -30,12 +30,12 @@ class CORE_EXPORT WorkerTaskQueue : public ScriptWrappable { ...@@ -30,12 +30,12 @@ class CORE_EXPORT WorkerTaskQueue : public ScriptWrappable {
~WorkerTaskQueue() override = default; ~WorkerTaskQueue() override = default;
ScriptPromise postFunction(ScriptState*, ScriptPromise postFunction(ScriptState*,
const ScriptValue& task, const ScriptValue& function,
AbortSignal*, AbortSignal*,
const Vector<ScriptValue>& arguments); const Vector<ScriptValue>& arguments);
Task* postTask(ScriptState*, Task* postTask(ScriptState*,
const ScriptValue& task, const ScriptValue& function,
const Vector<ScriptValue>& arguments); const Vector<ScriptValue>& arguments);
void Trace(blink::Visitor*) override; void Trace(blink::Visitor*) override;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment