Commit bcab5d9d authored by Jesse McKenna's avatar Jesse McKenna Committed by Commit Bot

TaskScheduler: Delegate reinsertion of Sequences in SchedulerWorkerPool...

TaskScheduler: Delegate reinsertion of Sequences in SchedulerWorkerPool priority queue to TaskScheduler

Bug: 889029
Change-Id: I424024245e74c0c08d76a94263c94fad7a3631cc
Reviewed-on: https://chromium-review.googlesource.com/c/1271794
Commit-Queue: Jesse McKenna <jessemckenna@google.com>
Reviewed-by: default avatarFrançois Doray <fdoray@chromium.org>
Cr-Commit-Position: refs/heads/master@{#601615}
parent 54fcb574
...@@ -11,8 +11,11 @@ namespace internal { ...@@ -11,8 +11,11 @@ namespace internal {
PlatformNativeWorkerPoolWin::PlatformNativeWorkerPoolWin( PlatformNativeWorkerPoolWin::PlatformNativeWorkerPoolWin(
TrackedRef<TaskTracker> task_tracker, TrackedRef<TaskTracker> task_tracker,
DelayedTaskManager* delayed_task_manager) DelayedTaskManager* delayed_task_manager,
: SchedulerWorkerPool(task_tracker, delayed_task_manager) {} TrackedRef<Delegate> delegate)
: SchedulerWorkerPool(std::move(task_tracker),
delayed_task_manager,
std::move(delegate)) {}
PlatformNativeWorkerPoolWin::~PlatformNativeWorkerPoolWin() { PlatformNativeWorkerPoolWin::~PlatformNativeWorkerPoolWin() {
#if DCHECK_IS_ON() #if DCHECK_IS_ON()
...@@ -59,6 +62,11 @@ void PlatformNativeWorkerPoolWin::JoinForTesting() { ...@@ -59,6 +62,11 @@ void PlatformNativeWorkerPoolWin::JoinForTesting() {
#endif #endif
} }
void PlatformNativeWorkerPoolWin::ReEnqueueSequence(
scoped_refptr<Sequence> sequence) {
OnCanScheduleSequence(std::move(sequence));
}
// static // static
void CALLBACK PlatformNativeWorkerPoolWin::RunNextSequence( void CALLBACK PlatformNativeWorkerPoolWin::RunNextSequence(
PTP_CALLBACK_INSTANCE, PTP_CALLBACK_INSTANCE,
......
...@@ -30,7 +30,8 @@ namespace internal { ...@@ -30,7 +30,8 @@ namespace internal {
class BASE_EXPORT PlatformNativeWorkerPoolWin : public SchedulerWorkerPool { class BASE_EXPORT PlatformNativeWorkerPoolWin : public SchedulerWorkerPool {
public: public:
PlatformNativeWorkerPoolWin(TrackedRef<TaskTracker> task_tracker, PlatformNativeWorkerPoolWin(TrackedRef<TaskTracker> task_tracker,
DelayedTaskManager* delayed_task_manager); DelayedTaskManager* delayed_task_manager,
TrackedRef<Delegate> delegate);
// Destroying a PlatformNativeWorkerPoolWin is not allowed in // Destroying a PlatformNativeWorkerPoolWin is not allowed in
// production; it is always leaked. In tests, it can only be destroyed after // production; it is always leaked. In tests, it can only be destroyed after
...@@ -42,6 +43,7 @@ class BASE_EXPORT PlatformNativeWorkerPoolWin : public SchedulerWorkerPool { ...@@ -42,6 +43,7 @@ class BASE_EXPORT PlatformNativeWorkerPoolWin : public SchedulerWorkerPool {
// SchedulerWorkerPool: // SchedulerWorkerPool:
void JoinForTesting() override; void JoinForTesting() override;
void ReEnqueueSequence(scoped_refptr<Sequence> sequence) override;
private: private:
// Callback that gets run by |pool_|. It runs a task off the next sequence on // Callback that gets run by |pool_|. It runs a task off the next sequence on
......
...@@ -169,9 +169,11 @@ bool SchedulerWorkerPool::PostTaskWithSequence( ...@@ -169,9 +169,11 @@ bool SchedulerWorkerPool::PostTaskWithSequence(
SchedulerWorkerPool::SchedulerWorkerPool( SchedulerWorkerPool::SchedulerWorkerPool(
TrackedRef<TaskTracker> task_tracker, TrackedRef<TaskTracker> task_tracker,
DelayedTaskManager* delayed_task_manager) DelayedTaskManager* delayed_task_manager,
TrackedRef<Delegate> delegate)
: task_tracker_(std::move(task_tracker)), : task_tracker_(std::move(task_tracker)),
delayed_task_manager_(delayed_task_manager) { delayed_task_manager_(delayed_task_manager),
delegate_(std::move(delegate)) {
DCHECK(task_tracker_); DCHECK(task_tracker_);
DCHECK(delayed_task_manager_); DCHECK(delayed_task_manager_);
++g_active_pools_count; ++g_active_pools_count;
...@@ -192,6 +194,10 @@ void SchedulerWorkerPool::UnbindFromCurrentThread() { ...@@ -192,6 +194,10 @@ void SchedulerWorkerPool::UnbindFromCurrentThread() {
tls_current_worker_pool.Get().Set(nullptr); tls_current_worker_pool.Get().Set(nullptr);
} }
bool SchedulerWorkerPool::IsBoundToCurrentThread() const {
return GetCurrentWorkerPool() == this;
}
void SchedulerWorkerPool::PostTaskWithSequenceNow( void SchedulerWorkerPool::PostTaskWithSequenceNow(
Task task, Task task,
scoped_refptr<Sequence> sequence) { scoped_refptr<Sequence> sequence) {
......
...@@ -24,6 +24,17 @@ class TaskTracker; ...@@ -24,6 +24,17 @@ class TaskTracker;
// Interface for a worker pool. // Interface for a worker pool.
class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver { class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver {
public: public:
// Delegate interface for SchedulerWorkerPool.
class BASE_EXPORT Delegate {
public:
virtual ~Delegate() = default;
// Invoked when a |sequence| is non-empty after the SchedulerWorkerPool has
// run a task from it. The implementation must enqueue |sequence| in the
// appropriate priority queue, depending on |sequence| traits.
virtual void ReEnqueueSequence(scoped_refptr<Sequence> sequence) = 0;
};
~SchedulerWorkerPool() override; ~SchedulerWorkerPool() override;
// Returns a TaskRunner whose PostTask invocations result in scheduling tasks // Returns a TaskRunner whose PostTask invocations result in scheduling tasks
...@@ -49,6 +60,9 @@ class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver { ...@@ -49,6 +60,9 @@ class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver {
// Resets the worker pool in TLS. // Resets the worker pool in TLS.
void UnbindFromCurrentThread(); void UnbindFromCurrentThread();
// Returns true if the worker pool is registered in TLS.
bool IsBoundToCurrentThread() const;
// Prevents new tasks from starting to run and waits for currently running // Prevents new tasks from starting to run and waits for currently running
// tasks to complete their execution. It is guaranteed that no thread will do // tasks to complete their execution. It is guaranteed that no thread will do
// work on behalf of this SchedulerWorkerPool after this returns. It is // work on behalf of this SchedulerWorkerPool after this returns. It is
...@@ -57,9 +71,15 @@ class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver { ...@@ -57,9 +71,15 @@ class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver {
// task during JoinForTesting(). This can only be called once. // task during JoinForTesting(). This can only be called once.
virtual void JoinForTesting() = 0; virtual void JoinForTesting() = 0;
// Enqueues |sequence| in the worker pool's priority queue, then wakes up a
// worker if the worker pool is not bound to the current thread, i.e. if
// |sequence| is changing pools.
virtual void ReEnqueueSequence(scoped_refptr<Sequence> sequence) = 0;
protected: protected:
SchedulerWorkerPool(TrackedRef<TaskTracker> task_tracker, SchedulerWorkerPool(TrackedRef<TaskTracker> task_tracker,
DelayedTaskManager* delayed_task_manager); DelayedTaskManager* delayed_task_manager,
TrackedRef<Delegate> delegate);
// Posts |task| to be executed by this SchedulerWorkerPool as part of // Posts |task| to be executed by this SchedulerWorkerPool as part of
// |sequence|. This must only be called after |task| has gone through // |sequence|. This must only be called after |task| has gone through
...@@ -68,6 +88,7 @@ class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver { ...@@ -68,6 +88,7 @@ class BASE_EXPORT SchedulerWorkerPool : public CanScheduleSequenceObserver {
const TrackedRef<TaskTracker> task_tracker_; const TrackedRef<TaskTracker> task_tracker_;
DelayedTaskManager* const delayed_task_manager_; DelayedTaskManager* const delayed_task_manager_;
const TrackedRef<Delegate> delegate_;
private: private:
DISALLOW_COPY_AND_ASSIGN(SchedulerWorkerPool); DISALLOW_COPY_AND_ASSIGN(SchedulerWorkerPool);
......
...@@ -159,8 +159,11 @@ SchedulerWorkerPoolImpl::SchedulerWorkerPoolImpl( ...@@ -159,8 +159,11 @@ SchedulerWorkerPoolImpl::SchedulerWorkerPoolImpl(
StringPiece pool_label, StringPiece pool_label,
ThreadPriority priority_hint, ThreadPriority priority_hint,
TrackedRef<TaskTracker> task_tracker, TrackedRef<TaskTracker> task_tracker,
DelayedTaskManager* delayed_task_manager) DelayedTaskManager* delayed_task_manager,
: SchedulerWorkerPool(std::move(task_tracker), delayed_task_manager), TrackedRef<Delegate> delegate)
: SchedulerWorkerPool(std::move(task_tracker),
delayed_task_manager,
std::move(delegate)),
pool_label_(pool_label.as_string()), pool_label_(pool_label.as_string()),
priority_hint_(priority_hint), priority_hint_(priority_hint),
lock_(shared_priority_queue_.container_lock()), lock_(shared_priority_queue_.container_lock()),
...@@ -272,11 +275,16 @@ SchedulerWorkerPoolImpl::~SchedulerWorkerPoolImpl() { ...@@ -272,11 +275,16 @@ SchedulerWorkerPoolImpl::~SchedulerWorkerPoolImpl() {
void SchedulerWorkerPoolImpl::OnCanScheduleSequence( void SchedulerWorkerPoolImpl::OnCanScheduleSequence(
scoped_refptr<Sequence> sequence) { scoped_refptr<Sequence> sequence) {
PushSequenceToPriorityQueue(std::move(sequence));
WakeUpOneWorker();
}
void SchedulerWorkerPoolImpl::PushSequenceToPriorityQueue(
scoped_refptr<Sequence> sequence) {
DCHECK(sequence);
const auto sequence_sort_key = sequence->GetSortKey(); const auto sequence_sort_key = sequence->GetSortKey();
shared_priority_queue_.BeginTransaction()->Push(std::move(sequence), shared_priority_queue_.BeginTransaction()->Push(std::move(sequence),
sequence_sort_key); sequence_sort_key);
WakeUpOneWorker();
} }
void SchedulerWorkerPoolImpl::GetHistograms( void SchedulerWorkerPoolImpl::GetHistograms(
...@@ -355,6 +363,13 @@ void SchedulerWorkerPoolImpl::JoinForTesting() { ...@@ -355,6 +363,13 @@ void SchedulerWorkerPoolImpl::JoinForTesting() {
workers_.clear(); workers_.clear();
} }
void SchedulerWorkerPoolImpl::ReEnqueueSequence(
scoped_refptr<Sequence> sequence) {
PushSequenceToPriorityQueue(std::move(sequence));
if (!IsBoundToCurrentThread())
WakeUpOneWorker();
}
size_t SchedulerWorkerPoolImpl::NumberOfWorkersForTesting() const { size_t SchedulerWorkerPoolImpl::NumberOfWorkersForTesting() const {
AutoSchedulerLock auto_lock(lock_); AutoSchedulerLock auto_lock(lock_);
return workers_.size(); return workers_.size();
...@@ -543,14 +558,7 @@ void SchedulerWorkerPoolImpl::SchedulerWorkerDelegateImpl::DidRunTask() { ...@@ -543,14 +558,7 @@ void SchedulerWorkerPoolImpl::SchedulerWorkerDelegateImpl::DidRunTask() {
void SchedulerWorkerPoolImpl::SchedulerWorkerDelegateImpl::ReEnqueueSequence( void SchedulerWorkerPoolImpl::SchedulerWorkerDelegateImpl::ReEnqueueSequence(
scoped_refptr<Sequence> sequence) { scoped_refptr<Sequence> sequence) {
DCHECK_CALLED_ON_VALID_THREAD(worker_thread_checker_); outer_->delegate_->ReEnqueueSequence(std::move(sequence));
const SequenceSortKey sequence_sort_key = sequence->GetSortKey();
outer_->shared_priority_queue_.BeginTransaction()->Push(std::move(sequence),
sequence_sort_key);
// This worker will soon call GetWork(). Therefore, there is no need to wake
// up a worker to run the sequence that was just inserted into
// |outer_->shared_priority_queue_|.
} }
TimeDelta TimeDelta
......
...@@ -73,7 +73,8 @@ class BASE_EXPORT SchedulerWorkerPoolImpl : public SchedulerWorkerPool { ...@@ -73,7 +73,8 @@ class BASE_EXPORT SchedulerWorkerPoolImpl : public SchedulerWorkerPool {
StringPiece pool_label, StringPiece pool_label,
ThreadPriority priority_hint, ThreadPriority priority_hint,
TrackedRef<TaskTracker> task_tracker, TrackedRef<TaskTracker> task_tracker,
DelayedTaskManager* delayed_task_manager); DelayedTaskManager* delayed_task_manager,
TrackedRef<Delegate> delegate);
// Creates workers following the |params| specification, allowing existing and // Creates workers following the |params| specification, allowing existing and
// future tasks to run. The pool will run at most |max_best_effort_tasks| // future tasks to run. The pool will run at most |max_best_effort_tasks|
...@@ -97,6 +98,7 @@ class BASE_EXPORT SchedulerWorkerPoolImpl : public SchedulerWorkerPool { ...@@ -97,6 +98,7 @@ class BASE_EXPORT SchedulerWorkerPoolImpl : public SchedulerWorkerPool {
// SchedulerWorkerPool: // SchedulerWorkerPool:
void JoinForTesting() override; void JoinForTesting() override;
void ReEnqueueSequence(scoped_refptr<Sequence> sequence) override;
const HistogramBase* num_tasks_before_detach_histogram() const { const HistogramBase* num_tasks_before_detach_histogram() const {
return num_tasks_before_detach_histogram_; return num_tasks_before_detach_histogram_;
...@@ -165,6 +167,9 @@ class BASE_EXPORT SchedulerWorkerPoolImpl : public SchedulerWorkerPool { ...@@ -165,6 +167,9 @@ class BASE_EXPORT SchedulerWorkerPoolImpl : public SchedulerWorkerPool {
// SchedulerWorkerPool: // SchedulerWorkerPool:
void OnCanScheduleSequence(scoped_refptr<Sequence> sequence) override; void OnCanScheduleSequence(scoped_refptr<Sequence> sequence) override;
// Pushes |sequence| to |shared_priority_queue_|.
void PushSequenceToPriorityQueue(scoped_refptr<Sequence> sequence);
// Waits until at least |n| workers are idle. |lock_| must be held to call // Waits until at least |n| workers are idle. |lock_| must be held to call
// this function. // this function.
void WaitForWorkersIdleLockRequiredForTesting(size_t n); void WaitForWorkersIdleLockRequiredForTesting(size_t n);
......
...@@ -75,10 +75,12 @@ void WaitWithoutBlockingObserver(WaitableEvent* event) { ...@@ -75,10 +75,12 @@ void WaitWithoutBlockingObserver(WaitableEvent* event) {
event->Wait(); event->Wait();
} }
class TaskSchedulerWorkerPoolImplTestBase { class TaskSchedulerWorkerPoolImplTestBase
: public SchedulerWorkerPool::Delegate {
protected: protected:
TaskSchedulerWorkerPoolImplTestBase() TaskSchedulerWorkerPoolImplTestBase()
: service_thread_("TaskSchedulerServiceThread"){}; : service_thread_("TaskSchedulerServiceThread"),
tracked_ref_factory_(this){};
void CommonSetUp(TimeDelta suggested_reclaim_time = TimeDelta::Max()) { void CommonSetUp(TimeDelta suggested_reclaim_time = TimeDelta::Max()) {
CreateAndStartWorkerPool(suggested_reclaim_time, kMaxTasks); CreateAndStartWorkerPool(suggested_reclaim_time, kMaxTasks);
...@@ -89,6 +91,7 @@ class TaskSchedulerWorkerPoolImplTestBase { ...@@ -89,6 +91,7 @@ class TaskSchedulerWorkerPoolImplTestBase {
task_tracker_.FlushForTesting(); task_tracker_.FlushForTesting();
if (worker_pool_) if (worker_pool_)
worker_pool_->JoinForTesting(); worker_pool_->JoinForTesting();
worker_pool_.reset();
} }
void CreateWorkerPool() { void CreateWorkerPool() {
...@@ -97,7 +100,8 @@ class TaskSchedulerWorkerPoolImplTestBase { ...@@ -97,7 +100,8 @@ class TaskSchedulerWorkerPoolImplTestBase {
delayed_task_manager_.Start(service_thread_.task_runner()); delayed_task_manager_.Start(service_thread_.task_runner());
worker_pool_ = std::make_unique<SchedulerWorkerPoolImpl>( worker_pool_ = std::make_unique<SchedulerWorkerPoolImpl>(
"TestWorkerPool", "A", ThreadPriority::NORMAL, "TestWorkerPool", "A", ThreadPriority::NORMAL,
task_tracker_.GetTrackedRef(), &delayed_task_manager_); task_tracker_.GetTrackedRef(), &delayed_task_manager_,
tracked_ref_factory_.GetTrackedRef());
ASSERT_TRUE(worker_pool_); ASSERT_TRUE(worker_pool_);
} }
...@@ -118,11 +122,15 @@ class TaskSchedulerWorkerPoolImplTestBase { ...@@ -118,11 +122,15 @@ class TaskSchedulerWorkerPoolImplTestBase {
Thread service_thread_; Thread service_thread_;
TaskTracker task_tracker_ = {"Test"}; TaskTracker task_tracker_ = {"Test"};
std::unique_ptr<SchedulerWorkerPoolImpl> worker_pool_; std::unique_ptr<SchedulerWorkerPoolImpl> worker_pool_;
DelayedTaskManager delayed_task_manager_;
TrackedRefFactory<SchedulerWorkerPool::Delegate> tracked_ref_factory_;
private: private:
DelayedTaskManager delayed_task_manager_; // SchedulerWorkerPool::Delegate:
void ReEnqueueSequence(scoped_refptr<Sequence> sequence) override {
worker_pool_->ReEnqueueSequence(std::move(sequence));
}
DISALLOW_COPY_AND_ASSIGN(TaskSchedulerWorkerPoolImplTestBase); DISALLOW_COPY_AND_ASSIGN(TaskSchedulerWorkerPoolImplTestBase);
}; };
...@@ -1314,28 +1322,43 @@ TEST_F(TaskSchedulerWorkerPoolBlockingTest, ...@@ -1314,28 +1322,43 @@ TEST_F(TaskSchedulerWorkerPoolBlockingTest,
EXPECT_EQ(worker_pool_->GetMaxTasksForTesting(), kMaxTasks); EXPECT_EQ(worker_pool_->GetMaxTasksForTesting(), kMaxTasks);
} }
// Verify that workers that become idle due to the pool being over capacity will class TaskSchedulerWorkerPoolOverCapacityTest
// eventually cleanup. : public TaskSchedulerWorkerPoolImplTestBase,
TEST(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) { public testing::Test {
constexpr size_t kLocalMaxTasks = 3; public:
TaskSchedulerWorkerPoolOverCapacityTest() = default;
TaskTracker task_tracker("Test");
DelayedTaskManager delayed_task_manager;
scoped_refptr<TaskRunner> service_thread_task_runner =
MakeRefCounted<TestSimpleTaskRunner>();
delayed_task_manager.Start(service_thread_task_runner);
SchedulerWorkerPoolImpl worker_pool(
"OverCapacityTestWorkerPool", "A", ThreadPriority::NORMAL,
task_tracker.GetTrackedRef(), &delayed_task_manager);
worker_pool.Start(
SchedulerWorkerPoolParams(kLocalMaxTasks, kReclaimTimeForCleanupTests),
kLocalMaxTasks, service_thread_task_runner, nullptr,
SchedulerWorkerPoolImpl::WorkerEnvironment::NONE);
scoped_refptr<TaskRunner> task_runner = void SetUp() override {
worker_pool.CreateTaskRunnerWithTraits( CreateAndStartWorkerPool(kReclaimTimeForCleanupTests, kLocalMaxTasks);
{MayBlock(), WithBaseSyncPrimitives()}); task_runner_ = worker_pool_->CreateTaskRunnerWithTraits(
{MayBlock(), WithBaseSyncPrimitives()});
}
void TearDown() override {
TaskSchedulerWorkerPoolImplTestBase::CommonTearDown();
}
protected:
scoped_refptr<TaskRunner> task_runner_;
static constexpr size_t kLocalMaxTasks = 3;
void CreateWorkerPool() {
ASSERT_FALSE(worker_pool_);
service_thread_.Start();
delayed_task_manager_.Start(service_thread_.task_runner());
worker_pool_ = std::make_unique<SchedulerWorkerPoolImpl>(
"OverCapacityTestWorkerPool", "A", ThreadPriority::NORMAL,
task_tracker_.GetTrackedRef(), &delayed_task_manager_,
tracked_ref_factory_.GetTrackedRef());
ASSERT_TRUE(worker_pool_);
}
DISALLOW_COPY_AND_ASSIGN(TaskSchedulerWorkerPoolOverCapacityTest);
};
// Verify that workers that become idle due to the pool being over capacity will
// eventually cleanup.
TEST_F(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) {
WaitableEvent threads_running; WaitableEvent threads_running;
WaitableEvent threads_continue; WaitableEvent threads_continue;
RepeatingClosure threads_running_barrier = BarrierClosure( RepeatingClosure threads_running_barrier = BarrierClosure(
...@@ -1357,7 +1380,7 @@ TEST(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) { ...@@ -1357,7 +1380,7 @@ TEST(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) {
Unretained(&blocked_call_continue)); Unretained(&blocked_call_continue));
for (size_t i = 0; i < kLocalMaxTasks; ++i) for (size_t i = 0; i < kLocalMaxTasks; ++i)
task_runner->PostTask(FROM_HERE, closure); task_runner_->PostTask(FROM_HERE, closure);
threads_running.Wait(); threads_running.Wait();
...@@ -1369,7 +1392,7 @@ TEST(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) { ...@@ -1369,7 +1392,7 @@ TEST(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) {
BindOnce(&WaitableEvent::Signal, Unretained(&extra_threads_running))); BindOnce(&WaitableEvent::Signal, Unretained(&extra_threads_running)));
// These tasks should run on the new threads from increasing max tasks. // These tasks should run on the new threads from increasing max tasks.
for (size_t i = 0; i < kLocalMaxTasks; ++i) { for (size_t i = 0; i < kLocalMaxTasks; ++i) {
task_runner->PostTask( task_runner_->PostTask(
FROM_HERE, BindOnce( FROM_HERE, BindOnce(
[](Closure* extra_threads_running_barrier, [](Closure* extra_threads_running_barrier,
WaitableEvent* extra_threads_continue) { WaitableEvent* extra_threads_continue) {
...@@ -1381,26 +1404,25 @@ TEST(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) { ...@@ -1381,26 +1404,25 @@ TEST(TaskSchedulerWorkerPoolOverCapacityTest, VerifyCleanup) {
} }
extra_threads_running.Wait(); extra_threads_running.Wait();
ASSERT_EQ(kLocalMaxTasks * 2, worker_pool.NumberOfWorkersForTesting()); ASSERT_EQ(kLocalMaxTasks * 2, worker_pool_->NumberOfWorkersForTesting());
EXPECT_EQ(kLocalMaxTasks * 2, worker_pool.GetMaxTasksForTesting()); EXPECT_EQ(kLocalMaxTasks * 2, worker_pool_->GetMaxTasksForTesting());
blocked_call_continue.Signal(); blocked_call_continue.Signal();
extra_threads_continue.Signal(); extra_threads_continue.Signal();
// Periodically post tasks to ensure that posting tasks does not prevent // Periodically post tasks to ensure that posting tasks does not prevent
// workers that are idle due to the pool being over capacity from cleaning up. // workers that are idle due to the pool being over capacity from cleaning up.
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
task_runner->PostDelayedTask(FROM_HERE, DoNothing(), task_runner_->PostDelayedTask(FROM_HERE, DoNothing(),
kReclaimTimeForCleanupTests * i * 0.5); kReclaimTimeForCleanupTests * i * 0.5);
} }
// Note: one worker above capacity will not get cleaned up since it's on the // Note: one worker above capacity will not get cleaned up since it's on the
// top of the idle stack. // top of the idle stack.
worker_pool.WaitForWorkersCleanedUpForTesting(kLocalMaxTasks - 1); worker_pool_->WaitForWorkersCleanedUpForTesting(kLocalMaxTasks - 1);
EXPECT_EQ(kLocalMaxTasks + 1, worker_pool.NumberOfWorkersForTesting()); EXPECT_EQ(kLocalMaxTasks + 1, worker_pool_->NumberOfWorkersForTesting());
threads_continue.Signal(); threads_continue.Signal();
task_tracker_.FlushForTesting();
worker_pool.JoinForTesting();
} }
// Verify that the maximum number of workers is 256 and that hitting the max // Verify that the maximum number of workers is 256 and that hitting the max
......
...@@ -91,10 +91,12 @@ class ThreadPostingTasks : public SimpleThread { ...@@ -91,10 +91,12 @@ class ThreadPostingTasks : public SimpleThread {
}; };
class TaskSchedulerWorkerPoolTest class TaskSchedulerWorkerPoolTest
: public testing::TestWithParam<PoolExecutionType> { : public testing::TestWithParam<PoolExecutionType>,
public SchedulerWorkerPool::Delegate {
protected: protected:
TaskSchedulerWorkerPoolTest() TaskSchedulerWorkerPoolTest()
: service_thread_("TaskSchedulerServiceThread") {} : service_thread_("TaskSchedulerServiceThread"),
tracked_ref_factory_(this) {}
void SetUp() override { void SetUp() override {
service_thread_.Start(); service_thread_.Start();
...@@ -106,6 +108,7 @@ class TaskSchedulerWorkerPoolTest ...@@ -106,6 +108,7 @@ class TaskSchedulerWorkerPoolTest
service_thread_.Stop(); service_thread_.Stop();
if (worker_pool_) if (worker_pool_)
worker_pool_->JoinForTesting(); worker_pool_->JoinForTesting();
worker_pool_.reset();
} }
void CreateWorkerPool() { void CreateWorkerPool() {
...@@ -114,12 +117,14 @@ class TaskSchedulerWorkerPoolTest ...@@ -114,12 +117,14 @@ class TaskSchedulerWorkerPoolTest
case PoolType::GENERIC: case PoolType::GENERIC:
worker_pool_ = std::make_unique<SchedulerWorkerPoolImpl>( worker_pool_ = std::make_unique<SchedulerWorkerPoolImpl>(
"TestWorkerPool", "A", ThreadPriority::NORMAL, "TestWorkerPool", "A", ThreadPriority::NORMAL,
task_tracker_.GetTrackedRef(), &delayed_task_manager_); task_tracker_.GetTrackedRef(), &delayed_task_manager_,
tracked_ref_factory_.GetTrackedRef());
break; break;
#if defined(OS_WIN) #if defined(OS_WIN)
case PoolType::WINDOWS: case PoolType::WINDOWS:
worker_pool_ = std::make_unique<PlatformNativeWorkerPoolWin>( worker_pool_ = std::make_unique<PlatformNativeWorkerPoolWin>(
task_tracker_.GetTrackedRef(), &delayed_task_manager_); task_tracker_.GetTrackedRef(), &delayed_task_manager_,
tracked_ref_factory_.GetTrackedRef());
break; break;
#endif #endif
} }
...@@ -156,6 +161,13 @@ class TaskSchedulerWorkerPoolTest ...@@ -156,6 +161,13 @@ class TaskSchedulerWorkerPoolTest
std::unique_ptr<SchedulerWorkerPool> worker_pool_; std::unique_ptr<SchedulerWorkerPool> worker_pool_;
private: private:
// SchedulerWorkerPool::Delegate:
void ReEnqueueSequence(scoped_refptr<Sequence> sequence) override {
worker_pool_->ReEnqueueSequence(std::move(sequence));
}
TrackedRefFactory<SchedulerWorkerPool::Delegate> tracked_ref_factory_;
DISALLOW_COPY_AND_ASSIGN(TaskSchedulerWorkerPoolTest); DISALLOW_COPY_AND_ASSIGN(TaskSchedulerWorkerPoolTest);
}; };
......
...@@ -63,7 +63,8 @@ TaskSchedulerImpl::TaskSchedulerImpl( ...@@ -63,7 +63,8 @@ TaskSchedulerImpl::TaskSchedulerImpl(
BindRepeating(&TaskSchedulerImpl::ReportHeartbeatMetrics, BindRepeating(&TaskSchedulerImpl::ReportHeartbeatMetrics,
Unretained(this)))), Unretained(this)))),
single_thread_task_runner_manager_(task_tracker_->GetTrackedRef(), single_thread_task_runner_manager_(task_tracker_->GetTrackedRef(),
&delayed_task_manager_) { &delayed_task_manager_),
tracked_ref_factory_(this) {
DCHECK(!histogram_label.empty()); DCHECK(!histogram_label.empty());
static_assert(arraysize(environment_to_worker_pool_) == ENVIRONMENT_COUNT, static_assert(arraysize(environment_to_worker_pool_) == ENVIRONMENT_COUNT,
...@@ -84,7 +85,8 @@ TaskSchedulerImpl::TaskSchedulerImpl( ...@@ -84,7 +85,8 @@ TaskSchedulerImpl::TaskSchedulerImpl(
"."), "."),
kEnvironmentParams[environment_type].name_suffix, kEnvironmentParams[environment_type].name_suffix,
kEnvironmentParams[environment_type].priority_hint, kEnvironmentParams[environment_type].priority_hint,
task_tracker_->GetTrackedRef(), &delayed_task_manager_)); task_tracker_->GetTrackedRef(), &delayed_task_manager_,
tracked_ref_factory_.GetTrackedRef()));
} }
// Map environment indexes to pools. |kMergeBlockingNonBlockingPools| is // Map environment indexes to pools. |kMergeBlockingNonBlockingPools| is
...@@ -107,6 +109,9 @@ TaskSchedulerImpl::~TaskSchedulerImpl() { ...@@ -107,6 +109,9 @@ TaskSchedulerImpl::~TaskSchedulerImpl() {
#if DCHECK_IS_ON() #if DCHECK_IS_ON()
DCHECK(join_for_testing_returned_.IsSet()); DCHECK(join_for_testing_returned_.IsSet());
#endif #endif
// Clear |worker_pools_| to release held TrackedRefs, which block teardown.
worker_pools_.clear();
} }
void TaskSchedulerImpl::Start( void TaskSchedulerImpl::Start(
...@@ -312,6 +317,13 @@ void TaskSchedulerImpl::SetExecutionFenceEnabled(bool execution_fence_enabled) { ...@@ -312,6 +317,13 @@ void TaskSchedulerImpl::SetExecutionFenceEnabled(bool execution_fence_enabled) {
task_tracker_->SetExecutionFenceEnabled(execution_fence_enabled); task_tracker_->SetExecutionFenceEnabled(execution_fence_enabled);
} }
void TaskSchedulerImpl::ReEnqueueSequence(scoped_refptr<Sequence> sequence) {
DCHECK(sequence);
const TaskTraits new_traits =
SetUserBlockingPriorityIfNeeded(sequence->traits());
GetWorkerPoolForTraits(new_traits)->ReEnqueueSequence(std::move(sequence));
}
SchedulerWorkerPoolImpl* TaskSchedulerImpl::GetWorkerPoolForTraits( SchedulerWorkerPoolImpl* TaskSchedulerImpl::GetWorkerPoolForTraits(
const TaskTraits& traits) const { const TaskTraits& traits) const {
return environment_to_worker_pool_[GetEnvironmentIndexForTraits(traits)]; return environment_to_worker_pool_[GetEnvironmentIndexForTraits(traits)];
......
...@@ -45,7 +45,8 @@ namespace internal { ...@@ -45,7 +45,8 @@ namespace internal {
extern const BASE_EXPORT base::Feature kMergeBlockingNonBlockingPools; extern const BASE_EXPORT base::Feature kMergeBlockingNonBlockingPools;
// Default TaskScheduler implementation. This class is thread-safe. // Default TaskScheduler implementation. This class is thread-safe.
class BASE_EXPORT TaskSchedulerImpl : public TaskScheduler { class BASE_EXPORT TaskSchedulerImpl : public TaskScheduler,
public SchedulerWorkerPool::Delegate {
public: public:
using TaskTrackerImpl = using TaskTrackerImpl =
#if defined(OS_POSIX) && !defined(OS_NACL_SFI) #if defined(OS_POSIX) && !defined(OS_NACL_SFI)
...@@ -105,6 +106,9 @@ class BASE_EXPORT TaskSchedulerImpl : public TaskScheduler { ...@@ -105,6 +106,9 @@ class BASE_EXPORT TaskSchedulerImpl : public TaskScheduler {
void ReportHeartbeatMetrics() const; void ReportHeartbeatMetrics() const;
// SchedulerWorkerPool::Delegate:
void ReEnqueueSequence(scoped_refptr<Sequence> sequence) override;
const std::unique_ptr<TaskTrackerImpl> task_tracker_; const std::unique_ptr<TaskTrackerImpl> task_tracker_;
std::unique_ptr<Thread> service_thread_; std::unique_ptr<Thread> service_thread_;
DelayedTaskManager delayed_task_manager_; DelayedTaskManager delayed_task_manager_;
...@@ -135,6 +139,8 @@ class BASE_EXPORT TaskSchedulerImpl : public TaskScheduler { ...@@ -135,6 +139,8 @@ class BASE_EXPORT TaskSchedulerImpl : public TaskScheduler {
base::win::ComInitCheckHook com_init_check_hook_; base::win::ComInitCheckHook com_init_check_hook_;
#endif #endif
TrackedRefFactory<SchedulerWorkerPool::Delegate> tracked_ref_factory_;
DISALLOW_COPY_AND_ASSIGN(TaskSchedulerImpl); DISALLOW_COPY_AND_ASSIGN(TaskSchedulerImpl);
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment