Add cppcoro::task<T> and cppcoro::single_consumer_event.

Includes a few tests for basic functionality of task<T>.
This commit is contained in:
Lewis Baker 2017-04-01 22:02:33 +10:30
parent e6560707c8
commit bbd41b0702
4 changed files with 863 additions and 0 deletions

View file

@ -0,0 +1,20 @@
#ifndef CPPCORO_BROKEN_PROMISE_HPP_INCLUDED
#define CPPCORO_BROKEN_PROMISE_HPP_INCLUDED
#include <stdexcept>
namespace cppcoro
{
/// \brief
/// Exception thrown when you attempt to retrieve the result of
/// a task that has been detached from its promise/coroutine.
class broken_promise : public std::logic_error
{
public:
broken_promise()
: std::logic_error("broken promise")
{}
};
}
#endif

View file

@ -0,0 +1,118 @@
#ifndef CPPCORO_SINGLE_CONSUMER_EVENT_HPP_INCLUDED
#define CPPCORO_SINGLE_CONSUMER_EVENT_HPP_INCLUDED
#include <atomic>
#include <experimental/coroutine>
namespace cppcoro
{
/// \brief
/// A manual-reset event that supports only a single awaiting
/// coroutine at a time.
///
/// You can co_await the event to suspend the current coroutine until
/// some thread calls set(). If the event is already set then the
/// coroutine will not be suspended and will continue execution.
/// If the event was not yet set then the coroutine will be resumed
/// on the thread that calls set() within the call to set().
///
/// Callers must ensure that only one coroutine is executing a
/// co_await statement at any point in time.
class single_consumer_event
{
public:
single_consumer_event(bool initiallySet = false) noexcept
: m_state(initiallySet ? state::set : state::not_set)
{}
/// Query if this event has been set.
bool is_set() const noexcept
{
return m_state.load(std::memory_order_acquire) == state::set;
}
/// \brief
/// Transition this event to the 'set' state if it is not already set.
///
/// If there was a coroutine awaiting the event then it will be resumed
/// inside this call.
void set()
{
const state oldState = m_state.exchange(state::set, std::memory_order_acq_rel);
if (oldState == state::not_set_consumer_waiting)
{
m_awaiter.resume();
}
}
/// \brief
/// Transition this event to the 'non set' state if it was in the set state.
void reset() noexcept
{
state oldState = state::set;
m_state.compare_exchange_strong(oldState, state::not_set, std::memory_order_relaxed);
}
/// \brief
/// Wait until the event becomes set.
///
/// If the event is already set then the awaiting coroutine will not be suspended
/// and will continue execution. If the event was not yet set then the coroutine
/// will be suspended and will be later resumed inside a subsequent call to set()
/// on the thread that calls set().
auto operator co_await() noexcept
{
class awaiter
{
public:
awaiter(single_consumer_event& event) : m_event(event) {}
bool await_ready() const noexcept
{
return m_event.is_set();
}
bool await_suspend(std::experimental::coroutine_handle<> awaiter)
{
m_event.m_awaiter = awaiter;
state oldState = state::not_set;
return m_event.m_state.compare_exchange_strong(
oldState,
state::not_set_consumer_waiting,
std::memory_order_release,
std::memory_order_acquire);
}
void await_resume() noexcept {}
private:
single_consumer_event& m_event;
};
return awaiter{ *this };
}
private:
enum class state
{
not_set,
not_set_consumer_waiting,
set
};
// TODO: Merge these two fields into a single std::atomic<std::uintptr_t>
// by encoding 'not_set' as 0 (nullptr), 'set' as 1 and
// 'not_set_consumer_waiting' as a coroutine handle pointer.
std::atomic<state> m_state;
std::experimental::coroutine_handle<> m_awaiter;
};
}
#endif

393
include/cppcoro/task.hpp Normal file
View file

