Initial version of cancellation_token.

This commit is contained in:
Lewis Baker 2017-05-02 07:46:09 +09:30
parent b8ada4773b
commit d04f152bf9
11 changed files with 1418 additions and 1 deletions

View File

@ -0,0 +1,80 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_CANCELLATION_REGISTRATION_HPP_INCLUDED
#define CPPCORO_CANCELLATION_REGISTRATION_HPP_INCLUDED
#include <functional>
#include <utility>
#include <type_traits>
#include <atomic>
namespace cppcoro
{
class cancellation_token;
namespace detail
{
class cancellation_state;
}
class cancellation_registration
{
public:
/// Registers the callback to be executed when cancellation is requested
/// on the cancellation_token.
///
/// The callback will be executed if cancellation is requested for the
/// specified cancellation token. If cancellation has already been requested
/// then the callback will be executed immediately, before the constructor
/// returns. If cancellation has not yet been requested then the callback
/// will be executed on the first thread to request cancellation inside
/// the call to cancellation_source::request_cancellation().
///
/// \param token
/// The cancellation token to register the callback with.
///
/// \param callback
/// The callback to be executed when cancellation is requested on the
/// the cancellation_token.
///
/// \throw std::bad_alloc
/// If registration failed due to insufficient memory available.
template<
typename CALLBACK,
typename = std::enable_if_t<std::is_constructible_v<std::function<void()>, CALLBACK&&>>>
cancellation_registration(cancellation_token token, CALLBACK&& callback)
: m_callback(std::forward<CALLBACK>(callback))
{
register_callback(std::move(token));
}
cancellation_registration(const cancellation_registration& other) = delete;
cancellation_registration& operator=(const cancellation_registration& other) = delete;
/// Deregisters the callback.
///
/// After the destructor returns it is guaranteed that the callback
/// will not be subsequently called during a call to request_cancellation()
/// on the cancellation_source.
///
/// This may block if cancellation has been requested on another thread
/// is it will need to wait until this callback has finished executing
/// before the callback can be destroyed.
~cancellation_registration();
private:
friend class detail::cancellation_state;
void register_callback(cancellation_token&& token);
detail::cancellation_state* m_state;
std::function<void()> m_callback;
std::atomic<cancellation_registration*>* m_registrationSlot;
};
}
#endif

View File

@ -0,0 +1,71 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_CANCELLATION_SOURCE_HPP_INCLUDED
#define CPPCORO_CANCELLATION_SOURCE_HPP_INCLUDED
namespace cppcoro
{
class cancellation_token;
namespace detail
{
class cancellation_state;
}
class cancellation_source
{
public:
/// Construct to a new cancellation source.
cancellation_source();
/// Create a new reference to the same underlying cancellation
/// source as \p other.
cancellation_source(const cancellation_source& other) noexcept;
cancellation_source(cancellation_source&& other) noexcept;
~cancellation_source();
cancellation_source& operator=(const cancellation_source& other) noexcept;
cancellation_source& operator=(cancellation_source&& other) noexcept;
/// Query if this cancellation source can be cancelled.
///
/// A cancellation source object will not be cancellable if it has
/// previously been moved into another cancellation_source instance
/// or was copied from a cancellation_source that was not cancellable.
bool can_be_cancelled() const noexcept;
/// Obtain a cancellation token that can be used to query if
/// cancellation has been requested on this source.
///
/// The cancellation token can be passed into functions that you
/// may want to later be able to request cancellation.
cancellation_token token() const noexcept;
/// Request cancellation of operations that were passed an associated
/// cancellation token.
///
/// Any cancellation callback registered via a cancellation_registration
/// object will be called inside this function by the first thread to
/// call this method.
///
/// This operation is a no-op if can_be_cancelled() returns false.
void request_cancellation();
/// Query if some thread has called 'request_cancellation()' on this
/// cancellation_source.
bool is_cancellation_requested() const noexcept;
private:
detail::cancellation_state* m_state;
};
}
#endif

View File

@ -0,0 +1,72 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_CANCELLATION_TOKEN_HPP_INCLUDED
#define CPPCORO_CANCELLATION_TOKEN_HPP_INCLUDED
namespace cppcoro
{
class cancellation_source;
class cancellation_registration;
namespace detail
{
class cancellation_state;
}
class cancellation_token
{
public:
/// Construct to a cancellation token that can't be cancelled.
cancellation_token() noexcept;
/// Copy another cancellation token.
///
/// New token will refer to the same underlying state.
cancellation_token(const cancellation_token& other) noexcept;
cancellation_token(cancellation_token&& other) noexcept;
~cancellation_token();
cancellation_token& operator=(const cancellation_token& other) noexcept;
cancellation_token& operator=(cancellation_token&& other) noexcept;
void swap(cancellation_token& other) noexcept;
/// Query if it is possible that this operation will be cancelled
/// or not.
///
/// Cancellable operations may be able to take more efficient code-paths
/// if they don't need to handle cancellation requests.
bool can_be_cancelled() const noexcept;
/// Query if some thread has requested cancellation on an associated
/// cancellation_source object.
bool is_cancellation_requested() const noexcept;
/// Throws cppcoro::operation_cancelled exception if cancellation
/// has been requested for the associated operation.
void throw_if_cancellation_requested() const;
private:
friend class cancellation_source;
friend class cancellation_registration;
cancellation_token(detail::cancellation_state* state) noexcept;
detail::cancellation_state* m_state;
};
inline void swap(cancellation_token& a, cancellation_token& b) noexcept
{
a.swap(b);
}
}
#endif

