413 lines
8.7 KiB
C++
413 lines
8.7 KiB
C++
///////////////////////////////////////////////////////////////////////////////
|
|
// Copyright (c) Lewis Baker
|
|
// Licenced under MIT license. See LICENSE.txt for details.
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
#include <cppcoro/generator.hpp>
|
|
#include <cppcoro/on_scope_exit.hpp>
|
|
#include <cppcoro/fmap.hpp>
|
|
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <forward_list>
|
|
|
|
#include "doctest/doctest.h"
|
|
|
|
TEST_SUITE_BEGIN("generator");
|
|
|
|
using cppcoro::generator;
|
|
|
|
TEST_CASE("default-constructed generator is empty sequence")
|
|
{
|
|
generator<int> ints;
|
|
CHECK(ints.begin() == ints.end());
|
|
}
|
|
|
|
TEST_CASE("generator of arithmetic type returns by copy")
|
|
{
|
|
auto f = []() -> generator<float>
|
|
{
|
|
co_yield 1.0f;
|
|
co_yield 2.0f;
|
|
};
|
|
|
|
auto gen = f();
|
|
auto iter = gen.begin();
|
|
// TODO: Should this really be required?
|
|
//static_assert(std::is_same<decltype(*iter), float>::value, "operator* should return float by value");
|
|
CHECK(*iter == 1.0f);
|
|
++iter;
|
|
CHECK(*iter == 2.0f);
|
|
++iter;
|
|
CHECK(iter == gen.end());
|
|
}
|
|
|
|
TEST_CASE("generator of reference returns by reference")
|
|
{
|
|
auto f = [](float& value) -> generator<float&>
|
|
{
|
|
co_yield value;
|
|
};
|
|
|
|
float value = 1.0f;
|
|
for (auto& x : f(value))
|
|
{
|
|
CHECK(&x == &value);
|
|
x += 1.0f;
|
|
}
|
|
|
|
CHECK(value == 2.0f);
|
|
}
|
|
|
|
TEST_CASE("generator of const type")
|
|
{
|
|
auto fib = []() -> generator<const std::uint64_t>
|
|
{
|
|
std::uint64_t a = 0, b = 1;
|
|
while (true)
|
|
{
|
|
co_yield b;
|
|
b += std::exchange(a, b);
|
|
}
|
|
};
|
|
|
|
std::uint64_t count = 0;
|
|
for (auto i : fib())
|
|
{
|
|
if (i > 1'000'000) {
|
|
break;
|
|
}
|
|
++count;
|
|
}
|
|
|
|
// 30th fib number is 832'040
|
|
CHECK(count == 30);
|
|
}
|
|
|
|
TEST_CASE("value-category of fmap() matches reference type")
|
|
{
|
|
using cppcoro::fmap;
|
|
|
|
auto checkIsRvalue = [](auto&& x) {
|
|
static_assert(std::is_rvalue_reference_v<decltype(x)>);
|
|
static_assert(!std::is_const_v<std::remove_reference_t<decltype(x)>>);
|
|
CHECK(x == 123);
|
|
return x;
|
|
};
|
|
auto checkIsLvalue = [](auto&& x) {
|
|
static_assert(std::is_lvalue_reference_v<decltype(x)>);
|
|
static_assert(!std::is_const_v<std::remove_reference_t<decltype(x)>>);
|
|
CHECK(x == 123);
|
|
return x;
|
|
};
|
|
auto checkIsConstLvalue = [](auto&& x) {
|
|
static_assert(std::is_lvalue_reference_v<decltype(x)>);
|
|
static_assert(std::is_const_v<std::remove_reference_t<decltype(x)>>);
|
|
CHECK(x == 123);
|
|
return x;
|
|
};
|
|
auto checkIsConstRvalue = [](auto&& x) {
|
|
static_assert(std::is_rvalue_reference_v<decltype(x)>);
|
|
static_assert(std::is_const_v<std::remove_reference_t<decltype(x)>>);
|
|
CHECK(x == 123);
|
|
return x;
|
|
};
|
|
|
|
auto consume = [](auto&& range) {
|
|
for (auto&& x : range) {
|
|
(void)x;
|
|
}
|
|
};
|
|
|
|
consume([]() -> generator<int> { co_yield 123; }() | fmap(checkIsLvalue));
|
|
consume([]() -> generator<const int> { co_yield 123; }() | fmap(checkIsConstLvalue));
|
|
consume([]() -> generator<int&> { co_yield 123; }() | fmap(checkIsLvalue));
|
|
consume([]() -> generator<const int&> { co_yield 123; }() | fmap(checkIsConstLvalue));
|
|
consume([]() -> generator<int&&> { co_yield 123; }() | fmap(checkIsRvalue));
|
|
consume([]() -> generator<const int&&> { co_yield 123; }() | fmap(checkIsConstRvalue));
|
|
}
|
|
|
|
TEST_CASE("generator doesn't start until its called")
|
|
{
|
|
bool reachedA = false;
|
|
bool reachedB = false;
|
|
bool reachedC = false;
|
|
auto f = [&]() -> generator<int>
|
|
{
|
|
reachedA = true;
|
|
co_yield 1;
|
|
reachedB = true;
|
|
co_yield 2;
|
|
reachedC = true;
|
|
};
|
|
|
|
auto gen = f();
|
|
CHECK(!reachedA);
|
|
auto iter = gen.begin();
|
|
CHECK(reachedA);
|
|
CHECK(!reachedB);
|
|
CHECK(*iter == 1);
|
|
++iter;
|
|
CHECK(reachedB);
|
|
CHECK(!reachedC);
|
|
CHECK(*iter == 2);
|
|
++iter;
|
|
CHECK(reachedC);
|
|
CHECK(iter == gen.end());
|
|
}
|
|
|
|
TEST_CASE("destroying generator before completion destructs objects on stack")
|
|
{
|
|
bool destructed = false;
|
|
bool completed = false;
|
|
auto f = [&]() -> generator<int>
|
|
{
|
|
auto onExit = cppcoro::on_scope_exit([&]
|
|
{
|
|
destructed = true;
|
|
});
|
|
|
|
co_yield 1;
|
|
co_yield 2;
|
|
completed = true;
|
|
};
|
|
|
|
{
|
|
auto g = f();
|
|
auto it = g.begin();
|
|
auto itEnd = g.end();
|
|
CHECK(it != itEnd);
|
|
CHECK(*it == 1u);
|
|
CHECK(!destructed);
|
|
}
|
|
|
|
CHECK(!completed);
|
|
CHECK(destructed);
|
|
}
|
|
|
|
TEST_CASE("generator throwing before yielding first element rethrows out of begin()")
|
|
{
|
|
class X {};
|
|
|
|
auto g = []() -> cppcoro::generator<int>
|
|
{
|
|
throw X{};
|
|
co_return;
|
|
}();
|
|
|
|
try
|
|
{
|
|
g.begin();
|
|
FAIL("should have thrown");
|
|
}
|
|
catch (const X&)
|
|
{
|
|
}
|
|
}
|
|
|
|
TEST_CASE("generator throwing after first element rethrows out of operator++")
|
|
{
|
|
class X {};
|
|
|
|
auto g = []() -> cppcoro::generator<int>
|
|
{
|
|
co_yield 1;
|
|
throw X{};
|
|
}();
|
|
|
|
auto iter = g.begin();
|
|
REQUIRE(iter != g.end());
|
|
try
|
|
{
|
|
++iter;
|
|
FAIL("should have thrown");
|
|
}
|
|
catch (const X&)
|
|
{
|
|
}
|
|
}
|
|
|
|
namespace
|
|
{
|
|
template<typename FIRST, typename SECOND>
|
|
auto concat(FIRST&& first, SECOND&& second)
|
|
{
|
|
using value_type = std::remove_reference_t<decltype(*first.begin())>;
|
|
return [](FIRST first, SECOND second) -> cppcoro::generator<value_type>
|
|
{
|
|
for (auto&& x : first) co_yield x;
|
|
for (auto&& y : second) co_yield y;
|
|
}(std::forward<FIRST>(first), std::forward<SECOND>(second));
|
|
}
|
|
}
|
|
|
|
TEST_CASE("safe capture of r-value reference args")
|
|
{
|
|
using namespace std::string_literals;
|
|
|
|
// Check that we can capture l-values by reference and that temporary
|
|
// values are moved into the coroutine frame.
|
|
std::string byRef = "bar";
|
|
auto g = concat("foo"s, concat(byRef, std::vector<char>{ 'b', 'a', 'z' }));
|
|
|
|
byRef = "buzz";
|
|
|
|
std::string s;
|
|
for (char c : g)
|
|
{
|
|
s += c;
|
|
}
|
|
|
|
CHECK(s == "foobuzzbaz");
|
|
}
|
|
|
|
namespace
|
|
{
|
|
cppcoro::generator<int> range(int start, int end)
|
|
{
|
|
for (; start < end; ++start)
|
|
{
|
|
co_yield start;
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_CASE("fmap operator")
|
|
{
|
|
cppcoro::generator<int> gen = range(0, 5)
|
|
| cppcoro::fmap([](int x) { return x * 3; });
|
|
|
|
auto it = gen.begin();
|
|
CHECK(*it == 0);
|
|
CHECK(*++it == 3);
|
|
CHECK(*++it == 6);
|
|
CHECK(*++it == 9);
|
|
CHECK(*++it == 12);
|
|
CHECK(++it == gen.end());
|
|
}
|
|
|
|
namespace
|
|
{
|
|
template<std::size_t window, typename Range>
|
|
cppcoro::generator<const double> low_pass(Range rng)
|
|
{
|
|
auto it = std::begin(rng);
|
|
const auto itEnd = std::end(rng);
|
|
|
|
const double invCount = 1.0 / window;
|
|
double sum = 0;
|
|
|
|
using iter_cat =
|
|
typename std::iterator_traits<decltype(it)>::iterator_category;
|
|
|
|
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, iter_cat>)
|
|
{
|
|
for (std::size_t count = 0; it != itEnd && count < window; ++it)
|
|
{
|
|
sum += *it;
|
|
++count;
|
|
co_yield sum / count;
|
|
}
|
|
|
|
for (; it != itEnd; ++it)
|
|
{
|
|
sum -= *(it - window);
|
|
sum += *it;
|
|
co_yield sum * invCount;
|
|
}
|
|
}
|
|
else if constexpr (std::is_base_of_v<std::forward_iterator_tag, iter_cat>)
|
|
{
|
|
auto windowStart = it;
|
|
for (std::size_t count = 0; it != itEnd && count < window; ++it)
|
|
{
|
|
sum += *it;
|
|
++count;
|
|
co_yield sum / count;
|
|
}
|
|
|
|
for (; it != itEnd; ++it, ++windowStart)
|
|
{
|
|
sum -= *windowStart;
|
|
sum += *it;
|
|
co_yield sum * invCount;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Just assume an input iterator
|
|
double buffer[window];
|
|
|
|
for (std::size_t count = 0; it != itEnd && count < window; ++it)
|
|
{
|
|
buffer[count] = *it;
|
|
sum += buffer[count];
|
|
++count;
|
|
co_yield sum / count;
|
|
}
|
|
|
|
for (std::size_t pos = 0; it != itEnd; ++it, pos = (pos + 1 == window) ? 0 : (pos + 1))
|
|
{
|
|
sum -= std::exchange(buffer[pos], *it);
|
|
sum += buffer[pos];
|
|
co_yield sum * invCount;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// HACK: Disable this test as it's causing heap corruption errors under MSVC 2017 Update 5 x86 debug builds.
|
|
// Still needs investigation of root cause.
|
|
TEST_CASE("low_pass" * doctest::skip{ true })
|
|
{
|
|
// With random-access iterator
|
|
{
|
|
auto gen = low_pass<4>(std::vector<int>{ 10, 13, 10, 15, 18, 9, 11, 15 });
|
|
auto it = gen.begin();
|
|
CHECK(*it == 10.0);
|
|
CHECK(*++it == 11.5);
|
|
CHECK(*++it == 11.0);
|
|
CHECK(*++it == 12.0);
|
|
CHECK(*++it == 14.0);
|
|
CHECK(*++it == 13.0);
|
|
CHECK(*++it == 13.25);
|
|
CHECK(*++it == 13.25);
|
|
CHECK(++it == gen.end());
|
|
}
|
|
|
|
// With forward-iterator
|
|
{
|
|
auto gen = low_pass<4>(std::forward_list<int>{ 10, 13, 10, 15, 18, 9, 11, 15 });
|
|
auto it = gen.begin();
|
|
CHECK(*it == 10.0);
|
|
CHECK(*++it == 11.5);
|
|
CHECK(*++it == 11.0);
|
|
CHECK(*++it == 12.0);
|
|
CHECK(*++it == 14.0);
|
|
CHECK(*++it == 13.0);
|
|
CHECK(*++it == 13.25);
|
|
CHECK(*++it == 13.25);
|
|
CHECK(++it == gen.end());
|
|
}
|
|
|
|
// With input-iterator
|
|
{
|
|
auto gen = low_pass<3>(range(10, 20));
|
|
auto it = gen.begin();
|
|
CHECK(*it == 10.0);
|
|
CHECK(*++it == 10.5);
|
|
CHECK(*++it == 11.0);
|
|
CHECK(*++it == 12.0);
|
|
CHECK(*++it == 13.0);
|
|
CHECK(*++it == 14.0);
|
|
CHECK(*++it == 15.0);
|
|
CHECK(*++it == 16.0);
|
|
CHECK(*++it == 17.0);
|
|
CHECK(*++it == 18.0);
|
|
CHECK(++it == gen.end());
|
|
}
|
|
}
|
|
|
|
TEST_SUITE_END();
|