@ -0,0 +1,393 @@
#ifndef CPPCORO_TASK_HPP_INCLUDED
#define CPPCORO_TASK_HPP_INCLUDED
#include <cppcoro/broken_promise.hpp>
#include <atomic>
#include <exception>
#include <utility>
#include <type_traits>
#include <experimental/coroutine>
namespace cppcoro
{
template<typename T>
class task;
namespace detail
{
class task_promise_base
{
public:
task_promise_base() noexcept
: m_state(state::running)
{}
auto initial_suspend() noexcept
{
return std::experimental::suspend_never{};
}
auto final_suspend() noexcept
{
struct awaitable
{
task_promise_base& m_promise;
awaitable(task_promise_base& promise) noexcept
: m_promise(promise)
{}
bool await_ready() const noexcept
{
return m_promise.m_state.load(std::memory_order_acquire) == state::consumer_detached;
}
// If resuming awaiter can potentially throw what state would that leave this coroutine in?
bool await_suspend(std::experimental::coroutine_handle<>) noexcept
{
state oldState = m_promise.m_state.exchange(state::finished, std::memory_order_acq_rel);
if (oldState == state::consumer_suspended)
{
m_promise.m_awaiter.resume();
}
return oldState != state::consumer_detached;
}
void await_resume() noexcept
{}
};
return awaitable{ *this };
}
void unhandled_exception() noexcept
{
// No point capturing exception if consumer already detached.
if (m_state.load(std::memory_order_relaxed) != state::consumer_detached)
{
m_exception = std::current_exception();
}
}
void set_exception(std::exception_ptr exception)
{
m_exception = std::move(exception);
}
bool is_ready() const noexcept
{
return m_state.load(std::memory_order_acquire) == state::finished;
}
bool try_detach() noexcept
{
return m_state.exchange(
state::consumer_detached,
std::memory_order_acq_rel) == state::running;
}
bool try_await(std::experimental::coroutine_handle<> awaiter)
{
m_awaiter = awaiter;
state oldState = state::running;
return m_state.compare_exchange_strong(
oldState,
state::consumer_suspended,
std::memory_order_release,
std::memory_order_acquire);
}
protected:
bool completed_with_unhandled_exception()
{
return m_exception != nullptr;
}
void rethrow_if_unhandled_exception()
{
if (m_exception != nullptr)
{
std::rethrow_exception(m_exception);
}
}
private:
enum class state
{
running,
consumer_suspended,
consumer_detached,
finished
};
std::atomic<state> m_state;
std::experimental::coroutine_handle<> m_awaiter;
std::exception_ptr m_exception;
};
template<typename T>
class task_promise : public task_promise_base
{
public:
task_promise() noexcept = default;
~task_promise()
{
if (!completed_with_unhandled_exception())
{
reinterpret_cast<T*>(&m_valueStorage)->~T();
}
}
auto get_return_object() noexcept
{
return std::experimental::coroutine_handle<task_promise>::from_promise(*this);
}
template<
typename VALUE,
typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>>
void return_value(VALUE&& value)
noexcept(std::is_nothrow_constructible_v<T, VALUE&&>)
{
new (&m_valueStorage) T(std::forward<VALUE>(value));
}
T& result() &
{
rethrow_if_unhandled_exception();
return *reinterpret_cast<T*>(&m_valueStorage);
}
T&& result() &&
{
rethrow_if_unhandled_exception();
return std::move(*reinterpret_cast<T*>(&m_valueStorage));
}
private:
// Not using std::aligned_storage here due to bug in MSVC 2015 Update 2
// that means it doesn't work for types with alignof(T) > 8.
// See MS-Connect bug #2658635.
alignas(T) char m_valueStorage[sizeof(T)];
};
template<>
class task_promise<void> : public task_promise_base
{
public:
task_promise() noexcept = default;
auto get_return_object() noexcept
{
return std::experimental::coroutine_handle<task_promise>::from_promise(*this);
}
void return_void() noexcept
{}
void result()
{
rethrow_if_unhandled_exception();
}
};
template<typename T>
class task_promise<T&> : public task_promise_base
{
public:
task_promise() noexcept = default;
auto get_return_object() noexcept
{
return std::experimental::coroutine_handle<task_promise>::from_promise(*this);
}
void return_value(T& value) noexcept
{
m_value = std::addressof(value);
}
T& result()
{
rethrow_if_unhandled_exception();
return *m_value;
}
private:
T* m_value;
};
}
template<typename T = void>
class task
{
public:
using promise_type = detail::task_promise<T>;
private:
struct awaitable_base
{
std::experimental::coroutine_handle<promise_type> m_coroutine;
awaitable_base(std::experimental::coroutine_handle<promise_type> coroutine) noexcept
: m_coroutine(coroutine)
{}
bool await_ready() const noexcept
{
return !m_coroutine || m_coroutine.promise().is_ready();
}
bool await_suspend(std::experimental::coroutine_handle<> awaiter) noexcept
{
return m_coroutine.promise().try_await(awaiter);
}
};
public:
task() noexcept
: m_coroutine(nullptr)
{}
explicit task(std::experimental::coroutine_handle<promise_type> coroutine)
: m_coroutine(coroutine)
{}
task(task&& t) noexcept
: m_coroutine(t.m_coroutine)
{
t.m_coroutine = nullptr;
}
/// Disable copy construction/assignment.
task(const task&) = delete;
task& operator=(const task&) = delete;
/// Frees resources used by this task.
///
/// Calls std::terminate() if the task is not complete and
/// has not been detached (by calling detach() or moving into
/// another task).
~task()
{
if (m_coroutine)
{
if (!m_coroutine.promise().is_ready())
{
std::terminate();
}
m_coroutine.destroy();
}
}
/// \brief
/// Query if the task result is complete.
///
/// Awaiting a task that is ready will not block.
bool is_ready() const noexcept
{
return !m_coroutine || m_coroutine.promise().is_ready();
}
/// \brief
/// Detach this task value from the coroutine.
///
/// You will not be able to await the result of the task after this.
void detach()
{
if (m_coroutine)
{
auto coro = m_coroutine;
m_coroutine = nullptr;
if (!coro.promise().try_detach())
{
coro.destroy();
}
}
}
auto operator co_await() const & noexcept
{
struct awaitable : awaitable_base
{
using awaitable_base::awaitable_base;
decltype(auto) await_resume()
{
if (!m_coroutine)
{
throw broken_promise{};
}
return m_coroutine.promise().result();
}
};
return awaitable{ m_coroutine };
}
auto operator co_await() const && noexcept
{
struct awaitable : awaitable_base
{
using awaitable_base::awaitable_base;
decltype(auto) await_resume()
{
if (!m_coroutine)
{
throw broken_promise{};
}
return std::move(m_coroutine.promise()).result();
}
};
return awaitable{ m_coroutine };
}
/// \brief
/// Returns an awaitable that will await completion of the task without
/// attempting to retrieve the result.
auto when_ready() const noexcept
{
struct awaitable : awaitable_base
{
using awaitable_base::awaitable_base;
void await_resume() const noexcept {}
};
return awaitable{ m_coroutine };
}
private:
std::experimental::coroutine_handle<promise_type> m_coroutine;
};
}
#endif

