diff --git a/CMakeLists.txt b/CMakeLists.txt index fee9fcb..484478a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.7) set(CMAKE_OSX_DEPLOYMENT_TARGET 10.12 CACHE STRING "macOS deployment target (Apple clang only)") project(liboxenmq - VERSION 1.2.6 + VERSION 1.2.8 LANGUAGES CXX C) include(GNUInstallDirs) @@ -113,13 +113,15 @@ target_include_directories(oxenmq $ ) -target_compile_options(oxenmq PRIVATE -Wall -Wextra -Werror) -set_target_properties(oxenmq PROPERTIES - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - CXX_EXTENSIONS OFF - POSITION_INDEPENDENT_CODE ON -) +target_compile_options(oxenmq PRIVATE -Wall -Wextra) + +option(WARNINGS_AS_ERRORS "treat all warnings as errors" ON) +if(WARNINGS_AS_ERRORS) + target_compile_options(oxenmq PRIVATE -Werror) +endif() + +target_compile_features(oxenmq PUBLIC cxx_std_17) +set_target_properties(oxenmq PROPERTIES POSITION_INDEPENDENT_CODE ON) function(link_dep_libs target linktype libdirs) foreach(lib ${ARGN}) diff --git a/SPECS/oxenmq.spec b/SPECS/oxenmq.spec index aa450c3..ce653be 100644 --- a/SPECS/oxenmq.spec +++ b/SPECS/oxenmq.spec @@ -1,6 +1,6 @@ Name: oxenmq -Version: 1.2.6 -Release: 2%{?dist} +Version: 1.2.8 +Release: 1%{?dist} Summary: zeromq-based Oxen message passing library License: BSD @@ -87,11 +87,18 @@ build software using oxenmq. %changelog +* Mon Oct 25 2021 Technical Tumbleweed -1.2.8~1 +- Merge dev changes 1.2.8 +* Fri Oct 15 2021 Technical Tumbleweed -1.2.7~1 +- Merge dev changes 1.2.7 +- change branch Catch2 to 2.x +- bump version + * Mon Aug 09 2021 Jason Rhinelander - 1.2.6-2 - Split oxenmq into lib and devel package - Versioned the library, as we do for debs - Updated various package descriptions and build commands -* Thu Jul 22 2021 Technical Tumbleweed (necro_nemesis@hotmail.com) oxenmq +* Thu Jul 22 2021 Technical Tumbleweed oxenmq -First oxenmq RPM -Second build update to v1.2.6 diff --git a/cmake/local-libzmq/LocalLibzmq.cmake b/cmake/local-libzmq/LocalLibzmq.cmake index c4fb722..48bb67a 100644 --- a/cmake/local-libzmq/LocalLibzmq.cmake +++ b/cmake/local-libzmq/LocalLibzmq.cmake @@ -1,7 +1,7 @@ set(LIBZMQ_PREFIX ${CMAKE_BINARY_DIR}/libzmq) -set(ZeroMQ_VERSION 4.3.3) +set(ZeroMQ_VERSION 4.3.4) set(LIBZMQ_URL https://github.com/zeromq/libzmq/releases/download/v${ZeroMQ_VERSION}/zeromq-${ZeroMQ_VERSION}.tar.gz) -set(LIBZMQ_HASH SHA512=4c18d784085179c5b1fcb753a93813095a12c8d34970f2e1bfca6499be6c9d67769c71c68b7ca54ff181b20390043170e89733c22f76ff1ea46494814f7095b1) +set(LIBZMQ_HASH SHA512=e198ef9f82d392754caadd547537666d4fba0afd7d027749b3adae450516bcf284d241d4616cad3cb4ad9af8c10373d456de92dc6d115b037941659f141e7c0e) message(${LIBZMQ_URL}) diff --git a/cppzmq b/cppzmq index 76bf169..33ed542 160000 --- a/cppzmq +++ b/cppzmq @@ -1 +1 @@ -Subproject commit 76bf169fd67b8e99c1b0e6490029d9cd5ef97666 +Subproject commit 33ed54228e98a82db88da5ac9aca001147521c20 diff --git a/oxenmq/address.h b/oxenmq/address.h index 96af840..6be27cc 100644 --- a/oxenmq/address.h +++ b/oxenmq/address.h @@ -163,9 +163,8 @@ struct address { /// Returns a QR-code friendly address string. This returns an all-uppercase version of the /// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z - /// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we replace the - /// surround the - /// address with $ instead of $ + /// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we surround the + /// address with $ instead of [...] /// /// \throws std::logic_error if called on a unix socket address. std::string qr_address() const; diff --git a/oxenmq/auth.cpp b/oxenmq/auth.cpp index fc89bcd..a114b08 100644 --- a/oxenmq/auth.cpp +++ b/oxenmq/auth.cpp @@ -119,7 +119,6 @@ void OxenMQ::proxy_set_active_sns(pubkey_set pubkeys) { } void OxenMQ::update_active_sns(pubkey_set added, pubkey_set removed) { - LMQ_LOG(info, "uh, ", added.size()); if (proxy_thread.joinable()) { std::array data; data[0] = detail::serialize_object(std::move(added)); @@ -139,7 +138,6 @@ void OxenMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) { // values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists // to the _clean version. - LMQ_LOG(info, "uh, ", added.size(), ", ", removed.size()); for (auto it = removed.begin(); it != removed.end(); ) { const auto& pk = *it; if (pk.size() != 32) { diff --git a/oxenmq/base32z.h b/oxenmq/base32z.h index 074e522..78e6973 100644 --- a/oxenmq/base32z.h +++ b/oxenmq/base32z.h @@ -74,40 +74,81 @@ static_assert(b32z_lut.from_b32z('w') == 20 && b32z_lut.from_b32z('T') == 17 && } // namespace detail -/// Converts bytes into a base32z encoded character sequence. -template -void to_base32z(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "to_base32z requires chars/bytes"); - int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in [0, 4] - std::uint_fast16_t r = 0; - while (begin != end) { - r = r << 8 | static_cast(*begin++); +/// Returns the number of characters required to encode a base32z string from the given number of bytes. +inline constexpr size_t to_base32z_size(size_t byte_size) { return (byte_size*8 + 4) / 5; } // ⌈bits/5⌉ because 5 bits per byte +/// Returns the (maximum) number of bytes required to decode a base32z string of the given size. +inline constexpr size_t from_base32z_size(size_t b32z_size) { return b32z_size*5 / 8; } // ⌊bits/8⌋ - // we just added 8 bits, so we can *always* consume 5 to produce one character, so (net) we - // are adding 3 bits. - bits += 3; - *out++ = detail::b32z_lut.to_b32z(r >> bits); // Right-shift off the bits we aren't consuming right now +/// Iterable object for on-the-fly base32z encoding. Used internally, but also particularly useful +/// when converting from one encoding to another. +template +struct base32z_encoder final { +private: + InputIt _it, _end; + static_assert(sizeof(decltype(*_it)) == 1, "base32z_encoder requires chars/bytes input iterator"); + // Number of bits held in r; will always be >= 5 until we are at the end. + int bits{_it != _end ? 8 : 0}; + // Holds bits of data we've already read, which might belong to current or next chars + uint_fast16_t r{bits ? static_cast(*_it) : (unsigned char)0}; +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = char; + using reference = value_type; + using pointer = void; + base32z_encoder(InputIt begin, InputIt end) : _it{std::move(begin)}, _end{std::move(end)} {} - // Drop the bits we don't want to keep (because we just consumed them) + base32z_encoder end() { return {_end, _end}; } + + bool operator==(const base32z_encoder& i) { return _it == i._it && bits == i.bits; } + bool operator!=(const base32z_encoder& i) { return !(*this == i); } + + base32z_encoder& operator++() { + assert(bits >= 5); + // Discard the most significant 5 bits + bits -= 5; r &= (1 << bits) - 1; - - if (bits >= 5) { // We have enough bits to produce a second character; essentially the same as above - bits -= 5; // Except now we are just consuming 5 without having added any more - *out++ = detail::b32z_lut.to_b32z(r >> bits); - r &= (1 << bits) - 1; + // If we end up with less than 5 significant bits then try to pull another 8 bits: + if (bits < 5 && _it != _end) { + if (++_it != _end) { + r = (r << 8) | static_cast(*_it); + bits += 8; + } else if (bits > 0) { + // No more input bytes, so shift `r` to put the bits we have into the most + // significant bit position for the final character. E.g. if we have "11" we want + // the last character to be encoded "11000". + r <<= (5 - bits); + bits = 5; + } } + return *this; } + base32z_encoder operator++(int) { base32z_encoder copy{*this}; ++*this; return copy; } - if (bits > 0) // We hit the end, but still have some unconsumed bits so need one final character to append - *out++ = detail::b32z_lut.to_b32z(r << (5 - bits)); + char operator*() { + // Right-shift off the excess bits we aren't accessing yet + return detail::b32z_lut.to_b32z(r >> (bits - 5)); + } +}; + +/// Converts bytes into a base32z encoded character sequence, writing them starting at `out`. +/// Returns the final value of out (i.e. the iterator positioned just after the last written base32z +/// character). +template +OutputIt to_base32z(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "to_base32z requires chars/bytes"); + base32z_encoder it{begin, end}; + return std::copy(it, it.end(), out); } /// Creates a base32z string from an iterator pair of a byte sequence. template std::string to_base32z(It begin, It end) { std::string base32z; - if constexpr (std::is_base_of_v::iterator_category>) - base32z.reserve((std::distance(begin, end)*8 + 4) / 5); // == bytes*8/5, rounded up. + if constexpr (std::is_base_of_v::iterator_category>) { + using std::distance; + base32z.reserve(to_base32z_size(distance(begin, end))); + } to_base32z(begin, end, std::back_inserter(base32z)); return base32z; } @@ -117,15 +158,36 @@ template std::string to_base32z(std::basic_string_view s) { return to_base32z(s.begin(), s.end()); } inline std::string to_base32z(std::string_view s) { return to_base32z<>(s); } -/// Returns true if all elements in the range are base32z characters +/// Returns true if the given [begin, end) range is an acceptable base32z string: specifically every +/// character must be in the base32z alphabet, and the string must be a valid encoding length that +/// could have been produced by to_base32z (i.e. some lengths are impossible). template constexpr bool is_base32z(It begin, It end) { static_assert(sizeof(decltype(*begin)) == 1, "is_base32z requires chars/bytes"); + size_t count = 0; + constexpr bool random = std::is_base_of_v::iterator_category>; + if constexpr (random) { + using std::distance; + count = distance(begin, end) % 8; + if (count == 1 || count == 3 || count == 6) // see below + return false; + } for (; begin != end; ++begin) { auto c = static_cast(*begin); if (detail::b32z_lut.from_b32z(c) == 0 && !(c == 'y' || c == 'Y')) return false; + if constexpr (!random) + count++; } + // Check for a valid length. + // - 5n + 0 bytes encodes to 8n chars (no padding bits) + // - 5n + 1 bytes encodes to 8n+2 chars (last 2 bits are padding) + // - 5n + 2 bytes encodes to 8n+4 chars (last 4 bits are padding) + // - 5n + 3 bytes encodes to 8n+5 chars (last 1 bit is padding) + // - 5n + 4 bytes encodes to 8n+7 chars (last 3 bits are padding) + if constexpr (!random) + if (count %= 8; count == 1 || count == 3 || count == 6) + return false; return true; } @@ -134,55 +196,89 @@ template constexpr bool is_base32z(std::basic_string_view s) { return is_base32z(s.begin(), s.end()); } constexpr bool is_base32z(std::string_view s) { return is_base32z<>(s); } -/// Converts a sequence of base32z digits to bytes. Undefined behaviour if any characters are not -/// valid base32z alphabet characters. It is permitted for the input and output ranges to overlap -/// as long as `out` is no earlier than `begin`. Note that if you pass in a sequence that could not -/// have been created by a base32z encoding of a byte sequence, we treat the excess bits as if they -/// were not provided. +/// Iterable object for on-the-fly base32z decoding. Used internally, but also particularly useful +/// when converting from one encoding to another. The input range must be a valid base32z +/// encoded string. /// -/// For example, "yyy" represents a 15-bit value, but a byte sequence is either 8-bit (requiring 2 -/// characters) or 16-bit (requiring 4). Similarly, "yb" is an impossible encoding because it has -/// its 10th bit set (b = 00001), but a base32z encoded value should have all 0's beyond the 8th (or -/// 16th or 24th or ... bit). We treat any such bits as if they were not specified (even if they -/// are): which means "yy", "yb", "yyy", "yy9", "yd", etc. all decode to the same 1-byte value "\0". -template -void from_base32z(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "from_base32z requires chars/bytes"); - uint_fast16_t curr = 0; - int bits = 0; // number of bits we've loaded into val; we always keep this < 8. - while (begin != end) { - curr = curr << 5 | detail::b32z_lut.from_b32z(static_cast(*begin++)); - if (bits >= 3) { - bits -= 3; // Added 5, removing 8 - *out++ = static_cast>( - static_cast(curr >> bits)); - curr &= (1 << bits) - 1; - } else { - bits += 5; - } +/// Note that we ignore "padding" bits without requiring that they actually be 0. For instance, the +/// bytes "\ff\ff" are ideally encoded as "999o" (16 bits of 1s + 4 padding 0 bits), but we don't +/// require that the padding bits be 0. That is, "9999", "9993", etc. will all decode to the same +/// \ff\ff output string. +template +struct base32z_decoder final { +private: + InputIt _it, _end; + static_assert(sizeof(decltype(*_it)) == 1, "base32z_decoder requires chars/bytes input iterator"); + uint_fast16_t in = 0; + int bits = 0; // number of bits loaded into `in`; will be in [8, 12] until we hit the end +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = char; + using reference = value_type; + using pointer = void; + base32z_decoder(InputIt begin, InputIt end) : _it{std::move(begin)}, _end{std::move(end)} { + if (_it != _end) + load_byte(); } - // Ignore any trailing bits. base32z encoding always has at least as many bits as the source - // bytes, which means we should not be able to get here from a properly encoded b32z value with - // anything other than 0s: if we have no extra bits (e.g. 5 bytes == 8 b32z chars) then we have - // a 0-bit value; if we have some extra bits (e.g. 6 bytes requires 10 b32z chars, but that - // contains 50 bits > 48 bits) then those extra bits will be 0s (and this covers the bits -= 3 - // case above: it'll leave us with 0-4 extra bits, but those extra bits would be 0 if produced - // from an actual byte sequence). - // - // The "bits += 5" case, then, means that we could end with 5-7 bits. This, however, cannot be - // produced by a valid encoding: - // - 0 bytes gives us 0 chars with 0 leftover bits - // - 1 byte gives us 2 chars with 2 leftover bits - // - 2 bytes gives us 4 chars with 4 leftover bits - // - 3 bytes gives us 5 chars with 1 leftover bit - // - 4 bytes gives us 7 chars with 3 leftover bits - // - 5 bytes gives us 8 chars with 0 leftover bits (this is where the cycle repeats) - // - // So really the only way we can get 5-7 leftover bits is if you took a 0, 2 or 5 char output (or - // any 8n + {0,2,5} char output) and added a base32z character to the end. If you do that, - // well, too bad: you're giving invalid output and so we're just going to pretend that extra - // character you added isn't there by not doing anything here. + base32z_decoder end() { return {_end, _end}; } + + bool operator==(const base32z_decoder& i) { return _it == i._it; } + bool operator!=(const base32z_decoder& i) { return _it != i._it; } + + base32z_decoder& operator++() { + // Discard 8 most significant bits + bits -= 8; + in &= (1 << bits) - 1; + if (++_it != _end) + load_byte(); + return *this; + } + base32z_decoder operator++(int) { base32z_decoder copy{*this}; ++*this; return copy; } + + char operator*() { + return in >> (bits - 8); + } + +private: + void load_in() { + in = in << 5 + | detail::b32z_lut.from_b32z(static_cast(*_it)); + bits += 5; + } + + void load_byte() { + load_in(); + if (bits < 8 && ++_it != _end) + load_in(); + + // If we hit the _end iterator above then we hit the end of the input with fewer than 8 bits + // accumulated to make a full byte. For a properly encoded base32z string this should only + // be possible with 0-4 bits of all 0s; these are essentially "padding" bits (e.g. encoding + // 2 byte (16 bits) requires 4 b32z chars (20 bits), where only the first 16 bits are + // significant). Ideally any padding bits should be 0, but we don't check that and rather + // just ignore them. + // + // It also isn't possible to get here with 5-7 bits if the string passes `is_base32z` + // because the length checks we do there disallow such a length as valid. (If you were to + // pass such a string to us anyway then we are technically UB, but the current + // implementation just ignore the extra bits as if they are extra padding). + } +}; + +/// Converts a sequence of base32z digits to bytes. Undefined behaviour if any characters are not +/// valid base32z alphabet characters. It is permitted for the input and output ranges to overlap +/// as long as `out` is no later than `begin`. +/// +template +OutputIt from_base32z(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "from_base32z requires chars/bytes"); + base32z_decoder it{begin, end}; + auto bend = it.end(); + while (it != bend) + *out++ = static_cast>(*it++); + return out; } /// Convert a base32z sequence into a std::string of bytes. Undefined behaviour if any characters @@ -190,8 +286,10 @@ void from_base32z(InputIt begin, InputIt end, OutputIt out) { template std::string from_base32z(It begin, It end) { std::string bytes; - if constexpr (std::is_base_of_v::iterator_category>) - bytes.reserve((std::distance(begin, end)*5 + 7) / 8); // == chars*5/8, rounded up. + if constexpr (std::is_base_of_v::iterator_category>) { + using std::distance; + bytes.reserve(from_base32z_size(distance(begin, end))); + } from_base32z(begin, end, std::back_inserter(bytes)); return bytes; } diff --git a/oxenmq/base64.h b/oxenmq/base64.h index de703ae..2bfac55 100644 --- a/oxenmq/base64.h +++ b/oxenmq/base64.h @@ -76,70 +76,153 @@ static_assert(b64_lut.from_b64('/') == 63 && b64_lut.from_b64('7') == 59 && b64_ } // namespace detail -/// Converts bytes into a base64 encoded character sequence. -template -void to_base64(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "to_base64 requires chars/bytes"); - int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in {0, 2, 4} - std::uint_fast16_t r = 0; - while (begin != end) { - r = r << 8 | static_cast(*begin++); - - // we just added 8 bits, so we can *always* consume 6 to produce one character, so (net) we - // are adding 2 bits. - bits += 2; - *out++ = detail::b64_lut.to_b64(r >> bits); // Right-shift off the bits we aren't consuming right now - - // Drop the bits we don't want to keep (because we just consumed them) - r &= (1 << bits) - 1; - - if (bits == 6) { // We have enough bits to produce a second character (which means we had 4 before and added 8) - bits = 0; - *out++ = detail::b64_lut.to_b64(r); - r = 0; - } - } - - // If bits == 0 then we ended our 6-bit outputs coinciding with 8-bit values, i.e. at a multiple - // of 24 bits: this means we don't have anything else to output and don't need any padding. - if (bits == 2) { - // We finished with 2 unconsumed bits, which means we ended 1 byte past a 24-bit group (e.g. - // 1 byte, 4 bytes, 301 bytes, etc.); since we need to always be a multiple of 4 output - // characters that means we've produced 1: so we right-fill 0s to get the next char, then - // add two padding ='s. - *out++ = detail::b64_lut.to_b64(r << 4); - *out++ = '='; - *out++ = '='; - } else if (bits == 4) { - // 4 bits left means we produced 2 6-bit values from the first 2 bytes of a 3-byte group. - // Fill 0s to get the last one, plus one padding output. - *out++ = detail::b64_lut.to_b64(r << 2); - *out++ = '='; - } +/// Returns the number of characters required to encode a base64 string from the given number of bytes. +inline constexpr size_t to_base64_size(size_t byte_size, bool padded = true) { + return padded + ? (byte_size + 2) / 3 * 4 // bytes*4/3, rounded up to the next multiple of 4 + : (byte_size * 4 + 2) / 3; // ⌈bytes*4/3⌉ +} +/// Returns the (maximum) number of bytes required to decode a base64 string of the given size. +/// Note that this may overallocate by 1-2 bytes if the size includes 1-2 padding chars. +inline constexpr size_t from_base64_size(size_t b64_size) { + return b64_size * 3 / 4; // == ⌊bits/8⌋; floor because we ignore trailing "impossible" bits (see below) } -/// Creates and returns a base64 string from an iterator pair of a character sequence +/// Iterable object for on-the-fly base64 encoding. Used internally, but also particularly useful +/// when converting from one encoding to another. +template +struct base64_encoder final { +private: + InputIt _it, _end; + static_assert(sizeof(decltype(*_it)) == 1, "base64_encoder requires chars/bytes input iterator"); + // How much padding (at most) we can add at the end + int padding; + // Number of bits held in r; will always be >= 6 until we are at the end. + int bits{_it != _end ? 8 : 0}; + // Holds bits of data we've already read, which might belong to current or next chars + uint_fast16_t r{bits ? static_cast(*_it) : (unsigned char)0}; +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = char; + using reference = value_type; + using pointer = void; + base64_encoder(InputIt begin, InputIt end, bool padded = true) + : _it{std::move(begin)}, _end{std::move(end)}, padding{padded} {} + + base64_encoder end() { return {_end, _end, false}; } + + bool operator==(const base64_encoder& i) { return _it == i._it && bits == i.bits && padding == i.padding; } + bool operator!=(const base64_encoder& i) { return !(*this == i); } + + base64_encoder& operator++() { + if (bits == 0) { + padding--; + return *this; + } + assert(bits >= 6); + // Discard the most significant 6 bits + bits -= 6; + r &= (1 << bits) - 1; + // If we end up with less than 6 significant bits then try to pull another 8 bits: + if (bits < 6 && _it != _end) { + if (++_it != _end) { + r = (r << 8) | static_cast(*_it); + bits += 8; + } else if (bits > 0) { + // No more input bytes, so shift `r` to put the bits we have into the most + // significant bit position for the final character, and figure out how many padding + // bytes we want to append. E.g. if we have "11" we want + // the last character to be encoded "110000". + if (padding) { + // padding should be: + // 3n+0 input => 4n output, no padding, handled below + // 3n+1 input => 4n+2 output + 2 padding; we'll land here with 2 trailing bits + // 3n+2 input => 4n+3 output + 1 padding; we'll land here with 4 trailing bits + padding = 3 - bits / 2; + } + r <<= (6 - bits); + bits = 6; + } else { + padding = 0; // No excess bits, so input was a multiple of 3 and thus no padding + } + } + return *this; + } + base64_encoder operator++(int) { base64_encoder copy{*this}; ++*this; return copy; } + + char operator*() { + if (bits == 0 && padding) + return '='; + // Right-shift off the excess bits we aren't accessing yet + return detail::b64_lut.to_b64(r >> (bits - 6)); + } +}; + +/// Converts bytes into a base64 encoded character sequence, writing them starting at `out`. +/// Returns the final value of out (i.e. the iterator positioned just after the last written base64 +/// character). +template +OutputIt to_base64(InputIt begin, InputIt end, OutputIt out, bool padded = true) { + static_assert(sizeof(decltype(*begin)) == 1, "to_base64 requires chars/bytes"); + auto it = base64_encoder{begin, end, padded}; + return std::copy(it, it.end(), out); +} + +/// Creates and returns a base64 string from an iterator pair of a character sequence. The +/// resulting string will have '=' padding, if appropriate. template std::string to_base64(It begin, It end) { std::string base64; - if constexpr (std::is_base_of_v::iterator_category>) - base64.reserve((std::distance(begin, end) + 2) / 3 * 4); // bytes*4/3, rounded up to the next multiple of 4 + if constexpr (std::is_base_of_v::iterator_category>) { + using std::distance; + base64.reserve(to_base64_size(distance(begin, end))); + } to_base64(begin, end, std::back_inserter(base64)); return base64; } -/// Creates a base64 string from an iterable, std::string-like object +/// Creates and returns a base64 string from an iterator pair of a character sequence. The +/// resulting string will not be padded. +template +std::string to_base64_unpadded(It begin, It end) { + std::string base64; + if constexpr (std::is_base_of_v::iterator_category>) { + using std::distance; + base64.reserve(to_base64_size(distance(begin, end), false)); + } + to_base64(begin, end, std::back_inserter(base64), false); + return base64; +} + +/// Creates a base64 string from an iterable, std::string-like object. The string will have '=' +/// padding, if appropriate. template std::string to_base64(std::basic_string_view s) { return to_base64(s.begin(), s.end()); } inline std::string to_base64(std::string_view s) { return to_base64<>(s); } +/// Creates a base64 string from an iterable, std::string-like object. The string will not be +/// padded. +template +std::string to_base64_unpadded(std::basic_string_view s) { return to_base64_unpadded(s.begin(), s.end()); } +inline std::string to_base64_unpadded(std::string_view s) { return to_base64_unpadded<>(s); } + /// Returns true if the range is a base64 encoded value; we allow (but do not require) '=' padding, /// but only at the end, only 1 or 2, and only if it pads out the total to a multiple of 4. +/// Otherwise the string must contain only valid base64 characters, and must not have a length of +/// 4n+1 (because that cannot be produced by base64 encoding). template constexpr bool is_base64(It begin, It end) { static_assert(sizeof(decltype(*begin)) == 1, "is_base64 requires chars/bytes"); using std::distance; using std::prev; + size_t count = 0; + constexpr bool random = std::is_base_of_v::iterator_category>; + if constexpr (random) { + count = distance(begin, end) % 4; + if (count == 1) + return false; + } // Allow 1 or 2 padding chars *if* they pad it to a multiple of 4. if (begin != end && distance(begin, end) % 4 == 0) { @@ -154,7 +237,14 @@ constexpr bool is_base64(It begin, It end) { auto c = static_cast(*begin); if (detail::b64_lut.from_b64(c) == 0 && c != 'A') return false; + if constexpr (!random) + count++; } + + if constexpr (!random) + if (count % 4 == 1) // base64 encoding will produce 4n, 4n+2, 4n+3, but never 4n+1 + return false; + return true; } @@ -163,10 +253,87 @@ template constexpr bool is_base64(std::basic_string_view s) { return is_base64(s.begin(), s.end()); } constexpr bool is_base64(std::string_view s) { return is_base64(s.begin(), s.end()); } +/// Iterable object for on-the-fly base64 decoding. Used internally, but also particularly useful +/// when converting from one encoding to another. The input range must be a valid base64 encoded +/// string (with or without padding). +/// +/// Note that we ignore "padding" bits without requiring that they actually be 0. For instance, the +/// bytes "\ff\ff" are ideally encoded as "//8=" (16 bits of 1s + 2 padding 0 bits, then a full +/// 6-bit padding char). We don't, however, require that the padding bits be 0. That is, "///=", +/// "//9=", "//+=", etc. will all decode to the same \ff\ff output string. +template +struct base64_decoder final { +private: + InputIt _it, _end; + static_assert(sizeof(decltype(*_it)) == 1, "base64_decoder requires chars/bytes input iterator"); + uint_fast16_t in = 0; + int bits = 0; // number of bits loaded into `in`; will be in [8, 12] until we hit the end +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = char; + using reference = value_type; + using pointer = void; + base64_decoder(InputIt begin, InputIt end) : _it{std::move(begin)}, _end{std::move(end)} { + if (_it != _end) + load_byte(); + } + + base64_decoder end() { return {_end, _end}; } + + bool operator==(const base64_decoder& i) { return _it == i._it; } + bool operator!=(const base64_decoder& i) { return _it != i._it; } + + base64_decoder& operator++() { + // Discard 8 most significant bits + bits -= 8; + in &= (1 << bits) - 1; + if (++_it != _end) + load_byte(); + return *this; + } + base64_decoder operator++(int) { base64_decoder copy{*this}; ++*this; return copy; } + + char operator*() { + return in >> (bits - 8); + } + +private: + void load_in() { + // We hit padding trying to read enough for a full byte, so we're done. (And since you were + // already supposed to have checked validity with is_base64, the padding can only be at the + // end). + auto c = static_cast(*_it); + if (c == '=') { + _it = _end; + bits = 0; + return; + } + + in = in << 6 + | detail::b64_lut.from_b64(c); + bits += 6; + } + + void load_byte() { + load_in(); + if (bits && bits < 8 && ++_it != _end) + load_in(); + + // If we hit the _end iterator above then we hit the end of the input (or hit padding) with + // fewer than 8 bits accumulated to make a full byte. For a properly encoded base64 string + // this should only be possible with 0, 2, or 4 bits of all 0s; these are essentially + // "padding" bits (e.g. encoding 2 byte (16 bits) requires 3 b64 chars (18 bits), where + // only the first 16 bits are significant). Ideally any padding bits should be 0, but we + // don't check that and rather just ignore them. + } +}; + /// Converts a sequence of base64 digits to bytes. Undefined behaviour if any characters are not /// valid base64 alphabet characters. It is permitted for the input and output ranges to overlap as -/// long as `out` is no earlier than `begin`. Trailing padding characters are permitted but not -/// required. +/// long as `out` is no later than `begin`. Trailing padding characters are permitted but not +/// required. Returns the final value of out (that is, the iterator positioned just after the +/// last written character). /// /// It is possible to provide "impossible" base64 encoded values; for example "YWJja" which has 30 /// bits of data even though a base64 encoded byte string should have 24 (4 chars) or 36 (6 chars) @@ -175,30 +342,13 @@ constexpr bool is_base64(std::string_view s) { return is_base64(s.begin(), s.end /// encoding of "abcd") and "YWJjZB", "YWJjZC", ..., "YWJjZP" all decode to the same "abcd" value: /// the last 4 bits of the last character are essentially considered padding. template -void from_base64(InputIt begin, InputIt end, OutputIt out) { +OutputIt from_base64(InputIt begin, InputIt end, OutputIt out) { static_assert(sizeof(decltype(*begin)) == 1, "from_base64 requires chars/bytes"); - uint_fast16_t curr = 0; - int bits = 0; // number of bits we've loaded into val; we always keep this < 8. - while (begin != end) { - auto c = static_cast(*begin++); - - // padding; don't bother checking if we're at the end because is_base64 is a precondition - // and we're allowed UB if it isn't satisfied. - if (c == '=') continue; - - curr = curr << 6 | detail::b64_lut.from_b64(c); - if (bits == 0) - bits = 6; - else { - bits -= 2; // Added 6, removing 8 - *out++ = static_cast>( - static_cast(curr >> bits)); - curr &= (1 << bits) - 1; - } - } - // Don't worry about leftover bits because either they have to be 0, or they can't happen at - // all. See base32z.h for why: the reasoning is exactly the same (except using 6 bits per - // character here instead of 5). + base64_decoder it{begin, end}; + auto bend = it.end(); + while (it != bend) + *out++ = static_cast>(*it++); + return out; } /// Converts base64 digits from a iterator pair of characters into a std::string of bytes. @@ -206,8 +356,10 @@ void from_base64(InputIt begin, InputIt end, OutputIt out) { template std::string from_base64(It begin, It end) { std::string bytes; - if constexpr (std::is_base_of_v::iterator_category>) - bytes.reserve(std::distance(begin, end)*6 / 8); // each digit carries 6 bits; this may overallocate by 1-2 bytes due to padding + if constexpr (std::is_base_of_v::iterator_category>) { + using std::distance; + bytes.reserve(from_base64_size(distance(begin, end))); + } from_base64(begin, end, std::back_inserter(bytes)); return bytes; } diff --git a/oxenmq/bt_producer.h b/oxenmq/bt_producer.h new file mode 100644 index 0000000..cf866c1 --- /dev/null +++ b/oxenmq/bt_producer.h @@ -0,0 +1,306 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace oxenmq { + +using namespace std::literals; + +class bt_dict_producer; + +#if defined(__APPLE__) && defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500 +#define OXENMQ_APPLE_TO_CHARS_WORKAROUND +/// Really simplistic version of std::to_chars on Apple, because Apple doesn't allow `std::to_chars` +/// to be used if targetting anything before macOS 10.15. The buffer must have at least 20 chars of +/// space (for int types up to 64-bit); we return a pointer one past the last char written. +template +char* apple_to_chars10(char* buf, IntType val) { + static_assert(std::is_integral_v && sizeof(IntType) <= 64); + if constexpr (std::is_signed_v) { + if (val < 0) { + buf[0] = '-'; + return apple_to_chars10(buf+1, static_cast>(-val)); + } + } + + // write it to the buffer in reverse (because we don't know how many chars we'll need yet, but + // writing in reverse will figure that out). + char* pos = buf; + do { + *pos++ = '0' + static_cast(val % 10); + val /= 10; + } while (val > 0); + + // Reverse the digits into the right order + int swaps = (pos - buf) / 2; + for (int i = 0; i < swaps; i++) + std::swap(buf[i], pos[-1 - i]); + + return pos; +} +#endif + + +/// Class that allows you to build a bt-encoded list manually, without copying or allocating memory. +/// This is essentially the reverse of bt_list_consumer: where it lets you stream-parse a buffer, +/// this class lets you build directly into a buffer that you own. +/// +/// Out-of-buffer-space errors throw +class bt_list_producer { + friend class bt_dict_producer; + + // Our pointers to the next write position and the past-the-end pointer of the buffer. + using buf_span = std::pair; + // Our data is a begin/end pointer pair for the root list, or a pointer to our parent if a + // sublist. + std::variant data; + // Reference to the write buffer; this is simply a reference to the value inside `data` for the + // root element, and a pointer to the root's value for sublists/subdicts. + buf_span& buffer; + // True indicates we have an open child list/dict + bool has_child = false; + // The range that contains this currently serialized value; `from` equals wherever the `l` was + // written that started this list and `to` is one past the `e` that ends it. Note that `to` + // will always be ahead of `buf_span.first` because we always write the `e`s to close open lists + // but these `e`s don't advance the write position (they will be overwritten if we append). + const char* const from; + const char* to; + + // Sublist constructors + bt_list_producer(bt_list_producer* parent, std::string_view prefix = "l"sv); + bt_list_producer(bt_dict_producer* parent, std::string_view prefix = "l"sv); + + // Does the actual appending to the buffer, and throwing if we'd overrun. If advance is false + // then we append without moving the buffer pointer (primarily when we append intermediate `e`s + // that we will overwrite if more data is added). This means that the next write will overwrite + // whatever was previously written by an `advance=false` call. + void buffer_append(std::string_view d, bool advance = true); + + // Appends the 'e's into the buffer to close off open sublists/dicts *without* advancing the + // buffer position; we do this after each append so that the buffer always contains valid + // encoded data, even while we are still appending to it, and so that appending something raises + // a length_error if appending it would not leave enough space for the required e's to close the + // open list(s)/dict(s). + void append_intermediate_ends(size_t count = 1); + + // Writes an integer to the given buffer; returns the one-past-the-data pointer. Up to 20 bytes + // will be written and must be available in buf. Used for both string and integer + // serialization. + template + char* write_integer(IntType val, char* buf) { + static_assert(sizeof(IntType) <= 64); + +#ifndef OXENMQ_APPLE_TO_CHARS_WORKAROUND + auto [ptr, ec] = std::to_chars(buf, buf+20, val); + assert(ec == std::errc()); + return ptr; +#else + // Hate apple. + return apple_to_chars10(buf, val); +#endif + } + + // Serializes an integer value and appends it to the output buffer. Does not call + // append_intermediate_ends(). + template , int> = 0> + void append_impl(IntType val) { + char buf[22]; // 'i' + base10 representation + 'e' + buf[0] = 'i'; + auto* ptr = write_integer(val, buf+1); + *ptr++ = 'e'; + buffer_append({buf, static_cast(ptr-buf)}); + } + + // Appends a string value, but does not call append_intermediate_ends() + void append_impl(std::string_view s); + +public: + bt_list_producer() = delete; + bt_list_producer(const bt_list_producer&) = delete; + bt_list_producer& operator=(const bt_list_producer&) = delete; + bt_list_producer& operator=(bt_list_producer&&) = delete; + bt_list_producer(bt_list_producer&& other); + + ~bt_list_producer(); + + /// Constructs a list producer that writes into the range [begin, end). If a write would go + /// beyond the end of the buffer an exception is raised. Note that this will happen during + /// construction if the given buffer is not large enough to contain the `le` encoding of an + /// empty list. + bt_list_producer(char* begin, char* end); + + /// Constructs a list producer that writes into the range [begin, begin+size). If a write would + /// go beyond the end of the buffer an exception is raised. + bt_list_producer(char* begin, size_t len) : bt_list_producer{begin, begin + len} {} + + /// Returns a string_view into the currently serialized data buffer. Note that the returned + /// view includes the `e` list end serialization markers which will be overwritten if the list + /// (or an active sublist/subdict) is appended to. + std::string_view view() const { + return {from, static_cast(to-from)}; + } + + /// Returns the end position in the buffer. + const char* end() const { return to; } + + /// Appends an element containing binary string data + void append(std::string_view data); + + bt_list_producer& operator+=(std::string_view data) { append(data); return *this; } + + /// Appends an integer + template , int> = 0> + void append(IntType i) { + if (has_child) throw std::logic_error{"Cannot append to list when a sublist is active"}; + append_impl(i); + append_intermediate_ends(); + } + + template , int> = 0> + bt_list_producer& operator+=(IntType i) { append(i); return *this; } + + /// Appends elements from the range [from, to) to the list. This does *not* append the elements + /// as a sublist: for that you should use something like: `l.append_list().append(from, to);` + template + void append(ForwardIt from, ForwardIt to) { + if (has_child) throw std::logic_error{"Cannot append to list when a sublist is active"}; + while (from != to) + append_impl(*from++); + append_intermediate_ends(); + } + + /// Appends a sublist to this list. Returns a new bt_list_producer that references the parent + /// list. The parent cannot be added to until the sublist is destroyed. This is meant to be + /// used via RAII: + /// + /// buf data[16]; + /// bt_list_producer list{data, sizeof(data)}; + /// { + /// auto sublist = list.append_list(); + /// sublist.append(42); + /// } + /// list.append(1); + /// // `data` now contains: `lli42eei1ee` + /// + /// If doing more complex lifetime management, take care not to allow the child instance to + /// outlive the parent. + bt_list_producer append_list(); + + /// Appends a dict to this list. Returns a new bt_dict_producer that references the parent + /// list. The parent cannot be added to until the subdict is destroyed. This is meant to be + /// used via RAII (see append_list() for details). + /// + /// If doing more complex lifetime management, take care not to allow the child instance to + /// outlive the parent. + bt_dict_producer append_dict(); +}; + + +/// Class that allows you to build a bt-encoded dict manually, without copying or allocating memory. +/// This is essentially the reverse of bt_dict_consumer: where it lets you stream-parse a buffer, +/// this class lets you build directly into a buffer that you own. +/// +/// Note that bt-encoded dicts *must* be produced in (ASCII) ascending key order, but that this is +/// only tracked/enforced for non-release builds (i.e. without -DNDEBUG). +class bt_dict_producer : bt_list_producer { + friend class bt_list_producer; + + // Subdict constructors + bt_dict_producer(bt_list_producer* parent); + bt_dict_producer(bt_dict_producer* parent); + + // Checks a just-written key string to make sure it is monotonically increasing from the last + // key. Does nothing in a release build. +#ifdef NDEBUG + constexpr void check_incrementing_key(size_t) const {} +#else + // String view into the buffer where we wrote the previous key. + std::string_view last_key; + void check_incrementing_key(size_t size); +#endif + +public: + // Construction is identical to bt_list_producer + using bt_list_producer::bt_list_producer; + + /// Returns a string_view into the currently serialized data buffer. Note that the returned + /// view includes the `e` dict end serialization markers which will be overwritten if the dict + /// (or an active sublist/subdict) is appended to. + std::string_view view() const { return bt_list_producer::view(); } + + /// Appends a key-value pair with a string or integer value. The key must be > the last key + /// added, but this is only enforced (with an assertion) in debug builds. + template || std::is_integral_v, int> = 0> + void append(std::string_view key, const T& value) { + if (has_child) throw std::logic_error{"Cannot append to list when a sublist is active"}; + append_impl(key); + check_incrementing_key(key.size()); + append_impl(value); + append_intermediate_ends(); + } + + /// Appends pairs from the range [from, to) to the dict. Elements must have a .first + /// convertible to a string_view, and a .second that is either string view convertible or an + /// integer. This does *not* append the elements as a subdict: for that you should use + /// something like: `l.append_dict().append(key, from, to);` + /// + /// Also note that the range *must* be sorted by keys, which means either using an ordered + /// container (e.g. std::map) or a manually ordered container (such as a vector or list of + /// pairs). unordered_map, however, is not acceptable. + template , int> = 0> + void append(ForwardIt from, ForwardIt to) { + if (has_child) throw std::logic_error{"Cannot append to list when a sublist is active"}; + using KeyType = std::remove_cv_tfirst)>>; + using ValType = std::decay_tsecond)>; + static_assert(std::is_convertible_vfirst), std::string_view>); + static_assert(std::is_convertible_v || std::is_integral_v); + using BadUnorderedMap = std::unordered_map; + static_assert(!( // Disallow unordered_map iterators because they are not going to be ordered. + std::is_same_v || + std::is_same_v)); + while (from != to) { + const auto& [k, v] = *from++; + append_impl(k); + check_incrementing_key(k.size()); + append_impl(v); + } + append_intermediate_ends(); + } + + /// Appends a sub-dict value to this dict with the given key. Returns a new bt_dict_producer + /// that references the parent dict. The parent cannot be added to until the subdict is + /// destroyed. Key must be (ascii-comparison) larger than the previous key. + /// + /// This is meant to be used via RAII: + /// + /// buf data[32]; + /// bt_dict_producer dict{data, sizeof(data)}; + /// { + /// auto subdict = dict.begin_dict("myKey"); + /// subdict.append("x", 42); + /// } + /// dict.append("y", ""); + /// // `data` now contains: `d5:myKeyd1:xi42ee1:y0:e` + /// + /// If doing more complex lifetime management, take care not to allow the child instance to + /// outlive the parent. + bt_dict_producer append_dict(std::string_view key); + + /// Appends a list to this dict with the given key (which must be ascii-larger than the previous + /// key). Returns a new bt_list_producer that references the parent dict. The parent cannot be + /// added to until the sublist is destroyed. + /// + /// This is meant to be used via RAII (see append_dict() for details). + /// + /// If doing more complex lifetime management, take care not to allow the child instance to + /// outlive the parent. + bt_list_producer append_list(std::string_view key); +}; + +} // namespace oxenmq diff --git a/oxenmq/bt_serialize.cpp b/oxenmq/bt_serialize.cpp index e67b9bc..0c34f0e 100644 --- a/oxenmq/bt_serialize.cpp +++ b/oxenmq/bt_serialize.cpp @@ -27,6 +27,10 @@ // THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "bt_serialize.h" +#include "bt_producer.h" +#include "variant.h" + +#include #include namespace oxenmq { @@ -228,4 +232,138 @@ std::pair bt_dict_consumer::next_string() { } +bt_list_producer::bt_list_producer(bt_list_producer* parent, std::string_view prefix) + : data{parent}, buffer{parent->buffer}, from{buffer.first} { + parent->has_child = true; + buffer_append(prefix); + append_intermediate_ends(); +} + +bt_list_producer::bt_list_producer(bt_dict_producer* parent, std::string_view prefix) + : data{parent}, buffer{parent->buffer}, from{buffer.first} { + parent->has_child = true; + buffer_append(prefix); + append_intermediate_ends(); +} + +bt_list_producer::bt_list_producer(bt_list_producer&& other) + : data{std::move(other.data)}, buffer{other.buffer}, from{other.from}, to{other.to} { + if (other.has_child) throw std::logic_error{"Cannot move bt_list/dict_producer with active sublists/subdicts"}; + var::visit([](auto& x) { + if constexpr (!std::is_same_v) + x = nullptr; + }, other.data); +} + + +bt_list_producer::bt_list_producer(char* begin, char* end) + : data{buf_span{begin, end}}, buffer{*std::get_if(&data)}, from{buffer.first} { + buffer_append("l"sv); + append_intermediate_ends(); +} + +bt_list_producer::~bt_list_producer() { + var::visit([this](auto& x) { + if constexpr (!std::is_same_v) { + if (!x) + return; + + assert(!has_child); + assert(x->has_child); + x->has_child = false; + // We've already written the intermediate 'e', so just increment the buffer to + // finalize it. + buffer.first++; + } + }, data); +} + +void bt_list_producer::append(std::string_view data) { + if (has_child) throw std::logic_error{"Cannot append to list when a sublist is active"}; + append_impl(data); + append_intermediate_ends(); +} + +bt_list_producer bt_list_producer::append_list() { + if (has_child) throw std::logic_error{"Cannot call append_list while another nested list/dict is active"}; + return bt_list_producer{this}; +} + +bt_dict_producer bt_list_producer::append_dict() { + if (has_child) throw std::logic_error{"Cannot call append_dict while another nested list/dict is active"}; + return bt_dict_producer{this}; +} + + + +void bt_list_producer::buffer_append(std::string_view d, bool advance) { + var::visit([d, advance, this](auto& x) { + if constexpr (std::is_same_v) { + size_t avail = std::distance(x.first, x.second); + if (d.size() > avail) + throw std::length_error{"Cannot write bt_producer: buffer size exceeded"}; + std::copy(d.begin(), d.end(), x.first); + to = x.first + d.size(); + if (advance) + x.first += d.size(); + } else { + x->buffer_append(d, advance); + } + }, data); +} + +static constexpr std::string_view eee = "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"sv; + +void bt_list_producer::append_intermediate_ends(size_t count) { + return var::visit([this, count](auto& x) mutable { + if constexpr (std::is_same_v) { + for (; count > eee.size(); count -= eee.size()) + buffer_append(eee, false); + buffer_append(eee.substr(0, count), false); + } else { + // x is a parent pointer + x->append_intermediate_ends(count + 1); + to = x->to - 1; // Our `to` should be one 'e' before our parent's `to`. + } + }, data); +} + +void bt_list_producer::append_impl(std::string_view s) { + char buf[21]; // length + ':' + auto* ptr = write_integer(s.size(), buf); + *ptr++ = ':'; + buffer_append({buf, static_cast(ptr-buf)}); + buffer_append(s); +} + + +// Subdict constructors +bt_dict_producer::bt_dict_producer(bt_list_producer* parent) : bt_list_producer{parent, "d"sv} {} +bt_dict_producer::bt_dict_producer(bt_dict_producer* parent) : bt_list_producer{parent, "d"sv} {} + +#ifndef NDEBUG + +void bt_dict_producer::check_incrementing_key(size_t size) { + std::string_view this_key{buffer.first - size, size}; + assert(!last_key.data() || this_key > last_key); + last_key = this_key; +} + +#endif + +bt_dict_producer bt_dict_producer::append_dict(std::string_view key) { + if (has_child) throw std::logic_error{"Cannot call append_dict while another nested list/dict is active"}; + append_impl(key); + check_incrementing_key(key.size()); + return bt_dict_producer{this}; +} + +bt_list_producer bt_dict_producer::append_list(std::string_view key) { + if (has_child) throw std::logic_error{"Cannot call append_list while another nested list/dict is active"}; + append_impl(key); + check_incrementing_key(key.size()); + return bt_list_producer{this}; +} + + } // namespace oxenmq diff --git a/oxenmq/hex.h b/oxenmq/hex.h index fdfad3f..553b351 100644 --- a/oxenmq/hex.h +++ b/oxenmq/hex.h @@ -62,23 +62,65 @@ static_assert(hex_lut.from_hex('a') == 10 && hex_lut.from_hex('F') == 15 && hex_ } // namespace detail -/// Creates hex digits from a character sequence. -template -void to_hex(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "to_hex requires chars/bytes"); - for (; begin != end; ++begin) { - uint8_t c = static_cast(*begin); - *out++ = detail::hex_lut.to_hex(c >> 4); - *out++ = detail::hex_lut.to_hex(c & 0x0f); +/// Returns the number of characters required to encode a hex string from the given number of bytes. +inline constexpr size_t to_hex_size(size_t byte_size) { return byte_size * 2; } +/// Returns the number of bytes required to decode a hex string of the given size. +inline constexpr size_t from_hex_size(size_t hex_size) { return hex_size / 2; } + +/// Iterable object for on-the-fly hex encoding. Used internally, but also particularly useful when +/// converting from one encoding to another. +template +struct hex_encoder final { +private: + InputIt _it, _end; + static_assert(sizeof(decltype(*_it)) == 1, "hex_encoder requires chars/bytes input iterator"); + uint8_t c = 0; + bool second_half = false; +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = char; + using reference = value_type; + using pointer = void; + hex_encoder(InputIt begin, InputIt end) : _it{std::move(begin)}, _end{std::move(end)} {} + + hex_encoder end() { return {_end, _end}; } + + bool operator==(const hex_encoder& i) { return _it == i._it && second_half == i.second_half; } + bool operator!=(const hex_encoder& i) { return !(*this == i); } + + hex_encoder& operator++() { + second_half = !second_half; + if (!second_half) + ++_it; + return *this; } + hex_encoder operator++(int) { hex_encoder copy{*this}; ++*this; return copy; } + char operator*() { + return detail::hex_lut.to_hex(second_half + ? c & 0x0f + : (c = static_cast(*_it)) >> 4); + } +}; + +/// Creates hex digits from a character sequence given by iterators, writes them starting at `out`. +/// Returns the final value of out (i.e. the iterator positioned just after the last written +/// hex character). +template +OutputIt to_hex(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "to_hex requires chars/bytes"); + auto it = hex_encoder{begin, end}; + return std::copy(it, it.end(), out); } /// Creates a string of hex digits from a character sequence iterator pair template std::string to_hex(It begin, It end) { std::string hex; - if constexpr (std::is_base_of_v::iterator_category>) - hex.reserve(2 * std::distance(begin, end)); + if constexpr (std::is_base_of_v::iterator_category>) { + using std::distance; + hex.reserve(to_hex_size(distance(begin, end))); + } to_hex(begin, end, std::back_inserter(hex)); return hex; } @@ -101,9 +143,11 @@ template constexpr bool is_hex(It begin, It end) { static_assert(sizeof(decltype(*begin)) == 1, "is_hex requires chars/bytes"); constexpr bool ra = std::is_base_of_v::iterator_category>; - if constexpr (ra) - if (std::distance(begin, end) % 2 != 0) + if constexpr (ra) { + using std::distance; + if (distance(begin, end) % 2 != 0) return false; + } size_t count = 0; for (; begin != end; ++begin) { @@ -129,20 +173,61 @@ constexpr char from_hex_digit(unsigned char x) noexcept { /// Constructs a byte value from a pair of hex digits constexpr char from_hex_pair(unsigned char a, unsigned char b) noexcept { return (from_hex_digit(a) << 4) | from_hex_digit(b); } +/// Iterable object for on-the-fly hex decoding. Used internally but also particularly useful when +/// converting from one encoding to another. Undefined behaviour if the given iterator range is not +/// a valid hex string with even length (i.e. is_hex() should return true). +template +struct hex_decoder final { +private: + InputIt _it, _end; + static_assert(sizeof(decltype(*_it)) == 1, "hex_encoder requires chars/bytes input iterator"); + char byte; +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = char; + using reference = value_type; + using pointer = void; + hex_decoder(InputIt begin, InputIt end) : _it{std::move(begin)}, _end{std::move(end)} { + if (_it != _end) + load_byte(); + } + + hex_decoder end() { return {_end, _end}; } + + bool operator==(const hex_decoder& i) { return _it == i._it; } + bool operator!=(const hex_decoder& i) { return _it != i._it; } + + hex_decoder& operator++() { + if (++_it != _end) + load_byte(); + return *this; + } + hex_decoder operator++(int) { hex_decoder copy{*this}; ++*this; return copy; } + char operator*() const { return byte; } + +private: + void load_byte() { + auto a = *_it; + auto b = *++_it; + byte = from_hex_pair(static_cast(a), static_cast(b)); + } + +}; + /// Converts a sequence of hex digits to bytes. Undefined behaviour if any characters are not in /// [0-9a-fA-F] or if the input sequence length is not even: call `is_hex` first if you need to -/// check. It is permitted for the input and output ranges to overlap as long as out is no earlier -/// than begin. +/// check. It is permitted for the input and output ranges to overlap as long as out is no later +/// than begin. Returns the final value of out (that is, the iterator positioned just after the +/// last written character). template -void from_hex(InputIt begin, InputIt end, OutputIt out) { - using std::distance; +OutputIt from_hex(InputIt begin, InputIt end, OutputIt out) { assert(is_hex(begin, end)); - while (begin != end) { - auto a = *begin++; - auto b = *begin++; - *out++ = static_cast>( - from_hex_pair(static_cast(a), static_cast(b))); - } + auto it = hex_decoder(begin, end); + const auto hend = it.end(); + while (it != hend) + *out++ = static_cast>(*it++); + return out; } /// Converts a sequence of hex digits to a string of bytes and returns it. Undefined behaviour if @@ -150,8 +235,10 @@ void from_hex(InputIt begin, InputIt end, OutputIt out) { template std::string from_hex(It begin, It end) { std::string bytes; - if constexpr (std::is_base_of_v::iterator_category>) - bytes.reserve(std::distance(begin, end) / 2); + if constexpr (std::is_base_of_v::iterator_category>) { + using std::distance; + bytes.reserve(from_hex_size(distance(begin, end))); + } from_hex(begin, end, std::back_inserter(bytes)); return bytes; } diff --git a/oxenmq/message.h b/oxenmq/message.h index 73b96d0..aa399ad 100644 --- a/oxenmq/message.h +++ b/oxenmq/message.h @@ -75,6 +75,15 @@ public: explicit DeferredSend(Message& m) : oxenmq{m.oxenmq}, conn{m.conn}, reply_tag{m.reply_tag} {} + template + void operator()(Args &&...args) const { + if (reply_tag.empty()) + back(std::forward(args)...); + else + reply(std::forward(args)...); + }; + + /// Equivalent to msg.send_back(...), but can be invoked later. template void back(std::string_view command, Args&&... args) const; diff --git a/oxenmq/oxenmq.cpp b/oxenmq/oxenmq.cpp index c30095b..ef802d9 100644 --- a/oxenmq/oxenmq.cpp +++ b/oxenmq/oxenmq.cpp @@ -262,7 +262,9 @@ void OxenMQ::start() { } void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function on_bind) { - if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; }; + if (std::string_view{bind_addr}.substr(0, 9) == "inproc://") + throw std::logic_error{"inproc:// cannot be used with listen_curve"}; + if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; }; bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)}; if (proxy_thread.joinable()) detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d)))); @@ -271,7 +273,9 @@ void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std } void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function on_bind) { - if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; }; + if (std::string_view{bind_addr}.substr(0, 9) == "inproc://") + throw std::logic_error{"inproc:// cannot be used with listen_plain"}; + if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; }; bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)}; if (proxy_thread.joinable()) detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d)))); diff --git a/oxenmq/oxenmq.h b/oxenmq/oxenmq.h index 6b49b35..afd4629 100644 --- a/oxenmq/oxenmq.h +++ b/oxenmq/oxenmq.h @@ -82,6 +82,9 @@ inline constexpr auto DEFAULT_CONNECT_SN_KEEP_ALIVE = 5min; // The default timeout for connect_remote() inline constexpr auto REMOTE_CONNECT_TIMEOUT = 10s; +// Default timeout for connect_inproc() +inline constexpr auto INPROC_CONNECT_TIMEOUT = 50ms; + // The amount of time we wait for a reply to a REQUEST before calling the callback with // `false` to signal a timeout. inline constexpr auto DEFAULT_REQUEST_TIMEOUT = 15s; @@ -410,6 +413,9 @@ private: /// The connections to/from remotes we currently have open, both listening and outgoing. std::map connections; + /// The connection ID of the built-in inproc listener for making requests to self + int64_t inproc_listener_connid; + /// If set then it indicates a change in `connections` which means we need to rebuild pollitems /// and stop using existing connections iterators. bool connections_updated = true; @@ -442,7 +448,7 @@ private: /// indices of idle, active workers std::vector idle_workers; - /// Maximum number of general task workers, specified by g`/during construction + /// Maximum number of general task workers, specified by set_general_threads() int general_workers = std::max(1, std::thread::hardware_concurrency()); /// Maximum number of possible worker threads we can have. This is calculated when starting, @@ -780,13 +786,6 @@ public: * listening in curve25519 mode (otherwise we couldn't verify its authenticity). Should return * empty for not found or if SN lookups are not supported. * - * @param allow_incoming is a callback that OxenMQ can use to determine whether an incoming - * connection should be allowed at all and, if so, whether the connection is from a known - * service node. Called with the connecting IP, the remote's verified x25519 pubkey, and the - * called on incoming connections with the (verified) incoming connection - * pubkey (32-byte binary string) to determine whether the given SN should be allowed to - * connect. - * * @param log a function or callable object that writes a log message. If omitted then all log * messages are suppressed. * @@ -1000,7 +999,7 @@ public: * connections on this address to determine the incoming remote's access and authentication * level. Note that `allow_connection` here will be called with an empty pubkey. * - * @param bind address - can be any string zmq supports; typically a tcp IP/port combination + * @param bind address - can be any string zmq supports, for example a tcp IP/port combination * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". * * @param allow_connection function to call to determine whether to allow the connection and, if @@ -1027,7 +1026,7 @@ public: * @param pubkey - the public key (32-byte binary string) of the service node to connect to * @param options - connection options; see the structs in `connect_option`, in particular: * - keep_alive -- how long the SN connection will be kept alive after valid activity - * - remote_hint -- a remote address hint that may be used instead of doing a lookup + * - hint -- a remote address hint that may be used instead of doing a lookup * - ephemeral_routing_id -- allows you to override the EPHEMERAL_ROUTING_ID option for * this connection. * @@ -1095,6 +1094,16 @@ public: AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT); + /// Connects to the built-in in-process listening socket of this OxenMQ server for local + /// communication. Note that auth_level defaults to admin (unlike connect_remote), and the + /// default timeout is much shorter. + /// + /// Also note that incoming inproc requests are unauthenticated: that is, they will always have + /// admin-level access. + template + ConnectionID connect_inproc(ConnectSuccess on_connect, ConnectFailure on_failure, + const Option&... options); + /** * Disconnects an established outgoing connection established with `connect_remote()` (or, less * commonly, `connect_sn()`). @@ -1103,8 +1112,7 @@ public: * * @param linger how long to allow the connection to linger while there are still pending * outbound messages to it before disconnecting and dropping any pending messages. (Note that - * this lingering is internal; the disconnect_remote() call does not block). The default is 1 - * second. + * this lingering is internal; the disconnect() call does not block). The default is 1 second. * * If given a pubkey, we try to close an outgoing connection to the given SN if one exists; note * however that this is often not particularly useful as messages to that SN can immediately @@ -1118,8 +1126,8 @@ public: * if not already connected). * * If a new connection is established it will have a relatively short (30s) idle timeout. If - * the connection should stay open longer you should either call `connect(pubkey, IDLETIME)` or - * pass a a `send_option::keep_alive{IDLETIME}` in `opts`. + * the connection should stay open longer you should either call `connect_sn(pubkey, IDLETIME)` + * or pass a a `send_option::keep_alive{IDLETIME}` in `opts`. * * Note that this method (along with connect) doesn't block waiting for a connection or for the * message to send; it merely instructs the proxy thread that it should send. ZMQ will @@ -1273,9 +1281,9 @@ public: /** * Adds a timer that gets scheduled periodically in the job queue. Normally jobs are not * double-booked: that is, a new timed job will not be scheduled if the timer fires before a - * previously scheduled callback of the job has not yet completed. If you want to override this - * (so that, under heavy load or long jobs, there can be more than one of the same job scheduled - * or running at a time) then specify `squelch` as `false`. + * previously scheduled callback of the job has completed. If you want to override this (so + * that, under heavy load or long jobs, there can be more than one of the same job scheduled or + * running at a time) then specify `squelch` as `false`. * * The returned value can be kept and later passed into `cancel_timer()` if you want to be able * to cancel a timer. @@ -1411,6 +1419,8 @@ struct outgoing { /// Specifies the idle timeout for the connection - if a new or existing outgoing connection is used /// for the send and its current idle timeout setting is less than this value then it is updated. +/// +/// A negative value is treated as if the option were not supplied at all. struct keep_alive { std::chrono::milliseconds time; explicit keep_alive(std::chrono::milliseconds time) : time{std::move(time)} {} @@ -1421,6 +1431,8 @@ struct keep_alive { /// (This has no effect if specified on a non-request() call). Note that requests failures are only /// processed in the CONN_CHECK_INTERVAL timer, so it can be up to that much longer than the time /// specified here before a failure callback is invoked. +/// +/// Specifying a negative timeout is equivalent to not specifying the option at all. struct request_timeout { std::chrono::milliseconds time; explicit request_timeout(std::chrono::milliseconds time) : time{std::move(time)} {} @@ -1466,7 +1478,7 @@ namespace connect_option { /// the default (OxenMQ::EPHEMERAL_ROUTING_ID). See OxenMQ::EPHEMERAL_ROUTING_ID for a description /// of this. /// -/// Typically use: `connect_options::ephemeral_routing_id{}` or `connect_options::ephemeral_routing_id{false}`. +/// Typically use: `connect_option::ephemeral_routing_id{}` or `connect_option::ephemeral_routing_id{false}`. struct ephemeral_routing_id { bool use_ephemeral_routing_id = true; // Constructor; default construction gives you ephemeral routing id, but the bool parameter can @@ -1486,6 +1498,8 @@ struct timeout { /// milliseconds. If an outgoing connection already exists, the longer of the existing and the /// given keep alive is used. /// +/// A negative value is treated as if the keep_alive option had not been specified. +/// /// Note that, if not specified, the default keep-alive for a connection established via /// connect_sn() is 5 minutes (which is much longer than the default for an implicit connect() by /// calling send() directly with a pubkey.) @@ -1501,6 +1515,7 @@ struct keep_alive { /// potentially expensive lookup call). struct hint { std::string address; + // Constructor taking a hint. If the hint is an empty string then no hint will be used. explicit hint(std::string_view address) : address{address} {} }; @@ -1554,6 +1569,7 @@ void apply_send_option(bt_list& parts, bt_dict&, const send_option::data_parts_i /// `hint` specialization: sets the hint in the control data inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::hint& hint) { + if (hint.connect_hint.empty()) return; control_data["hint"] = hint.connect_hint; } @@ -1574,12 +1590,14 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option /// `keep_alive` specialization: increases the outgoing socket idle timeout (if shorter) inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::keep_alive& timeout) { - control_data["keep_alive"] = timeout.time.count(); + if (timeout.time >= 0ms) + control_data["keep_alive"] = timeout.time.count(); } /// `request_timeout` specialization: set the timeout time for a request inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::request_timeout& timeout) { - control_data["request_timeout"] = timeout.time.count(); + if (timeout.time >= 0ms) + control_data["request_timeout"] = timeout.time.count(); } /// `queue_failure` specialization @@ -1624,10 +1642,12 @@ inline void apply_connect_option(OxenMQ& omq, bool remote, bt_dict& opts, const else omq.log(LogLevel::warn, __FILE__, __LINE__, "connect_option::timeout ignored for connect_sn(...)"); } inline void apply_connect_option(OxenMQ& omq, bool remote, bt_dict& opts, const connect_option::keep_alive& ka) { - if (!remote) opts["keep_alive"] = ka.time.count(); + if (ka.time < 0ms) return; + else if (!remote) opts["keep_alive"] = ka.time.count(); else omq.log(LogLevel::warn, __FILE__, __LINE__, "connect_option::keep_alive ignored for connect_remote(...)"); } inline void apply_connect_option(OxenMQ& omq, bool remote, bt_dict& opts, const connect_option::hint& hint) { + if (hint.address.empty()) return; if (!remote) opts["hint"] = hint.address; else omq.log(LogLevel::warn, __FILE__, __LINE__, "connect_option::hint ignored for connect_remote(...)"); } @@ -1678,6 +1698,27 @@ ConnectionID OxenMQ::connect_sn(std::string_view pubkey, const Option&... option return pubkey; } +template +ConnectionID OxenMQ::connect_inproc(ConnectSuccess on_connect, ConnectFailure on_failure, + const Option&... options) { + bt_dict opts{ + {"timeout", INPROC_CONNECT_TIMEOUT.count()}, + {"auth_level", static_cast>(AuthLevel::admin)} + }; + + (detail::apply_connect_option(*this, true, opts, options), ...); + + auto id = next_conn_id++; + opts["conn_id"] = id; + opts["connect"] = detail::serialize_object(std::move(on_connect)); + opts["failure"] = detail::serialize_object(std::move(on_failure)); + opts["remote"] = "inproc://sn-self"; + + detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize(opts)); + + return id; +} + template void OxenMQ::send(ConnectionID to, std::string_view cmd, const T&... opts) { detail::send_control(get_control_socket(), "SEND", diff --git a/oxenmq/proxy.cpp b/oxenmq/proxy.cpp index db564e1..1bae0fd 100644 --- a/oxenmq/proxy.cpp +++ b/oxenmq/proxy.cpp @@ -411,6 +411,13 @@ void OxenMQ::proxy_loop() { saved_umask = umask(STARTUP_UMASK); #endif + { + zmq::socket_t inproc_listener{context, zmq::socket_type::router}; + inproc_listener.bind(SN_ADDR_SELF); + inproc_listener_connid = next_conn_id++; + connections.emplace_hint(connections.end(), inproc_listener_connid, std::move(inproc_listener)); + } + for (size_t i = 0; i < bind.size(); i++) { if (!proxy_bind(bind[i], i)) { LMQ_LOG(warn, "OxenMQ failed to listen on ", bind[i].address); diff --git a/oxenmq/worker.cpp b/oxenmq/worker.cpp index ae5046f..1e71629 100644 --- a/oxenmq/worker.cpp +++ b/oxenmq/worker.cpp @@ -19,7 +19,7 @@ namespace { [[gnu::always_inline]] inline bool worker_wait_for(OxenMQ& lmq, zmq::socket_t& sock, std::vector& parts, const std::string_view worker_id, const std::string_view expect) { while (true) { - lmq.log(LogLevel::debug, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect); + lmq.log(LogLevel::trace, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect); parts.clear(); recv_message_parts(sock, parts); if (parts.size() != 1) { @@ -293,6 +293,11 @@ void OxenMQ::proxy_to_worker(int64_t conn_id, zmq::socket_t& sock, std::vectorsecond; + } else if (conn_id == inproc_listener_connid) { + tmp_peer.auth_level = AuthLevel::admin; + tmp_peer.pubkey = pubkey; + tmp_peer.service_node = active_service_nodes.count(pubkey); + peer = &tmp_peer; } else { std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back()); tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey); diff --git a/tests/Catch2 b/tests/Catch2 index b3b0721..dba29b6 160000 --- a/tests/Catch2 +++ b/tests/Catch2 @@ -1 +1 @@ -Subproject commit b3b07215d1ca2224aea6ff3e21d87ad0f7750df2 +Subproject commit dba29b60d639bf8d206a9a12c223e6ed4284fb13 diff --git a/tests/common.h b/tests/common.h index 2798675..f41b9d2 100644 --- a/tests/common.h +++ b/tests/common.h @@ -5,12 +5,22 @@ using namespace oxenmq; +// Apple's mutexes, thread scheduling, and IO handling are garbage and it shows up with lots of +// spurious failures in this test suite (because it expects a system to not suck that badly), so we +// multiply the time-sensitive bits by this factor as a hack to make the test suite work. +constexpr int TIME_DILATION = +#ifdef __APPLE__ + 5; +#else + 1; +#endif + static auto startup = std::chrono::steady_clock::now(); /// Returns a localhost connection string to listen on. It can be considered random, though in -/// practice in the current implementation is sequential starting at 4500. +/// practice in the current implementation is sequential starting at 25432. inline std::string random_localhost() { - static uint16_t last = 4499; + static std::atomic last = 25432; last++; assert(last); // We should never call this enough to overflow return "tcp://127.0.0.1:" + std::to_string(last); @@ -30,7 +40,7 @@ inline void wait_for(Func f) { for (int i = 0; i < 20; i++) { if (f()) break; - std::this_thread::sleep_for(10ms); + std::this_thread::sleep_for(10ms * TIME_DILATION); } auto lock = catch_lock(); UNSCOPED_INFO("done waiting after " << (std::chrono::steady_clock::now() - start).count() << "ns"); @@ -43,7 +53,7 @@ inline void wait_for_conn(std::atomic &c) { } /// Waits enough time for us to receive a reply from a localhost remote. -inline void reply_sleep() { std::this_thread::sleep_for(10ms); } +inline void reply_sleep() { std::this_thread::sleep_for(10ms * TIME_DILATION); } inline OxenMQ::Logger get_logger(std::string prefix = "") { std::string me = "tests/common.h"; diff --git a/tests/test_bt.cpp b/tests/test_bt.cpp index b3f8b0b..d9dd019 100644 --- a/tests/test_bt.cpp +++ b/tests/test_bt.cpp @@ -1,4 +1,5 @@ #include "oxenmq/bt_serialize.h" +#include "oxenmq/bt_producer.h" #include "common.h" #include #include @@ -210,6 +211,10 @@ TEST_CASE("bt tuple serialization", "[bt][tuple][serialization]") { bt_list m{{1, 2, std::make_tuple(3, 4, "hi"sv), std::make_pair("foo"s, "bar"sv), -4}}; REQUIRE( bt_serialize(m) == "li1ei2eli3ei4e2:hiel3:foo3:barei-4ee" ); +} + +TEST_CASE("bt allocation-free consumer", "[bt][dict][list][consumer]") { + // Consumer deserialization: bt_list_consumer lc{"li1ei2eli3ei4e2:hiel3:foo3:barei-4ee"}; REQUIRE( lc.consume_integer() == 1 ); @@ -227,31 +232,91 @@ TEST_CASE("bt tuple serialization", "[bt][tuple][serialization]") { std::make_pair("b"sv, std::make_tuple(1, 2, 3)) ); } -#if 0 +TEST_CASE("bt allocation-free producer", "[bt][dict][list][producer]") { + + char smallbuf[16]; + bt_list_producer toosmall{smallbuf, 16}; // le, total = 2 + toosmall += 42; // i42e, total = 6 + toosmall += "abcdefgh"; // 8:abcdefgh, total=16 + CHECK( toosmall.view() == "li42e8:abcdefghe" ); + + CHECK_THROWS_AS( toosmall += "", std::length_error ); + + char buf[1024]; + bt_list_producer lp{buf, sizeof(buf)}; + CHECK( lp.view() == "le" ); + CHECK( (void*) lp.end() == (void*) (buf + 2) ); + + lp.append("abc"); + CHECK( lp.view() == "l3:abce" ); + lp += 42; + CHECK( lp.view() == "l3:abci42ee" ); + std::vector randos = {{1, 17, -999}}; + lp.append(randos.begin(), randos.end()); + CHECK( lp.view() == "l3:abci42ei1ei17ei-999ee" ); + { - std::cout << "zomg consumption\n"; - bt_dict_consumer dc{zomg_}; - for (int i = 0; i < 5; i++) - if (!dc.skip_until("b")) - throw std::runtime_error("Couldn't find b, but I know it's there!"); + auto sublist = lp.append_list(); + CHECK_THROWS_AS( lp.append(1), std::logic_error ); + CHECK( sublist.view() == "le" ); + CHECK( lp.view() == "l3:abci42ei1ei17ei-999elee" ); + sublist.append(0); - auto dc1 = dc; - if (dc.skip_until("z")) { - auto v = dc.consume_integer(); - std::cout << " - " << v.first << ": " << v.second << "\n"; - } else { - std::cout << " - no z (bad!)\n"; - } - - std::cout << "zomg (second pass)\n"; - for (auto &p : dc1.consume_dict().second) { - std::cout << " - " << p.first << " = (whatever)\n"; - } - while (dc1) { - auto v = dc1.consume_integer(); - std::cout << " - " << v.first << ": " << v.second << "\n"; - } + auto sublist2{std::move(sublist)}; + sublist2 += ""; + CHECK( sublist2.view() == "li0e0:e" ); + CHECK( lp.view() == "l3:abci42ei1ei17ei-999eli0e0:ee" ); } + + lp.append_list().append_list().append_list() += "omg"s; + CHECK( lp.view() == "l3:abci42ei1ei17ei-999eli0e0:elll3:omgeeee" ); + + { + auto dict = lp.append_dict(); + CHECK( dict.view() == "de" ); + CHECK( lp.view() == "l3:abci42ei1ei17ei-999eli0e0:elll3:omgeeedee" ); + + CHECK_THROWS_AS( lp.append(1), std::logic_error ); + + dict.append("foo", "bar"); + dict.append("g", 42); + + CHECK( dict.view() == "d3:foo3:bar1:gi42ee" ); + CHECK( lp.view() == "l3:abci42ei1ei17ei-999eli0e0:elll3:omgeeed3:foo3:bar1:gi42eee" ); + + dict.append_list("h").append_dict().append_dict("a").append_list("A") += 999; + CHECK( dict.view() == "d3:foo3:bar1:gi42e1:hld1:ad1:Ali999eeeeee" ); + CHECK( lp.view() == "l3:abci42ei1ei17ei-999eli0e0:elll3:omgeeed3:foo3:bar1:gi42e1:hld1:ad1:Ali999eeeeeee" ); + } +} + +#ifdef OXENMQ_APPLE_TO_CHARS_WORKAROUND +TEST_CASE("apple to_chars workaround test", "[bt][apple][sucks]") { + char buf[20]; + auto buf_view = [&](char* end) { return std::string_view{buf, static_cast(end - buf)}; }; + CHECK( buf_view(oxenmq::apple_to_chars10(buf, 0)) == "0" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, 1)) == "1" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, 2)) == "2" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, 10)) == "10" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, 42)) == "42" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, 99)) == "99" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, 1234567890)) == "1234567890" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, -1)) == "-1" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, -2)) == "-2" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, -10)) == "-10" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, -99)) == "-99" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, -1234567890)) == "-1234567890" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, char{42})) == "42" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, (unsigned char){42})) == "42" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, short{42})) == "42" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, std::numeric_limits::min())) == "-128" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, std::numeric_limits::max())) == "127" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, (unsigned char){42})) == "42" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, std::numeric_limits::max())) == "18446744073709551615" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, int64_t{-1})) == "-1" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, std::numeric_limits::min())) == "-9223372036854775808" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, int64_t{-9223372036854775807})) == "-9223372036854775807" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, int64_t{9223372036854775807})) == "9223372036854775807" ); + CHECK( buf_view(oxenmq::apple_to_chars10(buf, int64_t{9223372036854775806})) == "9223372036854775806" ); +} #endif - - diff --git a/tests/test_commands.cpp b/tests/test_commands.cpp index 606fed3..93e70c4 100644 --- a/tests/test_commands.cpp +++ b/tests/test_commands.cpp @@ -470,9 +470,9 @@ TEST_CASE("deferred replies", "[commands][send][deferred]") { std::string msg = m.data.empty() ? ""s : std::string{m.data.front()}; std::thread t{[send=m.send_later(), msg=std::move(msg)] { { auto lock = catch_lock(); UNSCOPED_INFO("sleeping"); } - std::this_thread::sleep_for(50ms); + std::this_thread::sleep_for(50ms * TIME_DILATION); { auto lock = catch_lock(); UNSCOPED_INFO("sending"); } - send.reply(msg); + send(msg); }}; t.detach(); }); @@ -516,7 +516,7 @@ TEST_CASE("deferred replies", "[commands][send][deferred]") { auto lock = catch_lock(); REQUIRE( replies.size() == 0 ); // The server waits 50ms before sending, so we shouldn't have any reply yet } - std::this_thread::sleep_for(60ms); // We're at least 70ms in now so the 50ms-delayed server responses should have arrived + std::this_thread::sleep_for(60ms * TIME_DILATION); // We're at least 70ms in now so the 50ms-delayed server responses should have arrived { std::lock_guard lq{reply_mut}; auto lock = catch_lock(); diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index 0b4b846..84e20e8 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -279,7 +279,7 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") { lmq[2]->send(pubkey[1], "sn.hi"); lmq[1]->send(pubkey[0], "BYE"); lmq[0]->send(pubkey[2], "sn.hi"); - std::this_thread::sleep_for(50ms); + std::this_thread::sleep_for(50ms * TIME_DILATION); auto lock = catch_lock(); REQUIRE(his == 5); @@ -504,3 +504,107 @@ TEST_CASE("SN backchatter", "[connect][sn]") { auto lock = catch_lock(); REQUIRE(f == "abc"); } + +TEST_CASE("inproc connections", "[connect][inproc]") { + std::string inproc_name = "foo"; + OxenMQ omq{get_logger("OMQ» "), LogLevel::trace}; + + omq.add_category("public", Access{AuthLevel::none}); + omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); }); + + omq.start(); + + std::atomic got{0}; + bool success = false; + auto c_inproc = omq.connect_inproc( + [&](auto conn) { success = true; got++; }, + [&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; } + ); + + wait_for([&got] { return got.load() > 0; }); + { + auto lock = catch_lock(); + REQUIRE( success ); + REQUIRE( got == 1 ); + } + + got = 0; + success = false; + omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) { + success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++; + }); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got == 1 ); + REQUIRE( success ); + } +} + +TEST_CASE("no explicit inproc listening", "[connect][inproc]") { + OxenMQ omq{get_logger("OMQ» "), LogLevel::trace}; + REQUIRE_THROWS_AS(omq.listen_plain("inproc://foo"), std::logic_error); + REQUIRE_THROWS_AS(omq.listen_curve("inproc://foo"), std::logic_error); +} + +TEST_CASE("inproc connection permissions", "[connect][inproc]") { + std::string listen = random_localhost(); + OxenMQ omq{get_logger("OMQ» "), LogLevel::trace}; + + omq.add_category("public", Access{AuthLevel::none}); + omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); }); + omq.add_category("private", Access{AuthLevel::admin}); + omq.add_request_command("private", "handshake", [&](Message& m) { m.send_reply("yo dude"); }); + + omq.listen_plain(listen); + + omq.start(); + + std::atomic got{0}; + bool success = false; + auto c_inproc = omq.connect_inproc( + [&](auto conn) { success = true; got++; }, + [&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; } + ); + + bool pub_success = false; + auto c_pub = omq.connect_remote(address{listen}, + [&](auto conn) { pub_success = true; got++; }, + [&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("tcp connection failed: " << reason); got++; } + ); + + wait_for([&got] { return got.load() == 2; }); + { + auto lock = catch_lock(); + REQUIRE( got == 2 ); + REQUIRE( success ); + REQUIRE( pub_success ); + } + + got = 0; + success = false; + pub_success = false; + bool success_private = false; + bool pub_success_private = false; + omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) { + success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++; + }); + omq.request(c_pub, "public.hello", [&](auto success_, auto parts_) { + pub_success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++; + }); + omq.request(c_inproc, "private.handshake", [&](auto success_, auto parts_) { + success_private = success_ && parts_.size() == 1 && parts_.front() == "yo dude"; got++; + }); + omq.request(c_pub, "private.handshake", [&](auto success_, auto parts_) { + pub_success_private = success_; got++; + }); + wait_for([&got] { return got.load() == 4; }); + { + auto lock = catch_lock(); + REQUIRE( got == 4 ); + REQUIRE( success ); + REQUIRE( pub_success ); + REQUIRE( success_private ); + REQUIRE_FALSE( pub_success_private ); + } +} diff --git a/tests/test_encoding.cpp b/tests/test_encoding.cpp index f506d8c..85427f7 100644 --- a/tests/test_encoding.cpp +++ b/tests/test_encoding.cpp @@ -44,12 +44,32 @@ TEST_CASE("hex encoding/decoding", "[encoding][decoding][hex]") { std::basic_string_view b{bytes.data(), bytes.size()}; REQUIRE( oxenmq::to_hex(b) == "ff421234"s ); + // In-place decoding and truncation via to_hex's returned iterator: + std::string some_hex = "48656c6c6f"; + some_hex.erase(oxenmq::from_hex(some_hex.begin(), some_hex.end(), some_hex.begin()), some_hex.end()); + REQUIRE( some_hex == "Hello" ); + + // Test the returned iterator from encoding + std::string hellohex; + *oxenmq::to_hex(some_hex.begin(), some_hex.end(), std::back_inserter(hellohex))++ = '!'; + REQUIRE( hellohex == "48656c6c6f!" ); + bytes.resize(8); bytes[0] = std::byte{'f'}; bytes[1] = std::byte{'f'}; bytes[2] = std::byte{'4'}; bytes[3] = std::byte{'2'}; bytes[4] = std::byte{'1'}; bytes[5] = std::byte{'2'}; bytes[6] = std::byte{'3'}; bytes[7] = std::byte{'4'}; std::basic_string_view hex_bytes{bytes.data(), bytes.size()}; REQUIRE( oxenmq::is_hex(hex_bytes) ); REQUIRE( oxenmq::from_hex(hex_bytes) == "\xff\x42\x12\x34" ); + + REQUIRE( oxenmq::to_hex_size(1) == 2 ); + REQUIRE( oxenmq::to_hex_size(2) == 4 ); + REQUIRE( oxenmq::to_hex_size(3) == 6 ); + REQUIRE( oxenmq::to_hex_size(4) == 8 ); + REQUIRE( oxenmq::to_hex_size(100) == 200 ); + REQUIRE( oxenmq::from_hex_size(2) == 1 ); + REQUIRE( oxenmq::from_hex_size(4) == 2 ); + REQUIRE( oxenmq::from_hex_size(6) == 3 ); + REQUIRE( oxenmq::from_hex_size(98) == 49 ); } TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") { @@ -99,6 +119,16 @@ TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") { REQUIRE( pk_b32z_again == pk_b32z ); REQUIRE( pk_again == pk ); + // In-place decoding and truncation via returned iterator: + std::string some_b32z = "jb1sa5dx"; + some_b32z.erase(oxenmq::from_base32z(some_b32z.begin(), some_b32z.end(), some_b32z.begin()), some_b32z.end()); + REQUIRE( some_b32z == "Hello" ); + + // Test the returned iterator from encoding + std::string hellob32z; + *oxenmq::to_base32z(some_b32z.begin(), some_b32z.end(), std::back_inserter(hellob32z))++ = '!'; + REQUIRE( hellob32z == "jb1sa5dx!" ); + std::vector bytes{{std::byte{0}, std::byte{255}}}; std::basic_string_view b{bytes.data(), bytes.size()}; REQUIRE( oxenmq::to_base32z(b) == "yd9o" ); @@ -108,6 +138,37 @@ TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") { std::basic_string_view b32_bytes{bytes.data(), bytes.size()}; REQUIRE( oxenmq::is_base32z(b32_bytes) ); REQUIRE( oxenmq::from_base32z(b32_bytes) == "\x00\xff"sv ); + + REQUIRE( oxenmq::is_base32z("") ); + REQUIRE_FALSE( oxenmq::is_base32z("y") ); + REQUIRE( oxenmq::is_base32z("yy") ); + REQUIRE_FALSE( oxenmq::is_base32z("yyy") ); + REQUIRE( oxenmq::is_base32z("yyyy") ); + REQUIRE( oxenmq::is_base32z("yyyyy") ); + REQUIRE_FALSE( oxenmq::is_base32z("yyyyyy") ); + REQUIRE( oxenmq::is_base32z("yyyyyyy") ); + REQUIRE( oxenmq::is_base32z("yyyyyyyy") ); + + REQUIRE( oxenmq::to_base32z_size(1) == 2 ); + REQUIRE( oxenmq::to_base32z_size(2) == 4 ); + REQUIRE( oxenmq::to_base32z_size(3) == 5 ); + REQUIRE( oxenmq::to_base32z_size(4) == 7 ); + REQUIRE( oxenmq::to_base32z_size(5) == 8 ); + REQUIRE( oxenmq::to_base32z_size(30) == 48 ); + REQUIRE( oxenmq::to_base32z_size(31) == 50 ); + REQUIRE( oxenmq::to_base32z_size(32) == 52 ); + REQUIRE( oxenmq::to_base32z_size(33) == 53 ); + REQUIRE( oxenmq::to_base32z_size(100) == 160 ); + REQUIRE( oxenmq::from_base32z_size(160) == 100 ); + REQUIRE( oxenmq::from_base32z_size(53) == 33 ); + REQUIRE( oxenmq::from_base32z_size(52) == 32 ); + REQUIRE( oxenmq::from_base32z_size(50) == 31 ); + REQUIRE( oxenmq::from_base32z_size(48) == 30 ); + REQUIRE( oxenmq::from_base32z_size(8) == 5 ); + REQUIRE( oxenmq::from_base32z_size(7) == 4 ); + REQUIRE( oxenmq::from_base32z_size(5) == 3 ); + REQUIRE( oxenmq::from_base32z_size(4) == 2 ); + REQUIRE( oxenmq::from_base32z_size(2) == 1 ); } TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { @@ -125,6 +186,13 @@ TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { REQUIRE( oxenmq::to_base64("abcde") == "YWJjZGU=" ); REQUIRE( oxenmq::to_base64("abcdef") == "YWJjZGVm" ); + REQUIRE( oxenmq::to_base64_unpadded("a") == "YQ" ); + REQUIRE( oxenmq::to_base64_unpadded("ab") == "YWI" ); + REQUIRE( oxenmq::to_base64_unpadded("abc") == "YWJj" ); + REQUIRE( oxenmq::to_base64_unpadded("abcd") == "YWJjZA" ); + REQUIRE( oxenmq::to_base64_unpadded("abcde") == "YWJjZGU" ); + REQUIRE( oxenmq::to_base64_unpadded("abcdef") == "YWJjZGVm" ); + REQUIRE( oxenmq::to_base64("\0\0\0\xff"s) == "AAAA/w==" ); REQUIRE( oxenmq::to_base64("\0\0\0\xff\xff"s) == "AAAA//8=" ); REQUIRE( oxenmq::to_base64("\0\0\0\xff\xff\xff"s) == "AAAA////" ); @@ -159,6 +227,7 @@ TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { REQUIRE( oxenmq::is_base64("YWJjZB") ); // not really valid, but we explicitly accept it REQUIRE_FALSE( oxenmq::is_base64("YWJjZ=") ); // invalid padding (padding can only be 4th or 3rd+4th of a 4-char block) + REQUIRE_FALSE( oxenmq::is_base64("YYYYA") ); // invalid: base64 can never be length 4n+1 REQUIRE_FALSE( oxenmq::is_base64("YWJj=") ); REQUIRE_FALSE( oxenmq::is_base64("YWJj=A") ); REQUIRE_FALSE( oxenmq::is_base64("YWJjA===") ); @@ -189,6 +258,16 @@ TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { REQUIRE( pk_b64_again == pk_b64 ); REQUIRE( pk_again == pk ); + // In-place decoding and truncation via returned iterator: + std::string some_b64 = "SGVsbG8="; + some_b64.erase(oxenmq::from_base64(some_b64.begin(), some_b64.end(), some_b64.begin()), some_b64.end()); + REQUIRE( some_b64 == "Hello" ); + + // Test the returned iterator from encoding + std::string hellob64; + *oxenmq::to_base64(some_b64.begin(), some_b64.end(), std::back_inserter(hellob64))++ = '!'; + REQUIRE( hellob64 == "SGVsbG8=!" ); + std::vector bytes{{std::byte{0}, std::byte{255}}}; std::basic_string_view b{bytes.data(), bytes.size()}; REQUIRE( oxenmq::to_base64(b) == "AP8=" ); @@ -198,6 +277,114 @@ TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { std::basic_string_view b64_bytes{bytes.data(), bytes.size()}; REQUIRE( oxenmq::is_base64(b64_bytes) ); REQUIRE( oxenmq::from_base64(b64_bytes) == "\xff\x00"sv ); + + REQUIRE( oxenmq::to_base64_size(1) == 4 ); + REQUIRE( oxenmq::to_base64_size(2) == 4 ); + REQUIRE( oxenmq::to_base64_size(3) == 4 ); + REQUIRE( oxenmq::to_base64_size(4) == 8 ); + REQUIRE( oxenmq::to_base64_size(5) == 8 ); + REQUIRE( oxenmq::to_base64_size(6) == 8 ); + REQUIRE( oxenmq::to_base64_size(30) == 40 ); + REQUIRE( oxenmq::to_base64_size(31) == 44 ); + REQUIRE( oxenmq::to_base64_size(32) == 44 ); + REQUIRE( oxenmq::to_base64_size(33) == 44 ); + REQUIRE( oxenmq::to_base64_size(100) == 136 ); + REQUIRE( oxenmq::from_base64_size(136) == 102 ); // Not symmetric because we don't know the last two are padding + REQUIRE( oxenmq::from_base64_size(134) == 100 ); // Unpadded + REQUIRE( oxenmq::from_base64_size(44) == 33 ); + REQUIRE( oxenmq::from_base64_size(43) == 32 ); + REQUIRE( oxenmq::from_base64_size(42) == 31 ); + REQUIRE( oxenmq::from_base64_size(40) == 30 ); + REQUIRE( oxenmq::from_base64_size(8) == 6 ); + REQUIRE( oxenmq::from_base64_size(7) == 5 ); + REQUIRE( oxenmq::from_base64_size(6) == 4 ); + REQUIRE( oxenmq::from_base64_size(4) == 3 ); + REQUIRE( oxenmq::from_base64_size(3) == 2 ); + REQUIRE( oxenmq::from_base64_size(2) == 1 ); +} + +TEST_CASE("transcoding", "[decoding][encoding][base32z][hex][base64]") { + // Decoders: + oxenmq::base64_decoder in64{pk_b64.begin(), pk_b64.end()}; + oxenmq::base32z_decoder in32z{pk_b32z.begin(), pk_b32z.end()}; + oxenmq::hex_decoder in16{pk_hex.begin(), pk_hex.end()}; + + // Transcoders: + oxenmq::base32z_encoder b64_to_b32z{in64, in64.end()}; + oxenmq::base32z_encoder hex_to_b32z{in16, in16.end()}; + oxenmq::hex_encoder b64_to_hex{in64, in64.end()}; + oxenmq::hex_encoder b32z_to_hex{in32z, in32z.end()}; + oxenmq::base64_encoder hex_to_b64{in16, in16.end()}; + oxenmq::base64_encoder b32z_to_b64{in32z, in32z.end()}; + // These ones are stupid, but should work anyway: + oxenmq::base64_encoder b64_to_b64{in64, in64.end()}; + oxenmq::base32z_encoder b32z_to_b32z{in32z, in32z.end()}; + oxenmq::hex_encoder hex_to_hex{in16, in16.end()}; + + // Decoding to bytes: + std::string x; + auto xx = std::back_inserter(x); + std::copy(in64, in64.end(), xx); + REQUIRE( x == pk ); + x.clear(); + std::copy(in32z, in32z.end(), xx); + REQUIRE( x == pk ); + x.clear(); + std::copy(in16, in16.end(), xx); + REQUIRE( x == pk ); + + // Transcoding + x.clear(); + std::copy(b64_to_hex, b64_to_hex.end(), xx); + CHECK( x == pk_hex ); + + x.clear(); + std::copy(b64_to_b32z, b64_to_b32z.end(), xx); + CHECK( x == pk_b32z ); + + x.clear(); + std::copy(b64_to_b64, b64_to_b64.end(), xx); + CHECK( x == pk_b64 ); + + x.clear(); + std::copy(b32z_to_hex, b32z_to_hex.end(), xx); + CHECK( x == pk_hex ); + + x.clear(); + std::copy(b32z_to_b32z, b32z_to_b32z.end(), xx); + CHECK( x == pk_b32z ); + + x.clear(); + std::copy(b32z_to_b64, b32z_to_b64.end(), xx); + CHECK( x == pk_b64 ); + + x.clear(); + std::copy(hex_to_hex, hex_to_hex.end(), xx); + CHECK( x == pk_hex ); + + x.clear(); + std::copy(hex_to_b32z, hex_to_b32z.end(), xx); + CHECK( x == pk_b32z ); + + x.clear(); + std::copy(hex_to_b64, hex_to_b64.end(), xx); + CHECK( x == pk_b64 ); + + // Make a big chain of conversions + oxenmq::base32z_encoder it1{in64, in64.end()}; + oxenmq::base32z_decoder it2{it1, it1.end()}; + oxenmq::base64_encoder it3{it2, it2.end()}; + oxenmq::base64_decoder it4{it3, it3.end()}; + oxenmq::hex_encoder it5{it4, it4.end()}; + x.clear(); + std::copy(it5, it5.end(), xx); + CHECK( x == pk_hex ); + + // No-padding b64 encoding: + oxenmq::base64_encoder b64_nopad{pk.begin(), pk.end(), false}; + x.clear(); + std::copy(b64_nopad, b64_nopad.end(), xx); + CHECK( x == pk_b64.substr(0, pk_b64.size()-1) ); } TEST_CASE("std::byte decoding", "[decoding][hex][base32z][base64]") { diff --git a/tests/test_timer.cpp b/tests/test_timer.cpp index 93df350..b51a39b 100644 --- a/tests/test_timer.cpp +++ b/tests/test_timer.cpp @@ -15,9 +15,10 @@ TEST_CASE("timer test", "[timer][basic]") { auto start = std::chrono::steady_clock::now(); wait_for([&] { return ticks.load() > 3; }); { + auto elapsed_ms = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); auto lock = catch_lock(); REQUIRE( ticks.load() > 3 ); - REQUIRE( std::chrono::steady_clock::now() - start < 40ms ); + REQUIRE( elapsed_ms < 50 * TIME_DILATION ); } } @@ -35,13 +36,13 @@ TEST_CASE("timer squelch", "[timer][squelch]") { // finishes, by which point we set `done` and so should get exactly 1 tick. auto timer = omq.add_timer([&] { if (first.exchange(false)) { - std::this_thread::sleep_for(30ms); + std::this_thread::sleep_for(30ms * TIME_DILATION); ticks++; done = true; } else if (!done) { ticks++; } - }, 5ms, true /* squelch */); + }, 5ms * TIME_DILATION, true /* squelch */); omq.start(); wait_for([&] { return done.load(); }); @@ -58,7 +59,7 @@ TEST_CASE("timer squelch", "[timer][squelch]") { std::atomic ticks2 = 0; auto timer2 = omq.add_timer([&] { if (first2.exchange(false)) { - std::this_thread::sleep_for(30ms); + std::this_thread::sleep_for(40ms * TIME_DILATION); done2 = true; } else if (!done2) { ticks2++; @@ -82,13 +83,13 @@ TEST_CASE("timer cancel", "[timer][cancel]") { std::atomic ticks = 0; // We set up *and cancel* this timer before omq starts, so it should never fire - auto notimer = omq.add_timer([&] { ticks += 1000; }, 5ms); + auto notimer = omq.add_timer([&] { ticks += 1000; }, 5ms * TIME_DILATION); omq.cancel_timer(notimer); TimerID timer = omq.add_timer([&] { if (++ticks == 3) omq.cancel_timer(timer); - }, 5ms); + }, 5ms * TIME_DILATION); omq.start();