View File

@ -0,0 +1,24 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_OPERATION_CANCELLED_HPP_INCLUDED
#define CPPCORO_OPERATION_CANCELLED_HPP_INCLUDED
#include <exception>
namespace cppcoro
{
class operation_cancelled : std::exception
{
public:
operation_cancelled() noexcept
: std::exception()
{}
const char* what() const override { return "operation cancelled"; }
};
}
#endif

View File

@ -10,14 +10,25 @@ from cake.tools import compiler, script, env, project
includes = cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', [
'async_mutex.hpp',
'broken_promise.hpp',
'cancellation_registration.hpp',
'cancellation_source.hpp',
'cancellation_token.hpp',
'lazy_task.hpp',
'shared_task.hpp',
'single_consumer_event.hpp',
'task.hpp',
])
privateHeaders = script.cwd([
'cancellation_state.hpp',
])
sources = script.cwd([
'async_mutex.cpp',
'cancellation_state.cpp',
'cancellation_token.cpp',
'cancellation_source.cpp',
'cancellation_registration.cpp',
])
extras = script.cwd([
@ -43,7 +54,7 @@ vcproj = project.project(
target=env.expand('${CPPCORO_PROJECT}/cppcoro'),
items={
'Include': includes,
'Source': sources,
'Source': sources + privateHeaders,
'': extras
},
output=lib,

View File

@ -0,0 +1,41 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/cancellation_registration.hpp>
#include "cancellation_state.hpp"
#include <cassert>
cppcoro::cancellation_registration::~cancellation_registration()
{
if (m_state != nullptr)
{
m_state->deregister_callback(this);
m_state->release_token_ref();
}
}
void cppcoro::cancellation_registration::register_callback(cancellation_token&& token)
{
auto* state = token.m_state;
if (state != nullptr && state->can_be_cancelled())
{
m_state = state;
if (state->try_register_callback(this))
{
token.m_state = nullptr;
}
else
{
m_state = nullptr;
m_callback();
}
}
else
{
m_state = nullptr;
}
}

View File

@ -0,0 +1,97 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/cancellation_source.hpp>
#include "cancellation_state.hpp"
#include <cassert>
cppcoro::cancellation_source::cancellation_source()
: m_state(detail::cancellation_state::create())
{
}
cppcoro::cancellation_source::cancellation_source(const cancellation_source& other) noexcept
: m_state(other.m_state)
{
if (m_state != nullptr)
{
m_state->add_source_ref();
}
}
cppcoro::cancellation_source::cancellation_source(cancellation_source&& other) noexcept
: m_state(other.m_state)
{
other.m_state = nullptr;
}
cppcoro::cancellation_source::~cancellation_source()
{
if (m_state != nullptr)
{
m_state->release_source_ref();
}
}
cppcoro::cancellation_source& cppcoro::cancellation_source::operator=(const cancellation_source& other) noexcept
{
if (m_state != other.m_state)
{
if (m_state != nullptr)
{
m_state->release_source_ref();
}
m_state = other.m_state;
if (m_state != nullptr)
{
m_state->add_source_ref();
}
}
return *this;
}
cppcoro::cancellation_source& cppcoro::cancellation_source::operator=(cancellation_source&& other) noexcept
{
if (this != &other)
{
if (m_state != nullptr)
{
m_state->release_source_ref();
}
m_state = other.m_state;
other.m_state = nullptr;
}
return *this;
}
bool cppcoro::cancellation_source::can_be_cancelled() const noexcept
{
return m_state != nullptr;
}
cppcoro::cancellation_token cppcoro::cancellation_source::token() const noexcept
{
return cancellation_token(m_state);
}
void cppcoro::cancellation_source::request_cancellation()
{
if (m_state != nullptr)
{
m_state->request_cancellation();
}
}
bool cppcoro::cancellation_source::is_cancellation_requested() const noexcept
{
return m_state != nullptr && m_state->is_cancellation_requested();
}

535
lib/cancellation_state.cpp Normal file
View File

@ -0,0 +1,535 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include "cancellation_state.hpp"
#include <cppcoro/cancellation_registration.hpp>
#include <cassert>
#include <cstdlib>
struct cppcoro::detail::cancellation_state::registration_list_chunk
{
static registration_list_chunk* allocate(std::uint32_t entryCount);
static void free(registration_list_chunk* chunk) noexcept;
std::atomic<registration_list_chunk*> m_nextChunk;
registration_list_chunk* m_prevChunk;
std::uint32_t m_entryCount;
std::atomic<cancellation_registration*> m_entries[1];
};
struct cppcoro::detail::cancellation_state::registration_list_bucket
{
static registration_list_bucket* allocate();
static void free(registration_list_bucket* bucket) noexcept;
std::atomic<registration_list_chunk*> m_approximateTail;
registration_list_chunk m_headChunk;
};
struct cppcoro::detail::cancellation_state::registration_list
{
static registration_list* allocate();
static void free(registration_list* list) noexcept;
void add_registration(cppcoro::cancellation_registration* registration);
std::thread::id m_notificationThreadId;
std::uint32_t m_bucketCount;
std::atomic<registration_list_bucket*> m_buckets[1];
};
cppcoro::detail::cancellation_state::registration_list_chunk*
cppcoro::detail::cancellation_state::registration_list_chunk::allocate(std::uint32_t entryCount)
{
auto* chunk = static_cast<registration_list_chunk*>(std::malloc(
sizeof(registration_list_chunk) +
(entryCount - 1) * sizeof(registration_list_chunk::m_entries[0])));
if (chunk == nullptr)
{
throw std::bad_alloc{};
}
chunk->m_nextChunk.store(nullptr, std::memory_order_release);
chunk->m_prevChunk = nullptr;
chunk->m_entryCount = entryCount;
for (std::uint32_t i = 0; i < entryCount; ++i)
{
chunk->m_entries[i].store(nullptr, std::memory_order_relaxed);
}
return chunk;
}
void cppcoro::detail::cancellation_state::registration_list_chunk::free(registration_list_chunk* bucket) noexcept
{
std::free(static_cast<void*>(bucket));
}
cppcoro::detail::cancellation_state::registration_list_bucket*
cppcoro::detail::cancellation_state::registration_list_bucket::allocate()
{
constexpr std::uint32_t initialChunkSize = 16;
auto* bucket = static_cast<registration_list_bucket*>(std::malloc(
sizeof(registration_list_bucket) +
(initialChunkSize - 1) * sizeof(registration_list_chunk::m_entries[0])));
if (bucket == nullptr)
{
throw std::bad_alloc{};
}
bucket->m_approximateTail = &bucket->m_headChunk;
bucket->m_headChunk.m_nextChunk.store(nullptr, std::memory_order_relaxed);
bucket->m_headChunk.m_prevChunk = nullptr;
bucket->m_headChunk.m_entryCount = initialChunkSize;
for (std::uint32_t i = 0; i < initialChunkSize; ++i)
{
bucket->m_headChunk.m_entries[i].store(nullptr, std::memory_order_relaxed);
}
return bucket;
}
void cppcoro::detail::cancellation_state::registration_list_bucket::free(registration_list_bucket* bucket) noexcept
{
std::free(static_cast<void*>(bucket));
}
cppcoro::detail::cancellation_state::registration_list*
cppcoro::detail::cancellation_state::registration_list::allocate()
{
constexpr std::uint32_t maxBucketCount = 16;
auto bucketCount = std::thread::hardware_concurrency();
if (bucketCount > maxBucketCount)
{
bucketCount = maxBucketCount;
}
else if (bucketCount == 0)
{
bucketCount = 1;
}
auto* list = static_cast<registration_list*>(std::malloc(
sizeof(registration_list) +
(bucketCount - 1) * sizeof(registration_list::m_buckets[0])));
if (list == nullptr)
{
throw std::bad_alloc{};
}
list->m_bucketCount = bucketCount;
for (std::uint32_t i = 0; i < bucketCount; ++i)
{
list->m_buckets[i].store(nullptr, std::memory_order_relaxed);
}
return list;
}
void cppcoro::detail::cancellation_state::registration_list::free(registration_list* list) noexcept
{
std::free(static_cast<void*>(list));
}
void cppcoro::detail::cancellation_state::registration_list::add_registration(cppcoro::cancellation_registration* registration)
{
// Pick a bucket to add to based on the current thread to reduce the
// chance of contention with multiple threads concurrently registering
// callbacks.
const auto threadIdHashCode = std::hash<std::thread::id>{}(std::this_thread::get_id());
auto& bucketPtr = m_buckets[threadIdHashCode % m_bucketCount];
auto* bucket = bucketPtr.load(std::memory_order_acquire);
if (bucket == nullptr)
{
auto* newBucket = registration_list_bucket::allocate();
// Pre-claim the first slot.
auto& slot = newBucket->m_headChunk.m_entries[0];
slot.store(registration, std::memory_order_relaxed);
registration->m_registrationSlot = &slot;
if (bucketPtr.compare_exchange_strong(
bucket,
newBucket,
std::memory_order_seq_cst,
std::memory_order_acquire))
{
return;
}
else
{
registration_list_bucket::free(newBucket);
}
}
while (true)
{
// Navigate to the end of the chain of chunks and work backwards looking for a free slot.
auto* const originalLastChunk = bucket->m_approximateTail.load(std::memory_order_acquire);
auto* lastChunk = originalLastChunk;
for (auto* next = lastChunk->m_nextChunk.load(std::memory_order_acquire);
next != nullptr;
next = next->m_nextChunk.load(std::memory_order_acquire))
{
lastChunk = next;
}
if (lastChunk != originalLastChunk)
{
// Update the cache of last chunk pointer so that subsequent
// registration requests can start there instead.
// Doesn't matter if these writes race as it will eventually
// converge to the true last chunk.
bucket->m_approximateTail.store(lastChunk, std::memory_order_release);
}
for (auto* chunk = lastChunk;
chunk != nullptr;
chunk = chunk->m_prevChunk)
{
for (std::uint32_t i = 0, entryCount = chunk->m_entryCount; i < entryCount; ++i)
{
auto& slot = chunk->m_entries[i];
// Do a cheap initial read of the slot value to see if the
// slot is likely free. This can potentially read stale values
// and so may lead to falsely thinking it's free or falsely
// thinking it's occupied. But approximate is good enough here.
auto* slotValue = slot.load(std::memory_order_relaxed);
if (slotValue == nullptr)
{
registration->m_registrationSlot = &slot;
if (slot.compare_exchange_strong(
slotValue,
registration,
std::memory_order_seq_cst,
std::memory_order_relaxed))
{
// Successfully claimed the slot.
return;
}
}
}
}
// We've traversed through all of the chunks and found no free slots.
// So try and allocate a new chunk and append it to the list.
constexpr std::uint32_t maxElementCount = 1024;
const std::uint32_t elementCount =
lastChunk->m_entryCount < maxElementCount ?
lastChunk->m_entryCount * 2 : maxElementCount;
// May throw std::bad_alloc if out of memory.
auto* newChunk = registration_list_chunk::allocate(elementCount);
newChunk->m_prevChunk = lastChunk;
// Pre-allocate first slot.
auto& slot = newChunk->m_entries[0];
registration->m_registrationSlot = &slot;
slot.store(registration, std::memory_order_relaxed);
registration_list_chunk* oldNext = nullptr;
if (lastChunk->m_nextChunk.compare_exchange_strong(
oldNext,
newChunk,
std::memory_order_seq_cst,
std::memory_order_relaxed))
{
bucket->m_approximateTail.store(newChunk, std::memory_order_release);
return;
}
// Some other thread published a new chunk to the end of the list
// concurrently. Free our chunk and go around the loop again, hopefully
// allocating a slot from the chunk the other thread just allocated.
registration_list_chunk::free(newChunk);
}
}
cppcoro::detail::cancellation_state* cppcoro::detail::cancellation_state::create()
{
return new cancellation_state();
}
cppcoro::detail::cancellation_state::~cancellation_state()
{
assert((m_state.load(std::memory_order_relaxed) & cancellation_ref_count_mask) == 0);
// Use relaxed memory order in reads here since we should already have visibility
// to all writes as the ref-count decrement that preceded the call to the destructor
// has acquire-release semantics.
auto* registrations = m_registrations.load(std::memory_order_relaxed);
if (registrations != nullptr)
{
for (std::uint32_t i = 0; i < registrations->m_bucketCount; ++i)
{
auto* bucket = registrations->m_buckets[i].load(std::memory_order_relaxed);
if (bucket != nullptr)
{
auto* chunk = bucket->m_headChunk.m_nextChunk.load(std::memory_order_relaxed);
registration_list_bucket::free(bucket);
while (chunk != nullptr)
{
auto* next = chunk->m_nextChunk.load(std::memory_order_relaxed);
registration_list_chunk::free(chunk);
chunk = next;
}
}
}
registration_list::free(registrations);
}
}
void cppcoro::detail::cancellation_state::add_token_ref() noexcept
{
m_state.fetch_add(cancellation_token_ref_increment, std::memory_order_relaxed);
}
void cppcoro::detail::cancellation_state::release_token_ref() noexcept
{
const std::uint64_t oldState = m_state.fetch_sub(cancellation_token_ref_increment, std::memory_order_acq_rel);
if ((oldState & cancellation_ref_count_mask) == cancellation_token_ref_increment)
{
delete this;
}
}
void cppcoro::detail::cancellation_state::add_source_ref() noexcept
{
m_state.fetch_add(cancellation_source_ref_increment, std::memory_order_relaxed);
}
void cppcoro::detail::cancellation_state::release_source_ref() noexcept
{
const std::uint64_t oldState = m_state.fetch_sub(cancellation_source_ref_increment, std::memory_order_acq_rel);
if ((oldState & cancellation_ref_count_mask) == cancellation_source_ref_increment)
{
delete this;
}
}
bool cppcoro::detail::cancellation_state::can_be_cancelled() const noexcept
{
return (m_state.load(std::memory_order_acquire) & can_be_cancelled_mask) != 0;
}
bool cppcoro::detail::cancellation_state::is_cancellation_requested() const noexcept
{
return (m_state.load(std::memory_order_acquire) & cancellation_requested_flag) != 0;
}
bool cppcoro::detail::cancellation_state::is_cancellation_notification_complete() const noexcept
{
return (m_state.load(std::memory_order_acquire) & cancellation_notification_complete_flag) != 0;
}
void cppcoro::detail::cancellation_state::request_cancellation()
{
const auto oldState = m_state.fetch_or(cancellation_requested_flag, std::memory_order_seq_cst);
if ((oldState & cancellation_requested_flag) != 0)
{
// Some thread has already called request_cancellation().
return;
}
// We are the first caller of request_cancellation.
// Need to execute any registered callbacks to notify them of cancellation.
// NOTE: We need to use sequentially-consistent operations here to ensure
// that if there is a concurrent call to try_register_callback() on another
// thread that either the other thread will read the prior write to m_state
// after they write to a registration slot or we will read their write to the
// registration slot after the prior write to m_state.
auto* const registrations = m_registrations.load(std::memory_order_seq_cst);
if (registrations != nullptr)
{
// Note that there should be no data-race in writing to this value here
// as another thread will only read it if they are trying to deregister
// a callback and that fails because we have acquired the pointer to
// the registration inside the loop below. In this case the atomic
// exchange that acquires the pointer below acts as a release-operation
// that synchronises with the failed exchange operation in deregister_callback()
// which has acquire semantics and thus will have visibility of the write to
// the m_notificationThreadId value.
registrations->m_notificationThreadId = std::this_thread::get_id();
for (std::uint32_t bucketIndex = 0, bucketCount = registrations->m_bucketCount;
bucketIndex < bucketCount;
++bucketIndex)
{
auto* bucket = registrations->m_buckets[bucketIndex].load(std::memory_order_seq_cst);
if (bucket == nullptr)
{
continue;
}
auto* chunk = &bucket->m_headChunk;
do
{
for (std::uint32_t entryIndex = 0, entryCount = chunk->m_entryCount;
entryIndex < entryCount;
++entryIndex)
{
auto& entry = chunk->m_entries[entryIndex];
// Quick read-only operation to check if any registration
// is present.
auto* registration = entry.load(std::memory_order_seq_cst);
if (registration != nullptr)
{
// Try to acquire ownership of the registration by replacing its
// slot with nullptr atomically. This resolves the race between
// a concurrent call to deregister_callback() from the registration's
// destructor.
registration = entry.exchange(nullptr, std::memory_order_seq_cst);
if (registration != nullptr)
{
try
{
registration->m_callback();
}
catch (...)
{
// TODO: What should behaviour of unhandled exception in a callback be here?
std::terminate();
}
}
}
}
chunk = chunk->m_nextChunk.load(std::memory_order_seq_cst);
} while (chunk != nullptr);
}
m_state.fetch_add(cancellation_notification_complete_flag, std::memory_order_release);
}
}
bool cppcoro::detail::cancellation_state::try_register_callback(cancellation_registration* registration)
{
if (is_cancellation_requested())
{
return false;
}
auto* registrationList = m_registrations.load(std::memory_order_acquire);
if (registrationList == nullptr)
{
// Could throw std::bad_alloc
auto* newRegistrationList = registration_list::allocate();
// Need to use 'sequentially consistent' on the write here to ensure that if
// we subsequently read a value from m_state at the end of this function that
// doesn't have the cancellation_requested_flag bit set that a subsequent call
// in another thread to request_cancellation() will see this write.
if (m_registrations.compare_exchange_strong(
registrationList,
newRegistrationList,
std::memory_order_seq_cst,
std::memory_order_acquire))
{
registrationList = newRegistrationList;
}
else
{
registration_list::free(newRegistrationList);
}
}
// Could throw std::bad_alloc
registrationList->add_registration(registration);
// Need to check status again to handle the case where
// another thread calls request_cancellation() concurrently
// but doesn't see our write to the registration list.
if ((m_state.load(std::memory_order_seq_cst) & cancellation_requested_flag) != 0)
{
// Cancellation was requested concurrently with adding the
// registration to the list. Try to remove the registration.
// If successful we return false to indicate that the callback
// has not been registered and the caller should execute the
// callback. If it fails it means that the thread that requested
// cancellation will execute our callback and we need to wait
// until it finishes before returning.
auto& slot = *registration->m_registrationSlot;
// Need to use compare_exchange here rather than just exchange since
// it may be possible that the thread calling request_cancellation()
// acquired our registration and executed the callback, freeing up
// the slot and then a third thread registers a new registration
// that gets allocated to this slot.
//
// Can use relaxed memory order here since in the case that this succeeds
// no other thread will have written to the cancellation_registration record
// so we can safely read from the record without synchronisation.
auto* oldValue = registration;
const bool deregisteredSuccessfully =
slot.compare_exchange_strong(oldValue, nullptr, std::memory_order_relaxed);
if (deregisteredSuccessfully)
{
return false;
}
// Otherwise, the cancelling thread has taken ownership for executing
// the callback and we can just act as if the registration succeeded.
}
return true;
}
void cppcoro::detail::cancellation_state::deregister_callback(cancellation_registration* registration) noexcept
{
// Could use 'relaxed' memory order on success case as if this succeeds it means that
// no thread will have written to the registration object.
// Use 'acquire' memory order on failure case so that we synchronise with the write
// to the slot inside request_cancellation() that acquired the registration such that
// we have visibility of its prior write to m_notifyingThreadId.
auto* oldValue = registration;
bool deregisteredSuccessfully = registration->m_registrationSlot->compare_exchange_strong(
oldValue,
nullptr,
std::memory_order_acquire);
if (!deregisteredSuccessfully)
{
// A thread executing request_cancellation() has acquired this callback and
// is executing it. Need to wait until it finishes executing before we return
// and the registration object is destructed.
// However, we also need to handle the case where the registration is being
// removed from within a callback which would otherwise deadlock waiting
// for the callbacks to finish executing.
// Use relaxed memory order here as we should already have visibility
// of the write to m_registrations from when the registration was first
// registered.
auto* registrationList = m_registrations.load(std::memory_order_relaxed);
if (std::this_thread::get_id() != registrationList->m_notificationThreadId)
{
// TODO: More efficient busy-wait backoff strategy
while (!is_cancellation_notification_complete())
{
std::this_thread::yield();
}
}
}
}
cppcoro::detail::cancellation_state::cancellation_state() noexcept
: m_state(cancellation_source_ref_increment)
, m_registrations(nullptr)
{
}

110
lib/cancellation_state.hpp Normal file
View File

@ -0,0 +1,110 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_CANCELLATION_STATE_HPP_INCLUDED
#define CPPCORO_CANCELLATION_STATE_HPP_INLUCDED
#include <cppcoro/cancellation_token.hpp>
#include <thread>
#include <atomic>
#include <cstdint>
namespace cppcoro
{
namespace detail
{
class cancellation_state
{
public:
/// Allocates a new cancellation_state object.
///
/// \throw std::bad_alloc
/// If there was insufficient memory to allocate one.
static cancellation_state* create();
~cancellation_state();
/// Increment the reference count of cancellation_token and
/// cancellation_registration objects referencing this state.
void add_token_ref() noexcept;
/// Decrement the reference count of cancellation_token and
/// cancellation_registration objects referencing this state.
void release_token_ref() noexcept;
/// Increment the reference count of cancellation_source objects.
void add_source_ref() noexcept;
/// Decrement the reference count of cancellation_souce objects.
///
/// The cancellation_state will no longer be cancellable once the
/// cancellation_source ref count reaches zero.
void release_source_ref() noexcept;
/// Query if the cancellation_state can have cancellation requested.
///
/// \return
/// Returns true if there are no more references to a cancellation_source
/// object.
bool can_be_cancelled() const noexcept;
/// Query if some thread has called request_cancellation().
bool is_cancellation_requested() const noexcept;
/// Flag state has having cancellation_requested and execute any
/// registered callbacks.
void request_cancellation();
/// Try to register the cancellation_registration as a callback to be executed
/// when cancellation is requested.
///
/// \return
/// true if the callback was successfully registered, false if the callback was
/// not registered because cancellation had already been requested.
///
/// \throw std::bad_alloc
/// If callback was unable to be registered due to insufficient memory.
bool try_register_callback(cancellation_registration* registration);
/// Deregister a callback previously registered successfully in a call to try_register_callback().
///
/// If the callback is currently being executed on another
/// thread that is concurrently calling request_cancellation()
/// then this call will block until the callback has finished executing.
void deregister_callback(cancellation_registration* registration) noexcept;
private:
cancellation_state() noexcept;
bool is_cancellation_notification_complete() const noexcept;
struct registration_list;
struct registration_list_bucket;
struct registration_list_chunk;
static constexpr std::uint64_t cancellation_requested_flag = 1;
static constexpr std::uint64_t cancellation_notification_complete_flag = 2;
static constexpr std::uint64_t cancellation_source_ref_increment = 4;
static constexpr std::uint64_t cancellation_token_ref_increment = UINT64_C(1) << 33;
static constexpr std::uint64_t can_be_cancelled_mask = cancellation_token_ref_increment - 1;
static constexpr std::uint64_t cancellation_ref_count_mask =
~(cancellation_requested_flag | cancellation_notification_complete_flag);
// A value that has:
// - bit 0 - indicates whether cancellation has been requested.
// - bit 1 - indicates whether cancellation notification is complete.
// - bits 2-32 - ref-count for cancellation_source instances.
// - bits 33-63 - ref-count for cancellation_token/cancellation_registration instances.
std::atomic<std::uint64_t> m_state;
std::atomic<registration_list*> m_registrations;
};
}
}
#endif

108
lib/cancellation_token.cpp Normal file
View File

@ -0,0 +1,108 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/cancellation_token.hpp>
#include <cppcoro/operation_cancelled.hpp>
#include "cancellation_state.hpp"
#include <utility>
#include <cassert>
cppcoro::cancellation_token::cancellation_token() noexcept
: m_state(nullptr)
{
}
cppcoro::cancellation_token::cancellation_token(const cancellation_token& other) noexcept
: m_state(other.m_state)
{
if (m_state != nullptr)
{
m_state->add_token_ref();
}
}
cppcoro::cancellation_token::cancellation_token(cancellation_token&& other) noexcept
: m_state(other.m_state)
{
other.m_state = nullptr;
}
cppcoro::cancellation_token::~cancellation_token()
{
if (m_state != nullptr)
{
m_state->release_token_ref();
}
}
cppcoro::cancellation_token& cppcoro::cancellation_token::operator=(const cancellation_token& other) noexcept
{
if (other.m_state != m_state)
{
if (m_state != nullptr)
{
m_state->release_token_ref();
}
m_state = other.m_state;
if (m_state != nullptr)
{
m_state->add_token_ref();
}
}
return *this;
}
cppcoro::cancellation_token& cppcoro::cancellation_token::operator=(cancellation_token&& other) noexcept
{
if (this != &other)
{
if (m_state != nullptr)
{
m_state->release_token_ref();
}
m_state = other.m_state;
other.m_state = nullptr;
}
return *this;
}
void cppcoro::cancellation_token::swap(cancellation_token& other) noexcept
{
std::swap(m_state, other.m_state);
}
bool cppcoro::cancellation_token::can_be_cancelled() const noexcept
{
return m_state != nullptr && m_state->can_be_cancelled();
}
bool cppcoro::cancellation_token::is_cancellation_requested() const noexcept
{
return m_state != nullptr && m_state->is_cancellation_requested();
}
void cppcoro::cancellation_token::throw_if_cancellation_requested() const
{
if (is_cancellation_requested())
{
throw operation_cancelled{};
}
}
cppcoro::cancellation_token::cancellation_token(detail::cancellation_state* state) noexcept
: m_state(state)
{
if (m_state != nullptr)
{
m_state->add_token_ref();
}
}

View File

@ -11,6 +11,10 @@
#include <cppcoro/single_consumer_event.hpp>
#include <cppcoro/async_mutex.hpp>
#include <cppcoro/shared_task.hpp>
#include <cppcoro/cancellation_source.hpp>
#include <cppcoro/cancellation_token.hpp>
#include <cppcoro/cancellation_registration.hpp>
#include <cppcoro/operation_cancelled.hpp>
#include <memory>
#include <string>
@ -806,6 +810,260 @@ void testMakeSharedTask()
assert(consumerTask1.is_ready());
}
void testDefaultCancellationTokenIsNotCancellable()
{
cppcoro::cancellation_token t;
assert(!t.is_cancellation_requested());
assert(!t.can_be_cancelled());
}
void testRequestCancellation()
{
cppcoro::cancellation_source s;
cppcoro::cancellation_token t = s.token();
assert(t.can_be_cancelled());
assert(!t.is_cancellation_requested());
s.request_cancellation();
assert(t.is_cancellation_requested());
assert(t.can_be_cancelled());
}
void testCantBeCancelledWhenLastSourceDestructed()
{
cppcoro::cancellation_token t;
{
cppcoro::cancellation_source s;
t = s.token();
assert(t.can_be_cancelled());
}
assert(!t.can_be_cancelled());
}
void testCanBeCancelledWhenLastSourceDestructedIfCancellationAlreadyRequested()
{
cppcoro::cancellation_token t;
{
cppcoro::cancellation_source s;
t = s.token();
assert(t.can_be_cancelled());
s.request_cancellation();
}
assert(t.can_be_cancelled());
assert(t.is_cancellation_requested());
}
void testCancellationRegistrationWhenCancellationNotRequested()
{
cppcoro::cancellation_source s;
bool callbackExecuted = false;
{
cppcoro::cancellation_registration callbackRegistration(
s.token(),
[&] { callbackExecuted = true; });
}
assert(!callbackExecuted);
{
cppcoro::cancellation_registration callbackRegistration(
s.token(),
[&] { callbackExecuted = true; });
assert(!callbackExecuted);
s.request_cancellation();
assert(callbackExecuted);
}
}
void testThrowIfCancellationRequested()
{
cppcoro::cancellation_source s;
cppcoro::cancellation_token t = s.token();
try
{
t.throw_if_cancellation_requested();
}
catch (cppcoro::operation_cancelled)
{
assert(false);
}
s.request_cancellation();
try
{
t.throw_if_cancellation_requested();
assert(false);
}
catch (cppcoro::operation_cancelled)
{
}
}
void testCancellationRegistrationCalledImmediatelyWhenCancellationAlreadyRequested()
{
cppcoro::cancellation_source s;
s.request_cancellation();
bool executed = false;
cppcoro::cancellation_registration r{ s.token(), [&] { executed = true; } };
assert(executed);
}
void testRegisteringManyCallbacks()
{
cppcoro::cancellation_source s;
auto t = s.token();
int callbackExecutionCount = 0;
auto callback = [&] { ++callbackExecutionCount; };
// Allocate enough to require a second chunk to be allocated.
cppcoro::cancellation_registration r1{ t, callback };
cppcoro::cancellation_registration r2{ t, callback };
cppcoro::cancellation_registration r3{ t, callback };
cppcoro::cancellation_registration r4{ t, callback };
cppcoro::cancellation_registration r5{ t, callback };
cppcoro::cancellation_registration r6{ t, callback };
cppcoro::cancellation_registration r7{ t, callback };
cppcoro::cancellation_registration r8{ t, callback };
cppcoro::cancellation_registration r9{ t, callback };
cppcoro::cancellation_registration r10{ t, callback };
cppcoro::cancellation_registration r11{ t, callback };
cppcoro::cancellation_registration r12{ t, callback };
cppcoro::cancellation_registration r13{ t, callback };
cppcoro::cancellation_registration r14{ t, callback };
cppcoro::cancellation_registration r15{ t, callback };
cppcoro::cancellation_registration r16{ t, callback };
cppcoro::cancellation_registration r17{ t, callback };
cppcoro::cancellation_registration r18{ t, callback };
s.request_cancellation();
assert(callbackExecutionCount == 18);
}
void testConcurrentRegistrationAndCancellation()
{
// Just check this runs and terminates without crashing.
for (int i = 0; i < 100; ++i)
{
cppcoro::cancellation_source source;
std::thread waiter1{ [token = source.token()]
{
std::atomic<bool> cancelled = false;
while (!cancelled)
{
cppcoro::cancellation_registration registration{ token, [&]
{
cancelled = true;
} };
cppcoro::cancellation_registration reg0{ token, [] {} };
cppcoro::cancellation_registration reg1{ token, [] {} };
cppcoro::cancellation_registration reg2{ token, [] {} };
cppcoro::cancellation_registration reg3{ token, [] {} };
cppcoro::cancellation_registration reg4{ token, [] {} };
cppcoro::cancellation_registration reg5{ token, [] {} };
cppcoro::cancellation_registration reg6{ token, [] {} };
cppcoro::cancellation_registration reg7{ token, [] {} };
cppcoro::cancellation_registration reg8{ token, [] {} };
cppcoro::cancellation_registration reg9{ token, [] {} };
cppcoro::cancellation_registration reg10{ token, [] {} };
cppcoro::cancellation_registration reg11{ token, [] {} };
cppcoro::cancellation_registration reg12{ token, [] {} };
cppcoro::cancellation_registration reg13{ token, [] {} };
cppcoro::cancellation_registration reg14{ token, [] {} };
cppcoro::cancellation_registration reg15{ token, [] {} };
cppcoro::cancellation_registration reg17{ token, [] {} };
std::this_thread::yield();
}
} };
std::thread waiter2{ [token = source.token()]
{
std::atomic<bool> cancelled = false;
while (!cancelled)
{
cppcoro::cancellation_registration registration{ token, [&]
{
cancelled = true;
} };
cppcoro::cancellation_registration reg0{ token, [] {} };
cppcoro::cancellation_registration reg1{ token, [] {} };
cppcoro::cancellation_registration reg2{ token, [] {} };
cppcoro::cancellation_registration reg3{ token, [] {} };
cppcoro::cancellation_registration reg4{ token, [] {} };
cppcoro::cancellation_registration reg5{ token, [] {} };
cppcoro::cancellation_registration reg6{ token, [] {} };
cppcoro::cancellation_registration reg7{ token, [] {} };
cppcoro::cancellation_registration reg8{ token, [] {} };
cppcoro::cancellation_registration reg9{ token, [] {} };
cppcoro::cancellation_registration reg10{ token, [] {} };
cppcoro::cancellation_registration reg11{ token, [] {} };
cppcoro::cancellation_registration reg12{ token, [] {} };
cppcoro::cancellation_registration reg13{ token, [] {} };
cppcoro::cancellation_registration reg14{ token, [] {} };
cppcoro::cancellation_registration reg15{ token, [] {} };
cppcoro::cancellation_registration reg16{ token, [] {} };
std::this_thread::yield();
}
} };
std::thread waiter3{ [token = source.token()]
{
std::atomic<bool> cancelled = false;
while (!cancelled)
{
cppcoro::cancellation_registration registration{ token, [&]
{
cancelled = true;
} };
cppcoro::cancellation_registration reg0{ token, [] {} };
cppcoro::cancellation_registration reg1{ token, [] {} };
cppcoro::cancellation_registration reg2{ token, [] {} };
cppcoro::cancellation_registration reg3{ token, [] {} };
cppcoro::cancellation_registration reg4{ token, [] {} };
cppcoro::cancellation_registration reg5{ token, [] {} };
cppcoro::cancellation_registration reg6{ token, [] {} };
cppcoro::cancellation_registration reg7{ token, [] {} };
cppcoro::cancellation_registration reg8{ token, [] {} };
cppcoro::cancellation_registration reg9{ token, [] {} };
cppcoro::cancellation_registration reg10{ token, [] {} };
cppcoro::cancellation_registration reg11{ token, [] {} };
cppcoro::cancellation_registration reg12{ token, [] {} };
cppcoro::cancellation_registration reg13{ token, [] {} };
cppcoro::cancellation_registration reg14{ token, [] {} };
cppcoro::cancellation_registration reg15{ token, [] {} };
cppcoro::cancellation_registration reg16{ token, [] {} };
std::this_thread::yield();
}
} };
std::thread canceller{ [&source]
{
source.request_cancellation();
} };
canceller.join();
waiter1.join();
waiter2.join();
waiter3.join();
}
}
int main(int argc, char** argv)
{
testAwaitSynchronouslyCompletingVoidFunction();
@ -843,5 +1101,15 @@ int main(int argc, char** argv)
testSharedTaskEquality();
testMakeSharedTask();
testDefaultCancellationTokenIsNotCancellable();
testRequestCancellation();
testCantBeCancelledWhenLastSourceDestructed();
testCanBeCancelledWhenLastSourceDestructedIfCancellationAlreadyRequested();
testCancellationRegistrationWhenCancellationNotRequested();
testCancellationRegistrationCalledImmediatelyWhenCancellationAlreadyRequested();
testThrowIfCancellationRequested();
testRegisteringManyCallbacks();
testConcurrentRegistrationAndCancellation();
return 0;
}