332
test/main.cpp Normal file
View file

@ -0,0 +1,332 @@
#include <cppcoro/task.hpp>
#include <cppcoro/single_consumer_event.hpp>
#include <memory>
#include <cassert>
struct counter
{
static int default_construction_count;
static int copy_construction_count;
static int move_construction_count;
static int destruction_count;
int id;
static void reset_counts()
{
default_construction_count = 0;
copy_construction_count = 0;
move_construction_count = 0;
destruction_count = 0;
}
static int construction_count()
{
return default_construction_count + copy_construction_count + move_construction_count;
}
static int active_count()
{
return construction_count() - destruction_count;
}
counter() : id(default_construction_count++) {}
counter(const counter& other) : id(other.id) { ++copy_construction_count; }
counter(counter&& other) : id(other.id) { ++move_construction_count; other.id = -1; }
~counter() { ++destruction_count; }
};
int counter::default_construction_count;
int counter::copy_construction_count;
int counter::move_construction_count;
int counter::destruction_count;
void testAwaitSynchronouslyCompletingVoidFunction()
{
auto doNothingAsync = []() -> cppcoro::task<>
{
co_return;
};
auto task = doNothingAsync();
assert(task.is_ready());
bool ok = false;
auto test = [&]() -> cppcoro::task<>
{
co_await task;
ok = true;
};
test();
assert(ok);
}
void testAwaitTaskReturningMoveOnlyType()
{
auto getIntPtrAsync = []() -> cppcoro::task<std::unique_ptr<int>>
{
co_return std::make_unique<int>(123);
};
auto test = [&]() -> cppcoro::task<>
{
auto intPtr = co_await getIntPtrAsync();
assert(*intPtr == 123);
auto intPtrTask = getIntPtrAsync();
{
// co_await yields l-value reference if task is l-value
auto& intPtr2 = co_await intPtrTask;
assert(*intPtr2 == 123);
}
{
// Returns r-value reference if task is r-value
auto intPtr3 = co_await std::move(intPtrTask);
assert(*intPtr3 == 123);
}
};
auto task = test();
assert(task.is_ready());
}
void testAwaitTaskReturningReference()
{
int value = 0;
auto getRefAsync = [&]() -> cppcoro::task<int&>
{
co_return value;
};
auto test = [&]() -> cppcoro::task<>
{
// Await r-value task results in l-value reference
decltype(auto) result = co_await getRefAsync();
assert(&result == &value);
// Await l-value task results in l-value reference
auto getRefTask = getRefAsync();
decltype(auto) result2 = co_await getRefTask;
assert(&result2 == &value);
};
auto task = test();
assert(task.is_ready());
}
void testAwaitTaskReturningValueMovesIntoPromiseIfPassedRValue()
{
counter::reset_counts();
auto f = []() -> cppcoro::task<counter>
{
co_return counter{};
};
assert(counter::active_count() == 0);
{
auto t = f();
assert(counter::default_construction_count == 1);
assert(counter::copy_construction_count == 0);
assert(counter::move_construction_count == 1);
assert(counter::destruction_count == 1);
assert(counter::active_count() == 1);
// Moving task doesn't move/copy result.
auto t2 = std::move(t);
assert(counter::default_construction_count == 1);
assert(counter::copy_construction_count == 0);
assert(counter::move_construction_count == 1);
assert(counter::destruction_count == 1);
assert(counter::active_count() == 1);
}
assert(counter::active_count() == 0);
}
void testAwaitTaskReturningValueCopiesIntoPromiseIfPassedLValue()
{
counter::reset_counts();
auto f = []() -> cppcoro::task<counter>
{
counter temp;
// Should be calling copy-constructor here since <promise>.return_value()
// is being passed an l-value reference.
co_return temp;
};
assert(counter::active_count() == 0);
{
auto t = f();
assert(counter::default_construction_count == 1);
assert(counter::copy_construction_count == 1);
assert(counter::move_construction_count == 0);
assert(counter::destruction_count == 1);
assert(counter::active_count() == 1);
// Moving the task doesn't move/copy the result
auto t2 = std::move(t);
assert(counter::default_construction_count == 1);
assert(counter::copy_construction_count == 1);
assert(counter::move_construction_count == 0);
assert(counter::destruction_count == 1);
assert(counter::active_count() == 1);
}
assert(counter::active_count() == 0);
}
void testAwaitDelayedCompletionChain()
{
cppcoro::single_consumer_event event;
bool reachedPointA = false;
bool reachedPointB = false;
auto async1 = [&]() -> cppcoro::task<int>
{
reachedPointA = true;
co_await event;
reachedPointB = true;
co_return 1;
};
bool reachedPointC = false;
bool reachedPointD = false;
auto async2 = [&]() -> cppcoro::task<int>
{
reachedPointC = true;
int result = co_await async1();
reachedPointD = true;
co_return result;
};
auto task = async2();
assert(!task.is_ready());
assert(reachedPointA);
assert(!reachedPointB);
assert(reachedPointC);
assert(!reachedPointD);
event.set();
assert(task.is_ready());
assert(reachedPointB);
assert(reachedPointD);
[](cppcoro::task<int> t) -> cppcoro::task<>
{
int value = co_await t;
assert(value == 1);
}(std::move(task));
}
void testAwaitingBrokenPromiseThrows()
{
bool ok = false;
auto test = [&]() -> cppcoro::task<>
{
cppcoro::task<> broken;
try
{
co_await broken;
}
catch (cppcoro::broken_promise)
{
ok = true;
}
};
auto t = test();
assert(t.is_ready());
assert(ok);
}
void testAwaitRethrowsException()
{
class X {};
auto run = [](bool doThrow) -> cppcoro::task<>
{
if (doThrow) throw X{};
co_return;
};
auto t = run(true);
bool ok = false;
auto consumeT = [&]() -> cppcoro::task<>
{
try
{
co_await t;
}
catch (X)
{
ok = true;
}
};
auto consumer = consumeT();
assert(t.is_ready());
assert(consumer.is_ready());
assert(ok);
}
void testAwaitWhenReadyDoesntThrowException()
{
class X {};
auto run = [](bool doThrow) -> cppcoro::task<>
{
if (doThrow) throw X{};
co_return;
};
auto t = run(true);
bool ok = false;
auto consumeT = [&]() -> cppcoro::task<>
{
try
{
co_await t.when_ready();
ok = true;
}
catch (...)
{
}
};
auto consumer = consumeT();
assert(t.is_ready());
assert(consumer.is_ready());
assert(ok);
}
int main(int argc, char** argv)
{
testAwaitSynchronouslyCompletingVoidFunction();
testAwaitTaskReturningMoveOnlyType();
testAwaitTaskReturningReference();
testAwaitDelayedCompletionChain();
testAwaitTaskReturningValueMovesIntoPromiseIfPassedRValue();
testAwaitTaskReturningValueCopiesIntoPromiseIfPassedLValue();
testAwaitingBrokenPromiseThrows();
testAwaitRethrowsException();
testAwaitWhenReadyDoesntThrowException();
return 0;
}