From bbd41b0702d4a745034efc677baf948d8c7d555f Mon Sep 17 00:00:00 2001 From: Lewis Baker Date: Sat, 1 Apr 2017 22:02:33 +1030 Subject: [PATCH] Add cppcoro::task and cppcoro::single_consumer_event. Includes a few tests for basic functionality of task. --- include/cppcoro/broken_promise.hpp | 20 ++ include/cppcoro/single_consumer_event.hpp | 118 +++++++ include/cppcoro/task.hpp | 393 ++++++++++++++++++++++ test/main.cpp | 332 ++++++++++++++++++ 4 files changed, 863 insertions(+) create mode 100644 include/cppcoro/broken_promise.hpp create mode 100644 include/cppcoro/single_consumer_event.hpp create mode 100644 include/cppcoro/task.hpp create mode 100644 test/main.cpp diff --git a/include/cppcoro/broken_promise.hpp b/include/cppcoro/broken_promise.hpp new file mode 100644 index 0000000..5acde6f --- /dev/null +++ b/include/cppcoro/broken_promise.hpp @@ -0,0 +1,20 @@ +#ifndef CPPCORO_BROKEN_PROMISE_HPP_INCLUDED +#define CPPCORO_BROKEN_PROMISE_HPP_INCLUDED + +#include + +namespace cppcoro +{ + /// \brief + /// Exception thrown when you attempt to retrieve the result of + /// a task that has been detached from its promise/coroutine. + class broken_promise : public std::logic_error + { + public: + broken_promise() + : std::logic_error("broken promise") + {} + }; +} + +#endif diff --git a/include/cppcoro/single_consumer_event.hpp b/include/cppcoro/single_consumer_event.hpp new file mode 100644 index 0000000..387036a --- /dev/null +++ b/include/cppcoro/single_consumer_event.hpp @@ -0,0 +1,118 @@ +#ifndef CPPCORO_SINGLE_CONSUMER_EVENT_HPP_INCLUDED +#define CPPCORO_SINGLE_CONSUMER_EVENT_HPP_INCLUDED + +#include +#include + +namespace cppcoro +{ + /// \brief + /// A manual-reset event that supports only a single awaiting + /// coroutine at a time. + /// + /// You can co_await the event to suspend the current coroutine until + /// some thread calls set(). If the event is already set then the + /// coroutine will not be suspended and will continue execution. + /// If the event was not yet set then the coroutine will be resumed + /// on the thread that calls set() within the call to set(). + /// + /// Callers must ensure that only one coroutine is executing a + /// co_await statement at any point in time. + class single_consumer_event + { + public: + + single_consumer_event(bool initiallySet = false) noexcept + : m_state(initiallySet ? state::set : state::not_set) + {} + + /// Query if this event has been set. + bool is_set() const noexcept + { + return m_state.load(std::memory_order_acquire) == state::set; + } + + /// \brief + /// Transition this event to the 'set' state if it is not already set. + /// + /// If there was a coroutine awaiting the event then it will be resumed + /// inside this call. + void set() + { + const state oldState = m_state.exchange(state::set, std::memory_order_acq_rel); + if (oldState == state::not_set_consumer_waiting) + { + m_awaiter.resume(); + } + } + + /// \brief + /// Transition this event to the 'non set' state if it was in the set state. + void reset() noexcept + { + state oldState = state::set; + m_state.compare_exchange_strong(oldState, state::not_set, std::memory_order_relaxed); + } + + /// \brief + /// Wait until the event becomes set. + /// + /// If the event is already set then the awaiting coroutine will not be suspended + /// and will continue execution. If the event was not yet set then the coroutine + /// will be suspended and will be later resumed inside a subsequent call to set() + /// on the thread that calls set(). + auto operator co_await() noexcept + { + class awaiter + { + public: + + awaiter(single_consumer_event& event) : m_event(event) {} + + bool await_ready() const noexcept + { + return m_event.is_set(); + } + + bool await_suspend(std::experimental::coroutine_handle<> awaiter) + { + m_event.m_awaiter = awaiter; + + state oldState = state::not_set; + return m_event.m_state.compare_exchange_strong( + oldState, + state::not_set_consumer_waiting, + std::memory_order_release, + std::memory_order_acquire); + } + + void await_resume() noexcept {} + + private: + + single_consumer_event& m_event; + + }; + + return awaiter{ *this }; + } + + private: + + enum class state + { + not_set, + not_set_consumer_waiting, + set + }; + + // TODO: Merge these two fields into a single std::atomic + // by encoding 'not_set' as 0 (nullptr), 'set' as 1 and + // 'not_set_consumer_waiting' as a coroutine handle pointer. + std::atomic m_state; + std::experimental::coroutine_handle<> m_awaiter; + + }; +} + +#endif diff --git a/include/cppcoro/task.hpp b/include/cppcoro/task.hpp new file mode 100644 index 0000000..675571f --- /dev/null +++ b/include/cppcoro/task.hpp @@ -0,0 +1,393 @@ +#ifndef CPPCORO_TASK_HPP_INCLUDED +#define CPPCORO_TASK_HPP_INCLUDED + +#include + +#include +#include +#include +#include + +#include + +namespace cppcoro +{ + template + class task; + + namespace detail + { + class task_promise_base + { + public: + + task_promise_base() noexcept + : m_state(state::running) + {} + + auto initial_suspend() noexcept + { + return std::experimental::suspend_never{}; + } + + auto final_suspend() noexcept + { + struct awaitable + { + task_promise_base& m_promise; + + awaitable(task_promise_base& promise) noexcept + : m_promise(promise) + {} + + bool await_ready() const noexcept + { + return m_promise.m_state.load(std::memory_order_acquire) == state::consumer_detached; + } + + // If resuming awaiter can potentially throw what state would that leave this coroutine in? + bool await_suspend(std::experimental::coroutine_handle<>) noexcept + { + state oldState = m_promise.m_state.exchange(state::finished, std::memory_order_acq_rel); + if (oldState == state::consumer_suspended) + { + m_promise.m_awaiter.resume(); + } + + return oldState != state::consumer_detached; + } + + void await_resume() noexcept + {} + }; + + return awaitable{ *this }; + } + + void unhandled_exception() noexcept + { + // No point capturing exception if consumer already detached. + if (m_state.load(std::memory_order_relaxed) != state::consumer_detached) + { + m_exception = std::current_exception(); + } + } + + void set_exception(std::exception_ptr exception) + { + m_exception = std::move(exception); + } + + bool is_ready() const noexcept + { + return m_state.load(std::memory_order_acquire) == state::finished; + } + + bool try_detach() noexcept + { + return m_state.exchange( + state::consumer_detached, + std::memory_order_acq_rel) == state::running; + } + + bool try_await(std::experimental::coroutine_handle<> awaiter) + { + m_awaiter = awaiter; + + state oldState = state::running; + return m_state.compare_exchange_strong( + oldState, + state::consumer_suspended, + std::memory_order_release, + std::memory_order_acquire); + } + + protected: + + bool completed_with_unhandled_exception() + { + return m_exception != nullptr; + } + + void rethrow_if_unhandled_exception() + { + if (m_exception != nullptr) + { + std::rethrow_exception(m_exception); + } + } + + private: + + enum class state + { + running, + consumer_suspended, + consumer_detached, + finished + }; + + std::atomic m_state; + std::experimental::coroutine_handle<> m_awaiter; + std::exception_ptr m_exception; + + }; + + template + class task_promise : public task_promise_base + { + public: + + task_promise() noexcept = default; + + ~task_promise() + { + if (!completed_with_unhandled_exception()) + { + reinterpret_cast(&m_valueStorage)->~T(); + } + } + + auto get_return_object() noexcept + { + return std::experimental::coroutine_handle::from_promise(*this); + } + + template< + typename VALUE, + typename = std::enable_if_t>> + void return_value(VALUE&& value) + noexcept(std::is_nothrow_constructible_v) + { + new (&m_valueStorage) T(std::forward(value)); + } + + T& result() & + { + rethrow_if_unhandled_exception(); + return *reinterpret_cast(&m_valueStorage); + } + + T&& result() && + { + rethrow_if_unhandled_exception(); + return std::move(*reinterpret_cast(&m_valueStorage)); + } + + private: + + // Not using std::aligned_storage here due to bug in MSVC 2015 Update 2 + // that means it doesn't work for types with alignof(T) > 8. + // See MS-Connect bug #2658635. + alignas(T) char m_valueStorage[sizeof(T)]; + + }; + + template<> + class task_promise : public task_promise_base + { + public: + + task_promise() noexcept = default; + + auto get_return_object() noexcept + { + return std::experimental::coroutine_handle::from_promise(*this); + } + + void return_void() noexcept + {} + + void result() + { + rethrow_if_unhandled_exception(); + } + + }; + + template + class task_promise : public task_promise_base + { + public: + + task_promise() noexcept = default; + + auto get_return_object() noexcept + { + return std::experimental::coroutine_handle::from_promise(*this); + } + + void return_value(T& value) noexcept + { + m_value = std::addressof(value); + } + + T& result() + { + rethrow_if_unhandled_exception(); + return *m_value; + } + + private: + + T* m_value; + + }; + } + + template + class task + { + public: + + using promise_type = detail::task_promise; + + private: + + struct awaitable_base + { + std::experimental::coroutine_handle m_coroutine; + + awaitable_base(std::experimental::coroutine_handle coroutine) noexcept + : m_coroutine(coroutine) + {} + + bool await_ready() const noexcept + { + return !m_coroutine || m_coroutine.promise().is_ready(); + } + + bool await_suspend(std::experimental::coroutine_handle<> awaiter) noexcept + { + return m_coroutine.promise().try_await(awaiter); + } + }; + + public: + + task() noexcept + : m_coroutine(nullptr) + {} + + explicit task(std::experimental::coroutine_handle coroutine) + : m_coroutine(coroutine) + {} + + task(task&& t) noexcept + : m_coroutine(t.m_coroutine) + { + t.m_coroutine = nullptr; + } + + /// Disable copy construction/assignment. + task(const task&) = delete; + task& operator=(const task&) = delete; + + /// Frees resources used by this task. + /// + /// Calls std::terminate() if the task is not complete and + /// has not been detached (by calling detach() or moving into + /// another task). + ~task() + { + if (m_coroutine) + { + if (!m_coroutine.promise().is_ready()) + { + std::terminate(); + } + + m_coroutine.destroy(); + } + } + + /// \brief + /// Query if the task result is complete. + /// + /// Awaiting a task that is ready will not block. + bool is_ready() const noexcept + { + return !m_coroutine || m_coroutine.promise().is_ready(); + } + + /// \brief + /// Detach this task value from the coroutine. + /// + /// You will not be able to await the result of the task after this. + void detach() + { + if (m_coroutine) + { + auto coro = m_coroutine; + m_coroutine = nullptr; + + if (!coro.promise().try_detach()) + { + coro.destroy(); + } + } + } + + auto operator co_await() const & noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + decltype(auto) await_resume() + { + if (!m_coroutine) + { + throw broken_promise{}; + } + + return m_coroutine.promise().result(); + } + }; + + return awaitable{ m_coroutine }; + } + + auto operator co_await() const && noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + decltype(auto) await_resume() + { + if (!m_coroutine) + { + throw broken_promise{}; + } + + return std::move(m_coroutine.promise()).result(); + } + }; + + return awaitable{ m_coroutine }; + } + + /// \brief + /// Returns an awaitable that will await completion of the task without + /// attempting to retrieve the result. + auto when_ready() const noexcept + { + struct awaitable : awaitable_base + { + using awaitable_base::awaitable_base; + + void await_resume() const noexcept {} + }; + + return awaitable{ m_coroutine }; + } + + private: + + std::experimental::coroutine_handle m_coroutine; + + }; +} + +#endif diff --git a/test/main.cpp b/test/main.cpp new file mode 100644 index 0000000..d3139b8 --- /dev/null +++ b/test/main.cpp @@ -0,0 +1,332 @@ +#include +#include + +#include + +#include + +struct counter +{ + static int default_construction_count; + static int copy_construction_count; + static int move_construction_count; + static int destruction_count; + + int id; + + static void reset_counts() + { + default_construction_count = 0; + copy_construction_count = 0; + move_construction_count = 0; + destruction_count = 0; + } + + static int construction_count() + { + return default_construction_count + copy_construction_count + move_construction_count; + } + + static int active_count() + { + return construction_count() - destruction_count; + } + + counter() : id(default_construction_count++) {} + counter(const counter& other) : id(other.id) { ++copy_construction_count; } + counter(counter&& other) : id(other.id) { ++move_construction_count; other.id = -1; } + ~counter() { ++destruction_count; } + +}; + +int counter::default_construction_count; +int counter::copy_construction_count; +int counter::move_construction_count; +int counter::destruction_count; + +void testAwaitSynchronouslyCompletingVoidFunction() +{ + auto doNothingAsync = []() -> cppcoro::task<> + { + co_return; + }; + + auto task = doNothingAsync(); + + assert(task.is_ready()); + + bool ok = false; + auto test = [&]() -> cppcoro::task<> + { + co_await task; + ok = true; + }; + + test(); + + assert(ok); +} + +void testAwaitTaskReturningMoveOnlyType() +{ + auto getIntPtrAsync = []() -> cppcoro::task> + { + co_return std::make_unique(123); + }; + + auto test = [&]() -> cppcoro::task<> + { + auto intPtr = co_await getIntPtrAsync(); + assert(*intPtr == 123); + + auto intPtrTask = getIntPtrAsync(); + { + // co_await yields l-value reference if task is l-value + auto& intPtr2 = co_await intPtrTask; + assert(*intPtr2 == 123); + } + + { + // Returns r-value reference if task is r-value + auto intPtr3 = co_await std::move(intPtrTask); + assert(*intPtr3 == 123); + } + }; + + auto task = test(); + + assert(task.is_ready()); +} + +void testAwaitTaskReturningReference() +{ + int value = 0; + auto getRefAsync = [&]() -> cppcoro::task + { + co_return value; + }; + + auto test = [&]() -> cppcoro::task<> + { + // Await r-value task results in l-value reference + decltype(auto) result = co_await getRefAsync(); + assert(&result == &value); + + // Await l-value task results in l-value reference + auto getRefTask = getRefAsync(); + decltype(auto) result2 = co_await getRefTask; + assert(&result2 == &value); + }; + + auto task = test(); + assert(task.is_ready()); +} + +void testAwaitTaskReturningValueMovesIntoPromiseIfPassedRValue() +{ + counter::reset_counts(); + + auto f = []() -> cppcoro::task + { + co_return counter{}; + }; + + assert(counter::active_count() == 0); + + { + auto t = f(); + assert(counter::default_construction_count == 1); + assert(counter::copy_construction_count == 0); + assert(counter::move_construction_count == 1); + assert(counter::destruction_count == 1); + assert(counter::active_count() == 1); + + // Moving task doesn't move/copy result. + auto t2 = std::move(t); + assert(counter::default_construction_count == 1); + assert(counter::copy_construction_count == 0); + assert(counter::move_construction_count == 1); + assert(counter::destruction_count == 1); + assert(counter::active_count() == 1); + } + + assert(counter::active_count() == 0); +} + +void testAwaitTaskReturningValueCopiesIntoPromiseIfPassedLValue() +{ + counter::reset_counts(); + + auto f = []() -> cppcoro::task + { + counter temp; + + // Should be calling copy-constructor here since .return_value() + // is being passed an l-value reference. + co_return temp; + }; + + assert(counter::active_count() == 0); + + { + auto t = f(); + assert(counter::default_construction_count == 1); + assert(counter::copy_construction_count == 1); + assert(counter::move_construction_count == 0); + assert(counter::destruction_count == 1); + assert(counter::active_count() == 1); + + // Moving the task doesn't move/copy the result + auto t2 = std::move(t); + assert(counter::default_construction_count == 1); + assert(counter::copy_construction_count == 1); + assert(counter::move_construction_count == 0); + assert(counter::destruction_count == 1); + assert(counter::active_count() == 1); + } + + assert(counter::active_count() == 0); +} + +void testAwaitDelayedCompletionChain() +{ + cppcoro::single_consumer_event event; + bool reachedPointA = false; + bool reachedPointB = false; + auto async1 = [&]() -> cppcoro::task + { + reachedPointA = true; + co_await event; + reachedPointB = true; + co_return 1; + }; + + bool reachedPointC = false; + bool reachedPointD = false; + auto async2 = [&]() -> cppcoro::task + { + reachedPointC = true; + int result = co_await async1(); + reachedPointD = true; + co_return result; + }; + + auto task = async2(); + + assert(!task.is_ready()); + assert(reachedPointA); + assert(!reachedPointB); + assert(reachedPointC); + assert(!reachedPointD); + + event.set(); + + assert(task.is_ready()); + assert(reachedPointB); + assert(reachedPointD); + + [](cppcoro::task t) -> cppcoro::task<> + { + int value = co_await t; + assert(value == 1); + }(std::move(task)); +} + +void testAwaitingBrokenPromiseThrows() +{ + bool ok = false; + auto test = [&]() -> cppcoro::task<> + { + cppcoro::task<> broken; + try + { + co_await broken; + } + catch (cppcoro::broken_promise) + { + ok = true; + } + }; + + auto t = test(); + assert(t.is_ready()); + assert(ok); +} + +void testAwaitRethrowsException() +{ + class X {}; + + auto run = [](bool doThrow) -> cppcoro::task<> + { + if (doThrow) throw X{}; + co_return; + }; + + auto t = run(true); + + bool ok = false; + auto consumeT = [&]() -> cppcoro::task<> + { + try + { + co_await t; + } + catch (X) + { + ok = true; + } + }; + + auto consumer = consumeT(); + + assert(t.is_ready()); + assert(consumer.is_ready()); + assert(ok); +} + +void testAwaitWhenReadyDoesntThrowException() +{ + class X {}; + + auto run = [](bool doThrow) -> cppcoro::task<> + { + if (doThrow) throw X{}; + co_return; + }; + + auto t = run(true); + + bool ok = false; + auto consumeT = [&]() -> cppcoro::task<> + { + try + { + co_await t.when_ready(); + ok = true; + } + catch (...) + { + } + }; + + auto consumer = consumeT(); + + assert(t.is_ready()); + assert(consumer.is_ready()); + assert(ok); +} + +int main(int argc, char** argv) +{ + testAwaitSynchronouslyCompletingVoidFunction(); + testAwaitTaskReturningMoveOnlyType(); + testAwaitTaskReturningReference(); + testAwaitDelayedCompletionChain(); + testAwaitTaskReturningValueMovesIntoPromiseIfPassedRValue(); + testAwaitTaskReturningValueCopiesIntoPromiseIfPassedLValue(); + testAwaitingBrokenPromiseThrows(); + testAwaitRethrowsException(); + testAwaitWhenReadyDoesntThrowException(); + return 0; +}