Commit 63fbbb92 authored by Etienne Pierre-doray's avatar Etienne Pierre-doray Committed by Commit Bot

[ThreadPool]: Implement Job Delegate's ShouldYield and concurrency increase.

This CL partially implement job's delegate, as well as concurrency
usage assertion:
- ShouldYield must be called in worker task.
- max concurrency is expected to decrease unless
  NotifyConcurrencyIncrease() is called.

To implement ShouldYield, PooledTaskRunnerDelegate is used
to communicate with ThreadPool and a racy priority is added to
task source.

Bug: 839091
Change-Id: I3504f00ec48ab600f79b64e44151bc9dded408a8
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1713146
Commit-Queue: Etienne Pierre-Doray <etiennep@chromium.org>
Reviewed-by: default avatarFrançois Doray <fdoray@chromium.org>
Reviewed-by: default avatarGabriel Charette <gab@chromium.org>
Cr-Commit-Position: refs/heads/master@{#690127}
parent 0d40333a
......@@ -5,16 +5,49 @@
#include "base/task/post_job.h"
#include "base/task/thread_pool/job_task_source.h"
#include "base/task/thread_pool/pooled_task_runner_delegate.h"
namespace base {
namespace experimental {
JobDelegate::JobDelegate(internal::JobTaskSource* task_source)
: task_source_(task_source) {}
JobDelegate::JobDelegate(
internal::JobTaskSource* task_source,
internal::PooledTaskRunnerDelegate* pooled_task_runner_delegate)
: task_source_(task_source),
pooled_task_runner_delegate_(pooled_task_runner_delegate) {
DCHECK(task_source_);
DCHECK(pooled_task_runner_delegate_);
#if DCHECK_IS_ON()
recorded_increase_version_ = task_source_->GetConcurrencyIncreaseVersion();
// Record max concurrency before running the worker task.
recorded_max_concurrency_ = task_source_->GetMaxConcurrency();
#endif // DCHECK_IS_ON()
}
JobDelegate::~JobDelegate() {
#if DCHECK_IS_ON()
// When ShouldYield() returns false, the worker task is expected to do
// work before returning.
size_t expected_max_concurrency = recorded_max_concurrency_;
if (!last_should_yield_ && expected_max_concurrency > 0)
--expected_max_concurrency;
AssertExpectedConcurrency(expected_max_concurrency);
#endif // DCHECK_IS_ON()
}
bool JobDelegate::ShouldYield() {
// TODO(crbug.com/839091): Implement this.
return false;
#if DCHECK_IS_ON()
// ShouldYield() shouldn't be called again after returning true.
DCHECK(!last_should_yield_);
AssertExpectedConcurrency(recorded_max_concurrency_);
#endif // DCHECK_IS_ON()
const bool should_yield =
pooled_task_runner_delegate_->ShouldYield(task_source_);
#if DCHECK_IS_ON()
last_should_yield_ = should_yield;
#endif // DCHECK_IS_ON()
return should_yield;
}
void JobDelegate::YieldIfNeeded() {
......@@ -25,5 +58,45 @@ void JobDelegate::NotifyConcurrencyIncrease() {
task_source_->NotifyConcurrencyIncrease();
}
void JobDelegate::AssertExpectedConcurrency(size_t expected_max_concurrency) {
// In dcheck builds, verify that max concurrency falls in one of the following
// cases:
// 1) max concurrency behaves normally and is below or equals the expected
// value.
// 2) max concurrency increased above the expected value, which implies
// there are new work items that the associated worker task didn't see and
// NotifyConcurrencyIncrease() should be called to adjust the number of
// worker.
// a) NotifyConcurrencyIncrease() was already called and the recorded
// concurrency version is out of date, i.e. less than the actual version.
// b) NotifyConcurrencyIncrease() has not yet been called, in which case the
// function waits for an imminent increase of the concurrency version.
// This prevent ill-formed GetMaxConcurrency() implementations that:
// - Don't decrease with the number of remaining work items.
// - Don't return an up-to-date value.
#if DCHECK_IS_ON()
// Case 1:
const size_t max_concurrency = task_source_->GetMaxConcurrency();
if (max_concurrency <= expected_max_concurrency)
return;
// Case 2a:
const size_t actual_version = task_source_->GetConcurrencyIncreaseVersion();
DCHECK_LE(recorded_increase_version_, actual_version);
if (recorded_increase_version_ < actual_version)
return;
// Case 2b:
const bool updated = task_source_->WaitForConcurrencyIncreaseUpdate(
recorded_increase_version_);
DCHECK(updated)
<< "Value returned by |max_concurrency_callback| is expected to "
"decrease, unless NotifyConcurrencyIncrease() is called.";
recorded_increase_version_ = task_source_->GetConcurrencyIncreaseVersion();
recorded_max_concurrency_ = task_source_->GetMaxConcurrency();
#endif // DCHECK_IS_ON()
}
} // namespace experimental
} // namespace base
\ No newline at end of file
......@@ -6,12 +6,14 @@
#define BASE_TASK_POST_JOB_H_
#include "base/base_export.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/time/time.h"
namespace base {
namespace internal {
class JobTaskSource;
class PooledTaskRunnerDelegate;
}
namespace experimental {
......@@ -19,7 +21,13 @@ namespace experimental {
// communicate with the scheduler.
class BASE_EXPORT JobDelegate {
public:
explicit JobDelegate(internal::JobTaskSource* task_source);
// A JobDelegate is instantiated for each worker task that is run.
// |task_source| is the task source whose worker task is running with this
// delegate and |pooled_task_runner_delegate| provides communication with the
// thread pool.
JobDelegate(internal::JobTaskSource* task_source,
internal::PooledTaskRunnerDelegate* pooled_task_runner_delegate);
~JobDelegate();
// Returns true if this thread should return from the worker task on the
// current thread ASAP. Workers should periodically invoke ShouldYield (or
......@@ -38,7 +46,25 @@ class BASE_EXPORT JobDelegate {
void NotifyConcurrencyIncrease();
private:
// Verifies that either max concurrency is lower or equal to
// |expected_max_concurrency|, or there is an increase version update
// triggered by NotifyConcurrencyIncrease().
void AssertExpectedConcurrency(size_t expected_max_concurrency);
internal::JobTaskSource* const task_source_;
internal::PooledTaskRunnerDelegate* const pooled_task_runner_delegate_;
#if DCHECK_IS_ON()
// Used in AssertExpectedConcurrency(), see that method's impl for details.
// Value of max concurrency recorded before running the worker task.
size_t recorded_max_concurrency_;
// Value of the increase version recorded before running the worker task.
size_t recorded_increase_version_;
// Value returned by the last call to ShouldYield().
bool last_should_yield_ = false;
#endif
DISALLOW_COPY_AND_ASSIGN(JobDelegate);
};
} // namespace experimental
......
......@@ -11,7 +11,9 @@
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/task/task_features.h"
#include "base/task/thread_pool/pooled_task_runner_delegate.h"
#include "base/time/time.h"
#include "base/time/time_override.h"
namespace base {
namespace internal {
......@@ -20,7 +22,8 @@ JobTaskSource::JobTaskSource(
const Location& from_here,
const TaskTraits& traits,
RepeatingCallback<void(experimental::JobDelegate*)> worker_task,
RepeatingCallback<size_t()> max_concurrency_callback)
RepeatingCallback<size_t()> max_concurrency_callback,
PooledTaskRunnerDelegate* delegate)
: TaskSource(traits, nullptr, TaskSourceExecutionMode::kJob),
from_here_(from_here),
max_concurrency_callback_(std::move(max_concurrency_callback)),
......@@ -29,14 +32,15 @@ JobTaskSource::JobTaskSource(
const RepeatingCallback<void(experimental::JobDelegate*)>&
worker_task) {
// Each worker task has its own delegate with associated state.
// TODO(crbug.com/839091): Implement assertions on max concurrency
// increase in the delegate.
experimental::JobDelegate job_delegate{self};
experimental::JobDelegate job_delegate{self, self->delegate_};
worker_task.Run(&job_delegate);
},
base::Unretained(this),
std::move(worker_task))),
queue_time_(TimeTicks::Now()) {}
queue_time_(TimeTicks::Now()),
delegate_(delegate) {
DCHECK(delegate_);
}
JobTaskSource::~JobTaskSource() {
#if DCHECK_IS_ON()
......@@ -100,16 +104,49 @@ size_t JobTaskSource::GetRemainingConcurrency() const {
}
void JobTaskSource::NotifyConcurrencyIncrease() {
// TODO(839091): Implement this.
#if DCHECK_IS_ON()
{
AutoLock auto_lock(version_lock_);
++increase_version_;
version_condition_.Broadcast();
}
#endif // DCHECK_IS_ON()
// Make sure the task source is in the queue if not already.
// Caveat: it's possible but unlikely that the task source has already reached
// its intended concurrency and doesn't need to be enqueued if there
// previously were too many worker. For simplicity, the task source is always
// enqueued and will get discarded if already saturated when it is popped from
// the priority queue.
delegate_->EnqueueJobTaskSource(this);
}
size_t JobTaskSource::GetMaxConcurrency() const {
return max_concurrency_callback_.Run();
}
#if DCHECK_IS_ON()
size_t JobTaskSource::GetConcurrencyIncreaseVersion() const {
AutoLock auto_lock(version_lock_);
return increase_version_;
}
bool JobTaskSource::WaitForConcurrencyIncreaseUpdate(size_t recorded_version) {
AutoLock auto_lock(version_lock_);
constexpr TimeDelta timeout = TimeDelta::FromSeconds(1);
const base::TimeTicks start_time = subtle::TimeTicksNowIgnoringOverride();
do {
DCHECK_LE(recorded_version, increase_version_);
if (recorded_version != increase_version_)
return true;
version_condition_.TimedWait(timeout);
} while (subtle::TimeTicksNowIgnoringOverride() - start_time < timeout);
return false;
}
#endif // DCHECK_IS_ON()
Optional<Task> JobTaskSource::TakeTask(TaskSource::Transaction* transaction) {
// JobTaskSource members are not lock-protected so no need to acquire a lock
// if |transaction| is nullptr.
DCHECK_GT(worker_count_.load(std::memory_order_relaxed), 0U);
DCHECK(worker_task_);
return base::make_optional<Task>(from_here_, worker_task_, TimeDelta());
......
......@@ -14,6 +14,8 @@
#include "base/callback.h"
#include "base/macros.h"
#include "base/optional.h"
#include "base/synchronization/condition_variable.h"
#include "base/synchronization/lock.h"
#include "base/task/post_job.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool/sequence_sort_key.h"
......@@ -23,6 +25,8 @@
namespace base {
namespace internal {
class PooledTaskRunnerDelegate;
// A JobTaskSource generates many Tasks from a single RepeatingClosure.
//
// Derived classes control the intended concurrency with GetMaxConcurrency().
......@@ -31,7 +35,8 @@ class BASE_EXPORT JobTaskSource : public TaskSource {
JobTaskSource(const Location& from_here,
const TaskTraits& traits,
RepeatingCallback<void(experimental::JobDelegate*)> worker_task,
RepeatingCallback<size_t()> max_concurrency_callback);
RepeatingCallback<size_t()> max_concurrency_callback,
PooledTaskRunnerDelegate* delegate);
// Notifies this task source that max concurrency was increased, and the
// number of worker should be adjusted.
......@@ -41,17 +46,23 @@ class BASE_EXPORT JobTaskSource : public TaskSource {
ExecutionEnvironment GetExecutionEnvironment() override;
size_t GetRemainingConcurrency() const override;
// Returns the maximum number of tasks from this TaskSource that can run
// concurrently.
size_t GetMaxConcurrency() const;
#if DCHECK_IS_ON()
size_t GetConcurrencyIncreaseVersion() const;
// Returns true if the concurrency version was updated above
// |recorded_version|, or false on timeout.
bool WaitForConcurrencyIncreaseUpdate(size_t recorded_version);
#endif // DCHECK_IS_ON()
private:
static constexpr size_t kInvalidWorkerCount =
std::numeric_limits<size_t>::max();
~JobTaskSource() override;
// Returns the maximum number of tasks from this TaskSource that can run
// concurrently. The implementation can only return values lower than or equal
// to previously returned values.
size_t GetMaxConcurrency() const;
// TaskSource:
RunStatus WillRunTask() override;
Optional<Task> TakeTask(TaskSource::Transaction* transaction) override;
......@@ -67,6 +78,16 @@ class BASE_EXPORT JobTaskSource : public TaskSource {
base::RepeatingCallback<size_t()> max_concurrency_callback_;
base::RepeatingClosure worker_task_;
const TimeTicks queue_time_;
PooledTaskRunnerDelegate* delegate_;
#if DCHECK_IS_ON()
// Synchronizes accesses to |increase_version_|.
mutable Lock version_lock_;
// Signaled whenever increase_version_ is updated.
ConditionVariable version_condition_{&version_lock_};
// Incremented every time max concurrency is increased.
size_t increase_version_ GUARDED_BY(version_lock_) = 0;
#endif // DCHECK_IS_ON()
DISALLOW_COPY_AND_ASSIGN(JobTaskSource);
};
......
......@@ -8,20 +8,45 @@
#include "base/bind_helpers.h"
#include "base/memory/ptr_util.h"
#include "base/task/thread_pool/pooled_task_runner_delegate.h"
#include "base/task/thread_pool/test_utils.h"
#include "base/test/bind_test_util.h"
#include "base/test/gtest_util.h"
#include "base/test/test_timeouts.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using ::testing::_;
using ::testing::Return;
namespace base {
namespace internal {
class MockPooledTaskRunnerDelegate : public PooledTaskRunnerDelegate {
public:
MOCK_METHOD2(PostTaskWithSequence,
bool(Task task, scoped_refptr<Sequence> sequence));
MOCK_CONST_METHOD1(ShouldYield, bool(TaskSource* task_source));
MOCK_METHOD1(EnqueueJobTaskSource,
bool(scoped_refptr<JobTaskSource> task_source));
MOCK_CONST_METHOD1(IsRunningPoolWithTraits, bool(const TaskTraits& traits));
MOCK_METHOD2(UpdatePriority,
void(scoped_refptr<TaskSource> task_source,
TaskPriority priority));
};
class ThreadPoolJobTaskSourceTest : public testing::Test {
protected:
testing::StrictMock<MockPooledTaskRunnerDelegate>
pooled_task_runner_delegate_;
};
// Verifies the normal flow of running 2 tasks one after the other.
TEST(ThreadPoolJobTaskSourceTest, RunTasks) {
TEST_F(ThreadPoolJobTaskSourceTest, RunTasks) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
DoNothing(), /* num_tasks_to_run */ 2);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, {ThreadPool(), TaskPriority::BEST_EFFORT});
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
......@@ -55,11 +80,11 @@ TEST(ThreadPoolJobTaskSourceTest, RunTasks) {
// Verifies that a job task source doesn't allow any new RunStatus after Clear()
// is called.
TEST(ThreadPoolJobTaskSourceTest, Clear) {
TEST_F(ThreadPoolJobTaskSourceTest, Clear) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
DoNothing(), /* num_tasks_to_run */ 5);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, {ThreadPool(), TaskPriority::BEST_EFFORT});
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
EXPECT_EQ(5U, task_source->GetRemainingConcurrency());
auto registered_task_source_a =
......@@ -116,11 +141,11 @@ TEST(ThreadPoolJobTaskSourceTest, Clear) {
}
// Verifies that multiple tasks can run in parallel up to |max_concurrency|.
TEST(ThreadPoolJobTaskSourceTest, RunTasksInParallel) {
TEST_F(ThreadPoolJobTaskSourceTest, RunTasksInParallel) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
DoNothing(), /* num_tasks_to_run */ 2);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, {ThreadPool(), TaskPriority::BEST_EFFORT});
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
......@@ -158,12 +183,138 @@ TEST(ThreadPoolJobTaskSourceTest, RunTasksInParallel) {
EXPECT_FALSE(registered_task_source_c.DidProcessTask());
}
TEST(ThreadPoolJobTaskSourceTest, InvalidTakeTask) {
// Verifies that a call to NotifyConcurrencyIncrease() calls the delegate
// and allows to run additional tasks.
TEST_F(ThreadPoolJobTaskSourceTest, NotifyConcurrencyIncrease) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
DoNothing(), /* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
EXPECT_EQ(registered_task_source_a.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task_a = registered_task_source_a.TakeTask();
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
job_task->SetNumTasksToRun(2);
EXPECT_CALL(pooled_task_runner_delegate_, EnqueueJobTaskSource(_)).Times(1);
task_source->NotifyConcurrencyIncrease();
auto registered_task_source_b =
RegisteredTaskSource::CreateForTesting(task_source);
// WillRunTask() should return a valid RunStatus because max concurrency was
// increased to 2.
EXPECT_EQ(registered_task_source_b.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task_b = registered_task_source_b.TakeTask();
EXPECT_EQ(RegisteredTaskSource::CreateForTesting(task_source).WillRunTask(),
TaskSource::RunStatus::kDisallowed);
std::move(task_a->task).Run();
EXPECT_FALSE(registered_task_source_a.DidProcessTask());
std::move(task_b->task).Run();
EXPECT_FALSE(registered_task_source_b.DidProcessTask());
}
// Verifies that ShouldYield() calls the delegate.
TEST_F(ThreadPoolJobTaskSourceTest, ShouldYield) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
BindLambdaForTesting([](experimental::JobDelegate* delegate) {
// As set up below, the mock will return false once and true the second
// time.
EXPECT_FALSE(delegate->ShouldYield());
EXPECT_TRUE(delegate->ShouldYield());
}),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
ASSERT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task = registered_task_source.TakeTask();
EXPECT_CALL(pooled_task_runner_delegate_, ShouldYield(_))
.Times(2)
.WillOnce(Return(false))
.WillOnce(Return(true));
std::move(task->task).Run();
EXPECT_FALSE(registered_task_source.DidProcessTask());
}
// Verifies that max concurrency is allowed to stagnate when ShouldYield returns
// true.
TEST_F(ThreadPoolJobTaskSourceTest, MaxConcurrencyStagnateIfShouldYield) {
scoped_refptr<JobTaskSource> task_source =
base::MakeRefCounted<JobTaskSource>(
FROM_HERE, ThreadPool(),
BindRepeating([](experimental::JobDelegate* delegate) {
// As set up below, the mock will return true once.
ASSERT_TRUE(delegate->ShouldYield());
}),
BindRepeating([]() -> size_t {
return 1; // max concurrency is always 1.
}),
&pooled_task_runner_delegate_);
EXPECT_CALL(pooled_task_runner_delegate_, ShouldYield(_))
.WillOnce(Return(true));
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
ASSERT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task = registered_task_source.TakeTask();
// Running the task should not fail even though max concurrency remained at 1,
// since ShouldYield() returned true.
std::move(task->task).Run();
registered_task_source.DidProcessTask();
}
// Verifies that a missing call to NotifyConcurrencyIncrease() causes a DCHECK
// death after a timeout.
TEST_F(ThreadPoolJobTaskSourceTest, InvalidConcurrency) {
testing::FLAGS_gtest_death_test_style = "threadsafe";
scoped_refptr<test::MockJobTask> job_task;
job_task = base::MakeRefCounted<test::MockJobTask>(
BindLambdaForTesting([&](experimental::JobDelegate* delegate) {
EXPECT_FALSE(delegate->ShouldYield());
job_task->SetNumTasksToRun(2);
EXPECT_FALSE(delegate->ShouldYield());
// After returning, a DCHECK should trigger because we never called
// NotifyConcurrencyIncrease().
}),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
ASSERT_EQ(registered_task_source.WillRunTask(),
TaskSource::RunStatus::kAllowedSaturated);
auto task = registered_task_source.TakeTask();
EXPECT_DCHECK_DEATH(std::move(task->task).Run());
registered_task_source.DidProcessTask();
}
TEST_F(ThreadPoolJobTaskSourceTest, InvalidTakeTask) {
auto job_task =
base::MakeRefCounted<test::MockJobTask>(DoNothing(),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, {ThreadPool(), TaskPriority::BEST_EFFORT});
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
auto registered_task_source_a =
RegisteredTaskSource::CreateForTesting(task_source);
......@@ -182,12 +333,12 @@ TEST(ThreadPoolJobTaskSourceTest, InvalidTakeTask) {
registered_task_source_a.DidProcessTask();
}
TEST(ThreadPoolJobTaskSourceTest, InvalidDidProcessTask) {
TEST_F(ThreadPoolJobTaskSourceTest, InvalidDidProcessTask) {
auto job_task =
base::MakeRefCounted<test::MockJobTask>(DoNothing(),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, {ThreadPool(), TaskPriority::BEST_EFFORT});
FROM_HERE, ThreadPool(), &pooled_task_runner_delegate_);
auto registered_task_source =
RegisteredTaskSource::CreateForTesting(task_source);
......
......@@ -27,6 +27,11 @@ class BASE_EXPORT PooledTaskRunnerDelegate {
// outlives the ThreadPoolInstance that created it.
static bool Exists();
// Returns true if |task_source| currently running must return ASAP.
// Thread-safe but may return an outdated result (if a task unnecessarily
// yields due to this, it will simply be re-scheduled).
virtual bool ShouldYield(TaskSource* task_source) const = 0;
// Invoked when a |task| is posted to the PooledParallelTaskRunner or
// PooledSequencedTaskRunner. The implementation must post |task| to
// |sequence| within the appropriate priority queue, depending on |sequence|
......@@ -36,7 +41,8 @@ class BASE_EXPORT PooledTaskRunnerDelegate {
// Invoked when a task is posted as a Job. The implementation must add
// |task_source| to the appropriate priority queue, depending on |task_source|
// traits. Returns true if task source was successfully enqueued.
// traits, if it's not there already. Returns true if task source was
// successfully enqueued or was already enqueued.
virtual bool EnqueueJobTaskSource(
scoped_refptr<JobTaskSource> task_source) = 0;
......
......@@ -40,6 +40,8 @@ void TaskSource::Transaction::UpdatePriority(TaskPriority priority) {
if (FeatureList::IsEnabled(kAllTasksUserBlocking))
return;
task_source_->traits_.UpdatePriority(priority);
task_source_->priority_racy_.store(task_source_->traits_.priority(),
std::memory_order_relaxed);
}
void TaskSource::SetHeapHandle(const HeapHandle& handle) {
......@@ -54,6 +56,7 @@ TaskSource::TaskSource(const TaskTraits& traits,
TaskRunner* task_runner,
TaskSourceExecutionMode execution_mode)
: traits_(traits),
priority_racy_(traits.priority()),
task_runner_(task_runner),
execution_mode_(execution_mode) {
DCHECK(task_runner_ ||
......
......@@ -160,6 +160,14 @@ class BASE_EXPORT TaskSource : public RefCountedThreadSafe<TaskSource> {
TaskShutdownBehavior shutdown_behavior() const {
return traits_.shutdown_behavior();
}
// Returns a racy priority of the TaskSource. Can be accessed without a
// Transaction but may return an outdated result.
TaskPriority priority_racy() const {
return priority_racy_.load(std::memory_order_relaxed);
}
// Returns the thread policy of the TaskSource. Can be accessed without a
// Transaction because it is never mutated.
ThreadPolicy thread_policy() const { return traits_.thread_policy(); }
// A reference to TaskRunner is only retained between PushTask() and when
// DidProcessTask() returns false, guaranteeing it is safe to dereference this
......@@ -193,6 +201,9 @@ class BASE_EXPORT TaskSource : public RefCountedThreadSafe<TaskSource> {
// The TaskTraits of all Tasks in the TaskSource.
TaskTraits traits_;
// The cached priority for atomic access.
std::atomic<TaskPriority> priority_racy_;
// Synchronizes access to all members.
mutable CheckedLock lock_{UniversalPredecessor()};
......
......@@ -56,8 +56,8 @@ bool MockJobTaskRunner::PostDelayedTask(const Location& from_here,
return false;
auto job_task = base::MakeRefCounted<MockJobTask>(std::move(closure));
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(from_here, traits_);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
from_here, traits_, pooled_task_runner_delegate_);
return pooled_task_runner_delegate_->EnqueueJobTaskSource(
std::move(task_source));
}
......@@ -217,6 +217,10 @@ void MockPooledTaskRunnerDelegate::PostTaskWithSequenceNow(
}
}
bool MockPooledTaskRunnerDelegate::ShouldYield(TaskSource* task_source) const {
return thread_group_->ShouldYield(task_source->priority_racy());
}
bool MockPooledTaskRunnerDelegate::EnqueueJobTaskSource(
scoped_refptr<JobTaskSource> task_source) {
// |thread_group_| must be initialized with SetThreadGroup() before
......@@ -277,8 +281,6 @@ size_t MockJobTask::GetMaxConcurrency() const {
}
void MockJobTask::Run(experimental::JobDelegate* delegate) {
if (delegate->ShouldYield())
return;
worker_task_.Run(delegate);
size_t before = remaining_num_tasks_to_run_.fetch_sub(1);
DCHECK_GT(before, 0U);
......@@ -286,10 +288,12 @@ void MockJobTask::Run(experimental::JobDelegate* delegate) {
scoped_refptr<JobTaskSource> MockJobTask::GetJobTaskSource(
const Location& from_here,
const TaskTraits& traits) {
const TaskTraits& traits,
PooledTaskRunnerDelegate* delegate) {
return MakeRefCounted<JobTaskSource>(
from_here, traits, base::BindRepeating(&test::MockJobTask::Run, this),
base::BindRepeating(&test::MockJobTask::GetMaxConcurrency, this));
base::BindRepeating(&test::MockJobTask::GetMaxConcurrency, this),
delegate);
}
RegisteredTaskSource QueueAndRunTaskSource(
......
......@@ -62,6 +62,7 @@ class MockPooledTaskRunnerDelegate : public PooledTaskRunnerDelegate {
bool PostTaskWithSequence(Task task,
scoped_refptr<Sequence> sequence) override;
bool EnqueueJobTaskSource(scoped_refptr<JobTaskSource> task_source) override;
bool ShouldYield(TaskSource* task_source) const override;
bool IsRunningPoolWithTraits(const TaskTraits& traits) const override;
void UpdatePriority(scoped_refptr<TaskSource> task_source,
TaskPriority priority) override;
......@@ -81,8 +82,6 @@ class MockPooledTaskRunnerDelegate : public PooledTaskRunnerDelegate {
class MockJobTask : public base::RefCountedThreadSafe<MockJobTask> {
public:
// Gives |worker_task| to requesting workers |num_tasks_to_run| times.
// ShouldYield() is automatically called on JobDelegate before running
// |worker_task| so that DoNothing() may be passed.
MockJobTask(
base::RepeatingCallback<void(experimental::JobDelegate*)> worker_task,
size_t num_tasks_to_run);
......@@ -99,8 +98,10 @@ class MockJobTask : public base::RefCountedThreadSafe<MockJobTask> {
size_t GetMaxConcurrency() const;
void Run(experimental::JobDelegate* delegate);
scoped_refptr<JobTaskSource> GetJobTaskSource(const Location& from_here,
const TaskTraits& traits);
scoped_refptr<JobTaskSource> GetJobTaskSource(
const Location& from_here,
const TaskTraits& traits,
PooledTaskRunnerDelegate* delegate);
private:
friend class base::RefCountedThreadSafe<MockJobTask>;
......
......@@ -311,7 +311,8 @@ TEST_F(ThreadGroupImplImplTest, ShouldYieldFloodedUserVisible) {
}),
/* num_tasks_to_run */ kMaxTasks);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, {ThreadPool(), TaskPriority::USER_VISIBLE});
FROM_HERE, {ThreadPool(), TaskPriority::USER_VISIBLE},
&mock_pooled_task_runner_delegate_);
auto registered_task_source = task_tracker_.RegisterTaskSource(task_source);
ASSERT_TRUE(registered_task_source);
......
......@@ -596,8 +596,8 @@ TEST_P(ThreadGroupTest, ScheduleJobTaskSource) {
test::WaitWithoutBlockingObserver(&threads_continue);
}),
/* num_tasks_to_run */ kMaxTasks);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {ThreadPool()});
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, ThreadPool(), &mock_pooled_task_runner_delegate_);
auto registered_task_source =
task_tracker_.RegisterTaskSource(std::move(task_source));
......@@ -614,6 +614,95 @@ TEST_P(ThreadGroupTest, ScheduleJobTaskSource) {
task_tracker_.FlushForTesting();
}
// Verify that tasks from a JobTaskSource run at the intended concurrency.
TEST_P(ThreadGroupTest, ScheduleJobTaskSourceMultipleTime) {
StartThreadGroup();
WaitableEvent thread_running;
WaitableEvent thread_continue;
auto job_task = base::MakeRefCounted<test::MockJobTask>(
BindLambdaForTesting(
[&thread_running, &thread_continue](experimental::JobDelegate*) {
DCHECK(!thread_running.IsSignaled());
thread_running.Signal();
test::WaitWithoutBlockingObserver(&thread_continue);
}),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, ThreadPool(), &mock_pooled_task_runner_delegate_);
thread_group_->PushTaskSourceAndWakeUpWorkers(
TransactionWithRegisteredTaskSource::FromTaskSource(
task_tracker_.RegisterTaskSource(task_source)));
// Enqueuing the task source again shouldn't affect the number of time it's
// run.
thread_group_->PushTaskSourceAndWakeUpWorkers(
TransactionWithRegisteredTaskSource::FromTaskSource(
task_tracker_.RegisterTaskSource(task_source)));
thread_running.Wait();
thread_continue.Signal();
// Once the worker task ran, enqueuing the task source has no effect.
thread_group_->PushTaskSourceAndWakeUpWorkers(
TransactionWithRegisteredTaskSource::FromTaskSource(
task_tracker_.RegisterTaskSource(task_source)));
// Flush the task tracker to be sure that no local variables are accessed by
// tasks after the end of the scope.
task_tracker_.FlushForTesting();
}
// Verify that calling JobTaskSource::NotifyConcurrencyIncrease() (re-)schedule
// tasks with the intended concurrency.
TEST_P(ThreadGroupTest, JobTaskSourceConcurrencyIncrease) {
StartThreadGroup();
WaitableEvent threads_running_a;
WaitableEvent threads_continue;
// Initially schedule half the tasks.
RepeatingClosure threads_running_barrier = BarrierClosure(
kMaxTasks / 2,
BindOnce(&WaitableEvent::Signal, Unretained(&threads_running_a)));
auto job_state = base::MakeRefCounted<test::MockJobTask>(
BindLambdaForTesting([&threads_running_barrier,
&threads_continue](experimental::JobDelegate*) {
threads_running_barrier.Run();
test::WaitWithoutBlockingObserver(&threads_continue);
}),
/* num_tasks_to_run */ kMaxTasks / 2);
auto task_source = job_state->GetJobTaskSource(
FROM_HERE, ThreadPool(), &mock_pooled_task_runner_delegate_);
auto registered_task_source = task_tracker_.RegisterTaskSource(task_source);
EXPECT_TRUE(registered_task_source);
thread_group_->PushTaskSourceAndWakeUpWorkers(
TransactionWithRegisteredTaskSource::FromTaskSource(
std::move(registered_task_source)));
threads_running_a.Wait();
// Reset |threads_running_barrier| for the remaining tasks.
WaitableEvent threads_running_b;
threads_running_barrier = BarrierClosure(
kMaxTasks / 2,
BindOnce(&WaitableEvent::Signal, Unretained(&threads_running_b)));
job_state->SetNumTasksToRun(kMaxTasks);
// Unblocks tasks to let them racily wait for NotifyConcurrencyIncrease() to
// be called.
threads_continue.Signal();
task_source->NotifyConcurrencyIncrease();
// Wait for the remaining tasks. This should not block forever.
threads_running_b.Wait();
// Flush the task tracker to be sure that no local variables are accessed by
// tasks after the end of the scope.
task_tracker_.FlushForTesting();
}
// Verify that a JobTaskSource that becomes empty while in the queue eventually
// gets discarded.
TEST_P(ThreadGroupTest, ScheduleEmptyJobTaskSource) {
......@@ -624,8 +713,8 @@ TEST_P(ThreadGroupTest, ScheduleEmptyJobTaskSource) {
auto job_task = base::MakeRefCounted<test::MockJobTask>(
BindRepeating([](experimental::JobDelegate*) { ShouldNotRun(); }),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {ThreadPool()});
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, ThreadPool(), &mock_pooled_task_runner_delegate_);
auto registered_task_source =
task_tracker_.RegisterTaskSource(std::move(task_source));
......@@ -674,7 +763,8 @@ TEST_P(ThreadGroupTest, JobTaskSourceUpdatePriority) {
}),
/* num_tasks_to_run */ kMaxTasks);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, {ThreadPool(), TaskPriority::BEST_EFFORT});
FROM_HERE, {ThreadPool(), TaskPriority::BEST_EFFORT},
&mock_pooled_task_runner_delegate_);
auto registered_task_source = task_tracker_.RegisterTaskSource(task_source);
EXPECT_TRUE(registered_task_source);
......
......@@ -18,6 +18,7 @@
#include "base/metrics/field_trial_params.h"
#include "base/stl_util.h"
#include "base/strings/string_util.h"
#include "base/task/scoped_set_task_priority_for_current_thread.h"
#include "base/task/task_features.h"
#include "base/task/thread_pool/pooled_parallel_task_runner.h"
#include "base/task/thread_pool/pooled_sequenced_task_runner.h"
......@@ -396,6 +397,18 @@ bool ThreadPoolImpl::PostTaskWithSequence(Task task,
return true;
}
bool ThreadPoolImpl::ShouldYield(TaskSource* task_source) const {
const TaskPriority priority = task_source->priority_racy();
auto* const thread_group =
GetThreadGroupForTraits({priority, task_source->thread_policy()});
// A task whose priority changed and is now running in the wrong thread group
// should yield so it's rescheduled in the right one.
if (!thread_group->IsBoundToCurrentThread())
return true;
return GetThreadGroupForTraits({priority, task_source->thread_policy()})
->ShouldYield(priority);
}
bool ThreadPoolImpl::EnqueueJobTaskSource(
scoped_refptr<JobTaskSource> task_source) {
auto registered_task_source =
......
......@@ -103,6 +103,8 @@ class BASE_EXPORT ThreadPoolImpl : public ThreadPoolInstance,
// PooledTaskRunnerDelegate:
bool EnqueueJobTaskSource(scoped_refptr<JobTaskSource> task_source) override;
void UpdatePriority(scoped_refptr<TaskSource> task_source,
TaskPriority priority) override;
// Returns the TimeTicks of the next task scheduled on ThreadPool (Now() if
// immediate, nullopt if none). This is thread-safe, i.e., it's safe if tasks
......@@ -139,8 +141,7 @@ class BASE_EXPORT ThreadPoolImpl : public ThreadPoolInstance,
bool PostTaskWithSequence(Task task,
scoped_refptr<Sequence> sequence) override;
bool IsRunningPoolWithTraits(const TaskTraits& traits) const override;
void UpdatePriority(scoped_refptr<TaskSource> task_source,
TaskPriority priority) override;
bool ShouldYield(TaskSource* task_source) const override;
const std::unique_ptr<TaskTrackerImpl> task_tracker_;
std::unique_ptr<Thread> service_thread_;
......
......@@ -1087,12 +1087,46 @@ TEST_P(ThreadPoolImplTest, ScheduleJobTaskSource) {
}),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source =
job_task->GetJobTaskSource(FROM_HERE, {ThreadPool()});
job_task->GetJobTaskSource(FROM_HERE, ThreadPool(), thread_pool_.get());
thread_pool_->EnqueueJobTaskSource(task_source);
threads_running.Wait();
}
// Verify that calling ShouldYield() returns true for a job task source that
// needs to change thread group because of a priority update.
TEST_P(ThreadPoolImplTest, ThreadGroupChangeShouldYield) {
StartThreadPool();
WaitableEvent threads_running;
WaitableEvent threads_continue;
auto job_task = base::MakeRefCounted<test::MockJobTask>(
BindLambdaForTesting([&threads_running, &threads_continue](
experimental::JobDelegate* delegate) {
EXPECT_FALSE(delegate->ShouldYield());
threads_running.Signal();
test::WaitWithoutBlockingObserver(&threads_continue);
// The task source needs to yield if background thread groups exist.
EXPECT_EQ(delegate->ShouldYield(),
CanUseBackgroundPriorityForWorkerThread());
}),
/* num_tasks_to_run */ 1);
scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
FROM_HERE, TaskPriority::USER_VISIBLE, thread_pool_.get());
thread_pool_->EnqueueJobTaskSource(task_source);
threads_running.Wait();
thread_pool_->UpdatePriority(task_source, TaskPriority::BEST_EFFORT);
threads_continue.Signal();
// Flush the task tracker to be sure that no local variables are accessed by
// tasks after the end of the scope.
thread_pool_->FlushForTesting();
}
namespace {
class MustBeDestroyed {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment