426 lines
9.9 KiB
C++
426 lines
9.9 KiB
C++
///////////////////////////////////////////////////////////////////////////////
|
|
// Copyright (c) Lewis Baker
|
|
// Licenced under MIT license. See LICENSE.txt for details.
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
#ifndef CPPCORO_TASK_HPP_INCLUDED
|
|
#define CPPCORO_TASK_HPP_INCLUDED
|
|
|
|
#include <cppcoro/config.hpp>
|
|
#include <cppcoro/awaitable_traits.hpp>
|
|
#include <cppcoro/broken_promise.hpp>
|
|
|
|
#include <cppcoro/detail/remove_rvalue_reference.hpp>
|
|
|
|
#include <atomic>
|
|
#include <exception>
|
|
#include <utility>
|
|
#include <type_traits>
|
|
#include <cstdint>
|
|
|
|
#include <experimental/coroutine>
|
|
|
|
namespace cppcoro
|
|
{
|
|
template<typename T> class task;
|
|
|
|
namespace detail
|
|
{
|
|
class task_promise_base
|
|
{
|
|
friend struct final_awaitable;
|
|
|
|
struct final_awaitable
|
|
{
|
|
bool await_ready() const noexcept { return false; }
|
|
|
|
template<typename PROMISE>
|
|
void await_suspend(std::experimental::coroutine_handle<PROMISE> coroutine)
|
|
{
|
|
task_promise_base& promise = coroutine.promise();
|
|
|
|
// Use 'release' memory semantics in case we finish before the
|
|
// awaiter can suspend so that the awaiting thread sees our
|
|
// writes to the resulting value.
|
|
// Use 'acquire' memory semantics in case the caller registered
|
|
// the continuation before we finished. Ensure we see their write
|
|
// to m_continuation.
|
|
if (promise.m_state.exchange(true, std::memory_order_acq_rel))
|
|
{
|
|
promise.m_continuation.resume();
|
|
}
|
|
}
|
|
|
|
void await_resume() noexcept {}
|
|
};
|
|
|
|
public:
|
|
|
|
task_promise_base() noexcept
|
|
: m_state(false)
|
|
{}
|
|
|
|
auto initial_suspend() noexcept
|
|
{
|
|
return std::experimental::suspend_always{};
|
|
}
|
|
|
|
auto final_suspend() noexcept
|
|
{
|
|
return final_awaitable{};
|
|
}
|
|
|
|
void unhandled_exception() noexcept
|
|
{
|
|
m_exception = std::current_exception();
|
|
}
|
|
|
|
bool try_set_continuation(std::experimental::coroutine_handle<> continuation)
|
|
{
|
|
m_continuation = continuation;
|
|
return !m_state.exchange(true, std::memory_order_acq_rel);
|
|
}
|
|
|
|
protected:
|
|
|
|
bool completed() const noexcept
|
|
{
|
|
return m_state.load(std::memory_order_relaxed);
|
|
}
|
|
|
|
bool completed_with_unhandled_exception()
|
|
{
|
|
return m_exception != nullptr;
|
|
}
|
|
|
|
void rethrow_if_unhandled_exception()
|
|
{
|
|
if (m_exception != nullptr)
|
|
{
|
|
std::rethrow_exception(m_exception);
|
|
}
|
|
}
|
|
|
|
private:
|
|
|
|
std::experimental::coroutine_handle<> m_continuation;
|
|
std::exception_ptr m_exception;
|
|
|
|
// Initially false. Set to true when either a continuation is registered
|
|
// or when the coroutine has run to completion. Whichever operation
|
|
// successfully transitions from false->true got there first.
|
|
std::atomic<bool> m_state;
|
|
|
|
};
|
|
|
|
template<typename T>
|
|
class task_promise final : public task_promise_base
|
|
{
|
|
public:
|
|
|
|
task_promise() noexcept = default;
|
|
|
|
~task_promise()
|
|
{
|
|
if (completed() && !completed_with_unhandled_exception())
|
|
{
|
|
reinterpret_cast<T*>(&m_valueStorage)->~T();
|
|
}
|
|
}
|
|
|
|
task<T> get_return_object() noexcept;
|
|
|
|
template<
|
|
typename VALUE,
|
|
typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>>
|
|
void return_value(VALUE&& value)
|
|
noexcept(std::is_nothrow_constructible_v<T, VALUE&&>)
|
|
{
|
|
new (&m_valueStorage) T(std::forward<VALUE>(value));
|
|
}
|
|
|
|
T& result() &
|
|
{
|
|
rethrow_if_unhandled_exception();
|
|
return *reinterpret_cast<T*>(&m_valueStorage);
|
|
}
|
|
|
|
// HACK: Need to have co_await of task<int> return prvalue rather than
|
|
// rvalue-reference to work around an issue with MSVC where returning
|
|
// rvalue reference of a fundamental type from await_resume() will
|
|
// cause the value to be copied to a temporary. This breaks the
|
|
// sync_wait() implementation.
|
|
// See https://github.com/lewissbaker/cppcoro/issues/40#issuecomment-326864107
|
|
using rvalue_type = std::conditional_t<
|
|
std::is_arithmetic_v<T> || std::is_pointer_v<T>,
|
|
T,
|
|
T&&>;
|
|
|
|
rvalue_type result() &&
|
|
{
|
|
rethrow_if_unhandled_exception();
|
|
return std::move(*reinterpret_cast<T*>(&m_valueStorage));
|
|
}
|
|
|
|
private:
|
|
|
|
#if CPPCORO_COMPILER_MSVC
|
|
# pragma warning(push)
|
|
# pragma warning(disable : 4324) // structure was padded due to alignment.
|
|
#endif
|
|
|
|
// Not using std::aligned_storage here due to bug in MSVC 2015 Update 2
|
|
// that means it doesn't work for types with alignof(T) > 8.
|
|
// See MS-Connect bug #2658635.
|
|
alignas(T) char m_valueStorage[sizeof(T)];
|
|
|
|
#if CPPCORO_COMPILER_MSVC
|
|
# pragma warning(pop)
|
|
#endif
|
|
|
|
};
|
|
|
|
template<>
|
|
class task_promise<void> : public task_promise_base
|
|
{
|
|
public:
|
|
|
|
task_promise() noexcept = default;
|
|
|
|
task<void> get_return_object() noexcept;
|
|
|
|
void return_void() noexcept
|
|
{}
|
|
|
|
void result()
|
|
{
|
|
rethrow_if_unhandled_exception();
|
|
}
|
|
|
|
};
|
|
|
|
template<typename T>
|
|
class task_promise<T&> : public task_promise_base
|
|
{
|
|
public:
|
|
|
|
task_promise() noexcept = default;
|
|
|
|
task<T&> get_return_object() noexcept;
|
|
|
|
void return_value(T& value) noexcept
|
|
{
|
|
m_value = std::addressof(value);
|
|
}
|
|
|
|
T& result()
|
|
{
|
|
rethrow_if_unhandled_exception();
|
|
return *m_value;
|
|
}
|
|
|
|
private:
|
|
|
|
T* m_value;
|
|
|
|
};
|
|
}
|
|
|
|
/// \brief
|
|
/// A task represents an operation that produces a result both lazily
|
|
/// and asynchronously.
|
|
///
|
|
/// When you call a coroutine that returns a task, the coroutine
|
|
/// simply captures any passed parameters and returns exeuction to the
|
|
/// caller. Execution of the coroutine body does not start until the
|
|
/// coroutine is first co_await'ed.
|
|
template<typename T = void>
|
|
class task
|
|
{
|
|
public:
|
|
|
|
using promise_type = detail::task_promise<T>;
|
|
|
|
using value_type = T;
|
|
|
|
private:
|
|
|
|
struct awaitable_base
|
|
{
|
|
std::experimental::coroutine_handle<promise_type> m_coroutine;
|
|
|
|
awaitable_base(std::experimental::coroutine_handle<promise_type> coroutine) noexcept
|
|
: m_coroutine(coroutine)
|
|
{}
|
|
|
|
bool await_ready() const noexcept
|
|
{
|
|
return !m_coroutine || m_coroutine.done();
|
|
}
|
|
|
|
bool await_suspend(std::experimental::coroutine_handle<> awaitingCoroutine) noexcept
|
|
{
|
|
// NOTE: We are using the bool-returning version of await_suspend() here
|
|
// to work around a potential stack-overflow issue if a coroutine
|
|
// awaits many synchronously-completing tasks in a loop.
|
|
//
|
|
// We first start the task by calling resume() and then conditionally
|
|
// attach the continuation if it has not already completed. This allows us
|
|
// to immediately resume the awaiting coroutine without increasing
|
|
// the stack depth, avoiding the stack-overflow problem. However, it has
|
|
// the down-side of requiring a std::atomic to arbitrate the race between
|
|
// the coroutine potentially completing on another thread concurrently
|
|
// with registering the continuation on this thread.
|
|
//
|
|
// We can eliminate the use of the std::atomic once we have access to
|
|
// coroutine_handle-returning await_suspend() on both MSVC and Clang
|
|
// as this will provide ability to suspend the awaiting coroutine and
|
|
// resume another coroutine with a guaranteed tail-call to resume().
|
|
m_coroutine.resume();
|
|
return m_coroutine.promise().try_set_continuation(awaitingCoroutine);
|
|
}
|
|
};
|
|
|
|
public:
|
|
|
|
task() noexcept
|
|
: m_coroutine(nullptr)
|
|
{}
|
|
|
|
explicit task(std::experimental::coroutine_handle<promise_type> coroutine)
|
|
: m_coroutine(coroutine)
|
|
{}
|
|
|
|
task(task&& t) noexcept
|
|
: m_coroutine(t.m_coroutine)
|
|
{
|
|
t.m_coroutine = nullptr;
|
|
}
|
|
|
|
/// Disable copy construction/assignment.
|
|
task(const task&) = delete;
|
|
task& operator=(const task&) = delete;
|
|
|
|
/// Frees resources used by this task.
|
|
~task()
|
|
{
|
|
if (m_coroutine)
|
|
{
|
|
m_coroutine.destroy();
|
|
}
|
|
}
|
|
|
|
task& operator=(task&& other) noexcept
|
|
{
|
|
if (std::addressof(other) != this)
|
|
{
|
|
if (m_coroutine)
|
|
{
|
|
m_coroutine.destroy();
|
|
}
|
|
|
|
m_coroutine = other.m_coroutine;
|
|
other.m_coroutine = nullptr;
|
|
}
|
|
|
|
return *this;
|
|
}
|
|
|
|
/// \brief
|
|
/// Query if the task result is complete.
|
|
///
|
|
/// Awaiting a task that is ready is guaranteed not to block/suspend.
|
|
bool is_ready() const noexcept
|
|
{
|
|
return !m_coroutine || m_coroutine.done();
|
|
}
|
|
|
|
auto operator co_await() const & noexcept
|
|
{
|
|
struct awaitable : awaitable_base
|
|
{
|
|
using awaitable_base::awaitable_base;
|
|
|
|
decltype(auto) await_resume()
|
|
{
|
|
if (!this->m_coroutine)
|
|
{
|
|
throw broken_promise{};
|
|
}
|
|
|
|
return this->m_coroutine.promise().result();
|
|
}
|
|
};
|
|
|
|
return awaitable{ m_coroutine };
|
|
}
|
|
|
|
auto operator co_await() const && noexcept
|
|
{
|
|
struct awaitable : awaitable_base
|
|
{
|
|
using awaitable_base::awaitable_base;
|
|
|
|
decltype(auto) await_resume()
|
|
{
|
|
if (!this->m_coroutine)
|
|
{
|
|
throw broken_promise{};
|
|
}
|
|
|
|
return std::move(this->m_coroutine.promise()).result();
|
|
}
|
|
};
|
|
|
|
return awaitable{ m_coroutine };
|
|
}
|
|
|
|
/// \brief
|
|
/// Returns an awaitable that will await completion of the task without
|
|
/// attempting to retrieve the result.
|
|
auto when_ready() const noexcept
|
|
{
|
|
struct awaitable : awaitable_base
|
|
{
|
|
using awaitable_base::awaitable_base;
|
|
|
|
void await_resume() const noexcept {}
|
|
};
|
|
|
|
return awaitable{ m_coroutine };
|
|
}
|
|
|
|
private:
|
|
|
|
std::experimental::coroutine_handle<promise_type> m_coroutine;
|
|
|
|
};
|
|
|
|
namespace detail
|
|
{
|
|
template<typename T>
|
|
task<T> task_promise<T>::get_return_object() noexcept
|
|
{
|
|
return task<T>{ std::experimental::coroutine_handle<task_promise>::from_promise(*this) };
|
|
}
|
|
|
|
inline task<void> task_promise<void>::get_return_object() noexcept
|
|
{
|
|
return task<void>{ std::experimental::coroutine_handle<task_promise>::from_promise(*this) };
|
|
}
|
|
|
|
template<typename T>
|
|
task<T&> task_promise<T&>::get_return_object() noexcept
|
|
{
|
|
return task<T&>{ std::experimental::coroutine_handle<task_promise>::from_promise(*this) };
|
|
}
|
|
}
|
|
|
|
template<typename AWAITABLE>
|
|
auto make_task(AWAITABLE awaitable)
|
|
-> task<detail::remove_rvalue_reference_t<typename awaitable_traits<AWAITABLE>::await_result_t>>
|
|
{
|
|
co_return co_await static_cast<AWAITABLE&&>(awaitable);
|
|
}
|
|
}
|
|
|
|
#endif
|