Commit fb5abd41 authored by Simeon Anfinrud's avatar Simeon Anfinrud Committed by Commit Bot

[Chromecast] Type-safe static sequences.

With this, you can make the compiler ensure, statically, that
the methods of your class always run on the same sequence.
This is vastly safer than runtime checks like SequenceChecker
since it will fail to compile if you violate the requirement.

(It wouldn't be C++ if there weren't hacky ways around this, but
any attempt to trick the compiler into allowing a sequence
violation should look pretty obvious to reviewers.)

To use this, declare a struct that inherits from StaticSequence.
This struct will automatically declare a Key class that can only
be constructed inside the StaticSequence's PostTask function. To
force users to go through that StaticSequence's PostTask() to
call your method, simply declare a reference to the Key object
as the last parameter to the method.

This also includes a wrapper template called Sequenced, similar
to base::SequenceBound, but with the TaskRunner known at compile
time rather than runtime. This can be used to add an extra level
of safety over runtime checks like
DCHECK_CALLED_ON_VALID_SEQUENCE or
TaskRunner::RunsTasksInCurrentThread. This wrapper ensures the
wrapped object is destroyed on the correct sequence, essentially
turning any sequence-affine object to a thread-safe object.

This will be made even more useful once base::PostTask() returns
a base::Promise, as it will work with methods that return values
as well as void methods.

Bug: None
Test: cast_base_unittests
Change-Id: Ic408e343bc084c19c4c6f9e983b343b98502daf2
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1863549Reviewed-by: default avatarKenneth MacKay <kmackay@chromium.org>
Reviewed-by: default avatarYuchen Liu <yucliu@chromium.org>
Commit-Queue: Simeon Anfinrud <sanfin@chromium.org>
Cr-Commit-Position: refs/heads/master@{#711007}
parent a375a850
......@@ -185,6 +185,7 @@ test("cast_base_unittests") {
":thread_health_checker",
"//base/test:run_all_unittests",
"//base/test:test_support",
"//chromecast/base/static_sequence:tests",
"//testing/gmock",
"//testing/gtest",
]
......
# Copyright 2019 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import("//build/nocompile.gni")
import("//chromecast/chromecast.gni")
cast_source_set("static_sequence") {
sources = [
"static_sequence.cc",
"static_sequence.h",
]
deps = [
"//base",
]
}
cast_source_set("tests") {
testonly = true
sources = [
"static_sequence_unittest.cc",
]
deps = [
":static_sequence",
"//base/test:test_support",
"//testing/gtest",
]
if (enable_nocompile_tests) {
deps += [ ":nocompile_tests" ]
}
}
if (enable_nocompile_tests) {
nocompile_test("nocompile_tests") {
sources = [
"static_sequence_unittest.nc",
]
deps = [
":static_sequence",
"//base/test:run_all_unittests",
"//testing/gtest",
]
}
}
sanfin@chromium.org
thoren@chromium.org
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromecast/base/static_sequence/static_sequence.h"
namespace util {
namespace internal {
StaticTaskRunnerHolder::StaticTaskRunnerHolder(base::TaskTraits traits)
: traits_(traits), initialized_(false) {}
StaticTaskRunnerHolder::~StaticTaskRunnerHolder() = default;
void StaticTaskRunnerHolder::WillDestroyCurrentMessageLoop() {
initialized_ = false;
task_runner_ = nullptr;
}
const scoped_refptr<base::SequencedTaskRunner>& StaticTaskRunnerHolder::Get() {
if (!initialized_) {
task_runner_ = base::CreateSequencedTaskRunner(traits_);
base::MessageLoopCurrent::Get().AddDestructionObserver(this);
initialized_ = true;
}
return task_runner_;
}
} // namespace internal
} // namespace util
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROMECAST_BASE_STATIC_SEQUENCE_STATIC_SEQUENCE_H_
#define CHROMECAST_BASE_STATIC_SEQUENCE_STATIC_SEQUENCE_H_
#include <memory>
#include <utility>
#include "base/callback_forward.h"
#include "base/location.h"
#include "base/memory/ref_counted.h"
#include "base/message_loop/message_loop_current.h"
#include "base/no_destructor.h"
#include "base/task/post_task.h"
// Allows sequences to be defined at compile time so that objects can opt into
// requiring that their methods are called on a specific sequence in a way that
// can be checked by the compiler rather than DCHECKs.
//
// To define a sequence, just create a class that extends this one using the
// Curiously Recurring Template Pattern:
//
// struct MySequence : util::StaticSequence<MySequence> {};
//
// To require that a function run on that sequence, add a Key parameter from the
// sequence:
//
// void MyFunction(int x, int y, MySequence::Key&);
//
// Such a function must be called through the MySequence's PostTask() method:
//
// // Can run on any thread.
// void MyFunctionThreadSafe(int x, int y) {
// MySequence::PostTask(FROM_HERE, base::BindOnce(&MyFunction, x, y));
// }
//
// You can also add the Key as the final parameter to instance methods to
// similarly require that the method be called on the sequence:
//
// struct MyStruct {
// // The Key needs to be the last parameter!
// void MyMethod(int x, int y, MySequence::Key&);
// };
//
// void CallMyMethodFromOriginThreadSafe(MyStruct* m) {
// MySequence::PostTask(
// FROM_HERE,
// base::BindOnce(&MyStruct::MyMethod, base::Unretained(m), 0, 0));
// }
//
// If a class is tightly coupled to a given sequence (i.e. expects to always be
// called on that sequence), it may be worth wrapping in Sequenced, which is
// similar to base::SequenceBound but will work with statically-sequenced
// method calls. This will also ensure the destructor is run on the same
// sequence.
namespace util {
template <typename T, typename TraitsProvider>
class StaticSequence;
namespace internal {
// Provides a TaskRunner and can persist after the message loop is destroyed,
// which is useful if e.g. a StaticTaskRunnerHolder outlives a
// base::test::TaskEnvironment in tests. Only usable by StaticSequence.
class StaticTaskRunnerHolder
: public base::MessageLoopCurrent::DestructionObserver {
public:
~StaticTaskRunnerHolder() override;
private:
template <typename T, typename TraitsProvider>
friend class ::util::StaticSequence;
explicit StaticTaskRunnerHolder(base::TaskTraits traits);
void WillDestroyCurrentMessageLoop() override;
const scoped_refptr<base::SequencedTaskRunner>& Get();
const base::TaskTraits traits_;
bool initialized_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
};
} // namespace internal
// Default traits for a static sequence. They can be overridden by specifying
// another struct with a GetTraits() static method as the second template
// parameter to StaticSequence.
//
// Example:
//
// class MyBackgroundService {
// struct BackgroundTaskTraitsProvider {
// static constexpr base::TaskTraits GetTraits() {
// return {
// base::ThreadPool(),
// base::TaskPriority::BEST_EFFORT,
// base::MayBlock(),
// };
// }
// };
// public:
// struct BackgroundSequence
// : util::StaticSequence<BackgroundSequence,
// BackgroundTaskTraitsProvider> {};
// void DoBackgroundWork(const std::string& request,
// BackgroundSequence::Key&);
// };
struct DefaultStaticSequenceTraitsProvider {
static constexpr base::TaskTraits GetTraits() { return {base::ThreadPool()}; }
};
// A class that extends StaticSequence is a holder for a process-global
// TaskRunner that is created on-demand with the desired traits, which also
// provides static PostTask overloads that can take callbacks that require a
// special Key that only the StaticSequence can provide. This trick is what
// guarantees at compile time that all invocations of a statically-sequenced
// function are run on the correct TaskRunner.
template <typename T,
typename TraitsProvider = DefaultStaticSequenceTraitsProvider>
class StaticSequence {
public:
// Can only be constructed by the StaticSequence implementation. This
// restriction allows functions and methods to statically assert that they are
// being called on the correct sequence because StaticSequences will only
// provide a reference to its Key through their PostTask() method.
//
// The reference can be passed around, but the key itself cannot be copied or
// moved, and the address cannot be taken.
class Key {
public:
using Sequence = T;
~Key() = default;
private:
friend class StaticSequence;
constexpr Key() = default;
// Cannot copy, move, or take the address of a Key. This prevents the common
// ways one might attempt to obtain a Key outside the scope where it is
// valid.
Key(const Key&) = delete;
Key& operator=(const Key&) = delete;
const Key* operator&() const = delete;
};
static const scoped_refptr<base::SequencedTaskRunner>& TaskRunner() {
// A StaticTaskRunnerHolder is able to regenerate a TaskRunner after the
// global thread pool is destroyed and re-created (which can happen between
// unittests that use base::test::TaskEnvironment).
static internal::StaticTaskRunnerHolder task_runner(
TraitsProvider::GetTraits());
return task_runner.Get();
}
// Catches you if you attempt to post a callback that consumes a Key of
// another StaticSequence. The compiler will print a message containing
// PostedTo, the StaticSequence whose PostTask method was called; and
// Expected, the StaticSequence whose Key was requested by the task.
template <typename U>
using IncompatibleCallback = base::OnceCallback<void(U&)>;
template <typename U, typename Expected = typename U::Sequence>
static void PostTask(
IncompatibleCallback<U> cb,
const base::Location& from_here = base::Location::Current()) {
using PostedTo = T;
static_assert(invalid<PostedTo, Expected>,
"Attempting to post a statically-sequenced task to the wrong "
"static sequence!");
}
// Takes a callback that specifically requires that it be invoked from this
// sequence. Such callbacks can only be invoked through this method because
// the Key is only constructible here.
using CompatibleCallback = base::OnceCallback<void(Key&)>;
static void PostTask(
CompatibleCallback cb,
const base::Location& from_here = base::Location::Current()) {
TaskRunner()->PostTask(from_here,
base::BindOnce(std::move(cb), std::ref(key_)));
}
// Takes any closure with no unbound arguments.
static void PostTask(
base::OnceClosure cb,
const base::Location& from_here = base::Location::Current()) {
TaskRunner()->PostTask(from_here, std::move(cb));
}
// The Run() overload set can only be invoked on the sequence, and accepts
// callbacks that may or may not require a Key to the sequence.
static void Run(CompatibleCallback cb, Key& key) { std::move(cb).Run(key); }
static void Run(base::OnceClosure cb, Key&) { std::move(cb).Run(); }
template <typename U, typename Expected = typename U::Sequence>
static void Run(IncompatibleCallback<U> cb, Key&) {
using PostedTo = T;
static_assert(invalid<PostedTo, Expected>,
"Attempting to post a statically-sequenced task to the wrong "
"static sequence!");
}
// Forwards a functor and arguments before posting as a task, to avoid
// unnecessary mallocs. Prefer this to PostTask() when possible to reduce
// runtime overhead.
template <typename F, typename... Args>
static void Post(const base::Location& from_here, F&& f, Args&&... args) {
TaskRunner()->PostTask(
from_here, BindHelper<needs_key<F>, F, Args...>::Bind(
std::forward<F>(f), std::forward<Args>(args)...));
}
private:
// Used to help print readable compiler messages in static_assert failures.
template <typename... Args>
constexpr static bool invalid = false;
template <typename... Ts>
struct Pack;
template <typename Pack>
struct LastArgumentIsKey;
template <typename First, typename... Rest>
struct LastArgumentIsKey<Pack<First, Rest...>>
: LastArgumentIsKey<Pack<Rest...>> {};
template <>
struct LastArgumentIsKey<Pack<Key&>> : std::true_type {};
template <>
struct LastArgumentIsKey<Pack<>> : std::false_type {};
template <typename F>
struct GetArgs;
template <typename R, typename... Args>
struct GetArgs<R (*)(Args...)> {
using type = Pack<Args...>;
};
template <typename R, typename Obj, typename... Args>
struct GetArgs<R (Obj::*)(Args...)> {
using type = Pack<Args...>;
};
template <typename F>
constexpr static bool needs_key =
LastArgumentIsKey<typename GetArgs<F>::type>::value;
template <bool requires_key, typename... Args>
struct BindHelper;
template <typename... Args>
struct BindHelper<false, Args...> {
static base::OnceClosure Bind(Args... args) {
return base::BindOnce(std::forward<Args>(args)...);
}
};
template <typename... Args>
struct BindHelper<true, Args...> {
static base::OnceClosure Bind(Args... args) {
return base::BindOnce(std::forward<Args>(args)..., std::ref(key_));
}
};
static Key key_;
};
template <typename T, typename TraitsProvider>
typename StaticSequence<T, TraitsProvider>::Key
StaticSequence<T, TraitsProvider>::key_ = {};
// Behaves like the SequenceBound class wrapper for static sequences, wrapping
// an object and forcing all method calls to go through Post(), which ensures
// they are all called on the statically assigned sequence, whether the methods
// ask for a Key or not.
template <typename T, typename Sequence>
class Sequenced {
public:
template <typename... Args>
explicit Sequenced(Args&&... args) : obj_(Uninitialized()) {
Sequence::Post(FROM_HERE, &Sequenced::Construct<Args...>,
base::Unretained(this), std::forward<Args>(args)...);
}
template <typename... Args, typename... Bound>
void Post(const base::Location& from_here,
void (T::*method)(Args...),
Bound&&... args) {
Sequence::Post(from_here, &Sequenced::Call<decltype(method), Bound...>,
base::Unretained(this), method,
std::forward<Bound>(args)...);
}
private:
using UniquePtr = std::unique_ptr<T, base::OnTaskRunnerDeleter>;
template <typename... Args>
void Construct(Args&&... args, typename Sequence::Key& key) {
obj_ = MakeUnique<Args...>(std::forward<Args>(args)..., key);
}
static UniquePtr Uninitialized() {
return UniquePtr(nullptr,
base::OnTaskRunnerDeleter(Sequence::TaskRunner()));
}
template <typename... Args>
UniquePtr MakeUnique(Args&&... args, typename Sequence::Key&) {
return UniquePtr(new T(std::forward<Args>(args)...),
base::OnTaskRunnerDeleter(Sequence::TaskRunner()));
}
template <typename Method, typename... Bound>
void Call(Method method, Bound&&... args, typename Sequence::Key& key) {
Sequence::Run(base::BindOnce(method, base::Unretained(obj_.get()),
std::forward<Bound>(args)...),
key);
}
UniquePtr obj_;
};
} // namespace util
#endif // CHROMECAST_BASE_STATIC_SEQUENCE_STATIC_SEQUENCE_H_
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromecast/base/static_sequence/static_sequence.h"
#include "base/test/task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace util {
namespace {
struct TestSequence : StaticSequence<TestSequence> {};
struct CustomTraitsProvider {
static constexpr base::TaskTraits GetTraits() {
return {base::ThreadPool(), base::TaskPriority::LOWEST,
base::ThreadPolicy::PREFER_BACKGROUND, base::MayBlock()};
}
};
struct TestSequenceWithCustomTraits
: StaticSequence<TestSequenceWithCustomTraits, CustomTraitsProvider> {};
void DoSomething(bool* activated) {
*activated = true;
}
void DoSomethingWithRequiredSequence(bool* activated, TestSequence::Key&) {
*activated = true;
}
class TestObject {
public:
void DoSomething(bool* activated) { *activated = true; }
void DoSomethingWithRequiredSequence(bool* activated, TestSequence::Key&) {
*activated = true;
}
};
class ParameterizedObject {
public:
explicit ParameterizedObject(int increment_by)
: increment_by_(increment_by) {}
void Increment(int* out, TestSequence::Key&) { *out += increment_by_; }
private:
int increment_by_;
};
class HasSideEffectsInConstructor {
public:
HasSideEffectsInConstructor(int x, int y, int* r) { *r = x + y; }
};
class HasSideEffectsInDestructor {
public:
HasSideEffectsInDestructor(int x, int y, int* r) : r_(r), sum_(x + y) {}
~HasSideEffectsInDestructor() { *r_ = sum_; }
private:
int* r_;
int sum_;
};
} // namespace
TEST(StaticSequenceTest, StaticProperties) {
static_assert(!std::is_copy_constructible<TestSequence::Key>::value,
"Keys must not be copyable.");
static_assert(!std::is_move_constructible<TestSequence::Key>::value,
"Keys must not be movable.");
}
TEST(StaticSequenceTest, InvokeUnprotectedCallback) {
base::test::TaskEnvironment env;
bool activated = false;
TestSequence::PostTask(base::BindOnce(&DoSomething, &activated));
EXPECT_FALSE(activated);
env.RunUntilIdle();
EXPECT_TRUE(activated);
}
TEST(StaticSequenceTest, InvokeProtectedCallback) {
base::test::TaskEnvironment env;
bool activated = false;
TestSequence::PostTask(
base::BindOnce(&DoSomethingWithRequiredSequence, &activated));
EXPECT_FALSE(activated);
env.RunUntilIdle();
EXPECT_TRUE(activated);
}
TEST(StaticSequenceTest, InvokeObjectUnprotectedMethod) {
base::test::TaskEnvironment env;
bool activated = false;
TestObject obj;
TestSequence::PostTask(base::BindOnce(&TestObject::DoSomething,
base::Unretained(&obj), &activated));
EXPECT_FALSE(activated);
env.RunUntilIdle();
EXPECT_TRUE(activated);
}
TEST(StaticSequenceTest, InvokeSequencedObjectUnprotectedMethod) {
base::test::TaskEnvironment env;
bool activated = false;
Sequenced<TestObject, TestSequence> obj;
obj.Post(FROM_HERE, &TestObject::DoSomething, &activated);
EXPECT_FALSE(activated);
env.RunUntilIdle();
EXPECT_TRUE(activated);
}
TEST(StaticSequenceTest, InvokeSequencedObjectProtectedMethod) {
base::test::TaskEnvironment env;
bool activated = false;
Sequenced<TestObject, TestSequence> obj;
obj.Post(FROM_HERE, &TestObject::DoSomethingWithRequiredSequence, &activated);
EXPECT_FALSE(activated);
env.RunUntilIdle();
EXPECT_TRUE(activated);
}
TEST(StaticSequenceTest, SequencedConstructorIncludesArguments) {
base::test::TaskEnvironment env;
int r = 0;
Sequenced<ParameterizedObject, TestSequence> obj(2);
obj.Post(FROM_HERE, &ParameterizedObject::Increment, &r);
EXPECT_EQ(r, 0);
env.RunUntilIdle();
EXPECT_EQ(r, 2);
}
TEST(StaticSequenceTest, UseCustomTraits) {
base::test::TaskEnvironment env;
bool r = false;
Sequenced<TestObject, TestSequenceWithCustomTraits> obj;
obj.Post(FROM_HERE, &TestObject::DoSomething, &r);
EXPECT_FALSE(r);
env.RunUntilIdle();
EXPECT_TRUE(r);
}
TEST(StaticSequenceTest, ConstructsOnSequence) {
base::test::TaskEnvironment env;
int r = 0;
// The constructor for HasSideEffectsInConstructor will set |r| to the sum of
// the first two arguments, but should only run on the sequence.
Sequenced<HasSideEffectsInConstructor, TestSequence> obj(1, 2, &r);
EXPECT_EQ(r, 0);
env.RunUntilIdle();
EXPECT_EQ(r, 3);
}
TEST(StaticSequenceTest, DestructOnSequence) {
base::test::TaskEnvironment env;
int r = 0;
{
// The destructor for HasSideEffectsInDestructor will set |r| to the sum of
// the first two constructor arguments, but should only run on the sequence.
Sequenced<HasSideEffectsInDestructor, TestSequence> obj(2, 3, &r);
env.RunUntilIdle();
EXPECT_EQ(r, 0);
}
EXPECT_EQ(r, 0);
env.RunUntilIdle();
EXPECT_EQ(r, 5);
}
TEST(StaticSequenceTest, PostUnprotectedMemberFunction) {
base::test::TaskEnvironment env;
TestObject x;
bool r = false;
TestSequence::Post(FROM_HERE, &TestObject::DoSomething, base::Unretained(&x),
&r);
EXPECT_FALSE(r);
env.RunUntilIdle();
EXPECT_TRUE(r);
}
TEST(StaticSequenceTest, PostProtectedMemberFunction) {
base::test::TaskEnvironment env;
TestObject x;
bool r = false;
TestSequence::Post(FROM_HERE, &TestObject::DoSomethingWithRequiredSequence,
base::Unretained(&x), &r);
EXPECT_FALSE(r);
env.RunUntilIdle();
EXPECT_TRUE(r);
}
TEST(StaticSequenceTest, PostUnprotectedFreeFunction) {
base::test::TaskEnvironment env;
bool r = false;
TestSequence::Post(FROM_HERE, &DoSomething, &r);
EXPECT_FALSE(r);
env.RunUntilIdle();
EXPECT_TRUE(r);
}
TEST(StaticSequenceTest, PostProtectedFreeFunction) {
base::test::TaskEnvironment env;
bool r = false;
TestSequence::Post(FROM_HERE, &DoSomethingWithRequiredSequence, &r);
EXPECT_FALSE(r);
env.RunUntilIdle();
EXPECT_TRUE(r);
}
} // namespace util
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// This is a no-compile test suite.
// http://dev.chromium.org/developers/testing/no-compile-tests
#include "chromecast/base/static_sequence/static_sequence.h"
namespace util {
struct SequenceA : StaticSequence<SequenceA> {};
struct SequenceB : StaticSequence<SequenceB> {};
void Foo(SequenceA::Key&);
class Bar {
public:
void Baz(SequenceA::Key&) {}
};
void StaticSequenceNoCompileTests() {
Sequenced<Bar, SequenceB> bar;
#if defined(NCTEST_POST_FUNCTION_TO_WRONG_SEQUENCE) // [r"fatal error: static_assert failed due to requirement 'invalid<util::SequenceB, util::SequenceA>' \"Attempting to post a statically-sequenced task to the wrong static sequence!\""]
SequenceB::PostTask(base::BindOnce(&Foo));
#elif defined(NCTEST_POST_METHOD_TO_WRONG_SEQUENCE) // [r"fatal error: static_assert failed due to requirement 'invalid<util::SequenceB, util::SequenceA>' \"Attempting to post a statically-sequenced task to the wrong static sequence!\""]
bar.Post(FROM_HERE, &Bar::Baz);
#endif
}
} // namespace util
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment