diff --git a/CMakeLists.txt b/CMakeLists.txt index 01dde97..247a127 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,65 +1,83 @@ + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +find_program(CCACHE_PROGRAM ccache) +if(CCACHE_PROGRAM) + foreach(lang C CXX) + if(NOT DEFINED CMAKE_${lang}_COMPILER_LAUNCHER AND NOT CMAKE_${lang}_COMPILER MATCHES ".*/ccache") + message(STATUS "Enabling ccache for ${lang}") + set(CMAKE_${lang}_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE STRING "") + endif() + endforeach() +endif() + cmake_minimum_required(VERSION 3.7) # Has to be set before `project()`, and ignored on non-macos: set(CMAKE_OSX_DEPLOYMENT_TARGET 10.12 CACHE STRING "macOS deployment target (Apple clang only)") -project(liblokimq CXX C) +project(liboxenmq CXX C) include(GNUInstallDirs) -set(LOKIMQ_VERSION_MAJOR 1) -set(LOKIMQ_VERSION_MINOR 2) -set(LOKIMQ_VERSION_PATCH 2) -set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}") -message(STATUS "lokimq v${LOKIMQ_VERSION}") +set(OXENMQ_VERSION_MAJOR 1) +set(OXENMQ_VERSION_MINOR 2) +set(OXENMQ_VERSION_PATCH 3) +set(OXENMQ_VERSION "${OXENMQ_VERSION_MAJOR}.${OXENMQ_VERSION_MINOR}.${OXENMQ_VERSION_PATCH}") +message(STATUS "oxenmq v${OXENMQ_VERSION}") -set(LOKIMQ_LIBVERSION 0) +set(OXENMQ_LIBVERSION 0) if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) - set(lokimq_IS_TOPLEVEL_PROJECT TRUE) + set(oxenmq_IS_TOPLEVEL_PROJECT TRUE) else() - set(lokimq_IS_TOPLEVEL_PROJECT FALSE) + set(oxenmq_IS_TOPLEVEL_PROJECT FALSE) endif() option(BUILD_SHARED_LIBS "Build shared libraries instead of static ones" ON) -set(lokimq_INSTALL_DEFAULT OFF) -if(BUILD_SHARED_LIBS OR lokimq_IS_TOPLEVEL_PROJECT) - set(lokimq_INSTALL_DEFAULT ON) +set(oxenmq_INSTALL_DEFAULT OFF) +if(BUILD_SHARED_LIBS OR oxenmq_IS_TOPLEVEL_PROJECT) + set(oxenmq_INSTALL_DEFAULT ON) endif() -option(LOKIMQ_BUILD_TESTS "Building and perform lokimq tests" ${lokimq_IS_TOPLEVEL_PROJECT}) -option(LOKIMQ_INSTALL "Add lokimq libraries and headers to cmake install target; defaults to ON if BUILD_SHARED_LIBS is enabled or we are the top-level project; OFF for a static subdirectory build" ${lokimq_INSTALL_DEFAULT}) -option(LOKIMQ_INSTALL_CPPZMQ "Install cppzmq header with lokimq/ headers (requires LOKIMQ_INSTALL)" ON) +option(OXENMQ_BUILD_TESTS "Building and perform oxenmq tests" ${oxenmq_IS_TOPLEVEL_PROJECT}) +option(OXENMQ_INSTALL "Add oxenmq libraries and headers to cmake install target; defaults to ON if BUILD_SHARED_LIBS is enabled or we are the top-level project; OFF for a static subdirectory build" ${oxenmq_INSTALL_DEFAULT}) +option(OXENMQ_INSTALL_CPPZMQ "Install cppzmq header with oxenmq/ headers (requires OXENMQ_INSTALL)" ON) +option(OXENMQ_LOKIMQ_COMPAT "Install lokimq compatibility headers and pkg-config for rename migration" ON) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") -configure_file(lokimq/version.h.in lokimq/version.h @ONLY) -configure_file(liblokimq.pc.in liblokimq.pc @ONLY) +configure_file(oxenmq/version.h.in oxenmq/version.h @ONLY) +configure_file(liboxenmq.pc.in liboxenmq.pc @ONLY) +if(OXENMQ_LOKIMQ_COMPAT) + configure_file(liblokimq.pc.in liblokimq.pc @ONLY) +endif() -add_library(lokimq - lokimq/address.cpp - lokimq/auth.cpp - lokimq/bt_serialize.cpp - lokimq/connections.cpp - lokimq/jobs.cpp - lokimq/lokimq.cpp - lokimq/proxy.cpp - lokimq/worker.cpp + +add_library(oxenmq + oxenmq/address.cpp + oxenmq/auth.cpp + oxenmq/bt_serialize.cpp + oxenmq/connections.cpp + oxenmq/jobs.cpp + oxenmq/oxenmq.cpp + oxenmq/proxy.cpp + oxenmq/worker.cpp ) -set_target_properties(lokimq PROPERTIES SOVERSION ${LOKIMQ_LIBVERSION}) +set_target_properties(oxenmq PROPERTIES SOVERSION ${OXENMQ_LIBVERSION}) set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) -target_link_libraries(lokimq PRIVATE Threads::Threads) +target_link_libraries(oxenmq PRIVATE Threads::Threads) # libzmq is nearly impossible to link statically from a system-installed static library: it depends # on a ton of other libraries, some of which are not all statically available. If the caller wants # to mess with this, so be it: they can set up a libzmq target and we'll use it. Otherwise if they # asked us to do things statically, don't even try to find a system lib and just build it. -set(lokimq_build_static_libzmq OFF) +set(oxenmq_build_static_libzmq OFF) if(TARGET libzmq) - target_link_libraries(lokimq PUBLIC libzmq) + target_link_libraries(oxenmq PUBLIC libzmq) elseif(BUILD_SHARED_LIBS) include(FindPkgConfig) pkg_check_modules(libzmq libzmq>=4.3 IMPORTED_TARGET) @@ -75,30 +93,30 @@ elseif(BUILD_SHARED_LIBS) set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc}) endif() - target_link_libraries(lokimq PUBLIC PkgConfig::libzmq) + target_link_libraries(oxenmq PUBLIC PkgConfig::libzmq) else() - set(lokimq_build_static_libzmq ON) + set(oxenmq_build_static_libzmq ON) endif() else() - set(lokimq_build_static_libzmq ON) + set(oxenmq_build_static_libzmq ON) endif() -if(lokimq_build_static_libzmq) +if(oxenmq_build_static_libzmq) message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled version") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/local-libzmq") include(LocalLibzmq) - target_link_libraries(lokimq PUBLIC libzmq_vendor) + target_link_libraries(oxenmq PUBLIC libzmq_vendor) endif() -target_include_directories(lokimq +target_include_directories(oxenmq PUBLIC $ $ $ ) -target_compile_options(lokimq PRIVATE -Wall -Wextra -Werror) -set_target_properties(lokimq PROPERTIES +target_compile_options(oxenmq PRIVATE -Wall -Wextra -Werror) +set_target_properties(oxenmq PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF @@ -117,8 +135,8 @@ endfunction() # If the caller has already set up a sodium target then we will just link to it, otherwise we go # looking for it. if(TARGET sodium) - target_link_libraries(lokimq PUBLIC sodium) - if(lokimq_build_static_libzmq) + target_link_libraries(oxenmq PUBLIC sodium) + if(oxenmq_build_static_libzmq) target_link_libraries(libzmq_vendor INTERFACE sodium) endif() else() @@ -126,67 +144,95 @@ else() pkg_check_modules(sodium REQUIRED libsodium IMPORTED_TARGET) if(BUILD_SHARED_LIBS) - target_link_libraries(lokimq PUBLIC PkgConfig::sodium) - if(lokimq_build_static_libzmq) + target_link_libraries(oxenmq PUBLIC PkgConfig::sodium) + if(oxenmq_build_static_libzmq) target_link_libraries(libzmq_vendor INTERFACE PkgConfig::sodium) endif() else() - link_dep_libs(lokimq PUBLIC "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES}) - target_include_directories(lokimq PUBLIC ${sodium_STATIC_INCLUDE_DIRS}) - if(lokimq_build_static_libzmq) + link_dep_libs(oxenmq PUBLIC "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES}) + target_include_directories(oxenmq PUBLIC ${sodium_STATIC_INCLUDE_DIRS}) + if(oxenmq_build_static_libzmq) link_dep_libs(libzmq_vendor INTERFACE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES}) target_link_libraries(libzmq_vendor INTERFACE ${sodium_STATIC_INCLUDE_DIRS}) endif() endif() endif() -add_library(lokimq::lokimq ALIAS lokimq) +add_library(oxenmq::oxenmq ALIAS oxenmq) +if(OXENMQ_LOKIMQ_COMPAT) + add_library(lokimq::lokimq ALIAS oxenmq) +endif() export( - TARGETS lokimq - NAMESPACE lokimq:: - FILE lokimqTargets.cmake + TARGETS oxenmq + NAMESPACE oxenmq:: + FILE oxenmqTargets.cmake ) -if(LOKIMQ_INSTALL) +if(OXENMQ_INSTALL) install( - TARGETS lokimq - EXPORT lokimqConfig + TARGETS oxenmq + EXPORT oxenmqConfig DESTINATION ${CMAKE_INSTALL_LIBDIR} ) install( - FILES lokimq/address.h - lokimq/auth.h - lokimq/base32z.h - lokimq/base64.h - lokimq/batch.h - lokimq/bt_serialize.h - lokimq/bt_value.h - lokimq/connections.h - lokimq/hex.h - lokimq/lokimq.h - lokimq/message.h - lokimq/string_view.h - lokimq/variant.h - ${CMAKE_CURRENT_BINARY_DIR}/lokimq/version.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq + FILES oxenmq/address.h + oxenmq/auth.h + oxenmq/base32z.h + oxenmq/base64.h + oxenmq/batch.h + oxenmq/bt_serialize.h + oxenmq/bt_value.h + oxenmq/connections.h + oxenmq/hex.h + oxenmq/oxenmq.h + oxenmq/message.h + oxenmq/variant.h + ${CMAKE_CURRENT_BINARY_DIR}/oxenmq/version.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq ) - if(LOKIMQ_INSTALL_CPPZMQ) + if(OXENMQ_INSTALL_CPPZMQ) install( FILES cppzmq/zmq.hpp - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq ) endif() install( - FILES ${CMAKE_CURRENT_BINARY_DIR}/liblokimq.pc + FILES ${CMAKE_CURRENT_BINARY_DIR}/liboxenmq.pc DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig ) + + if(OXENMQ_LOKIMQ_COMPAT) + install( + FILES lokimq/address.h + lokimq/auth.h + lokimq/base32z.h + lokimq/base64.h + lokimq/batch.h + lokimq/bt_serialize.h + lokimq/bt_value.h + lokimq/connections.h + lokimq/hex.h + lokimq/lokimq.h + lokimq/message.h + lokimq/variant.h + lokimq/version.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq + ) + + install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/liblokimq.pc + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig + ) + endif() + + endif() -if(LOKIMQ_BUILD_TESTS) +if(OXENMQ_BUILD_TESTS) add_subdirectory(tests) endif() diff --git a/liblokimq.pc.in b/liblokimq.pc.in index df5a76f..3c707a6 100644 --- a/liblokimq.pc.in +++ b/liblokimq.pc.in @@ -4,10 +4,10 @@ libdir=@CMAKE_INSTALL_FULL_LIBDIR@ includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ Name: liblokimq -Description: ZeroMQ-based communication library for Loki -Version: @LOKIMQ_VERSION@ +Description: ZeroMQ-based communication library (compatibility package for liboxenmq) +Version: @OXENMQ_VERSION@ -Libs: -L${libdir} -llokimq +Libs: -L${libdir} -loxenmq Libs.private: @PRIVATE_LIBS@ Requires.private: libzmq libsodium Cflags: -I${includedir} diff --git a/liboxenmq.pc.in b/liboxenmq.pc.in new file mode 100644 index 0000000..0f7c3e1 --- /dev/null +++ b/liboxenmq.pc.in @@ -0,0 +1,13 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: liboxenmq +Description: ZeroMQ-based communication library +Version: @OXENMQ_VERSION@ + +Libs: -L${libdir} -loxenmq +Libs.private: @PRIVATE_LIBS@ +Requires.private: libzmq libsodium +Cflags: -I${includedir} diff --git a/lokimq/address.h b/lokimq/address.h index 6c05888..5239b39 100644 --- a/lokimq/address.h +++ b/lokimq/address.h @@ -1,210 +1,4 @@ -// Copyright (c) 2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once -#include -#include -#include -#include +#include "../oxenmq/address.h" -namespace lokimq { - -using namespace std::literals; - -/** LokiMQ address abstraction class. This class uses and extends standard ZMQ addresses allowing - * extra parameters to be passed in in a relative standard way. - * - * External ZMQ addresses generally have two forms that we are concerned with: one for TCP and one - * for Unix sockets: - * - * tcp://HOST:PORT -- HOST can be a hostname, IPv4 address, or IPv6 address in [...] - * ipc://PATH -- PATH can be absolute (ipc:///path/to/some.sock) or relative (ipc://some.sock) - * - * but this doesn't carry enough info: in particular, we can connect with two very different - * protocols: curve25519-encrypted, or plaintext, but for curve25519-encrypted we require the - * remote's public key as well to verify the connection. - * - * This class, then, handles this by allowing addresses of: - * - * Standard ZMQ address: these carry no pubkey and so the connection will be unencrypted: - * - * tcp://HOSTNAME:PORT - * ipc://PATH - * - * Non-ZMQ address formats that specify that the connection shall be x25519 encrypted: - * - * curve://HOSTNAME:PORT/PUBKEY -- PUBKEY must be specified in hex (64 characters), base32z (52) - * or base64 (43 or 44 with one '=' trailing padding) - * ipc+curve:///path/to/my.sock/PUBKEY -- same requirements on PUBKEY as above. - * tcp+curve://(whatever) -- alias for curve://(whatever) - * - * We also accept special upper-case TCP-only variants which *only* accept uppercase characters and - * a few required symbols (:, /, $, ., and -) in the string: - * - * TCP://HOSTNAME:PORT - * CURVE://HOSTNAME:PORT/B32ZPUBKEY - * - * These versions are explicitly meant to be used with QR codes; the upper-case-only requirement - * allows a smaller QR code by allowing QR's alphanumeric mode (which allows only [A-Z0-9 $%*+./:-]) - * to be used. Such a QR-friendly address can be created from the qr_address() method. To support - * literal IPv6 addresses we surround the address with $...$ instead of the usual [...]. - * - * Note that this class does very little validate the host argument at all, and no socket path - * validation whatsoever. The only constraint on host is when parsing an encoded address: we check - * that it contains no : at all, or must be a [bracketed] expression that contains only hex - * characters, :'s, or .'s. Otherwise, if you pass broken crap into the hostname, expect broken - * crap out. - */ -struct address { - /// Supported address protocols: TCP connections (tcp), or unix sockets (ipc). - enum class proto { - tcp, - tcp_curve, - ipc, - ipc_curve - }; - /// Supported public key encodings (used when regenerating an augmented address). - enum class encoding { - hex, ///< hexadecimal encoded - base32z, ///< base32z encoded - base64, ///< base64 encoded (*without* trailing = padding) - BASE32Z ///< upper-case base32z encoding, meant for QR encoding - }; - - /// The protocol: one of the `protocol` enum values for tcp or ipc (unix sockets), with or - /// without _curve encryption. - proto protocol = proto::tcp; - /// The host for tcp connections; can be a hostname or IP address. If this is an IPv6 it must be surrounded with [ ]. - std::string host; - /// The port (for tcp connections) - uint16_t port = 0; - /// The socket path (for unix socket connections) - std::string socket; - /// If a curve connection, this is the required remote public key (in bytes) - std::string pubkey; - - /// Default constructor; this gives you an unusable address. - address() = default; - - /** - * Constructs an address by parsing a string_view containing one of the formats listed in the - * class description. This is intentionally implicitly constructible so that you can pass a - * string_view into anything expecting an `address`. - * - * Throw std::invalid_argument if the given address is not parseable. - */ - address(std::string_view addr); - - /** Constructs an address from a remote string and a separate pubkey. Typically `remote` is a - * basic ZMQ connect string, though this is not enforced. Any pubkey information embedded in - * the remote string will be discarded and replaced with the given pubkey string. The result - * will be curve encrypted if `pubkey` is non-empty, plaintext if `pubkey` is empty. - * - * Throws an exception if either addr or pubkey is invalid. - * - * Exactly equivalent to `address a{remote}; a.set_pubkey(pubkey);` - */ - address(std::string_view addr, std::string_view pubkey) : address(addr) { set_pubkey(pubkey); } - - /// Replaces the address's pubkey (if any) with the given pubkey (or no pubkey if empty). If - /// changing from pubkey to no-pubkey or no-pubkey to pubkey then the protocol is update to - /// switch to or from curve encryption. - /// - /// pubkey should be the 32-byte binary pubkey, or an empty string to remove an existing pubkey. - /// - /// Returns the object itself, so that you can chain it. - address& set_pubkey(std::string_view pubkey); - - /// Constructs and builds the ZMQ connection address from the stored connection details. This - /// does not contain any of the curve-related details; those must be specified separately when - /// interfacing with ZMQ. - std::string zmq_address() const; - - /// Returns true if the connection was specified as a curve-encryption-enabled connection, false - /// otherwise. - bool curve() const { return protocol == proto::tcp_curve || protocol == proto::ipc_curve; } - - /// True if the protocol is TCP (either with or without curve) - bool tcp() const { return protocol == proto::tcp || protocol == proto::tcp_curve; } - - /// True if the protocol is unix socket (either with or without curve) - bool ipc() const { return !tcp(); } - - /// Returns the full "augmented" address string (i.e. that could be passed in to the - /// constructor). This will be equivalent (but not necessarily identical) to an augmented - /// string passed into the constructor. Takes an optional encoding format for the pubkey (if - /// any), which defaults to base32z. - std::string full_address(encoding enc = encoding::base32z) const; - - /// Returns a QR-code friendly address string. This returns an all-uppercase version of the - /// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z - /// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we replace the - /// surround the - /// address with $ instead of $ - /// - /// \throws std::logic_error if called on a unix socket address. - std::string qr_address() const; - - /// Returns `.pubkey` but encoded in the given format - std::string encode_pubkey(encoding enc) const; - - /// Returns true if two addresses are identical (i.e. same protocol and relevant protocol - /// arguments). - /// - /// Note that it is possible for addresses to connect to the same socket without being - /// identical: for example, using "foo.sock" and "./foo.sock", or writing IPv6 addresses (or - /// even IPv4 addresses) in slightly different ways). Such equivalent but non-equal values will - /// result in a false return here. - /// - /// Note also that we ignore irrelevant arguments: for example, we don't care whether pubkeys - /// match when comparing two non-curve TCP addresses. - bool operator==(const address& other) const; - /// Negation of == - bool operator!=(const address& other) const { return !operator==(other); } - - /// Factory function that constructs a TCP address from a host and port. The connection will be - /// plaintext. If the host is an IPv6 address it *must* be surrounded with [ and ]. - static address tcp(std::string host, uint16_t port); - - /// Factory function that constructs a curve-encrypted TCP address from a host, port, and remote - /// pubkey. The pubkey must be 32 bytes. As above, IPv6 addresses must be specified as [addr]. - static address tcp_curve(std::string host, uint16_t, std::string pubkey); - - /// Factory function that constructs a unix socket address from a path. The connection will be - /// plaintext (which is usually fine for a socket since unix sockets are local machine). - static address ipc(std::string path); - - /// Factory function that constructs a unix socket address from a path and remote pubkey. The - /// connection will be curve25519 encrypted; the remote pubkey must be 32 bytes. - static address ipc_curve(std::string path, std::string pubkey); -}; - -// Outputs address.full_address() when sent to an ostream. -std::ostream& operator<<(std::ostream& o, const address& a); - -} +namespace lokimq = oxenmq; diff --git a/lokimq/auth.h b/lokimq/auth.h index 431998f..d6c38f6 100644 --- a/lokimq/auth.h +++ b/lokimq/auth.h @@ -1,55 +1,4 @@ #pragma once -#include -#include -#include -#include +#include "../oxenmq/auth.h" -namespace lokimq { - -/// Authentication levels for command categories and connections -enum class AuthLevel { - denied, ///< Not actually an auth level, but can be returned by the AllowFunc to deny an incoming connection. - none, ///< No authentication at all; any random incoming ZMQ connection can invoke this command. - basic, ///< Basic authentication commands require a login, or a node that is specifically configured to be a public node (e.g. for public RPC). - admin, ///< Advanced authentication commands require an admin user, either via explicit login or by implicit login from localhost. This typically protects administrative commands like shutting down, starting mining, or access sensitive data. -}; - -std::ostream& operator<<(std::ostream& os, AuthLevel a); - -/// The access level for a command category -struct Access { - /// Minimum access level required - AuthLevel auth; - /// If true only remote SNs may call the category commands - bool remote_sn; - /// If true the category requires that the local node is a SN - bool local_sn; - - /// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an - /// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both - /// remote and local sn set to false). - Access(AuthLevel auth = AuthLevel::none, bool remote_sn = false, bool local_sn = false) - : auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {} -}; - -/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey). -/// Falls back to std::hash if given a string smaller than a size_t. -struct already_hashed { - size_t operator()(const std::string& s) const { - if (s.size() < sizeof(size_t)) - return std::hash{}(s); - size_t hash; - std::memcpy(&hash, &s[0], sizeof(hash)); - return hash; - } -}; - -/// std::unordered_set specialization for specifying pubkeys (used, in particular, by -/// LokiMQ::set_active_sns and LokiMQ::update_active_sns); this is a std::string unordered_set that -/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the -/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like -/// pubkeys and a terrible hash choice for anything else). -using pubkey_set = std::unordered_set; - - -} +namespace lokimq = oxenmq; diff --git a/lokimq/base32z.h b/lokimq/base32z.h index d0227d3..0f40acf 100644 --- a/lokimq/base32z.h +++ b/lokimq/base32z.h @@ -1,203 +1,4 @@ -// Copyright (c) 2019-2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once -#include -#include -#include -#include -#include +#include "../oxenmq/base32z.h" -namespace lokimq { - -namespace detail { - -/// Compile-time generated lookup tables for base32z conversion. This is case insensitive (though -/// for byte -> b32z conversion we always produce lower case). -struct b32z_table { - // Store the 0-31 decoded value of every possible char; all the chars that aren't valid are set - // to 0. (If you don't trust your data, check it with is_base32z first, which uses these 0's - // to detect invalid characters -- which is why we want a full 256 element array). - char from_b32z_lut[256]; - // Store the encoded character of every 0-31 (5 bit) value. - char to_b32z_lut[32]; - - // constexpr constructor that fills out the above (and should do it at compile time for any half - // decent compiler). - constexpr b32z_table() noexcept : from_b32z_lut{}, - to_b32z_lut{ - 'y', 'b', 'n', 'd', 'r', 'f', 'g', '8', 'e', 'j', 'k', 'm', 'c', 'p', 'q', 'x', - 'o', 't', '1', 'u', 'w', 'i', 's', 'z', 'a', '3', '4', '5', 'h', '7', '6', '9' - } - { - for (unsigned char c = 0; c < 32; c++) { - unsigned char x = to_b32z_lut[c]; - from_b32z_lut[x] = c; - if (x >= 'a' && x <= 'z') - from_b32z_lut[x - 'a' + 'A'] = c; - } - } - // Convert a b32z encoded character into a 0-31 value - constexpr char from_b32z(unsigned char c) const noexcept { return from_b32z_lut[c]; } - // Convert a 0-31 value into a b32z encoded character - constexpr char to_b32z(unsigned char b) const noexcept { return to_b32z_lut[b]; } -} constexpr b32z_lut; - -// This main point of this static assert is to force the compiler to compile-time build the constexpr tables. -static_assert(b32z_lut.from_b32z('w') == 20 && b32z_lut.from_b32z('T') == 17 && b32z_lut.to_b32z(5) == 'f', ""); - -} // namespace detail - -/// Converts bytes into a base32z encoded character sequence. -template -void to_base32z(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "to_base32z requires chars/bytes"); - int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in [0, 4] - std::uint_fast16_t r = 0; - while (begin != end) { - r = r << 8 | static_cast(*begin++); - - // we just added 8 bits, so we can *always* consume 5 to produce one character, so (net) we - // are adding 3 bits. - bits += 3; - *out++ = detail::b32z_lut.to_b32z(r >> bits); // Right-shift off the bits we aren't consuming right now - - // Drop the bits we don't want to keep (because we just consumed them) - r &= (1 << bits) - 1; - - if (bits >= 5) { // We have enough bits to produce a second character; essentially the same as above - bits -= 5; // Except now we are just consuming 5 without having added any more - *out++ = detail::b32z_lut.to_b32z(r >> bits); - r &= (1 << bits) - 1; - } - } - - if (bits > 0) // We hit the end, but still have some unconsumed bits so need one final character to append - *out++ = detail::b32z_lut.to_b32z(r << (5 - bits)); -} - -/// Creates a base32z string from an iterator pair of a byte sequence. -template -std::string to_base32z(It begin, It end) { - std::string base32z; - if constexpr (std::is_base_of_v::iterator_category>) - base32z.reserve((std::distance(begin, end)*8 + 4) / 5); // == bytes*8/5, rounded up. - to_base32z(begin, end, std::back_inserter(base32z)); - return base32z; -} - -/// Creates a base32z string from an iterable, std::string-like object -template -std::string to_base32z(std::basic_string_view s) { return to_base32z(s.begin(), s.end()); } -inline std::string to_base32z(std::string_view s) { return to_base32z<>(s); } - -/// Returns true if all elements in the range are base32z characters -template -constexpr bool is_base32z(It begin, It end) { - static_assert(sizeof(decltype(*begin)) == 1, "is_base32z requires chars/bytes"); - for (; begin != end; ++begin) { - auto c = static_cast(*begin); - if (detail::b32z_lut.from_b32z(c) == 0 && !(c == 'y' || c == 'Y')) - return false; - } - return true; -} - -/// Returns true if all elements in the string-like value are base32z characters -template -constexpr bool is_base32z(std::basic_string_view s) { return is_base32z(s.begin(), s.end()); } -constexpr bool is_base32z(std::string_view s) { return is_base32z<>(s); } - -/// Converts a sequence of base32z digits to bytes. Undefined behaviour if any characters are not -/// valid base32z alphabet characters. It is permitted for the input and output ranges to overlap -/// as long as `out` is no earlier than `begin`. Note that if you pass in a sequence that could not -/// have been created by a base32z encoding of a byte sequence, we treat the excess bits as if they -/// were not provided. -/// -/// For example, "yyy" represents a 15-bit value, but a byte sequence is either 8-bit (requiring 2 -/// characters) or 16-bit (requiring 4). Similarly, "yb" is an impossible encoding because it has -/// its 10th bit set (b = 00001), but a base32z encoded value should have all 0's beyond the 8th (or -/// 16th or 24th or ... bit). We treat any such bits as if they were not specified (even if they -/// are): which means "yy", "yb", "yyy", "yy9", "yd", etc. all decode to the same 1-byte value "\0". -template -void from_base32z(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "from_base32z requires chars/bytes"); - uint_fast16_t curr = 0; - int bits = 0; // number of bits we've loaded into val; we always keep this < 8. - while (begin != end) { - curr = curr << 5 | detail::b32z_lut.from_b32z(static_cast(*begin++)); - if (bits >= 3) { - bits -= 3; // Added 5, removing 8 - *out++ = static_cast(curr >> bits); - curr &= (1 << bits) - 1; - } else { - bits += 5; - } - } - - // Ignore any trailing bits. base32z encoding always has at least as many bits as the source - // bytes, which means we should not be able to get here from a properly encoded b32z value with - // anything other than 0s: if we have no extra bits (e.g. 5 bytes == 8 b32z chars) then we have - // a 0-bit value; if we have some extra bits (e.g. 6 bytes requires 10 b32z chars, but that - // contains 50 bits > 48 bits) then those extra bits will be 0s (and this covers the bits -= 3 - // case above: it'll leave us with 0-4 extra bits, but those extra bits would be 0 if produced - // from an actual byte sequence). - // - // The "bits += 5" case, then, means that we could end with 5-7 bits. This, however, cannot be - // produced by a valid encoding: - // - 0 bytes gives us 0 chars with 0 leftover bits - // - 1 byte gives us 2 chars with 2 leftover bits - // - 2 bytes gives us 4 chars with 4 leftover bits - // - 3 bytes gives us 5 chars with 1 leftover bit - // - 4 bytes gives us 7 chars with 3 leftover bits - // - 5 bytes gives us 8 chars with 0 leftover bits (this is where the cycle repeats) - // - // So really the only way we can get 5-7 leftover bits is if you took a 0, 2 or 5 char output (or - // any 8n + {0,2,5} char output) and added a base32z character to the end. If you do that, - // well, too bad: you're giving invalid output and so we're just going to pretend that extra - // character you added isn't there by not doing anything here. -} - -/// Convert a base32z sequence into a std::string of bytes. Undefined behaviour if any characters -/// are not valid (case-insensitive) base32z characters. -template -std::string from_base32z(It begin, It end) { - std::string bytes; - if constexpr (std::is_base_of_v::iterator_category>) - bytes.reserve((std::distance(begin, end)*5 + 7) / 8); // == chars*5/8, rounded up. - from_base32z(begin, end, std::back_inserter(bytes)); - return bytes; -} - -/// Converts base32z digits from a std::string-like object into a std::string of bytes. Undefined -/// behaviour if any characters are not valid (case-insensitive) base32z characters. -template -std::string from_base32z(std::basic_string_view s) { return from_base32z(s.begin(), s.end()); } -inline std::string from_base32z(std::string_view s) { return from_base32z<>(s); } - -} +namespace lokimq = oxenmq; diff --git a/lokimq/base64.h b/lokimq/base64.h index 172c4cc..36f2c7d 100644 --- a/lokimq/base64.h +++ b/lokimq/base64.h @@ -1,219 +1,4 @@ -// Copyright (c) 2019-2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once -#include -#include -#include -#include -#include +#include "../oxenmq/base64.h" -namespace lokimq { - -namespace detail { - -/// Compile-time generated lookup tables for base64 conversion. -struct b64_table { - // Store the 0-63 decoded value of every possible char; all the chars that aren't valid are set - // to 0. (If you don't trust your data, check it with is_base64 first, which uses these 0's - // to detect invalid characters -- which is why we want a full 256 element array). - char from_b64_lut[256]; - // Store the encoded character of every 0-63 (6 bit) value. - char to_b64_lut[64]; - - // constexpr constructor that fills out the above (and should do it at compile time for any half - // decent compiler). - constexpr b64_table() noexcept : from_b64_lut{}, to_b64_lut{} { - for (unsigned char c = 0; c < 26; c++) { - from_b64_lut[(unsigned char)('A' + c)] = 0 + c; - to_b64_lut[ (unsigned char)( 0 + c)] = 'A' + c; - } - for (unsigned char c = 0; c < 26; c++) { - from_b64_lut[(unsigned char)('a' + c)] = 26 + c; - to_b64_lut[ (unsigned char)(26 + c)] = 'a' + c; - } - for (unsigned char c = 0; c < 10; c++) { - from_b64_lut[(unsigned char)('0' + c)] = 52 + c; - to_b64_lut[ (unsigned char)(52 + c)] = '0' + c; - } - to_b64_lut[62] = '+'; from_b64_lut[(unsigned char) '+'] = 62; - to_b64_lut[63] = '/'; from_b64_lut[(unsigned char) '/'] = 63; - } - // Convert a b64 encoded character into a 0-63 value - constexpr char from_b64(unsigned char c) const noexcept { return from_b64_lut[c]; } - // Convert a 0-31 value into a b64 encoded character - constexpr char to_b64(unsigned char b) const noexcept { return to_b64_lut[b]; } -} constexpr b64_lut; - -// This main point of this static assert is to force the compiler to compile-time build the constexpr tables. -static_assert(b64_lut.from_b64('/') == 63 && b64_lut.from_b64('7') == 59 && b64_lut.to_b64(38) == 'm', ""); - -} // namespace detail - -/// Converts bytes into a base64 encoded character sequence. -template -void to_base64(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "to_base64 requires chars/bytes"); - int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in {0, 2, 4} - std::uint_fast16_t r = 0; - while (begin != end) { - r = r << 8 | static_cast(*begin++); - - // we just added 8 bits, so we can *always* consume 6 to produce one character, so (net) we - // are adding 2 bits. - bits += 2; - *out++ = detail::b64_lut.to_b64(r >> bits); // Right-shift off the bits we aren't consuming right now - - // Drop the bits we don't want to keep (because we just consumed them) - r &= (1 << bits) - 1; - - if (bits == 6) { // We have enough bits to produce a second character (which means we had 4 before and added 8) - bits = 0; - *out++ = detail::b64_lut.to_b64(r); - r = 0; - } - } - - // If bits == 0 then we ended our 6-bit outputs coinciding with 8-bit values, i.e. at a multiple - // of 24 bits: this means we don't have anything else to output and don't need any padding. - if (bits == 2) { - // We finished with 2 unconsumed bits, which means we ended 1 byte past a 24-bit group (e.g. - // 1 byte, 4 bytes, 301 bytes, etc.); since we need to always be a multiple of 4 output - // characters that means we've produced 1: so we right-fill 0s to get the next char, then - // add two padding ='s. - *out++ = detail::b64_lut.to_b64(r << 4); - *out++ = '='; - *out++ = '='; - } else if (bits == 4) { - // 4 bits left means we produced 2 6-bit values from the first 2 bytes of a 3-byte group. - // Fill 0s to get the last one, plus one padding output. - *out++ = detail::b64_lut.to_b64(r << 2); - *out++ = '='; - } -} - -/// Creates and returns a base64 string from an iterator pair of a character sequence -template -std::string to_base64(It begin, It end) { - std::string base64; - if constexpr (std::is_base_of_v::iterator_category>) - base64.reserve((std::distance(begin, end) + 2) / 3 * 4); // bytes*4/3, rounded up to the next multiple of 4 - to_base64(begin, end, std::back_inserter(base64)); - return base64; -} - -/// Creates a base64 string from an iterable, std::string-like object -template -std::string to_base64(std::basic_string_view s) { return to_base64(s.begin(), s.end()); } -inline std::string to_base64(std::string_view s) { return to_base64<>(s); } - -/// Returns true if the range is a base64 encoded value; we allow (but do not require) '=' padding, -/// but only at the end, only 1 or 2, and only if it pads out the total to a multiple of 4. -template -constexpr bool is_base64(It begin, It end) { - static_assert(sizeof(decltype(*begin)) == 1, "is_base64 requires chars/bytes"); - using std::distance; - using std::prev; - - // Allow 1 or 2 padding chars *if* they pad it to a multiple of 4. - if (begin != end && distance(begin, end) % 4 == 0) { - auto last = prev(end); - if (static_cast(*last) == '=') - end = last--; - if (static_cast(*last) == '=') - end = last; - } - - for (; begin != end; ++begin) { - auto c = static_cast(*begin); - if (detail::b64_lut.from_b64(c) == 0 && c != 'A') - return false; - } - return true; -} - -/// Returns true if the string-like value is a base64 encoded value -template -constexpr bool is_base64(std::basic_string_view s) { return is_base64(s.begin(), s.end()); } -constexpr bool is_base64(std::string_view s) { return is_base64(s.begin(), s.end()); } - -/// Converts a sequence of base64 digits to bytes. Undefined behaviour if any characters are not -/// valid base64 alphabet characters. It is permitted for the input and output ranges to overlap as -/// long as `out` is no earlier than `begin`. Trailing padding characters are permitted but not -/// required. -/// -/// It is possible to provide "impossible" base64 encoded values; for example "YWJja" which has 30 -/// bits of data even though a base64 encoded byte string should have 24 (4 chars) or 36 (6 chars) -/// bits for a 3- and 4-byte input, respectively. We ignore any such "impossible" bits, and -/// similarly ignore impossible bits in the bit "overhang"; that means "YWJjZA==" (the proper -/// encoding of "abcd") and "YWJjZB", "YWJjZC", ..., "YWJjZP" all decode to the same "abcd" value: -/// the last 4 bits of the last character are essentially considered padding. -template -void from_base64(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "from_base64 requires chars/bytes"); - uint_fast16_t curr = 0; - int bits = 0; // number of bits we've loaded into val; we always keep this < 8. - while (begin != end) { - auto c = static_cast(*begin++); - - // padding; don't bother checking if we're at the end because is_base64 is a precondition - // and we're allowed UB if it isn't satisfied. - if (c == '=') continue; - - curr = curr << 6 | detail::b64_lut.from_b64(c); - if (bits == 0) - bits = 6; - else { - bits -= 2; // Added 6, removing 8 - *out++ = static_cast(curr >> bits); - curr &= (1 << bits) - 1; - } - } - // Don't worry about leftover bits because either they have to be 0, or they can't happen at - // all. See base32z.h for why: the reasoning is exactly the same (except using 6 bits per - // character here instead of 5). -} - -/// Converts base64 digits from a iterator pair of characters into a std::string of bytes. -/// Undefined behaviour if any characters are not valid base64 characters. -template -std::string from_base64(It begin, It end) { - std::string bytes; - if constexpr (std::is_base_of_v::iterator_category>) - bytes.reserve(std::distance(begin, end)*6 / 8); // each digit carries 6 bits; this may overallocate by 1-2 bytes due to padding - from_base64(begin, end, std::back_inserter(bytes)); - return bytes; -} - -/// Converts base64 digits from a std::string-like object into a std::string of bytes. Undefined -/// behaviour if any characters are not valid base64 characters. -template -std::string from_base64(std::basic_string_view s) { return from_base64(s.begin(), s.end()); } -inline std::string from_base64(std::string_view s) { return from_base64<>(s); } - -} +namespace lokimq = oxenmq; diff --git a/lokimq/batch.h b/lokimq/batch.h index 2faff6e..bf06f14 100644 --- a/lokimq/batch.h +++ b/lokimq/batch.h @@ -1,279 +1,4 @@ -// Copyright (c) 2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once -#include -#include -#include -#include "lokimq.h" +#include "../oxenmq/batch.h" -namespace lokimq { - -namespace detail { - -enum class BatchState { - running, // there are still jobs to run (or running) - complete, // the batch is complete but still has a completion job to call - done // the batch is complete and has no completion function -}; - -struct BatchStatus { - BatchState state; - int thread; -}; - -// Virtual base class for Batch -class Batch { -public: - // Returns the number of jobs in this batch and whether any of them are thread-specific - virtual std::pair size() const = 0; - // Returns a vector of exactly the same length of size().first containing the tagged thread ids - // of the batch jobs or 0 for general jobs. - virtual std::vector threads() const = 0; - // Called in a worker thread to run the job - virtual void run_job(int i) = 0; - // Called in the main proxy thread when the worker returns from finishing a job. The return - // value tells us whether the current finishing job finishes off the batch: `running` to tell us - // there are more jobs; `complete` to tell us that the jobs are done but the completion function - // needs to be called; and `done` to signal that the jobs are done and there is no completion - // function. - virtual BatchStatus job_finished() = 0; - // Called by a worker; not scheduled until all jobs are done. - virtual void job_completion() = 0; - - virtual ~Batch() = default; -}; - -} - -/** - * Simple class that can either hold a result or an exception and retrieves the result (or raises - * the exception) via a .get() method. - * - * This is designed to be like a very stripped down version of a std::promise/std::future pair. We - * reimplemented it, however, because by ditching all the thread synchronization that promise/future - * guarantees we can substantially reduce call overhead (by a factor of ~8 according to benchmarking - * code). Since LokiMQ's proxy<->worker communication channel already gives us thread that overhead - * would just be wasted. - * - * @tparam R the value type held by the result; must be default constructible. Note, however, that - * there are specializations provided for lvalue references types and `void` (which obviously don't - * satisfy this). - */ -template -class job_result { - R value; - std::exception_ptr exc; - -public: - /// Sets the value. Should be called only once, or not at all if set_exception was called. - void set_value(R&& v) { value = std::move(v); } - - /// Sets the exception, which will be rethrown when `get()` is called. Should be called - /// only once, or not at all if set_value() was called. - void set_exception(std::exception_ptr e) { exc = std::move(e); } - - /// Retrieves the value. If an exception was set instead of a value then that exception is - /// thrown instead. Note that the interval value is moved out of the held value so you should - /// not call this multiple times. - R get() { - if (exc) std::rethrow_exception(exc); - return std::move(value); - } -}; - -/** job_result specialization for reference types */ -template -class job_result::value>> { - std::remove_reference_t* value_ptr; - std::exception_ptr exc; - -public: - void set_value(R v) { value_ptr = &v; } - void set_exception(std::exception_ptr e) { exc = std::move(e); } - R get() { - if (exc) std::rethrow_exception(exc); - return *value_ptr; - } -}; - -/** job_result specialization for void; there is no value, but exceptions are still captured - * (rethrown when `get()` is called). - */ -template<> -class job_result { - std::exception_ptr exc; - -public: - void set_exception(std::exception_ptr e) { exc = std::move(e); } - // Returns nothing, but rethrows if there is a captured exception. - void get() { if (exc) std::rethrow_exception(exc); } -}; - -/// Helper class used to set up batches of jobs to be scheduled via the lokimq job handler. -/// -/// @tparam R - the return type of the individual jobs -/// -template -class Batch final : private detail::Batch { - friend class LokiMQ; -public: - /// The completion function type, called after all jobs have finished. - using CompletionFunc = std::function> results)>; - - // Default constructor - Batch() = default; - - // movable - Batch(Batch&&) = default; - Batch &operator=(Batch&&) = default; - - // non-copyable - Batch(const Batch&) = delete; - Batch &operator=(const Batch&) = delete; - -private: - std::vector, int>> jobs; - std::vector> results; - CompletionFunc complete; - std::size_t jobs_outstanding = 0; - int complete_in_thread = 0; - bool started = false; - bool tagged_thread_jobs = false; - - void check_not_started() { - if (started) - throw std::logic_error("Cannot add jobs or completion function after starting a lokimq::Batch!"); - } - -public: - /// Preallocates space in the internal vector that stores jobs. - void reserve(std::size_t num) { - jobs.reserve(num); - results.reserve(num); - } - - /// Adds a job. This takes any callable object that is invoked with no arguments and returns R - /// (the Batch return type). The tasks will be scheduled and run when the next worker thread is - /// available. The called function may throw exceptions (which will be propagated to the - /// completion function through the job_result values). There is no guarantee on the order of - /// invocation of the jobs. - /// - /// \param job the callback - /// \param thread an optional TaggedThreadID indicating a thread in which this job must run - void add_job(std::function job, std::optional thread = std::nullopt) { - check_not_started(); - if (thread && thread->_id == -1) - // There are some special case internal jobs where we allow this, but they use the - // private method below that doesn't have this check. - throw std::logic_error{"Cannot add a proxy thread batch job -- this makes no sense"}; - add_job(std::move(job), thread ? thread->_id : 0); - } - - /// Sets the completion function to invoke after all jobs have finished. If this is not set - /// then jobs simply run and results are discarded. - /// - /// \param comp - function to call when all jobs have finished - /// \param thread - optional tagged thread in which to schedule the completion job. If not - /// provided then the completion job is scheduled in the pool of batch job threads. - /// - /// `thread` can be provided the value &LokiMQ::run_in_proxy to invoke the completion function - /// *IN THE PROXY THREAD* itself after all jobs have finished. Be very, very careful: this - /// should be a nearly trivial job that does not require any substantial CPU time and does not - /// block for any reason. This is only intended for the case where the completion job is so - /// trivial that it will take less time than simply queuing the job to be executed by another - /// thread. - void completion(CompletionFunc comp, std::optional thread = std::nullopt) { - check_not_started(); - if (complete) - throw std::logic_error("Completion function can only be set once"); - complete = std::move(comp); - complete_in_thread = thread ? thread->_id : 0; - } - -private: - - void add_job(std::function job, int thread_id) { - jobs.emplace_back(std::move(job), thread_id); - results.emplace_back(); - jobs_outstanding++; - if (thread_id != 0) - tagged_thread_jobs = true; - } - - std::pair size() const override { - return {jobs.size(), tagged_thread_jobs}; - } - - std::vector threads() const override { - std::vector t; - t.reserve(jobs.size()); - for (auto& j : jobs) - t.push_back(j.second); - return t; - }; - - template - void set_value(job_result& r, std::function& f) { r.set_value(f()); } - void set_value(job_result&, std::function& f) { f(); } - - void run_job(const int i) override { - // called by worker thread - auto& r = results[i]; - try { - set_value(r, jobs[i].first); - } catch (...) { - r.set_exception(std::current_exception()); - } - } - - detail::BatchStatus job_finished() override { - --jobs_outstanding; - if (jobs_outstanding) - return {detail::BatchState::running, 0}; - if (complete) - return {detail::BatchState::complete, complete_in_thread}; - return {detail::BatchState::done, 0}; - } - - void job_completion() override { - return complete(std::move(results)); - } -}; - - -template -void LokiMQ::batch(Batch&& batch) { - if (batch.size().first == 0) - throw std::logic_error("Cannot batch a a job batch with 0 jobs"); - // Need to send this over to the proxy thread via the base class pointer. It assumes ownership. - auto* baseptr = static_cast(new Batch(std::move(batch))); - detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast(baseptr))); -} - -} +namespace lokimq = oxenmq; diff --git a/lokimq/bt_serialize.h b/lokimq/bt_serialize.h index 786f0ff..1d904a6 100644 --- a/lokimq/bt_serialize.h +++ b/lokimq/bt_serialize.h @@ -1,915 +1,4 @@ -// Copyright (c) 2019-2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once +#include "../oxenmq/bt_serialize.h" -#include -#include -#include -#include -#include -#include -#include -#include "variant.h" -#include -#include -#include -#include -#include -#include -#include - -#include "bt_value.h" - -namespace lokimq { - -using namespace std::literals; - -/** \file - * LokiMQ serialization for internal commands is very simple: we support two primitive types, - * strings and integers, and two container types, lists and dicts with string keys. On the wire - * these go in BitTorrent byte encoding as described in BEP-0003 - * (https://www.bittorrent.org/beps/bep_0003.html#bencoding). - * - * On the C++ side, on input we allow strings, integral types, STL-like containers of these types, - * and STL-like containers of pairs with a string first value and any of these types as second - * value. We also accept std::variants of these. - * - * One minor deviation from BEP-0003 is that we don't support serializing values that don't fit in a - * 64-bit integer (BEP-0003 specifies arbitrary precision integers). - * - * On deserialization we can either deserialize into a special bt_value type supports everything - * (with arbitrary nesting), or we can fill a container of your given type (though this fails if the - * container isn't compatible with the deserialized data). - * - * There is also a stream deserialization that allows you to deserialize without needing heap - * allocations (as long as you know the precise data structure layout). - */ - -/// Exception throw if deserialization fails -class bt_deserialize_invalid : public std::invalid_argument { - using std::invalid_argument::invalid_argument; -}; - -/// A more specific subclass that is thown if the serialization type is an initial mismatch: for -/// example, trying deserializing an int but the next thing in input is a list. This is not, -/// however, thrown if the type initially looks fine but, say, a nested serialization fails. This -/// error will only be thrown when the input stream has not been advanced (and so can be tried for a -/// different type). -class bt_deserialize_invalid_type : public bt_deserialize_invalid { - using bt_deserialize_invalid::bt_deserialize_invalid; -}; - -namespace detail { - -/// Reads digits into an unsigned 64-bit int. -uint64_t extract_unsigned(std::string_view& s); -// (Provide non-constant lvalue and rvalue ref functions so that we only accept explicit -// string_views but not implicitly converted ones) -inline uint64_t extract_unsigned(std::string_view&& s) { return extract_unsigned(s); } - -// Fallback base case; we only get here if none of the partial specializations below work -template -struct bt_serialize { static_assert(!std::is_same_v, "Cannot serialize T: unsupported type for bt serialization"); }; - -template -struct bt_deserialize { static_assert(!std::is_same_v, "Cannot deserialize T: unsupported type for bt deserialization"); }; - -/// Checks that we aren't at the end of a string view and throws if we are. -inline void bt_need_more(const std::string_view &s) { - if (s.empty()) - throw bt_deserialize_invalid{"Unexpected end of string while deserializing"}; -} - -/// Deserializes a signed or unsigned 64-bit integer from a string. Sets the second bool to true -/// iff the value read was negative, false if positive; in either case the unsigned value is return -/// in .first. Throws an exception if the read value doesn't fit in a int64_t (if negative) or a -/// uint64_t (if positive). Removes consumed characters from the string_view. -std::pair bt_deserialize_integer(std::string_view& s); - -/// Integer specializations -template -struct bt_serialize>> { - static_assert(sizeof(T) <= sizeof(uint64_t), "Serialization of integers larger than uint64_t is not supported"); - void operator()(std::ostream &os, const T &val) { - // Cast 1-byte types to a larger type to avoid iostream interpreting them as single characters - using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t, int, unsigned>>; - os << 'i' << static_cast(val) << 'e'; - } -}; - -template -struct bt_deserialize>> { - void operator()(std::string_view& s, T &val) { - constexpr uint64_t umax = static_cast(std::numeric_limits::max()); - constexpr int64_t smin = static_cast(std::numeric_limits::min()); - - auto [magnitude, negative] = bt_deserialize_integer(s); - - if (std::is_signed_v) { - if (!negative) { - if (magnitude > umax) - throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax)); - val = static_cast(magnitude); - } else { - auto sval = -static_cast(magnitude); - if (!std::is_same_v && sval < smin) - throw bt_deserialize_invalid("Integer deserialization failed: found too-low value " + std::to_string(sval) + " < " + std::to_string(smin)); - val = static_cast(sval); - } - } else { - if (negative) - throw bt_deserialize_invalid("Integer deserialization failed: found negative value -" + std::to_string(magnitude) + " but type is unsigned"); - if (!std::is_same_v && magnitude > umax) - throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax)); - val = static_cast(magnitude); - } - } -}; - -extern template struct bt_deserialize; -extern template struct bt_deserialize; - -template <> -struct bt_serialize { - void operator()(std::ostream &os, const std::string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); } -}; -template <> -struct bt_deserialize { - void operator()(std::string_view& s, std::string_view& val); -}; - -/// String specialization -template <> -struct bt_serialize { - void operator()(std::ostream &os, const std::string &val) { bt_serialize{}(os, val); } -}; -template <> -struct bt_deserialize { - void operator()(std::string_view& s, std::string& val) { std::string_view view; bt_deserialize{}(s, view); val = {view.data(), view.size()}; } -}; - -/// char * and string literals -- we allow serialization for convenience, but not deserialization -template <> -struct bt_serialize { - void operator()(std::ostream &os, const char *str) { bt_serialize{}(os, {str, std::strlen(str)}); } -}; -template -struct bt_serialize { - void operator()(std::ostream &os, const char *str) { bt_serialize{}(os, {str, N-1}); } -}; - -/// Partial dict validity; we don't check the second type for serializability, that will be handled -/// via the base case static_assert if invalid. -template struct is_bt_input_dict_container_impl : std::false_type {}; -template -struct is_bt_input_dict_container_impl> || - std::is_same_v>, - std::void_t>> -: std::true_type {}; - -/// Determines whether the type looks like something we can insert into (using `v.insert(v.end(), x)`) -template struct is_bt_insertable_impl : std::false_type {}; -template -struct is_bt_insertable_impl().insert(std::declval().end(), std::declval()))>> -: std::true_type {}; -template -constexpr bool is_bt_insertable = is_bt_insertable_impl::value; - -/// Determines whether the given type looks like a compatible map (i.e. has std::string keys) that -/// we can insert into. -template struct is_bt_output_dict_container_impl : std::false_type {}; -template -struct is_bt_output_dict_container_impl> && is_bt_insertable, - std::void_t>> -: std::true_type {}; - -template -constexpr bool is_bt_output_dict_container = is_bt_output_dict_container_impl::value; -template -constexpr bool is_bt_input_dict_container = is_bt_output_dict_container_impl::value; - -// Sanity checks: -static_assert(is_bt_input_dict_container); -static_assert(is_bt_output_dict_container); - -/// Specialization for a dict-like container (such as an unordered_map). We accept anything for a -/// dict that is const iterable over something that looks like a pair with std::string for first -/// value type. The value (i.e. second element of the pair) also must be serializable. -template -struct bt_serialize>> { - using second_type = typename T::value_type::second_type; - using ref_pair = std::reference_wrapper; - void operator()(std::ostream &os, const T &dict) { - os << 'd'; - std::vector pairs; - pairs.reserve(dict.size()); - for (const auto &pair : dict) - pairs.emplace(pairs.end(), pair); - std::sort(pairs.begin(), pairs.end(), [](ref_pair a, ref_pair b) { return a.get().first < b.get().first; }); - for (auto &ref : pairs) { - bt_serialize{}(os, ref.get().first); - bt_serialize{}(os, ref.get().second); - } - os << 'e'; - } -}; - -template -struct bt_deserialize>> { - using second_type = typename T::value_type::second_type; - void operator()(std::string_view& s, T& dict) { - // Smallest dict is 2 bytes "de", for an empty dict. - if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where dict expected"); - if (s[0] != 'd') throw bt_deserialize_invalid_type("Deserialization failed: expected 'd', found '"s + s[0] + "'"s); - s.remove_prefix(1); - dict.clear(); - bt_deserialize key_deserializer; - bt_deserialize val_deserializer; - - while (!s.empty() && s[0] != 'e') { - std::string key; - second_type val; - key_deserializer(s, key); - val_deserializer(s, val); - dict.insert(dict.end(), typename T::value_type{std::move(key), std::move(val)}); - } - if (s.empty()) - throw bt_deserialize_invalid("Deserialization failed: encountered end of string before dict was finished"); - s.remove_prefix(1); // Consume the 'e' - } -}; - - -/// Accept anything that looks iterable; value serialization validity isn't checked here (it fails -/// via the base case static assert). -template struct is_bt_input_list_container_impl : std::false_type {}; -template -struct is_bt_input_list_container_impl && !std::is_same_v && !is_bt_input_dict_container, - std::void_t>> -: std::true_type {}; - -template struct is_bt_output_list_container_impl : std::false_type {}; -template -struct is_bt_output_list_container_impl && !is_bt_output_dict_container && is_bt_insertable>> -: std::true_type {}; - -template -constexpr bool is_bt_output_list_container = is_bt_output_list_container_impl::value; -template -constexpr bool is_bt_input_list_container = is_bt_input_list_container_impl::value; - -// Sanity checks: -static_assert(is_bt_input_list_container); -static_assert(is_bt_output_list_container); - -/// List specialization -template -struct bt_serialize>> { - void operator()(std::ostream& os, const T& list) { - os << 'l'; - for (const auto &v : list) - bt_serialize>{}(os, v); - os << 'e'; - } -}; -template -struct bt_deserialize>> { - using value_type = typename T::value_type; - void operator()(std::string_view& s, T& list) { - // Smallest list is 2 bytes "le", for an empty list. - if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where list expected"); - if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization failed: expected 'l', found '"s + s[0] + "'"s); - s.remove_prefix(1); - list.clear(); - bt_deserialize deserializer; - while (!s.empty() && s[0] != 'e') { - value_type v; - deserializer(s, v); - list.insert(list.end(), std::move(v)); - } - if (s.empty()) - throw bt_deserialize_invalid("Deserialization failed: encountered end of string before list was finished"); - s.remove_prefix(1); // Consume the 'e' - } -}; - -/// Serializes a tuple or pair of serializable values (as a list on the wire) - -/// Common implementation for both tuple and pair: -template typename Tuple, typename... T> -struct bt_serialize_tuple { -private: - template - void operator()(std::ostream& os, const Tuple& elems, std::index_sequence) { - os << 'l'; - (bt_serialize{}(os, std::get(elems)), ...); - os << 'e'; - } -public: - void operator()(std::ostream& os, const Tuple& elems) { - operator()(os, elems, std::index_sequence_for{}); - } -}; -template typename Tuple, typename... T> -struct bt_deserialize_tuple { -private: - template - void operator()(std::string_view& s, Tuple& elems, std::index_sequence) { - // Smallest list is 2 bytes "le", for an empty list. - if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where tuple expected"); - if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization of tuple failed: expected 'l', found '"s + s[0] + "'"s); - s.remove_prefix(1); - (bt_deserialize{}(s, std::get(elems)), ...); - if (s.empty()) - throw bt_deserialize_invalid("Deserialization failed: encountered end of string before tuple was finished"); - if (s[0] != 'e') - throw bt_deserialize_invalid("Deserialization failed: expected end of tuple but found something else"); - s.remove_prefix(1); // Consume the 'e' - } -public: - void operator()(std::string_view& s, Tuple& elems) { - operator()(s, elems, std::index_sequence_for{}); - } -}; -template -struct bt_serialize> : bt_serialize_tuple {}; -template -struct bt_deserialize> : bt_deserialize_tuple {}; -template -struct bt_serialize> : bt_serialize_tuple {}; -template -struct bt_deserialize> : bt_deserialize_tuple {}; - -template -constexpr bool is_bt_tuple = false; -template -constexpr bool is_bt_tuple> = true; -template -constexpr bool is_bt_tuple> = true; - - -template -constexpr bool is_bt_deserializable = std::is_same_v || std::is_integral_v || - is_bt_output_dict_container || is_bt_output_list_container || is_bt_tuple; - -// General template and base case; this base will only actually be invoked when Ts... is empty, -// which means we reached the end without finding any variant type capable of holding the value. -template -struct bt_deserialize_try_variant_impl { - void operator()(std::string_view&, Variant&) { - throw bt_deserialize_invalid("Deserialization failed: could not deserialize value into any variant type"); - } -}; - -template -void bt_deserialize_try_variant(std::string_view& s, Variant& variant) { - bt_deserialize_try_variant_impl{}(s, variant); -} - - -template -struct bt_deserialize_try_variant_impl>, Variant, T, Ts...> { - void operator()(std::string_view& s, Variant& variant) { - if ( is_bt_output_list_container ? s[0] == 'l' : - is_bt_tuple ? s[0] == 'l' : - is_bt_output_dict_container ? s[0] == 'd' : - std::is_integral_v ? s[0] == 'i' : - std::is_same_v ? s[0] >= '0' && s[0] <= '9' : - false) { - T val; - bt_deserialize{}(s, val); - variant = std::move(val); - } else { - bt_deserialize_try_variant(s, variant); - } - } -}; - -template -struct bt_deserialize_try_variant_impl>, Variant, T, Ts...> { - void operator()(std::string_view& s, Variant& variant) { - // Unsupported deserialization type, skip it - bt_deserialize_try_variant(s, variant); - } -}; - -// Serialization of a variant; all variant types must be bt-serializable. -template -struct bt_serialize, std::void_t...>> { - void operator()(std::ostream& os, const std::variant& val) { - var::visit( - [&os] (const auto& val) { - using T = std::remove_cv_t>; - bt_serialize{}(os, val); - }, - val); - } -}; - -// Deserialization to a variant; at least one variant type must be bt-deserializble. -template -struct bt_deserialize, std::enable_if_t<(is_bt_deserializable || ...)>> { - void operator()(std::string_view& s, std::variant& val) { - bt_deserialize_try_variant(s, val); - } -}; - -template <> -struct bt_serialize : bt_serialize {}; - -template <> -struct bt_deserialize { - void operator()(std::string_view& s, bt_value& val); -}; - -template -struct bt_stream_serializer { - const T &val; - explicit bt_stream_serializer(const T &val) : val{val} {} - operator std::string() const { - std::ostringstream oss; - oss << *this; - return oss.str(); - } -}; -template -std::ostream &operator<<(std::ostream &os, const bt_stream_serializer &s) { - bt_serialize{}(os, s.val); - return os; -} - -} // namespace detail - - -/// Returns a wrapper around a value reference that can serialize the value directly to an output -/// stream. This class is intended to be used inline (i.e. without being stored) as in: -/// -/// std::list my_list{{1,2,3}}; -/// std::cout << bt_serializer(my_list); -/// -/// While it is possible to store the returned object and use it, such as: -/// -/// auto encoded = bt_serializer(42); -/// std::cout << encoded; -/// -/// this approach is not generally recommended: the returned object stores a reference to the -/// passed-in type, which may not survive. If doing this note that it is the caller's -/// responsibility to ensure the serializer is not used past the end of the lifetime of the value -/// being serialized. -/// -/// Also note that serializing directly to an output stream is more efficient as no intermediate -/// string containing the entire serialization has to be constructed. -/// -template -detail::bt_stream_serializer bt_serializer(const T &val) { return detail::bt_stream_serializer{val}; } - -/// Serializes the given value into a std::string. -/// -/// int number = 42; -/// std::string encoded = bt_serialize(number); -/// // Equivalent: -/// //auto encoded = (std::string) bt_serialize(number); -/// -/// This takes any serializable type: integral types, strings, lists of serializable types, and -/// string->value maps of serializable types. -template -std::string bt_serialize(const T &val) { return bt_serializer(val); } - -/// Deserializes the given string view directly into `val`. Usage: -/// -/// std::string encoded = "i42e"; -/// int value; -/// bt_deserialize(encoded, value); // Sets value to 42 -/// -template , int> = 0> -void bt_deserialize(std::string_view s, T& val) { - return detail::bt_deserialize{}(s, val); -} - - -/// Deserializes the given string_view into a `T`, which is returned. -/// -/// std::string encoded = "li1ei2ei3ee"; // bt-encoded list of ints: [1,2,3] -/// auto mylist = bt_deserialize>(encoded); -/// -template -T bt_deserialize(std::string_view s) { - T val; - bt_deserialize(s, val); - return val; -} - -/// Deserializes the given value into a generic `bt_value` type (wrapped std::variant) which is -/// capable of holding all possible BT-encoded values (including recursion). -/// -/// Example: -/// -/// std::string encoded = "i42e"; -/// auto val = bt_get(encoded); -/// int v = get_int(val); // fails unless the encoded value was actually an integer that -/// // fits into an `int` -/// -inline bt_value bt_get(std::string_view s) { - return bt_deserialize(s); -} - -/// Helper functions to extract a value of some integral type from a bt_value which contains either -/// a int64_t or uint64_t. Does range checking, throwing std::overflow_error if the stored value is -/// outside the range of the target type. -/// -/// Example: -/// -/// std::string encoded = "i123456789e"; -/// auto val = bt_get(encoded); -/// auto v = get_int(val); // throws if the decoded value doesn't fit in a uint32_t -template , int> = 0> -IntType get_int(const bt_value &v) { - if (auto* value = std::get_if(&v)) { - if constexpr (!std::is_same_v) - if (*value > static_cast(std::numeric_limits::max())) - throw std::overflow_error("Unable to extract integer value: stored value is too large for the requested type"); - return static_cast(*value); - } - - int64_t value = var::get(v); // throws if no int contained - if constexpr (!std::is_same_v) - if (value > static_cast(std::numeric_limits::max()) - || value < static_cast(std::numeric_limits::min())) - throw std::overflow_error("Unable to extract integer value: stored value is outside the range of the requested type"); - return static_cast(value); -} - -namespace detail { -template -void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence); -} - -/// Converts a bt_list into the given template std::tuple or std::pair. Throws a -/// std::invalid_argument if the list has the wrong size or wrong element types. Supports recursion -/// (i.e. if the tuple itself contains tuples or pairs). The tuple (or nested tuples) may only -/// contain integral types, strings, string_views, bt_list, bt_dict, and tuples/pairs of those. -template -Tuple get_tuple(const bt_list& x) { - Tuple t; - detail::get_tuple_impl(t, x, std::make_index_sequence>{}); - return t; -} -template -Tuple get_tuple(const bt_value& x) { - return get_tuple(var::get(static_cast(x))); -} - -namespace detail { -template -void get_tuple_impl_one(T& t, It& it) { - const bt_variant& v = *it++; - if constexpr (std::is_integral_v) { - t = lokimq::get_int(v); - } else if constexpr (is_bt_tuple) { - if (std::holds_alternative(v)) - throw std::invalid_argument{"Unable to convert tuple: cannot create sub-tuple from non-bt_list"}; - t = get_tuple(var::get(v)); - } else if constexpr (std::is_same_v || std::is_same_v) { - // If we request a string/string_view, we might have the other one and need to copy/view it. - if (std::holds_alternative(v)) - t = var::get(v); - else - t = var::get(v); - } else { - t = var::get(v); - } -} -template -void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence) { - if (l.size() != sizeof...(Is)) - throw std::invalid_argument{"Unable to convert tuple: bt_list has wrong size"}; - auto it = l.begin(); - (get_tuple_impl_one(std::get(t), it), ...); -} -} // namespace detail - - - - -/// Class that allows you to walk through a bt-encoded list in memory without copying or allocating -/// memory. It accesses existing memory directly and so the caller must ensure that the referenced -/// memory stays valid for the lifetime of the bt_list_consumer object. -class bt_list_consumer { -protected: - std::string_view data; - bt_list_consumer() = default; -public: - bt_list_consumer(std::string_view data_); - - /// Copy constructor. Making a copy copies the current position so can be used for multipass - /// iteration through a list. - bt_list_consumer(const bt_list_consumer&) = default; - bt_list_consumer& operator=(const bt_list_consumer&) = default; - - /// Get a copy of the current buffer - std::string_view current_buffer() const { return data; } - - /// Returns true if the next value indicates the end of the list - bool is_finished() const { return data.front() == 'e'; } - /// Returns true if the next element looks like an encoded string - bool is_string() const { return data.front() >= '0' && data.front() <= '9'; } - /// Returns true if the next element looks like an encoded integer - bool is_integer() const { return data.front() == 'i'; } - /// Returns true if the next element looks like an encoded negative integer - bool is_negative_integer() const { return is_integer() && data.size() >= 2 && data[1] == '-'; } - /// Returns true if the next element looks like an encoded list - bool is_list() const { return data.front() == 'l'; } - /// Returns true if the next element looks like an encoded dict - bool is_dict() const { return data.front() == 'd'; } - - /// Attempt to parse the next value as a string (and advance just past it). Throws if the next - /// value is not a string. - std::string consume_string(); - std::string_view consume_string_view(); - - /// Attempts to parse the next value as an integer (and advance just past it). Throws if the - /// next value is not an integer. - template - IntType consume_integer() { - if (!is_integer()) throw bt_deserialize_invalid_type{"next value is not an integer"}; - std::string_view next{data}; - IntType ret; - detail::bt_deserialize{}(next, ret); - data = next; - return ret; - } - - /// Consumes a list, return it as a list-like type. Can also be used for tuples/pairs. This - /// typically requires dynamic allocation, but only has to parse the data once. Compare with - /// consume_list_data() which allows alloc-free traversal, but requires parsing twice (if the - /// contents are to be used). - template - T consume_list() { - T list; - consume_list(list); - return list; - } - - /// Same as above, but takes a pre-existing list-like data type. - template - void consume_list(T& list) { - if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"}; - std::string_view n{data}; - detail::bt_deserialize{}(n, list); - data = n; - } - - /// Consumes a dict, return it as a dict-like type. This typically requires dynamic allocation, - /// but only has to parse the data once. Compare with consume_dict_data() which allows - /// alloc-free traversal, but requires parsing twice (if the contents are to be used). - template - T consume_dict() { - T dict; - consume_dict(dict); - return dict; - } - - /// Same as above, but takes a pre-existing dict-like data type. - template - void consume_dict(T& dict) { - if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"}; - std::string_view n{data}; - detail::bt_deserialize{}(n, dict); - data = n; - } - - /// Consumes a value without returning it. - void skip_value(); - - /// Attempts to parse the next value as a list and returns the string_view that contains the - /// entire thing. This is recursive into both lists and dicts and likely to be quite - /// inefficient for large, nested structures (unless the values only need to be skipped but - /// aren't separately needed). This, however, does not require dynamic memory allocation. - std::string_view consume_list_data(); - - /// Attempts to parse the next value as a dict and returns the string_view that contains the - /// entire thing. This is recursive into both lists and dicts and likely to be quite - /// inefficient for large, nested structures (unless the values only need to be skipped but - /// aren't separately needed). This, however, does not require dynamic memory allocation. - std::string_view consume_dict_data(); -}; - - -/// Class that allows you to walk through key-value pairs of a bt-encoded dict in memory without -/// copying or allocating memory. It accesses existing memory directly and so the caller must -/// ensure that the referenced memory stays valid for the lifetime of the bt_dict_consumer object. -class bt_dict_consumer : private bt_list_consumer { - std::string_view key_; - - /// Consume the key if not already consumed and there is a key present (rather than 'e'). - /// Throws exception if what should be a key isn't a string, or if the key consumes the entire - /// data (i.e. requires that it be followed by something). Returns true if the key was consumed - /// (either now or previously and cached). - bool consume_key(); - - /// Clears the cached key and returns it. Must have already called consume_key directly or - /// indirectly via one of the `is_{...}` methods. - std::string_view flush_key() { - std::string_view k; - k.swap(key_); - return k; - } - -public: - bt_dict_consumer(std::string_view data_); - - /// Copy constructor. Making a copy copies the current position so can be used for multipass - /// iteration through a list. - bt_dict_consumer(const bt_dict_consumer&) = default; - bt_dict_consumer& operator=(const bt_dict_consumer&) = default; - - /// Returns true if the next value indicates the end of the dict - bool is_finished() { return !consume_key() && data.front() == 'e'; } - /// Operator bool is an alias for `!is_finished()` - operator bool() { return !is_finished(); } - /// Returns true if the next value looks like an encoded string - bool is_string() { return consume_key() && data.front() >= '0' && data.front() <= '9'; } - /// Returns true if the next element looks like an encoded integer - bool is_integer() { return consume_key() && data.front() == 'i'; } - /// Returns true if the next element looks like an encoded negative integer - bool is_negative_integer() { return is_integer() && data.size() >= 2 && data[1] == '-'; } - /// Returns true if the next element looks like an encoded list - bool is_list() { return consume_key() && data.front() == 'l'; } - /// Returns true if the next element looks like an encoded dict - bool is_dict() { return consume_key() && data.front() == 'd'; } - /// Returns the key of the next pair. This does not have to be called; it is also returned by - /// all of the other consume_* methods. The value is cached whether called here or by some - /// other method; accessing it multiple times simple accesses the cache until the next value is - /// consumed. - std::string_view key() { - if (!consume_key()) - throw bt_deserialize_invalid{"Cannot access next key: at the end of the dict"}; - return key_; - } - - /// Attempt to parse the next value as a string->string pair (and advance just past it). Throws - /// if the next value is not a string. - std::pair next_string(); - - /// Attempts to parse the next value as an string->integer pair (and advance just past it). - /// Throws if the next value is not an integer. - template - std::pair next_integer() { - if (!is_integer()) throw bt_deserialize_invalid_type{"next bt dict value is not an integer"}; - std::pair ret; - ret.second = bt_list_consumer::consume_integer(); - ret.first = flush_key(); - return ret; - } - - /// Consumes a string->list pair, return it as a list-like type. This typically requires - /// dynamic allocation, but only has to parse the data once. Compare with consume_list_data() - /// which allows alloc-free traversal, but requires parsing twice (if the contents are to be - /// used). - template - std::pair next_list() { - std::pair pair; - pair.first = next_list(pair.second); - return pair; - } - - /// Same as above, but takes a pre-existing list-like data type. Returns the key. - template - std::string_view next_list(T& list) { - if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"}; - bt_list_consumer::consume_list(list); - return flush_key(); - } - - /// Consumes a string->dict pair, return it as a dict-like type. This typically requires - /// dynamic allocation, but only has to parse the data once. Compare with consume_dict_data() - /// which allows alloc-free traversal, but requires parsing twice (if the contents are to be - /// used). - template - std::pair next_dict() { - std::pair pair; - pair.first = consume_dict(pair.second); - return pair; - } - - /// Same as above, but takes a pre-existing dict-like data type. Returns the key. - template - std::string_view next_dict(T& dict) { - if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"}; - bt_list_consumer::consume_dict(dict); - return flush_key(); - } - - /// Attempts to parse the next value as a string->list pair and returns the string_view that - /// contains the entire thing. This is recursive into both lists and dicts and likely to be - /// quite inefficient for large, nested structures (unless the values only need to be skipped - /// but aren't separately needed). This, however, does not require dynamic memory allocation. - std::pair next_list_data() { - if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt dict value is not a list"}; - return {flush_key(), bt_list_consumer::consume_list_data()}; - } - - /// Same as next_list_data(), but wraps the value in a bt_list_consumer for convenience - std::pair next_list_consumer() { return next_list_data(); } - - /// Attempts to parse the next value as a string->dict pair and returns the string_view that - /// contains the entire thing. This is recursive into both lists and dicts and likely to be - /// quite inefficient for large, nested structures (unless the values only need to be skipped - /// but aren't separately needed). This, however, does not require dynamic memory allocation. - std::pair next_dict_data() { - if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt dict value is not a dict"}; - return {flush_key(), bt_list_consumer::consume_dict_data()}; - } - - /// Same as next_dict_data(), but wraps the value in a bt_dict_consumer for convenience - std::pair next_dict_consumer() { return next_dict_data(); } - - /// Skips ahead until we find the first key >= the given key or reach the end of the dict. - /// Returns true if we found an exact match, false if we reached some greater value or the end. - /// If we didn't hit the end, the next `consumer_*()` call will return the key-value pair we - /// found (either the exact match or the first key greater than the requested key). - /// - /// Two important notes: - /// - /// - properly encoded bt dicts must have lexicographically sorted keys, and this method assumes - /// that the input is correctly sorted (and thus if we find a greater value then your key does - /// not exist). - /// - this is irreversible; you cannot returned to skipped values without reparsing. (You *can* - /// however, make a copy of the bt_dict_consumer before calling and use the copy to return to - /// the pre-skipped position). - bool skip_until(std::string_view find) { - while (consume_key() && key_ < find) { - flush_key(); - skip_value(); - } - return key_ == find; - } - - /// The `consume_*` functions are wrappers around next_whatever that discard the returned key. - /// - /// Intended for use with skip_until such as: - /// - /// std::string value; - /// if (d.skip_until("key")) - /// value = d.consume_string(); - /// - - auto consume_string_view() { return next_string().second; } - auto consume_string() { return std::string{consume_string_view()}; } - - template - auto consume_integer() { return next_integer().second; } - - template - auto consume_list() { return next_list().second; } - - template - void consume_list(T& list) { next_list(list); } - - template - auto consume_dict() { return next_dict().second; } - - template - void consume_dict(T& dict) { next_dict(dict); } - - std::string_view consume_list_data() { return next_list_data().second; } - std::string_view consume_dict_data() { return next_dict_data().second; } - - bt_list_consumer consume_list_consumer() { return consume_list_data(); } - bt_dict_consumer consume_dict_consumer() { return consume_dict_data(); } -}; - - -} // namespace lokimq +namespace lokimq = oxenmq; diff --git a/lokimq/bt_value.h b/lokimq/bt_value.h index 7cf6269..c311b76 100644 --- a/lokimq/bt_value.h +++ b/lokimq/bt_value.h @@ -1,112 +1,4 @@ -// Copyright (c) 2019-2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once +#include "../oxenmq/bt_value.h" -// This header is here to provide just the basic bt_value/bt_dict/bt_list definitions without -// needing to include the full bt_serialize.h header. - -#include -#include -#include -#include -#include -#include - -namespace lokimq { - -struct bt_value; - -/// The type used to store dictionaries inside bt_value. -using bt_dict = std::map; // NB: unordered_map doesn't work because it can't be used with a predeclared type -/// The type used to store list items inside bt_value. -using bt_list = std::list; - -/// The basic variant that can hold anything (recursively). -using bt_variant = std::variant< - std::string, - std::string_view, - int64_t, - uint64_t, - bt_list, - bt_dict ->; - -#ifdef __cpp_lib_remove_cvref // C++20 -using std::remove_cvref_t; -#else -template -using remove_cvref_t = std::remove_cv_t>; -#endif - -template -struct has_alternative; -template -struct has_alternative> : std::bool_constant<(std::is_same_v || ...)> {}; -template -constexpr bool has_alternative_v = has_alternative::value; - -namespace detail { - template - bt_list tuple_to_list(const Tuple& tuple, std::index_sequence) { - return {{bt_value{std::get(tuple)}...}}; - } - template constexpr bool is_tuple = false; - template constexpr bool is_tuple> = true; - template constexpr bool is_tuple> = true; -} - -/// Recursive generic type that can fully represent everything valid for a BT serialization. -/// This is basically just an empty wrapper around the std::variant, except we add some extra -/// converting constructors: -/// - integer constructors so that any unsigned value goes to the uint64_t and any signed value goes -/// to the int64_t. -/// - std::tuple and std::pair constructors that build a bt_list out of the tuple/pair elements. -struct bt_value : bt_variant { - using bt_variant::bt_variant; - using bt_variant::operator=; - - template , std::enable_if_t && std::is_unsigned_v, int> = 0> - bt_value(T&& uint) : bt_variant{static_cast(uint)} {} - - template , std::enable_if_t && std::is_signed_v, int> = 0> - bt_value(T&& sint) : bt_variant{static_cast(sint)} {} - - template - bt_value(const std::tuple& tuple) : bt_variant{detail::tuple_to_list(tuple, std::index_sequence_for{})} {} - - template - bt_value(const std::pair& pair) : bt_variant{detail::tuple_to_list(pair, std::index_sequence_for{})} {} - - template , std::enable_if_t && !detail::is_tuple, int> = 0> - bt_value(T&& v) : bt_variant{std::forward(v)} {} - - bt_value(const char* s) : bt_value{std::string_view{s}} {} -}; - -} +namespace lokimq = oxenmq; diff --git a/lokimq/connections.h b/lokimq/connections.h index f43c567..5775338 100644 --- a/lokimq/connections.h +++ b/lokimq/connections.h @@ -1,96 +1,4 @@ #pragma once -#include "auth.h" -#include "bt_value.h" -#include -#include -#include -#include -#include -#include - -namespace lokimq { - -struct ConnectionID; - -namespace detail { -template -bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts); -} - -/// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For -/// connections to service node this is the service node pubkey (and you can pass a 32-byte string -/// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy -/// of the ConnectionID returned by connect_remote(). -struct ConnectionID { - // Default construction; creates a ConnectionID with an invalid internal ID that will not match - // an actual connection. - ConnectionID() : ConnectionID(0) {} - // Construction from a service node pubkey - ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} { - if (pk.size() != 32) - throw std::runtime_error{"Invalid pubkey: expected 32 bytes"}; - } - // Construction from a service node pubkey - ConnectionID(std::string_view pubkey_) : ConnectionID(std::string{pubkey_}) {} - - ConnectionID(const ConnectionID&) = default; - ConnectionID(ConnectionID&&) = default; - ConnectionID& operator=(const ConnectionID&) = default; - ConnectionID& operator=(ConnectionID&&) = default; - - // Returns true if this is a ConnectionID (false for a default-constructed, invalid id) - explicit operator bool() const { - return id != 0; - } - - // Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both - // not SNs and have matching internal IDs and routes. (Pubkeys do not have to match for - // non-SNs). - bool operator==(const ConnectionID &o) const { - if (sn() && o.sn()) - return pk == o.pk; - return id == o.id && route == o.route; - } - bool operator!=(const ConnectionID &o) const { return !(*this == o); } - bool operator<(const ConnectionID &o) const { - if (sn() && o.sn()) - return pk < o.pk; - return id < o.id || (id == o.id && route < o.route); - } - - // Returns true if this ConnectionID represents a SN connection - bool sn() const { return id == SN_ID; } - - // Returns this connection's pubkey, if any. (Note that all curve connections have pubkeys, not - // only SNs). - const std::string& pubkey() const { return pk; } - - // Returns a copy of the ConnectionID with the route set to empty. - ConnectionID unrouted() { return ConnectionID{id, pk, ""}; } - -private: - ConnectionID(long long id) : id{id} {} - ConnectionID(long long id, std::string pubkey, std::string route = "") - : id{id}, pk{std::move(pubkey)}, route{std::move(route)} {} - - constexpr static long long SN_ID = -1; - long long id = 0; - std::string pk; - std::string route; - friend class LokiMQ; - friend struct std::hash; - template - friend bt_dict detail::build_send(ConnectionID to, std::string_view cmd, T&&... opts); - friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn); -}; - -} // namespace lokimq -namespace std { - template <> struct hash { - size_t operator()(const lokimq::ConnectionID &c) const { - return c.sn() ? lokimq::already_hashed{}(c.pk) : - std::hash{}(c.id) + std::hash{}(c.route); - } - }; -} // namespace std +#include "../oxenmq/connections.h" +namespace lokimq = oxenmq; diff --git a/lokimq/hex.h b/lokimq/hex.h index 6998b9a..08c3b58 100644 --- a/lokimq/hex.h +++ b/lokimq/hex.h @@ -1,145 +1,4 @@ -// Copyright (c) 2019-2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once -#include -#include -#include -#include -#include +#include "../oxenmq/hex.h" -namespace lokimq { - -namespace detail { - -/// Compile-time generated lookup tables hex conversion -struct hex_table { - char from_hex_lut[256]; - char to_hex_lut[16]; - constexpr hex_table() noexcept : from_hex_lut{}, to_hex_lut{} { - for (unsigned char c = 0; c < 10; c++) { - from_hex_lut[(unsigned char)('0' + c)] = 0 + c; - to_hex_lut[ (unsigned char)( 0 + c)] = '0' + c; - } - for (unsigned char c = 0; c < 6; c++) { - from_hex_lut[(unsigned char)('a' + c)] = 10 + c; - from_hex_lut[(unsigned char)('A' + c)] = 10 + c; - to_hex_lut[ (unsigned char)(10 + c)] = 'a' + c; - } - } - constexpr char from_hex(unsigned char c) const noexcept { return from_hex_lut[c]; } - constexpr char to_hex(unsigned char b) const noexcept { return to_hex_lut[b]; } -} constexpr hex_lut; - -// This main point of this static assert is to force the compiler to compile-time build the constexpr tables. -static_assert(hex_lut.from_hex('a') == 10 && hex_lut.from_hex('F') == 15 && hex_lut.to_hex(13) == 'd', ""); - -} // namespace detail - -/// Creates hex digits from a character sequence. -template -void to_hex(InputIt begin, InputIt end, OutputIt out) { - static_assert(sizeof(decltype(*begin)) == 1, "to_hex requires chars/bytes"); - for (; begin != end; ++begin) { - uint8_t c = static_cast(*begin); - *out++ = detail::hex_lut.to_hex(c >> 4); - *out++ = detail::hex_lut.to_hex(c & 0x0f); - } -} - -/// Creates a string of hex digits from a character sequence iterator pair -template -std::string to_hex(It begin, It end) { - std::string hex; - if constexpr (std::is_base_of_v::iterator_category>) - hex.reserve(2 * std::distance(begin, end)); - to_hex(begin, end, std::back_inserter(hex)); - return hex; -} - -/// Creates a hex string from an iterable, std::string-like object -template -std::string to_hex(std::basic_string_view s) { return to_hex(s.begin(), s.end()); } -inline std::string to_hex(std::string_view s) { return to_hex<>(s); } - -/// Returns true if all elements in the range are hex characters -template -constexpr bool is_hex(It begin, It end) { - static_assert(sizeof(decltype(*begin)) == 1, "is_hex requires chars/bytes"); - for (; begin != end; ++begin) { - if (detail::hex_lut.from_hex(static_cast(*begin)) == 0 && static_cast(*begin) != '0') - return false; - } - return true; -} - -/// Returns true if all elements in the string-like value are hex characters -template -constexpr bool is_hex(std::basic_string_view s) { return is_hex(s.begin(), s.end()); } -constexpr bool is_hex(std::string_view s) { return is_hex(s.begin(), s.end()); } - -/// Convert a hex digit into its numeric (0-15) value -constexpr char from_hex_digit(unsigned char x) noexcept { - return detail::hex_lut.from_hex(x); -} - -/// Constructs a byte value from a pair of hex digits -constexpr char from_hex_pair(unsigned char a, unsigned char b) noexcept { return (from_hex_digit(a) << 4) | from_hex_digit(b); } - -/// Converts a sequence of hex digits to bytes. Undefined behaviour if any characters are not in -/// [0-9a-fA-F] or if the input sequence length is not even. It is permitted for the input and -/// output ranges to overlap as long as out is no earlier than begin. -template -void from_hex(InputIt begin, InputIt end, OutputIt out) { - using std::distance; - assert(distance(begin, end) % 2 == 0); - while (begin != end) { - auto a = *begin++; - auto b = *begin++; - *out++ = from_hex_pair(static_cast(a), static_cast(b)); - } -} - -/// Converts a sequence of hex digits to a string of bytes and returns it. Undefined behaviour if -/// the input sequence is not an even-length sequence of [0-9a-fA-F] characters. -template -std::string from_hex(It begin, It end) { - std::string bytes; - if constexpr (std::is_base_of_v::iterator_category>) - bytes.reserve(std::distance(begin, end) / 2); - from_hex(begin, end, std::back_inserter(bytes)); - return bytes; -} - -/// Converts hex digits from a std::string-like object into a std::string of bytes. Undefined -/// behaviour if any characters are not in [0-9a-fA-F] or if the input sequence length is not even. -template -std::string from_hex(std::basic_string_view s) { return from_hex(s.begin(), s.end()); } -inline std::string from_hex(std::string_view s) { return from_hex<>(s); } - -} +namespace lokimq = oxenmq; diff --git a/lokimq/lokimq.h b/lokimq/lokimq.h index 48dbb9c..57ed082 100644 --- a/lokimq/lokimq.h +++ b/lokimq/lokimq.h @@ -1,1528 +1,10 @@ -// Copyright (c) 2019-2020, The Loki Project -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, are -// permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list -// of conditions and the following disclaimer in the documentation and/or other -// materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be -// used to endorse or promote products derived from this software without specific -// prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF -// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - #pragma once +#include "../oxenmq/oxenmq.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "zmq.hpp" -#include "address.h" -#include "bt_serialize.h" -#include "connections.h" -#include "message.h" -#include "auth.h" +namespace lokimq = oxenmq; -#if ZMQ_VERSION < ZMQ_MAKE_VERSION (4, 3, 0) -// Timers were not added until 4.3.0 -#error "ZMQ >= 4.3.0 required" -#endif +namespace oxenmq { -namespace lokimq { - -using namespace std::literals; - -/// Logging levels passed into LogFunc. (Note that trace does nothing more than debug in a release -/// build). -enum class LogLevel { fatal, error, warn, info, debug, trace }; - -// Forward declarations; see batch.h -namespace detail { class Batch; } -template class Batch; - -/** The keep-alive time for a send() that results in a establishing a new outbound connection. To - * use a longer keep-alive to a host call `connect()` first with the desired keep-alive time or pass - * the send_option::keep_alive. - */ -inline constexpr auto DEFAULT_SEND_KEEP_ALIVE = 30s; - -// The default timeout for connect_remote() -inline constexpr auto REMOTE_CONNECT_TIMEOUT = 10s; - -// The amount of time we wait for a reply to a REQUEST before calling the callback with -// `false` to signal a timeout. -inline constexpr auto DEFAULT_REQUEST_TIMEOUT = 15s; - -/// Maximum length of a category -inline constexpr size_t MAX_CATEGORY_LENGTH = 50; - -/// Maximum length of a command -inline constexpr size_t MAX_COMMAND_LENGTH = 200; - -class CatHelper; - -/// Opaque handle for a tagged thread constructed by add_tagged_thread(...). Not directly -/// constructible, but is safe (and cheap) to copy. -struct TaggedThreadID { -private: - int _id; - explicit constexpr TaggedThreadID(int id) : _id{id} {} - friend class LokiMQ; - template friend class Batch; -}; - -/** - * Class that handles LokiMQ listeners, connections, proxying, and workers. An application - * typically has just one instance of this class. - */ -class LokiMQ { - -private: - - /// The global context - zmq::context_t context; - - /// A unique id for this LokiMQ instance, assigned in a thread-safe manner during construction. - const int object_id; - - /// The x25519 keypair of this connection. For service nodes these are the long-run x25519 keys - /// provided at construction, for non-service-node connections these are generated during - /// construction. - std::string pubkey, privkey; - - /// True if *this* node is running in service node mode (whether or not actually active) - bool local_service_node = false; - - /// The thread in which most of the intermediate work happens (handling external connections - /// and proxying requests between them to worker threads) - std::thread proxy_thread; - - /// Will be true (and is guarded by a mutex) if the proxy thread is quitting; guards against new - /// control sockets from threads trying to talk to the proxy thread. - bool proxy_shutting_down = false; - - /// We have one seldom-used mutex here: it is generally locked just once per thread (the first - /// time the thread calls get_control_socket()) and once more by the proxy thread when it shuts - /// down. - std::mutex control_sockets_mutex; - - /// Called to obtain a "command" socket that attaches to `control` to send commands to the - /// proxy thread from other threads. This socket is unique per thread and LokiMQ instance. - zmq::socket_t& get_control_socket(); - - /// Per-thread control sockets used by lokimq threads to talk to this object's proxy thread. - std::unordered_map> control_sockets; - -public: - - /// Callback type invoked to determine whether the given new incoming connection is allowed to - /// connect to us and to set its authentication level. - /// - /// @param address - the address of the incoming connection. For TCP connections this is an IP - /// address; for UDP connections it's a string such as "localhost:UID:GID:PID". - /// @param pubkey - the x25519 pubkey of the connecting client (32 byte string). Note that this - /// will only be non-empty for incoming connections on `listen_curve` sockets; `listen_plain` - /// sockets do not have a pubkey. - /// @param service_node - will be true if the `pubkey` is in the set of known active service - /// nodes. - /// - /// @returns an `AuthLevel` enum value indicating the default auth level for the incoming - /// connection, or AuthLevel::denied if the connection should be refused. - using AllowFunc = std::function; - - /// Callback that is invoked when we need to send a "strong" message to a SN that we aren't - /// already connected to and need to establish a connection. This callback returns the ZMQ - /// connection string we should use which is typically a string such as `tcp://1.2.3.4:5678`. - using SNRemoteAddress = std::function; - - /// The callback type for registered commands. - using CommandCallback = std::function; - - /// The callback for making requests. This is called with `true` and a (moved) vector of data - /// part strings when we get a reply, or `false` and empty vector on timeout. - using ReplyCallback = std::function data)>; - - /// Called to write a log message. This will only be called if the `level` is >= the current - /// LokiMQ object log level. It must be a raw function pointer (or a capture-less lambda) for - /// performance reasons. Takes four arguments: the log level of the message, the filename and - /// line number where the log message was invoked, and the log message itself. - using Logger = std::function; - - /// Callback for the success case of connect_remote() - using ConnectSuccess = std::function; - /// Callback for the failure case of connect_remote() - using ConnectFailure = std::function; - - /// Explicitly non-copyable, non-movable because most things here aren't copyable, and a few - /// things aren't movable, either. If you need to pass the LokiMQ instance around, wrap it - /// in a unique_ptr or shared_ptr. - LokiMQ(const LokiMQ&) = delete; - LokiMQ& operator=(const LokiMQ&) = delete; - LokiMQ(LokiMQ&&) = delete; - LokiMQ& operator=(LokiMQ&&) = delete; - - /** How long to wait for handshaking to complete on external connections before timing out and - * closing the connection. Setting this only affects new outgoing connections. */ - std::chrono::milliseconds HANDSHAKE_TIME = 10s; - - /** Whether to use a zmq routing ID based on the pubkey for new outgoing connections. This is - * normally desirable as it allows the listener to recognize that the incoming connection is a - * reconnection from the same remote and handover routing to the new socket while closing off - * the (likely dead) old socket. This, however, prevents a single LokiMQ instance from - * establishing multiple connections to the same listening LokiMQ, which is sometimes useful - * (for example when testing), and so this option can be overridden to `false` to use completely - * random zmq routing ids on outgoing connections (which will thus allow multiple connections). - */ - bool PUBKEY_BASED_ROUTING_ID = true; - - /** Maximum incoming message size; if a remote tries sending a message larger than this they get - * disconnected. -1 means no limit. */ - int64_t MAX_MSG_SIZE = 1 * 1024 * 1024; - - /** Maximum open sockets, passed to the ZMQ context during start(). The default here is 10k, - * designed to be enough to be more than enough to allow a full-mesh SN layer connection if - * necessary for the forseeable future. */ - int MAX_SOCKETS = 10000; - - /** Minimum reconnect interval: when a connection fails or dies, wait this long before - * attempting to reconnect. (ZMQ may randomize the value somewhat to avoid reconnection - * storms). See RECONNECT_INTERVAL_MAX as well. The LokiMQ default is 250ms. - */ - std::chrono::milliseconds RECONNECT_INTERVAL = 250ms; - - /** Maximum reconnect interval. When this is set to a value larger than RECONNECT_INTERVAL then - * ZMQ's reconnection logic uses an exponential backoff: each reconnection attempts waits twice - * as long as the previous attempt, up to this maximum. The LokiMQ default is 5 seconds. - */ - std::chrono::milliseconds RECONNECT_INTERVAL_MAX = 5s; - - /** How long (in ms) to linger sockets when closing them; this is the maximum time zmq spends - * trying to sending pending messages before dropping them and closing the underlying socket - * after the high-level zmq socket is closed. */ - std::chrono::milliseconds CLOSE_LINGER = 5s; - - /** How frequently we cleanup connections (closing idle connections, calling connect or request - * failure callbacks). Making this slower results in more "overshoot" before failure callbacks - * are invoked; making it too fast results in more proxy thread overhead. Any change to this - * variable must be set before calling start(). - */ - std::chrono::milliseconds CONN_CHECK_INTERVAL = 250ms; - - /** Whether to enable heartbeats on incoming/outgoing connections. If set to > 0 then we set up - * ZMQ to send a heartbeat ping over the socket this often, which helps keep the connection - * alive and lets failed connections be detected sooner (see the next option). - * - * Only new connections created after changing this are affected, so if changing it is - * recommended to set it before calling `start()`. - */ - std::chrono::milliseconds CONN_HEARTBEAT = 15s; - - /** When CONN_HEARTBEAT is enabled, this sets how long we wait for a reply on a socket before - * considering the socket to have died and closing it. - * - * Only new connections created after changing this are affected, so if changing it is - * recommended to set it before calling `start()`. - */ - std::chrono::milliseconds CONN_HEARTBEAT_TIMEOUT = 30s; - - /// Allows you to set options on the internal zmq context object. For advanced use only. - void set_zmq_context_option(zmq::ctxopt option, int value); - - /** The umask to apply when constructing sockets (which affects any new ipc:// listening sockets - * that get created). Does nothing if set to -1 (the default), and does nothing on Windows. - * Note that the umask is applied temporarily during `start()`, so may affect other threads that - * create files/directories at the same time as the start() call. - */ - int STARTUP_UMASK = -1; - - /** The gid that owns any sockets when constructed (same as umask) - */ - int SOCKET_GID = -1; - /** The uid that owns any sockets when constructed (same as umask but requires root) - */ - int SOCKET_UID = -1; - - /// A special TaggedThreadID value that always refers to the proxy thread; the main use of this is - /// to direct very simple batch completion jobs to be executed directly in the proxy thread. - inline static constexpr TaggedThreadID run_in_proxy{-1}; - - /// Writes a message to the logging system; intended mostly for internal use. - template - void log(LogLevel lvl, const char* filename, int line, const T&... stuff); - -private: - - /// The lookup function that tells us where to connect to a peer, or empty if not found. - SNRemoteAddress sn_lookup; - - /// The log level; this is atomic but we use relaxed order to set and access it (so changing it - /// might not be instantly visible on all threads, but that's okay). - std::atomic log_lvl{LogLevel::warn}; - - /// The callback to call with log messages - Logger logger; - - /////////////////////////////////////////////////////////////////////////////////// - /// NB: The following are all the domain of the proxy thread (once it is started)! - - /// The socket we listen on for handling ZAP authentication requests (the other end is internal - /// to zmq which sends requests to us as needed). - zmq::socket_t zap_auth{context, zmq::socket_type::rep}; - - struct bind_data { - bool curve; - size_t index; - AllowFunc allow; - bind_data(bool curve, AllowFunc allow) - : curve{curve}, index{0}, allow{std::move(allow)} {} - }; - - /// Addresses on which we are listening (or, before start(), on which we will listen). - std::vector> bind; - - /// Info about a peer's established connection with us. Note that "established" means both - /// connected and authenticated. Note that we only store peer info data for SN connections (in - /// or out), and outgoing non-SN connections. Incoming non-SN connections are handled on the - /// fly. - struct peer_info { - /// Pubkey of the remote, if this connection is a curve25519 connection; empty otherwise. - std::string pubkey; - - /// True if we've authenticated this peer as a service node. This gets set on incoming - /// messages when we check the remote's pubkey, and immediately on outgoing connections to - /// SNs (since we know their pubkey -- we'll fail to connect if it doesn't match). - bool service_node = false; - - /// The auth level of this peer, as returned by the AllowFunc for incoming connections or - /// specified during outgoing connections. - AuthLevel auth_level = AuthLevel::none; - - /// The actual internal socket index through which this connection is established - size_t conn_index; - - /// Will be set to a non-empty routing prefix *if* one is necessary on the connection. This - /// is used only for SN peers (non-SN incoming connections don't have a peer_info record, - /// and outgoing connections don't have a route). - std::string route; - - /// Returns true if this is an outgoing connection. (This is simply an alias for - /// route.empty() -- outgoing connections never have a route, incoming connections always - /// do). - bool outgoing() const { return route.empty(); } - - /// The last time we sent or received a message (or had some other relevant activity) on - /// this connection. Used for closing outgoing connections that have reached an inactivity - /// expiry time (closing inactive conns for incoming connections is done by the other end). - std::chrono::steady_clock::time_point last_activity; - - /// Updates last_activity to the current time - void activity() { last_activity = std::chrono::steady_clock::now(); } - - /// After more than this much inactivity we will close an idle (outgoing) connection - std::chrono::milliseconds idle_expiry; - }; - - /// Currently peer connections: id -> peer_info. The ID is as returned by connect_remote or a - /// SN pubkey string. - std::unordered_multimap peers; - - /// Maps connection indices (which can change) to ConnectionID values (which are permanent). - /// This is primarily for outgoing sockets, but incoming sockets are here too (with empty-route - /// (and thus unroutable) ConnectionIDs). - std::vector conn_index_to_id; - - /// Maps listening socket ConnectionIDs to connection index values (these don't have peers - /// entries). The keys here have empty routes (and thus aren't actually routable). - std::unordered_map incoming_conn_index; - - /// The next ConnectionID value we should use (for non-SN connections). - std::atomic next_conn_id{1}; - - /// Remotes we are still trying to connect to (via connect_remote(), not connect_sn()); when - /// we pass handshaking we move them out of here and (if set) trigger the on_connect callback. - /// Unlike regular node-to-node peers, these have an extra "HI"/"HELLO" sequence that we used - /// before we consider ourselves connected to the remote. - std::list> pending_connects; - - /// Pending requests that have been sent out but not yet received a matching "REPLY". The value - /// is the timeout timestamp. - std::unordered_map> - pending_requests; - - /// different polling sockets the proxy handler polls: this always contains some internal - /// sockets for inter-thread communication followed by a pollitem for every connection (both - /// incoming and outgoing) in `connections`. We rebuild this from `connections` whenever - /// `pollitems_stale` is set to true. - std::vector pollitems; - - /// If set then rebuild pollitems before the next poll (set when establishing new connections or - /// closing existing ones). - bool pollitems_stale = true; - - /// Rebuilds pollitems to include the internal sockets + all incoming/outgoing sockets. - void rebuild_pollitems(); - - /// The connections to/from remotes we currently have open, both listening and outgoing. Each - /// element [i] here corresponds to an the pollitem_t at pollitems[i+1+poll_internal_size]. - /// (Ideally we'd use one structure, but zmq requires the pollitems be in contiguous storage). - std::vector connections; - - /// Socket we listen on to receive control messages in the proxy thread. Each thread has its own - /// internal "control" connection (returned by `get_control_socket()`) to this socket used to - /// give instructions to the proxy such as instructing it to initiate a connection to a remote - /// or send a message. - zmq::socket_t command{context, zmq::socket_type::router}; - - /// Timers. TODO: once cppzmq adds an interface around the zmq C timers API then switch to it. - struct TimersDeleter { void operator()(void* timers); }; - struct timer_data { std::function function; bool squelch; bool running; int thread; }; - std::unordered_map timer_jobs; - std::unique_ptr timers; -public: - // This needs to be public because we have to be able to call it from a plain C function. - // Nothing external may call it! - void _queue_timer_job(int); -private: - - /// Router socket to reach internal worker threads from proxy - zmq::socket_t workers_socket{context, zmq::socket_type::router}; - - /// indices of idle, active workers - std::vector idle_workers; - - /// Maximum number of general task workers, specified by g`/during construction - int general_workers = std::max(1, std::thread::hardware_concurrency()); - - /// Maximum number of possible worker threads we can have. This is calculated when starting, - /// and equals general_workers plus the sum of all categories' reserved threads counts plus the - /// reserved batch workers count. This is also used to signal a shutdown; we set it to 0 when - /// quitting. - int max_workers; - - /// Number of active workers - int active_workers() const { return workers.size() - idle_workers.size(); } - - /// Worker thread loop. Tagged and start are provided for a tagged worker thread. - void worker_thread(unsigned int index, std::optional tagged = std::nullopt, std::function start = nullptr); - - /// If set, skip polling for one proxy loop iteration (set when we know we have something - /// processible without having to shove it onto a socket, such as scheduling an internal job). - bool proxy_skip_one_poll = false; - - /// Does the proxying work - void proxy_loop(); - - void proxy_conn_cleanup(); - - void proxy_worker_message(std::vector& parts); - - void proxy_process_queue(); - - void proxy_schedule_reply_job(std::function f); - - /// Looks up a peers element given a connect index (for outgoing connections where we already - /// knew the pubkey and SN status) or an incoming zmq message (which has the pubkey and sn - /// status metadata set during initial connection authentication), creating a new peer element - /// if required. - decltype(peers)::iterator proxy_lookup_peer(int conn_index, zmq::message_t& msg); - - /// Handles built-in primitive commands in the proxy thread for things like "BYE" that have to - /// be done in the proxy thread anyway (if we forwarded to a worker the worker would just have - /// to send an instruction back to the proxy to do it). Returns true if one was handled, false - /// to continue with sending to a worker. - bool proxy_handle_builtin(size_t conn_index, std::vector& parts); - - struct run_info; - /// Gets an idle worker's run_info and removes the worker from the idle worker list. If there - /// is no idle worker this creates a new `workers` element for a new worker (and so you should - /// only call this if new workers are permitted). Note that if this creates a new work info the - /// worker will *not* yet be started, so the caller must create the thread (in `.thread`) after - /// setting up the job if `.thread.joinable()` is false. - run_info& get_idle_worker(); - - /// Runs the worker; called after the `run` object has been set up. If the worker thread hasn't - /// been created then it is spawned; otherwise it is sent a RUN command. - void proxy_run_worker(run_info& run); - - /// Sets up a job for a worker then signals the worker (or starts a worker thread) - void proxy_to_worker(size_t conn_index, std::vector& parts); - - /// proxy thread command handlers for commands sent from the outer object QUIT. This doesn't - /// get called immediately on a QUIT command: the QUIT commands tells workers to quit, then this - /// gets called after all works have done so. - void proxy_quit(); - - // Common setup code for setting up an external (incoming or outgoing) socket. - void setup_external_socket(zmq::socket_t& socket); - - // Sets the various properties on an outgoing socket prior to connection. If remote_pubkey is - // provided then the connection will be curve25519 encrypted and authenticate; otherwise it will - // be unencrypted and unauthenticated. Note that the remote end must be in the same mode (i.e. - // either accepting curve connections, or not accepting curve). - void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey = {}); - - /// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and, - /// if a routing prefix is needed, the required prefix (or an empty string if not needed). For - /// an optional connect that fails (or some other connection failure), returns nullptr for the - /// socket. - /// - /// @param pubkey the pubkey to connect to - /// @param connect_hint if we need a new connection and this is non-empty then we *may* use it - /// instead of doing a call to `sn_lookup()`. - /// @param optional if we don't already have a connection then don't establish a new one - /// @param incoming_only only relay this if we have an established incoming connection from the - /// given SN, otherwise don't connect (like `optional`) - /// @param keep_alive the keep alive for the connection, if we establish a new outgoing - /// connection. If we already have an outgoing connection then its keep-alive gets increased to - /// this if currently less than this. - std::pair proxy_connect_sn(std::string_view pubkey, std::string_view connect_hint, - bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive); - - /// CONNECT_SN command telling us to connect to a new pubkey. Returns the socket (which could - /// be existing or a new one). This basically just unpacks arguments and passes them on to - /// proxy_connect_sn(). - std::pair proxy_connect_sn(bt_dict_consumer data); - - /// Opens a new connection to a remote, with callbacks. This is the proxy-side implementation - /// of the `connect_remote()` call. - void proxy_connect_remote(bt_dict_consumer data); - - /// Called to disconnect our remote connection to the given id (if we have one). - void proxy_disconnect(bt_dict_consumer data); - void proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger); - - /// SEND command. Does a connect first, if necessary. - void proxy_send(bt_dict_consumer data); - - /// REPLY command. Like SEND, but only has a listening socket route to send back to and so is - /// weaker (i.e. it cannot reconnect to the SN if the connection is no longer open). - void proxy_reply(bt_dict_consumer data); - - /// Currently active batch/reply jobs; this is the container that owns the Batch instances - std::unordered_set batches; - /// Individual batch jobs waiting to run; .second is the 0-n batch number or -1 for the - /// completion job - using batch_job = std::pair; - std::queue batch_jobs, reply_jobs; - int batch_jobs_active = 0; - int reply_jobs_active = 0; - int batch_jobs_reserved = -1; - int reply_jobs_reserved = -1; - /// Runs any queued batch jobs - void proxy_run_batch_jobs(std::queue& jobs, int reserved, int& active, bool reply); - - /// BATCH command. Called with a Batch (see lokimq/batch.h) object pointer for the proxy to - /// take over and queue batch jobs. - void proxy_batch(detail::Batch* batch); - - /// TIMER command. Called with a serialized list containing: function pointer to assume - /// ownership of, an interval count (in ms), and whether or not jobs should be squelched (see - /// `add_timer()`). - void proxy_timer(bt_list_consumer timer_data); - - /// Same, but deserialized - void proxy_timer(std::function job, std::chrono::milliseconds interval, bool squelch, int thread); - - /// ZAP (https://rfc.zeromq.org/spec:27/ZAP/) authentication handler; this does non-blocking - /// processing of any waiting authentication requests for new incoming connections. - void process_zap_requests(); - - /// Handles a control message from some outer thread to the proxy - void proxy_control_message(std::vector& parts); - - /// Closing any idle connections that have outlived their idle time. Note that this only - /// affects outgoing connections; incomings connections are the responsibility of the other end. - void proxy_expire_idle_peers(); - - /// Helper method to actually close a remote connection and update the stuff that needs updating. - void proxy_close_connection(size_t removed, std::chrono::milliseconds linger); - - /// Closes an outgoing connection immediately, updates internal variables appropriately. - /// Returns the next iterator (the original may or may not be removed from peers, depending on - /// whether or not it also has an active incoming connection). - decltype(peers)::iterator proxy_close_outgoing(decltype(peers)::iterator it); - - struct category { - Access access; - std::unordered_map> commands; - unsigned int reserved_threads = 0; - unsigned int active_threads = 0; - int max_queue = 200; - int queued = 0; - - category(Access access, unsigned int reserved_threads, int max_queue) - : access{access}, reserved_threads{reserved_threads}, max_queue{max_queue} {} - }; - - /// Categories, mapped by category name. - std::unordered_map categories; - - /// For enabling backwards compatibility with command renaming: this allows mapping one command - /// to another in a different category (which happens before the category and command lookup is - /// done). - std::unordered_map command_aliases; - - using cat_call_t = std::pair*>; - /// Retrieve category and callback from a command name, including alias mapping. Warns on - /// invalid commands and returns nullptrs. The command name will be updated in place if it is - /// aliased to another command. - cat_call_t get_command(std::string& command); - - /// Checks a peer's authentication level. Returns true if allowed, warns and returns false if - /// not. - bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, - zmq::message_t& command, const cat_call_t& cat_call, std::vector& data); - - struct injected_task { - category& cat; - std::string command; - std::string remote; - std::function callback; - }; - - /// Injects a external callback to be handled by a worker; this is the proxy side of - /// inject_task(). - void proxy_inject_task(injected_task task); - - - /// Set of active service nodes. - pubkey_set active_service_nodes; - - /// Resets or updates the stored set of active SN pubkeys - void proxy_set_active_sns(std::string_view data); - void proxy_set_active_sns(pubkey_set pubkeys); - void proxy_update_active_sns(bt_list_consumer data); - void proxy_update_active_sns(pubkey_set added, pubkey_set removed); - void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed); - - /// Details for a pending command; such a command already has authenticated access and is just - /// waiting for a thread to become available to handle it. This also gets used (via the - /// `callback` variant) for injected external jobs to be able to integrate some external - /// interface with the lokimq job queue. - struct pending_command { - category& cat; - std::string command; - std::vector data_parts; - std::variant< - const std::pair*, // Normal command callback - std::function // Injected external callback - > callback; - ConnectionID conn; - Access access; - std::string remote; - - // Normal ctor for an actual lmq command being processed - pending_command(category& cat, std::string command, std::vector data_parts, - const std::pair* callback, ConnectionID conn, Access access, std::string remote) - : cat{cat}, command{std::move(command)}, data_parts{std::move(data_parts)}, - callback{callback}, conn{std::move(conn)}, access{std::move(access)}, remote{std::move(remote)} {} - - // Ctor for an injected external command. - pending_command(category& cat, std::string command, std::function callback, std::string remote) - : cat{cat}, command{std::move(command)}, callback{std::move(callback)}, remote{std::move(remote)} {} - }; - std::list pending_commands; - - - /// End of proxy-specific members - /////////////////////////////////////////////////////////////////////////////////// - - - /// Structure that contains the data for a worker thread - both the thread itself, plus any - /// transient data we are passing into the thread. - struct run_info { - bool is_batch_job = false; - bool is_reply_job = false; - bool is_tagged_thread_job = false; - bool is_injected = false; - - // resets the job type bools, above. - void reset() { is_batch_job = is_reply_job = is_tagged_thread_job = is_injected = false; } - - // If is_batch_job is false then these will be set appropriate (if is_batch_job is true then - // these shouldn't be accessed and likely contain stale data). Note that if the command is - // an external, injected command then conn, access, conn_route, and data_parts will be - // empty/default constructed. - category *cat; - std::string command; - ConnectionID conn; // The connection (or SN pubkey) to reply on/to. - Access access; // The access level of the invoker (actual level, can be higher than the command's requirement) - std::string remote; // The remote address from which we received the request. - std::string conn_route; // if non-empty this is the reply routing prefix (for incoming connections) - std::vector data_parts; - - // If is_batch_job true then these are set (if is_batch_job false then don't access these!): - int batch_jobno; // >= 0 for a job, -1 for the completion job - - // The callback or batch job to run. The first of these is for regular tasks, the second - // for batch jobs, the third for injected external tasks. - std::variant< - const std::pair*, - detail::Batch*, - std::function - > to_run; - - // These belong to the proxy thread and must not be accessed by a worker: - std::thread worker_thread; - size_t worker_id; // The index in `workers` (0-n) or index+1 in `tagged_workers` (1-n) - std::string worker_routing_id; // "w123" where 123 == worker_id; "n123" for tagged threads. - - /// Loads the run info with an incoming command - run_info& load(category* cat, std::string command, ConnectionID conn, Access access, std::string remote, - std::vector data_parts, const std::pair* callback); - - /// Loads the run info with an injected external command - run_info& load(category* cat, std::string command, std::string remote, std::function callback); - - /// Loads the run info with a stored pending command - run_info& load(pending_command&& pending); - - /// Loads the run info with a batch job - run_info& load(batch_job&& bj, bool reply_job = false, int tagged_thread = 0); - }; - /// Data passed to workers for the RUN command. The proxy thread sets elements in this before - /// sending RUN to a worker then the worker uses it to get call info, and only allocates it - /// once, before starting any workers. Workers may only access their own index and may not - /// change it. - std::vector workers; - - /// Workers that are reserved for tagged thread tasks (as created with add_tagged_thread). The - /// queue here is similar to worker_jobs, but contains only the tagged thread's jobs. The bool - /// is whether the worker is currently busy (true) or available (false). - std::vector>> tagged_workers; - -public: - /** - * LokiMQ constructor. This constructs the object but does not start it; you will typically - * want to first add categories and commands, then finish startup by invoking `start()`. - * (Categories and commands cannot be added after startup). - * - * @param pubkey the public key (32-byte binary string). For a service node this is the service - * node x25519 keypair. For non-service nodes this (and privkey) can be empty strings to - * automatically generate an ephemeral keypair. - * - * @param privkey the service node's private key (32-byte binary string), or empty to generate - * one. - * - * @param service_node - true if this instance should be considered a service node for the - * purpose of allowing "Access::local_sn" remote calls. (This should be true if we are - * *capable* of being a service node, whether or not we are currently actively). If specified - * as true then the pubkey and privkey values must not be empty. - * - * @param sn_lookup function that takes a pubkey key (32-byte binary string) and returns a - * connection string such as "tcp://1.2.3.4:23456" to which a connection should be established - * to reach that service node. Note that this function is only called if there is no existing - * connection to that service node, and that the function is never called for a connection to - * self (that uses an internal connection instead). Also note that the service node must be - * listening in curve25519 mode (otherwise we couldn't verify its authenticity). Should return - * empty for not found or if SN lookups are not supported. - * - * @param allow_incoming is a callback that LokiMQ can use to determine whether an incoming - * connection should be allowed at all and, if so, whether the connection is from a known - * service node. Called with the connecting IP, the remote's verified x25519 pubkey, and the - * called on incoming connections with the (verified) incoming connection - * pubkey (32-byte binary string) to determine whether the given SN should be allowed to - * connect. - * - * @param log a function or callable object that writes a log message. If omitted then all log - * messages are suppressed. - * - * @param level the initial log level; defaults to warn. The log level can be changed later by - * calling log_level(...). - */ - LokiMQ( std::string pubkey, - std::string privkey, - bool service_node, - SNRemoteAddress sn_lookup, - Logger logger = [](LogLevel, const char*, int, std::string) { }, - LogLevel level = LogLevel::warn); - - /** - * Simplified LokiMQ constructor for a non-listening client or simple listener without any - * outgoing SN connection lookup capabilities. The LokiMQ object will not be able to establish - * new connections (including reconnections) to service nodes by pubkey. - */ - explicit LokiMQ( - Logger logger = [](LogLevel, const char*, int, std::string) { }, - LogLevel level = LogLevel::warn) - : LokiMQ("", "", false, [](auto) { return ""s; /*no peer lookups*/ }, std::move(logger), level) {} - - /** - * Destructor; instructs the proxy to quit. The proxy tells all workers to quit, waits for them - * to quit and rejoins the threads then quits itself. The outer thread (where the destructor is - * running) rejoins the proxy thread. - */ - ~LokiMQ(); - - /// Sets the log level of the LokiMQ object. - void log_level(LogLevel level); - - /// Gets the log level of the LokiMQ object. - LogLevel log_level() const; - - /** - * Add a new command category. This method may not be invoked after `start()` has been called. - * This method is also not thread safe, and is generally intended to be called (along with - * add_command) immediately after construction and immediately before calling start(). - * - * @param name - the category name which must consist of one or more characters and may not - * contain a ".". - * - * @param access_level the access requirements for remote invocation of the commands inside this - * category. - * - * @param reserved_threads if non-zero then the worker thread pool will ensure there are at at - * least this many threads either current processing or available to process commands in this - * category. This is used to ensure that a category's commands can be invoked even if - * long-running commands in some other category are currently using all worker threads. This - * can increase the number of worker threads above the `general_workers` parameter given in the - * constructor, but will only do so if the need arised: that is, if a command request arrives - * for a category when all workers are busy and no worker is currently processing any command in - * that category. - * - * @param max_queue is the maximum number of incoming messages in this category that we will - * queue up when waiting for a worker to become available for this category. Once the queue for - * a category exceeds this many incoming messages then new messages will be dropped until some - * messages are processed off the queue. -1 means unlimited, 0 means we will never queue (which - * means just dropping messages for this category if no workers are available to instantly - * handle the request). - * - * @returns a CatHelper object that makes adding commands slightly less verbose (see the - * CatHelper describe, above). - */ - CatHelper add_category(std::string name, Access access_level, unsigned int reserved_threads = 0, int max_queue = 200); - - /** - * Adds a new command to an existing category. This method may not be invoked after `start()` - * has been called. - * - * @param category - the category name (must already be created by a call to `add_category`) - * - * @param name - the command name, without the `category.` prefix. - * - * @param callback - a callable object which is callable as `callback(zeromq::Message &)` - */ - void add_command(const std::string& category, std::string name, CommandCallback callback); - - /** - * Adds a new "request" command to an existing category. These commands are just like normal - * commands, but are expected to call `msg.send_reply()` with any data parts on every request, - * while normal commands are more general. - * - * Parameters given here are identical to `add_command()`. - */ - void add_request_command(const std::string& category, std::string name, CommandCallback callback); - - /** - * Adds a command alias; this is intended for temporary backwards compatibility: if any aliases - * are defined then every command (not just aliased ones) has to be checked on invocation to see - * if it is defined in the alias list. May not be invoked after `start()`. - * - * Aliases should follow the `category.command` format for both the from and to names, and - * should only be called for `to` categories that are already defined. The category name is not - * currently enforced on the `from` name (for backwards compatility with Loki's quorumnet code) - * but will be at some point. - * - * Access permissions for an aliased command depend only on the mapped-to value; for example, if - * `cat.meow` is aliased to `dog.bark` then it is the access permissions on `dog` that apply, - * not those of `cat`, even if `cat` is more restrictive than `dog`. - */ - void add_command_alias(std::string from, std::string to); - - /** Creates a "tagged thread" and starts it immediately. A tagged thread is one that batches, - * jobs, and timer jobs can be sent to specifically, typically to perform coordination of some - * thread-unsafe work. - * - * Tagged threads will *only* process jobs sent specifically to them; they do not participate in - * the thread pool used for regular jobs. Each tagged thread also has its own job queue - * completely separate from any other jobs. - * - * Tagged threads must be created *before* `start()` is called. The name will be used to set the - * thread name in the process table (if supported on the OS). - * - * \param name - the name of the thread; will be used in log messages and (if supported by the - * OS) as the system thread name. - * - * \param start - an optional callback to invoke from the thread as soon as LokiMQ itself starts - * up (i.e. after a call to `start()`). - * - * \returns a TaggedThreadID object that can be passed to job(), batch(), or add_timer() to - * direct the task to the tagged thread. - */ - TaggedThreadID add_tagged_thread(std::string name, std::function start = nullptr); - - /** - * Sets the number of worker threads reserved for batch jobs. If not explicitly called then - * this defaults to half the general worker threads configured (rounded up). This works exactly - * like reserved_threads for a category, but allows to batch jobs. See category for details. - * - * Note that some internal jobs are counted as batch jobs: in particular timers added via - * add_timer() are scheduled as batch jobs. - * - * Cannot be called after start()ing the LokiMQ instance. - */ - void set_batch_threads(int threads); - - /** - * Sets the number of worker threads reserved for handling replies from servers; this is - * mostly for responses to `request()` calls, but also gets used for other network-related - * events such as the ConnectSuccess/ConnectFailure callbacks for establishing remote non-SN - * connections. - * - * Defaults to one-eighth of the number of configured general threads, rounded up. - * - * Cannot be changed after start()ing the LokiMQ instance. - */ - void set_reply_threads(int threads); - - /** - * Sets the number of general worker threads. This is the target number of threads to run that - * we generally try not to exceed. These threads can be used for any command, and will be - * created (up to the limit) on demand. Note that individual categories (or batch jobs) with - * reserved threads can create threads in addition to the amount specified here if necessary to - * fulfill the reserved threads count for the category. - * - * Adjusting this also adjusts the default values of batch and reply threads, above. - * - * Defaults to `std::thread::hardware_concurrency()`. - * - * Cannot be called after start()ing the LokiMQ instance. - */ - void set_general_threads(int threads); - - /** - * Finish starting up: binds to the bind locations given in the constructor and launches the - * proxy thread to handle message dispatching between remote nodes and worker threads. - * - * Things you want to do before calling this: - * - Use `add_category`/`add_command` to set up any commands remote connections can invoke. - * - If any commands require SN authentication, specify a list of currently active service node - * pubkeys via `set_active_sns()` (and make sure this gets updated when things change by - * another `set_active_sns()` or a `update_active_sns()` call). It *is* possible to make the - * initial call after calling `start()`, but that creates a window during which incoming - * remote SN connections will be erroneously treated as non-SN connections. - * - If this LMQ instance should accept incoming connections, set up any listening ports via - * `listen_curve()` and/or `listen_plain()`. - */ - void start(); - - /** Start listening on the given bind address using curve authentication/encryption. Incoming - * connections will only be allowed from clients that already have the server's pubkey, and - * will be encrypted. `allow_connection` is invoked for any incoming connections on this - * address to determine the incoming remote's access and authentication level. - * - * @param bind address - can be any string zmq supports; typically a tcp IP/port combination - * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". - * - * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns AuthLevel::none - * access. - */ - void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); - - /** Start listening on the given bind address in unauthenticated plain text mode. Incoming - * connections can come from anywhere. `allow_connection` is invoked for any incoming - * connections on this address to determine the incoming remote's access and authentication - * level. Note that `allow_connection` here will be called with an empty pubkey. - * - * @param bind address - can be any string zmq supports; typically a tcp IP/port combination - * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". - * - * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns AuthLevel::none - * access. - */ - void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); - - /** - * Try to initiate a connection to the given SN in anticipation of needing a connection in the - * future. If a connection is already established, the connection's idle timer will be reset - * (so that the connection will not be closed too soon). If the given idle timeout is greater - * than the current idle timeout then the timeout increases to the new value; if less than the - * current timeout it is ignored. (Note that idle timeouts only apply if the existing - * connection is an outgoing connection). - * - * Note that this method (along with send) doesn't block waiting for a connection; it merely - * instructs the proxy thread that it should establish a connection. - * - * @param pubkey - the public key (32-byte binary string) of the service node to connect to - * @param keep_alive - the connection will be kept alive if there was valid activity within - * the past `keep_alive` milliseconds. If an outgoing connection already - * exists, the longer of the existing and the given keep alive is used. - * (Note that the default applied here is much longer than the default for an - * implicit connect() by calling send() directly.) - * @param hint - if non-empty and a new outgoing connection needs to be made this hint value - * may be used instead of calling the lookup function. (Note that there is no - * guarantee that the hint will be used; it is only usefully specified if the - * connection address has already been incidentally determined). - * - * @returns a ConnectionID that identifies an connection with the given SN. Typically you - * *don't* need to worry about this (and can just discard it): you can always simply pass the - * pubkey as a string wherever a ConnectionID is called. - */ - ConnectionID connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive = 5min, std::string_view hint = {}); - - /** - * Establish a connection to the given remote with callbacks invoked on a successful or failed - * connection. Returns a ConnectionID associated with the connection being attempted. It is - * possible to send to the remote before the successful callback is invoked, but there is no - * guarantee that the messages will be delivered (e.g. if the connection ultimately fails). - * - * For connections to a service node you generally want connect_sn() instead (which verifies - * that it is talking to the SN and encrypts the connection). - * - * Unlike `connect_sn`, the connection established here will be kept open indefinitely (until - * you call disconnect). - * - * The `on_connect` and `on_failure` callbacks are invoked when a connection has been - * established or failed to establish. - * - * @param remote the remote connection address either as implicitly from a string or as a full - * lokimq::address object; see address.h for details. This specifies both the connection - * address and whether curve encryption should be used. - * @param on_connect called with the identifier after the connection has been established. - * @param on_failure called with the identifier and failure message if we fail to connect. - * @param auth_level determines the authentication level of the remote for issuing commands to - * us. The default is `AuthLevel::none`. - * @param timeout how long to try before aborting the connection attempt and calling the - * on_failure callback. Note that the connection can fail for various reasons before the - * timeout. - * - * @param returns ConnectionID that uniquely identifies the connection to this remote node. In - * order to talk to it you will need the returned value (or a copy of it). - */ - ConnectionID connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure, - AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT); - - /// Same as the above, but takes the address as a string_view and constructs an `address` from - /// it. - ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, - AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT) { - return connect_remote(address{remote}, std::move(on_connect), std::move(on_failure), auth_level, timeout); - } - - /// Deprecated version of the above that takes the remote address and remote pubkey for curve - /// encryption as separate arguments. New code should either use a pubkey-embedded address - /// string, or specify remote address and pubkey with an `address` object such as: - /// connect_remote(address{remote, pubkey}, ...) - [[deprecated("use connect_remote() with a lokimq::address instead")]] - ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, - std::string_view pubkey, - AuthLevel auth_level = AuthLevel::none, - std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT); - - /** - * Disconnects an established outgoing connection established with `connect_remote()` (or, less - * commonly, `connect_sn()`). - * - * @param id the connection id, as returned by `connect_remote()` or the SN pubkey. - * - * @param linger how long to allow the connection to linger while there are still pending - * outbound messages to it before disconnecting and dropping any pending messages. (Note that - * this lingering is internal; the disconnect_remote() call does not block). The default is 1 - * second. - * - * If given a pubkey, we try to close an outgoing connection to the given SN if one exists; note - * however that this is often not particularly useful as messages to that SN can immediately - * reopen the connection. - */ - void disconnect(ConnectionID id, std::chrono::milliseconds linger = 1s); - - /** - * Queue a message to be relayed to the given service node or remote without requiring a reply. - * LokiMQ will attempt to relay the message (first connecting and handshaking to the remote SN - * if not already connected). - * - * If a new connection is established it will have a relatively short (30s) idle timeout. If - * the connection should stay open longer you should either call `connect(pubkey, IDLETIME)` or - * pass a a `send_option::keep_alive{IDLETIME}` in `opts`. - * - * Note that this method (along with connect) doesn't block waiting for a connection or for the - * message to send; it merely instructs the proxy thread that it should send. ZMQ will - * generally try hard to deliver it (reconnecting if the connection fails), but if the - * connection fails persistently the message will eventually be dropped. - * - * @param remote - either a ConnectionID value returned by connect_remote, or a service node - * pubkey string. In the latter case, sending the message may trigger a new - * connection being established to the service node (i.e. you do not have to - * call connect() first). - * @param cmd - the first data frame value which is almost always the remote "category.command" name - * @param opts - any number of std::string (or string_views) and send options. Each send option - * affects how the send works; each string becomes a message part. - * - * Example: - * - * // Send to a SN, connecting to it if we aren't already connected: - * lmq.send(pubkey, "hello.world", "abc", send_option::hint("tcp://localhost:1234"), "def"); - * - * // Start connecting to a remote and immediately queue a message for it - * auto conn = lmq.connect_remote("tcp://127.0.0.1:1234", - * [](ConnectionID) { std::cout << "connected\n"; }, - * [](ConnectionID, string_view why) { std::cout << "connection failed: " << why << \n"; }); - * lmq.send(conn, "hello.world", "abc", "def"); - * - * Both of these send the command `hello.world` to the given pubkey, containing additional - * message parts "abc" and "def". In the first case, if not currently connected, the given - * connection hint may be used rather than performing a connection address lookup on the pubkey. - */ - template - void send(ConnectionID to, std::string_view cmd, const T&... opts); - - /** Send a command configured as a "REQUEST" command to a service node: the data parts will be - * prefixed with a random identifier. The remote is expected to reply with a ["REPLY", - * , ...] message, at which point we invoke the given callback with any [...] parts - * of the reply. - * - * Like `send()`, a new connection to the service node will be established if not already - * connected. - * - * @param to - the pubkey string or ConnectionID to send this request to - * @param cmd - the command name - * @param callback - the callback to invoke when we get a reply. Called with a true value and - * the data strings when a reply is received, or false with error string(s) indicating the - * failure reason upon failure or timeout. - * @param opts - anything else (i.e. strings, send_options) is forwarded to send(). - * - * Possible error data values: - * - ["TIMEOUT"] - we got no reply within the timeout window - * - ["UNKNOWNCOMMAND"] - the remote did not recognize the given request command - * - ["NO_REPLY_TAG"] - the invoked command is a request command but no reply tag was included - * - ["FORBIDDEN"] - the command requires an authorization level (e.g. Basic or Admin) that we - * do not have. - * - ["FORBIDDEN_SN"] - the command requires service node authentication, but the remote did not - * recognize us as a service node. You *may* want to retry the request a limited number of - * times (but do not retry indefinitely as that can be an infinite loop!) because this is - * typically also followed by a disconnection; a retried message would reconnect and - * reauthenticate which *may* result in picking up the SN authentication. - * - ["NOT_A_SERVICE_NODE"] - this command is only invokable on service nodes, and the remote is - * not running as a service node. - */ - template - void request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T&... opts); - - /** Injects an external task into the lokimq command queue. This is used to allow connecting - * non-LokiMQ requests into the LokiMQ thread pool as if they were ordinary requests, to be - * scheduled as commands of an individual category. For example, you might support rpc requests - * via LokiMQ as `rpc.some_command` and *also* accept them over HTTP. Using `inject_task()` - * allows you to handle processing the request in the same thread pool with the same priority as - * `rpc.*` commands. - * - * @param category - the category name that should handle the request for the purposes of - * scheduling the job. The category must have been added using add_category(). The category - * can be an actual category with added commands, in which case the injected tasks are queued - * along with LMQ requests for that category, or can have no commands to set up a distinct - * category for the injected jobs. - * - * @param command - a command name; this is mainly used for debugging and does not need to - * actually exist (and, in fact, is often less confusing if it does not). It is recommended for - * clarity purposes to use something that doesn't look like a typical command, for example - * "(http)". - * - * @param remote - some free-form identifier of the remote connection. For example, this could - * be a remote IP address. Can be blank if there is nothing suitable. - * - * @param callback - the function to call from a worker thread when the injected task is - * processed. Takes no arguments. - */ - void inject_task(const std::string& category, std::string command, std::string remote, std::function callback); - - /// The key pair this LokiMQ was created with; if empty keys were given during construction then - /// this returns the generated keys. - const std::string& get_pubkey() const { return pubkey; } - const std::string& get_privkey() const { return privkey; } - - /** Updates (or initially sets) LokiMQ's list of service node pubkeys with the given list. - * - * This has two main effects: - * - * - All commands processed after the update will have SN status determined by the new list. - * - All outgoing connections to service nodes that are no longer on the list will be closed. - * This includes both explicit connections (established by `connect_sn()`) and implicit ones - * (established by sending to a SN that wasn't connected). - * - * As this update is potentially quite heavy it is recommended that this be called only when - * necessary--i.e. when the list has changed (or potentially changed), but *not* on a short - * periodic timer. - * - * This method may (and should!) be called before start() to load an initial set of SNs. - * - * Once a full list has been set, updates on changes can either call this again with the new - * list, or use the more efficient update_active_sns() call if incremental results are - * available. - */ - void set_active_sns(pubkey_set pubkeys); - - /** Updates the list of active pubkeys by adding or removing the given pubkeys from the existing - * list. This is more efficient when the incremental information is already available; if it - * isn't, simply call set_active_sns with a new list to have LokiMQ figure out what was added or - * removed. - * - * \param added new pubkeys that were added since the last set_active_sns or update_active_sns - * call. - * - * \param removed pubkeys that were removed from active SN status since the last call. If a - * pubkey is in both `added` and `removed` for some reason then its presence in `removed` will - * be ignored. - */ - void update_active_sns(pubkey_set added, pubkey_set removed); - - /** - * Batches a set of jobs to be executed by workers, optionally followed by a completion function. - * - * Must include lokimq/batch.h to use. - */ - template - void batch(Batch&& batch); - - /** - * Queues a single job to be executed with no return value. This is a shortcut for creating and - * submitting a single-job, no-completion-function batch job. - * - * \param f the callback to invoke - * \param thread an optional tagged thread in which this job should run. You may *not* pass the - * proxy thread here. - */ - void job(std::function f, std::optional = std::nullopt); - - /** - * Adds a timer that gets scheduled periodically in the job queue. Normally jobs are not - * double-booked: that is, a new timed job will not be scheduled if the timer fires before a - * previously scheduled callback of the job has not yet completed. If you want to override this - * (so that, under heavy load or long jobs, there can be more than one of the same job scheduled - * or running at a time) then specify `squelch` as `false`. - * - * \param thread specifies a thread (added with add_tagged_thread()) on which this timer must run. - */ - void add_timer(std::function job, std::chrono::milliseconds interval, bool squelch = true, std::optional = std::nullopt); -}; - -/// Helper class that slightly simplifies adding commands to a category. -/// -/// This allows simplifying: -/// -/// lmq.add_category("foo", ...); -/// lmq.add_command("foo", "a", ...); -/// lmq.add_command("foo", "b", ...); -/// lmq.add_request_command("foo", "c", ...); -/// -/// to: -/// -/// lmq.add_category("foo", ...) -/// .add_command("a", ...) -/// .add_command("b", ...) -/// .add_request_command("b", ...) -/// ; -class CatHelper { - LokiMQ& lmq; - std::string cat; - -public: - CatHelper(LokiMQ& lmq, std::string cat) : lmq{lmq}, cat{std::move(cat)} {} - - CatHelper& add_command(std::string name, LokiMQ::CommandCallback callback) { - lmq.add_command(cat, std::move(name), std::move(callback)); - return *this; - } - - CatHelper& add_request_command(std::string name, LokiMQ::CommandCallback callback) { - lmq.add_request_command(cat, std::move(name), std::move(callback)); - return *this; - } -}; - - -/// Namespace for options to the send() method -namespace send_option { - -template -struct data_parts_impl { - InputIt begin, end; - data_parts_impl(InputIt begin, InputIt end) : begin{std::move(begin)}, end{std::move(end)} {} -}; - -/// Specifies an iterator pair of data parts to send, for when the number of arguments to send() -/// cannot be determined at compile time. The iterator pair must be over strings or string_view (or -/// something convertible to a string_view). -template ()), std::string_view>>> -data_parts_impl data_parts(InputIt begin, InputIt end) { return {std::move(begin), std::move(end)}; } - -/// Specifies a connection hint when passed in to send(). If there is no current connection to the -/// peer then the hint is used to save a call to the SNRemoteAddress to get the connection location. -/// (Note that there is no guarantee that the given hint will be used or that a SNRemoteAddress call -/// will not also be done.) -struct hint { - std::string connect_hint; - // Constructor taking a hint. If the hint is an empty string then no hint will be used. - explicit hint(std::string connect_hint) : connect_hint{std::move(connect_hint)} {} -}; - -/// Does a send() if we already have a connection (incoming or outgoing) with the given peer, -/// otherwise drops the message. -struct optional { - bool is_optional = true; - // Constructor; default construction gives you an optional, but the bool parameter can be - // specified as false to explicitly make a connection non-optional instead. - explicit optional(bool opt = true) : is_optional{opt} {} -}; - -/// Specifies that the message should be sent only if it can be sent on an existing incoming socket, -/// and dropped otherwise. -struct incoming { - bool is_incoming = true; - // Constructor; default construction gives you an incoming-only, but the bool parameter can be - // specified as false to explicitly disable incoming-only behaviour. - explicit incoming(bool inc = true) : is_incoming{inc} {} -}; - -/// Specifies that the message must use an outgoing connection; for messages to a service node the -/// message will be delivered over an existing outgoing connection, if one is established, and a new -/// outgoing connection opened to deliver the message if none is currently established. For non-SN -/// messages, the message will simply be dropped if it is attempting to be sent on an incoming -/// socket, and send otherwise on an outgoing socket (this option is primarily aimed at SN -/// messages). -struct outgoing { - bool is_outgoing = true; - // Constructor; default construction gives you an outgoing-only, but the bool parameter can be - // specified as false to explicitly disable the outgoing-only flag. - explicit outgoing(bool out = true) : is_outgoing{out} {} -}; - -/// Specifies the idle timeout for the connection - if a new or existing outgoing connection is used -/// for the send and its current idle timeout setting is less than this value then it is updated. -struct keep_alive { - std::chrono::milliseconds time; - explicit keep_alive(std::chrono::milliseconds time) : time{std::move(time)} {} -}; - -/// Specifies the amount of time to wait before triggering a failure callback for a request. If a -/// request reply arrives *after* the failure timeout has been triggered then it will be dropped. -/// (This has no effect if specified on a non-request() call). Note that requests failures are only -/// processed in the CONN_CHECK_INTERVAL timer, so it can be up to that much longer than the time -/// specified here before a failure callback is invoked. -struct request_timeout { - std::chrono::milliseconds time; - explicit request_timeout(std::chrono::milliseconds time) : time{std::move(time)} {} -}; - -/// Specifies a callback to invoke if the message couldn't be queued for delivery. There are -/// generally two failure modes here: a full queue, and a send exception. This callback is invoked -/// for both; to only catch full queues see `queue_full` instead. -/// -/// A full queue means there are too many messages queued for delivery already that haven't been -/// delivered yet (i.e. because the remote is slow); this error is potentially recoverable if the -/// remote end wakes up and receives/acknoledges its messages. -/// -/// A send exception is not recoverable: it indicates some failure such as the remote having -/// disconnected or an internal send error. -/// -/// This callback can be used by a caller to log, attempt to resend, or take other appropriate -/// action. -/// -/// Note that this callback is *not* exhaustive for all possible send failures: there are failure -/// cases (such as when a message is queued but the connection fails before delivery) that do not -/// trigger this failure at all; rather this callback only signals an immediate queuing failure. -struct queue_failure { - using callback_t = std::function; - /// Callback; invoked with nullptr for a queue full failure, otherwise will be set to a copy of - /// the raised exception. - callback_t callback; -}; - -/// This is similar to queue_failure_callback, but is only invoked on a (potentially recoverable) -/// full queue failure. Send failures are simply dropped. -struct queue_full { - using callback_t = std::function; - callback_t callback; -}; +using LokiMQ = OxenMQ; } - -namespace detail { - -/// Takes an rvalue reference, moves it into a new instance then returns a uintptr_t value -/// containing the pointer to be serialized to pass (via lokimq queues) from one thread to another. -/// Must be matched with a deserializer_pointer on the other side to reconstitute the object and -/// destroy the intermediate pointer. -template -uintptr_t serialize_object(T&& obj) { - static_assert(std::is_rvalue_reference::value, "serialize_object must be given an rvalue reference"); - auto* ptr = new T{std::forward(obj)}; - return reinterpret_cast(ptr); -} - -/// Takes a uintptr_t as produced by serialize_pointer and the type, converts the serialized value -/// back into a pointer, moves it into a new instance (to be returned) and destroys the -/// intermediate. -template T deserialize_object(uintptr_t ptrval) { - auto* ptr = reinterpret_cast(ptrval); - T ret{std::move(*ptr)}; - delete ptr; - return ret; -} - -// Sends a control message to the given socket consisting of the command plus optional dict -// data (only sent if the data is non-empty). -void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data = {}); - -/// Base case: takes a string-like value and appends it to the message parts -inline void apply_send_option(bt_list& parts, bt_dict&, std::string_view arg) { - parts.emplace_back(arg); -} - -/// `data_parts` specialization: appends a range of serialized data parts to the parts to send -template -void apply_send_option(bt_list& parts, bt_dict&, const send_option::data_parts_impl data) { - for (auto it = data.begin; it != data.end; ++it) - parts.emplace_back(*it); -} - -/// `hint` specialization: sets the hint in the control data -inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::hint& hint) { - control_data["hint"] = hint.connect_hint; -} - -/// `optional` specialization: sets the optional flag in the control data -inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::optional& o) { - control_data["optional"] = o.is_optional; -} - -/// `incoming` specialization: sets the incoming-only flag in the control data -inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::incoming& i) { - control_data["incoming"] = i.is_incoming; -} - -/// `outgoing` specialization: sets the outgoing-only flag in the control data -inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::outgoing& o) { - control_data["outgoing"] = o.is_outgoing; -} - -/// `keep_alive` specialization: increases the outgoing socket idle timeout (if shorter) -inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::keep_alive& timeout) { - control_data["keep_alive"] = timeout.time.count(); -} - -/// `request_timeout` specialization: set the timeout time for a request -inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::request_timeout& timeout) { - control_data["request_timeout"] = timeout.time.count(); -} - -/// `queue_failure` specialization -inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queue_failure f) { - control_data["send_fail"] = serialize_object(std::move(f.callback)); -} - -/// `queue_full` specialization -inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queue_full f) { - control_data["send_full_q"] = serialize_object(std::move(f.callback)); -} - -/// Extracts a pubkey and auth level from a zmq message received on a *listening* socket. -std::pair extract_metadata(zmq::message_t& msg); - -template -bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts) { - bt_dict control_data; - bt_list parts{{cmd}}; - (detail::apply_send_option(parts, control_data, std::forward(opts)),...); - - if (to.sn()) - control_data["conn_pubkey"] = std::move(to.pk); - else { - control_data["conn_id"] = to.id; - control_data["conn_route"] = std::move(to.route); - } - control_data["send"] = std::move(parts); - return control_data; - -} - -} // namespace detail - - -template -void LokiMQ::send(ConnectionID to, std::string_view cmd, const T&... opts) { - detail::send_control(get_control_socket(), "SEND", - bt_serialize(detail::build_send(std::move(to), cmd, opts...))); -} - -std::string make_random_string(size_t size); - -template -void LokiMQ::request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T &...opts) { - const auto reply_tag = make_random_string(15); // 15 random bytes is lots and should keep us in most stl implementations' small string optimization - bt_dict control_data = detail::build_send(std::move(to), cmd, reply_tag, opts...); - control_data["request"] = true; - control_data["request_callback"] = detail::serialize_object(std::move(callback)); - control_data["request_tag"] = std::string_view{reply_tag}; - detail::send_control(get_control_socket(), "SEND", bt_serialize(std::move(control_data))); -} - -template -void Message::send_back(std::string_view command, Args&&... args) { - lokimq.send(conn, command, send_option::optional{!conn.sn()}, std::forward(args)...); -} - -template -void Message::send_reply(Args&&... args) { - assert(!reply_tag.empty()); - lokimq.send(conn, "REPLY", reply_tag, send_option::optional{!conn.sn()}, std::forward(args)...); -} - -template -void Message::send_request(std::string_view cmd, Callback&& callback, Args&&... args) { - lokimq.request(conn, cmd, std::forward(callback), - send_option::optional{!conn.sn()}, std::forward(args)...); -} - -// When log messages are invoked we strip out anything before this in the filename: -constexpr std::string_view LOG_PREFIX{"lokimq/", 7}; -inline std::string_view trim_log_filename(std::string_view local_file) { - auto chop = local_file.rfind(LOG_PREFIX); - if (chop != local_file.npos) - local_file.remove_prefix(chop); - return local_file; -} - -template -void LokiMQ::log(LogLevel lvl, const char* file, int line, const T&... stuff) { - if (log_level() < lvl) - return; - - std::ostringstream os; - (os << ... << stuff); - logger(lvl, trim_log_filename(file).data(), line, os.str()); -} - -std::ostream &operator<<(std::ostream &os, LogLevel lvl); - -} // namespace lokimq - -// vim:sw=4:et diff --git a/lokimq/message.h b/lokimq/message.h index c7abff3..a0e8ad5 100644 --- a/lokimq/message.h +++ b/lokimq/message.h @@ -1,57 +1,4 @@ #pragma once -#include -#include "connections.h" +#include "../oxenmq/message.h" -namespace lokimq { - -class LokiMQ; - -/// Encapsulates an incoming message from a remote connection with message details plus extra -/// info need to send a reply back through the proxy thread via the `reply()` method. Note that -/// this object gets reused: callbacks should use but not store any reference beyond the callback. -class Message { -public: - LokiMQ& lokimq; ///< The owning LokiMQ object - std::vector data; ///< The provided command data parts, if any. - ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status. - std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`. - Access access; ///< The access level of the invoker. This can be higher than the access level of the command, for example for an admin invoking a basic command. - std::string remote; ///< Some sort of remote address from which the request came. Often "IP" for TCP connections and "localhost:UID:GID:PID" for UDP connections. - - /// Constructor - Message(LokiMQ& lmq, ConnectionID cid, Access access, std::string remote) - : lokimq{lmq}, conn{std::move(cid)}, access{std::move(access)}, remote{std::move(remote)} {} - - // Non-copyable - Message(const Message&) = delete; - Message& operator=(const Message&) = delete; - - /// Sends a command back to whomever sent this message. Arguments are forwarded to send() but - /// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e. - /// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to - /// establish a new connection to the SN if no longer connected. For non-SN messages the reply - /// will be attempted using the available routing information, but if the connection has already - /// been closed the reply will be dropped. - /// - /// If you want to send a non-strong reply even when the remote is a service node then add - /// an explicit `send_option::optional()` argument. - template - void send_back(std::string_view, Args&&... args); - - /// Sends a reply to a request. This takes no command: the command is always the built-in - /// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other - /// arguments are as in `send_back()`. You should only send one reply for a command expecting - /// replies, though this is not enforced: attempting to send multiple replies will simply be - /// dropped when received by the remote. (Note, however, that it is possible to send multiple - /// messages -- e.g. you could send a reply and then also call send_back() and/or send_request() - /// to send more requests back to the sender). - template - void send_reply(Args&&... args); - - /// Sends a request back to whomever sent this message. This is effectively a wrapper around - /// lmq.request() that takes care of setting up the recipient arguments. - template - void send_request(std::string_view cmd, ReplyCallback&& callback, Args&&... args); -}; - -} +namespace lokimq = oxenmq; diff --git a/lokimq/string_view.h b/lokimq/string_view.h deleted file mode 100644 index a64e015..0000000 --- a/lokimq/string_view.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include - -namespace lokimq { - -// Deprecated type alias for std::string_view -using string_view = std::string_view; - -// Deprecated "foo"_sv literal; you should use "foo"sv (from ) instead. -inline namespace literals { - inline constexpr std::string_view operator""_sv(const char* str, size_t len) { return {str, len}; } -} - -} diff --git a/lokimq/variant.h b/lokimq/variant.h index fb4c9fe..08cb791 100644 --- a/lokimq/variant.h +++ b/lokimq/variant.h @@ -1,103 +1,2 @@ #pragma once -// Workarounds for macos compatibility. On macOS we aren't allowed to touch anything in -// std::variant that could throw if compiling with a target <10.14 because Apple fails hard at -// properly updating their STL. Thus, if compiling in such a mode, we have to introduce -// workarounds. -// -// This header defines a `var` namespace with `var::get` and `var::visit` implementations. On -// everything except broken backwards macos, this is just an alias to `std`. On broken backwards -// macos, we provide implementations that throw std::runtime_error in failure cases since the -// std::bad_variant_access exception can't be touched. -// -// You also get a BROKEN_APPLE_VARIANT macro defined if targetting a problematic mac architecture. - -#include - -#ifdef __APPLE__ -# include -# if defined(__APPLE__) && MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_10_14 -# define BROKEN_APPLE_VARIANT -# endif -#endif - -#ifndef BROKEN_APPLE_VARIANT - -namespace var = std; // Oh look, actual C++17 support - -#else - -// Oh look, apple. - -namespace var { - -// Apple won't let us use std::visit or std::get if targetting some version of macos earlier than -// 10.14 because Apple is awful about not updating their STL. So we have to provide our own, and -// then call these without `std::` -- on crappy macos we'll come here, on everything else we'll ADL -// to the std:: implementation. -template -constexpr T& get(std::variant& var) { - if (auto* v = std::get_if(&var)) return *v; - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} -template -constexpr const T& get(const std::variant& var) { - if (auto* v = std::get_if(&var)) return *v; - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} -template -constexpr const T&& get(const std::variant&& var) { - if (auto* v = std::get_if(&var)) return std::move(*v); - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} -template -constexpr T&& get(std::variant&& var) { - if (auto* v = std::get_if(&var)) return std::move(*v); - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} -template -constexpr auto& get(std::variant& var) { - if (auto* v = std::get_if(&var)) return *v; - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} -template -constexpr const auto& get(const std::variant& var) { - if (auto* v = std::get_if(&var)) return *v; - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} -template -constexpr const auto&& get(const std::variant&& var) { - if (auto* v = std::get_if(&var)) return std::move(*v); - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} -template -constexpr auto&& get(std::variant&& var) { - if (auto* v = std::get_if(&var)) return std::move(*v); - throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; -} - -template -constexpr auto visit_helper(Visitor&& vis, Variant&& var) { - if (var.index() == I) - return vis(var::get(std::forward(var))); - else if constexpr (sizeof...(More) > 0) - return visit_helper(std::forward(vis), std::forward(var)); - else - throw std::runtime_error{"Bad visit -- variant is valueless"}; -} - -template -constexpr auto visit_helper(Visitor&& vis, Variant&& var, std::index_sequence) { - return visit_helper(std::forward(vis), std::forward(var)); -} - -// Only handle a single variant here because multi-variant invocation is notably harder (and we -// don't need it). -template -constexpr auto visit(Visitor&& vis, Variant&& var) { - return visit_helper(std::forward(vis), std::forward(var), - std::make_index_sequence>>{}); -} - -} // namespace var - -#endif +#include "../oxenmq/variant.h" diff --git a/lokimq/version.h b/lokimq/version.h new file mode 100644 index 0000000..d49e184 --- /dev/null +++ b/lokimq/version.h @@ -0,0 +1,4 @@ +#pragma once +#include "../oxenmq/version.h" + +namespace lokimq = oxenmq; diff --git a/lokimq/version.h.in b/lokimq/version.h.in deleted file mode 100644 index 0c400cb..0000000 --- a/lokimq/version.h.in +++ /dev/null @@ -1,5 +0,0 @@ -namespace lokimq { -constexpr int VERSION_MAJOR = @LOKIMQ_VERSION_MAJOR@; -constexpr int VERSION_MINOR = @LOKIMQ_VERSION_MINOR@; -constexpr int VERSION_PATCH = @LOKIMQ_VERSION_PATCH@; -} diff --git a/lokimq/address.cpp b/oxenmq/address.cpp similarity index 98% rename from lokimq/address.cpp rename to oxenmq/address.cpp index 1808f51..27e3496 100644 --- a/lokimq/address.cpp +++ b/oxenmq/address.cpp @@ -9,7 +9,7 @@ #include "base32z.h" #include "base64.h" -namespace lokimq { +namespace oxenmq { constexpr size_t enc_length(address::encoding enc) { return enc == address::encoding::hex ? 64 : @@ -23,13 +23,13 @@ constexpr size_t enc_length(address::encoding enc) { // given: for QR-friendly we only accept hex or base32z (since QR cannot handle base64's alphabet). std::string decode_pubkey(std::string_view& in, bool qr) { std::string pubkey; - if (in.size() >= 64 && lokimq::is_hex(in.substr(0, 64))) { + if (in.size() >= 64 && is_hex(in.substr(0, 64))) { pubkey = from_hex(in.substr(0, 64)); in.remove_prefix(64); - } else if (in.size() >= 52 && lokimq::is_base32z(in.substr(0, 52))) { + } else if (in.size() >= 52 && is_base32z(in.substr(0, 52))) { pubkey = from_base32z(in.substr(0, 52)); in.remove_prefix(52); - } else if (!qr && in.size() >= 43 && lokimq::is_base64(in.substr(0, 43))) { + } else if (!qr && in.size() >= 43 && is_base64(in.substr(0, 43))) { pubkey = from_base64(in.substr(0, 43)); in.remove_prefix(43); if (!in.empty() && in.front() == '=') diff --git a/oxenmq/address.h b/oxenmq/address.h new file mode 100644 index 0000000..96af840 --- /dev/null +++ b/oxenmq/address.h @@ -0,0 +1,210 @@ +// Copyright (c) 2020-2021, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#include +#include +#include +#include + +namespace oxenmq { + +using namespace std::literals; + +/** OxenMQ address abstraction class. This class uses and extends standard ZMQ addresses allowing + * extra parameters to be passed in in a relative standard way. + * + * External ZMQ addresses generally have two forms that we are concerned with: one for TCP and one + * for Unix sockets: + * + * tcp://HOST:PORT -- HOST can be a hostname, IPv4 address, or IPv6 address in [...] + * ipc://PATH -- PATH can be absolute (ipc:///path/to/some.sock) or relative (ipc://some.sock) + * + * but this doesn't carry enough info: in particular, we can connect with two very different + * protocols: curve25519-encrypted, or plaintext, but for curve25519-encrypted we require the + * remote's public key as well to verify the connection. + * + * This class, then, handles this by allowing addresses of: + * + * Standard ZMQ address: these carry no pubkey and so the connection will be unencrypted: + * + * tcp://HOSTNAME:PORT + * ipc://PATH + * + * Non-ZMQ address formats that specify that the connection shall be x25519 encrypted: + * + * curve://HOSTNAME:PORT/PUBKEY -- PUBKEY must be specified in hex (64 characters), base32z (52) + * or base64 (43 or 44 with one '=' trailing padding) + * ipc+curve:///path/to/my.sock/PUBKEY -- same requirements on PUBKEY as above. + * tcp+curve://(whatever) -- alias for curve://(whatever) + * + * We also accept special upper-case TCP-only variants which *only* accept uppercase characters and + * a few required symbols (:, /, $, ., and -) in the string: + * + * TCP://HOSTNAME:PORT + * CURVE://HOSTNAME:PORT/B32ZPUBKEY + * + * These versions are explicitly meant to be used with QR codes; the upper-case-only requirement + * allows a smaller QR code by allowing QR's alphanumeric mode (which allows only [A-Z0-9 $%*+./:-]) + * to be used. Such a QR-friendly address can be created from the qr_address() method. To support + * literal IPv6 addresses we surround the address with $...$ instead of the usual [...]. + * + * Note that this class does very little validate the host argument at all, and no socket path + * validation whatsoever. The only constraint on host is when parsing an encoded address: we check + * that it contains no : at all, or must be a [bracketed] expression that contains only hex + * characters, :'s, or .'s. Otherwise, if you pass broken crap into the hostname, expect broken + * crap out. + */ +struct address { + /// Supported address protocols: TCP connections (tcp), or unix sockets (ipc). + enum class proto { + tcp, + tcp_curve, + ipc, + ipc_curve + }; + /// Supported public key encodings (used when regenerating an augmented address). + enum class encoding { + hex, ///< hexadecimal encoded + base32z, ///< base32z encoded + base64, ///< base64 encoded (*without* trailing = padding) + BASE32Z ///< upper-case base32z encoding, meant for QR encoding + }; + + /// The protocol: one of the `protocol` enum values for tcp or ipc (unix sockets), with or + /// without _curve encryption. + proto protocol = proto::tcp; + /// The host for tcp connections; can be a hostname or IP address. If this is an IPv6 it must be surrounded with [ ]. + std::string host; + /// The port (for tcp connections) + uint16_t port = 0; + /// The socket path (for unix socket connections) + std::string socket; + /// If a curve connection, this is the required remote public key (in bytes) + std::string pubkey; + + /// Default constructor; this gives you an unusable address. + address() = default; + + /** + * Constructs an address by parsing a string_view containing one of the formats listed in the + * class description. This is intentionally implicitly constructible so that you can pass a + * string_view into anything expecting an `address`. + * + * Throw std::invalid_argument if the given address is not parseable. + */ + address(std::string_view addr); + + /** Constructs an address from a remote string and a separate pubkey. Typically `remote` is a + * basic ZMQ connect string, though this is not enforced. Any pubkey information embedded in + * the remote string will be discarded and replaced with the given pubkey string. The result + * will be curve encrypted if `pubkey` is non-empty, plaintext if `pubkey` is empty. + * + * Throws an exception if either addr or pubkey is invalid. + * + * Exactly equivalent to `address a{remote}; a.set_pubkey(pubkey);` + */ + address(std::string_view addr, std::string_view pubkey) : address(addr) { set_pubkey(pubkey); } + + /// Replaces the address's pubkey (if any) with the given pubkey (or no pubkey if empty). If + /// changing from pubkey to no-pubkey or no-pubkey to pubkey then the protocol is update to + /// switch to or from curve encryption. + /// + /// pubkey should be the 32-byte binary pubkey, or an empty string to remove an existing pubkey. + /// + /// Returns the object itself, so that you can chain it. + address& set_pubkey(std::string_view pubkey); + + /// Constructs and builds the ZMQ connection address from the stored connection details. This + /// does not contain any of the curve-related details; those must be specified separately when + /// interfacing with ZMQ. + std::string zmq_address() const; + + /// Returns true if the connection was specified as a curve-encryption-enabled connection, false + /// otherwise. + bool curve() const { return protocol == proto::tcp_curve || protocol == proto::ipc_curve; } + + /// True if the protocol is TCP (either with or without curve) + bool tcp() const { return protocol == proto::tcp || protocol == proto::tcp_curve; } + + /// True if the protocol is unix socket (either with or without curve) + bool ipc() const { return !tcp(); } + + /// Returns the full "augmented" address string (i.e. that could be passed in to the + /// constructor). This will be equivalent (but not necessarily identical) to an augmented + /// string passed into the constructor. Takes an optional encoding format for the pubkey (if + /// any), which defaults to base32z. + std::string full_address(encoding enc = encoding::base32z) const; + + /// Returns a QR-code friendly address string. This returns an all-uppercase version of the + /// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z + /// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we replace the + /// surround the + /// address with $ instead of $ + /// + /// \throws std::logic_error if called on a unix socket address. + std::string qr_address() const; + + /// Returns `.pubkey` but encoded in the given format + std::string encode_pubkey(encoding enc) const; + + /// Returns true if two addresses are identical (i.e. same protocol and relevant protocol + /// arguments). + /// + /// Note that it is possible for addresses to connect to the same socket without being + /// identical: for example, using "foo.sock" and "./foo.sock", or writing IPv6 addresses (or + /// even IPv4 addresses) in slightly different ways). Such equivalent but non-equal values will + /// result in a false return here. + /// + /// Note also that we ignore irrelevant arguments: for example, we don't care whether pubkeys + /// match when comparing two non-curve TCP addresses. + bool operator==(const address& other) const; + /// Negation of == + bool operator!=(const address& other) const { return !operator==(other); } + + /// Factory function that constructs a TCP address from a host and port. The connection will be + /// plaintext. If the host is an IPv6 address it *must* be surrounded with [ and ]. + static address tcp(std::string host, uint16_t port); + + /// Factory function that constructs a curve-encrypted TCP address from a host, port, and remote + /// pubkey. The pubkey must be 32 bytes. As above, IPv6 addresses must be specified as [addr]. + static address tcp_curve(std::string host, uint16_t, std::string pubkey); + + /// Factory function that constructs a unix socket address from a path. The connection will be + /// plaintext (which is usually fine for a socket since unix sockets are local machine). + static address ipc(std::string path); + + /// Factory function that constructs a unix socket address from a path and remote pubkey. The + /// connection will be curve25519 encrypted; the remote pubkey must be 32 bytes. + static address ipc_curve(std::string path, std::string pubkey); +}; + +// Outputs address.full_address() when sent to an ostream. +std::ostream& operator<<(std::ostream& o, const address& a); + +} diff --git a/lokimq/auth.cpp b/oxenmq/auth.cpp similarity index 95% rename from lokimq/auth.cpp rename to oxenmq/auth.cpp index 985ff63..dfe1f6c 100644 --- a/lokimq/auth.cpp +++ b/oxenmq/auth.cpp @@ -1,10 +1,10 @@ -#include "lokimq.h" +#include "oxenmq.h" #include "hex.h" -#include "lokimq-internal.h" +#include "oxenmq-internal.h" #include #include -namespace lokimq { +namespace oxenmq { std::ostream& operator<<(std::ostream& o, AuthLevel a) { return o << to_string(a); @@ -31,7 +31,7 @@ std::string zmtp_metadata(std::string_view key, std::string_view value) { } -bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, +bool OxenMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, zmq::message_t& cmd, const cat_call_t& cat_call, std::vector& data) { auto command = view(cmd); std::string reply; @@ -45,7 +45,7 @@ bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& reply = "FORBIDDEN"; } else if (cat_call.first->access.local_sn && !local_service_node) { LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), - ": that command is only available when this LokiMQ is running in service node mode"); + ": that command is only available when this OxenMQ is running in service node mode"); reply = "NOT_A_SERVICE_NODE"; } else if (cat_call.first->access.remote_sn && !peer.service_node) { LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), @@ -81,7 +81,7 @@ bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& return false; } -void LokiMQ::set_active_sns(pubkey_set pubkeys) { +void OxenMQ::set_active_sns(pubkey_set pubkeys) { if (proxy_thread.joinable()) { auto data = bt_serialize(detail::serialize_object(std::move(pubkeys))); detail::send_control(get_control_socket(), "SET_SNS", data); @@ -89,10 +89,10 @@ void LokiMQ::set_active_sns(pubkey_set pubkeys) { proxy_set_active_sns(std::move(pubkeys)); } } -void LokiMQ::proxy_set_active_sns(std::string_view data) { +void OxenMQ::proxy_set_active_sns(std::string_view data) { proxy_set_active_sns(detail::deserialize_object(bt_deserialize(data))); } -void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) { +void OxenMQ::proxy_set_active_sns(pubkey_set pubkeys) { pubkey_set added, removed; for (auto it = pubkeys.begin(); it != pubkeys.end(); ) { auto& pk = *it; @@ -118,7 +118,7 @@ void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) { proxy_update_active_sns_clean(std::move(added), std::move(removed)); } -void LokiMQ::update_active_sns(pubkey_set added, pubkey_set removed) { +void OxenMQ::update_active_sns(pubkey_set added, pubkey_set removed) { LMQ_LOG(info, "uh, ", added.size()); if (proxy_thread.joinable()) { std::array data; @@ -129,12 +129,12 @@ void LokiMQ::update_active_sns(pubkey_set added, pubkey_set removed) { proxy_update_active_sns(std::move(added), std::move(removed)); } } -void LokiMQ::proxy_update_active_sns(bt_list_consumer data) { +void OxenMQ::proxy_update_active_sns(bt_list_consumer data) { auto added = detail::deserialize_object(data.consume_integer()); auto remed = detail::deserialize_object(data.consume_integer()); proxy_update_active_sns(std::move(added), std::move(remed)); } -void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) { +void OxenMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) { // We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting // values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists // to the _clean version. @@ -167,7 +167,7 @@ void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) { proxy_update_active_sns_clean(std::move(added), std::move(removed)); } -void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) { +void OxenMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) { LMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys"); // For anything we remove we want close the connection to the SN (if outgoing), and remove the @@ -192,7 +192,7 @@ void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) active_service_nodes.insert(std::move(pk)); } -void LokiMQ::process_zap_requests() { +void OxenMQ::process_zap_requests() { for (std::vector frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) { #ifndef NDEBUG if (log_level() >= LogLevel::trace) { diff --git a/oxenmq/auth.h b/oxenmq/auth.h new file mode 100644 index 0000000..f0fffc7 --- /dev/null +++ b/oxenmq/auth.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include +#include +#include + +namespace oxenmq { + +/// Authentication levels for command categories and connections +enum class AuthLevel { + denied, ///< Not actually an auth level, but can be returned by the AllowFunc to deny an incoming connection. + none, ///< No authentication at all; any random incoming ZMQ connection can invoke this command. + basic, ///< Basic authentication commands require a login, or a node that is specifically configured to be a public node (e.g. for public RPC). + admin, ///< Advanced authentication commands require an admin user, either via explicit login or by implicit login from localhost. This typically protects administrative commands like shutting down, starting mining, or access sensitive data. +}; + +std::ostream& operator<<(std::ostream& os, AuthLevel a); + +/// The access level for a command category +struct Access { + /// Minimum access level required + AuthLevel auth; + /// If true only remote SNs may call the category commands + bool remote_sn; + /// If true the category requires that the local node is a SN + bool local_sn; + + /// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an + /// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both + /// remote and local sn set to false). + Access(AuthLevel auth = AuthLevel::none, bool remote_sn = false, bool local_sn = false) + : auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {} +}; + +/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey). +/// Falls back to std::hash if given a string smaller than a size_t. +struct already_hashed { + size_t operator()(const std::string& s) const { + if (s.size() < sizeof(size_t)) + return std::hash{}(s); + size_t hash; + std::memcpy(&hash, &s[0], sizeof(hash)); + return hash; + } +}; + +/// std::unordered_set specialization for specifying pubkeys (used, in particular, by +/// OxenMQ::set_active_sns and OxenMQ::update_active_sns); this is a std::string unordered_set that +/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the +/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like +/// pubkeys and a terrible hash choice for anything else). +using pubkey_set = std::unordered_set; + + +} diff --git a/oxenmq/base32z.h b/oxenmq/base32z.h new file mode 100644 index 0000000..074e522 --- /dev/null +++ b/oxenmq/base32z.h @@ -0,0 +1,205 @@ +// Copyright (c) 2019-2021, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#include +#include +#include +#include +#include +#include "byte_type.h" + +namespace oxenmq { + +namespace detail { + +/// Compile-time generated lookup tables for base32z conversion. This is case insensitive (though +/// for byte -> b32z conversion we always produce lower case). +struct b32z_table { + // Store the 0-31 decoded value of every possible char; all the chars that aren't valid are set + // to 0. (If you don't trust your data, check it with is_base32z first, which uses these 0's + // to detect invalid characters -- which is why we want a full 256 element array). + char from_b32z_lut[256]; + // Store the encoded character of every 0-31 (5 bit) value. + char to_b32z_lut[32]; + + // constexpr constructor that fills out the above (and should do it at compile time for any half + // decent compiler). + constexpr b32z_table() noexcept : from_b32z_lut{}, + to_b32z_lut{ + 'y', 'b', 'n', 'd', 'r', 'f', 'g', '8', 'e', 'j', 'k', 'm', 'c', 'p', 'q', 'x', + 'o', 't', '1', 'u', 'w', 'i', 's', 'z', 'a', '3', '4', '5', 'h', '7', '6', '9' + } + { + for (unsigned char c = 0; c < 32; c++) { + unsigned char x = to_b32z_lut[c]; + from_b32z_lut[x] = c; + if (x >= 'a' && x <= 'z') + from_b32z_lut[x - 'a' + 'A'] = c; + } + } + // Convert a b32z encoded character into a 0-31 value + constexpr char from_b32z(unsigned char c) const noexcept { return from_b32z_lut[c]; } + // Convert a 0-31 value into a b32z encoded character + constexpr char to_b32z(unsigned char b) const noexcept { return to_b32z_lut[b]; } +} constexpr b32z_lut; + +// This main point of this static assert is to force the compiler to compile-time build the constexpr tables. +static_assert(b32z_lut.from_b32z('w') == 20 && b32z_lut.from_b32z('T') == 17 && b32z_lut.to_b32z(5) == 'f', ""); + +} // namespace detail + +/// Converts bytes into a base32z encoded character sequence. +template +void to_base32z(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "to_base32z requires chars/bytes"); + int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in [0, 4] + std::uint_fast16_t r = 0; + while (begin != end) { + r = r << 8 | static_cast(*begin++); + + // we just added 8 bits, so we can *always* consume 5 to produce one character, so (net) we + // are adding 3 bits. + bits += 3; + *out++ = detail::b32z_lut.to_b32z(r >> bits); // Right-shift off the bits we aren't consuming right now + + // Drop the bits we don't want to keep (because we just consumed them) + r &= (1 << bits) - 1; + + if (bits >= 5) { // We have enough bits to produce a second character; essentially the same as above + bits -= 5; // Except now we are just consuming 5 without having added any more + *out++ = detail::b32z_lut.to_b32z(r >> bits); + r &= (1 << bits) - 1; + } + } + + if (bits > 0) // We hit the end, but still have some unconsumed bits so need one final character to append + *out++ = detail::b32z_lut.to_b32z(r << (5 - bits)); +} + +/// Creates a base32z string from an iterator pair of a byte sequence. +template +std::string to_base32z(It begin, It end) { + std::string base32z; + if constexpr (std::is_base_of_v::iterator_category>) + base32z.reserve((std::distance(begin, end)*8 + 4) / 5); // == bytes*8/5, rounded up. + to_base32z(begin, end, std::back_inserter(base32z)); + return base32z; +} + +/// Creates a base32z string from an iterable, std::string-like object +template +std::string to_base32z(std::basic_string_view s) { return to_base32z(s.begin(), s.end()); } +inline std::string to_base32z(std::string_view s) { return to_base32z<>(s); } + +/// Returns true if all elements in the range are base32z characters +template +constexpr bool is_base32z(It begin, It end) { + static_assert(sizeof(decltype(*begin)) == 1, "is_base32z requires chars/bytes"); + for (; begin != end; ++begin) { + auto c = static_cast(*begin); + if (detail::b32z_lut.from_b32z(c) == 0 && !(c == 'y' || c == 'Y')) + return false; + } + return true; +} + +/// Returns true if all elements in the string-like value are base32z characters +template +constexpr bool is_base32z(std::basic_string_view s) { return is_base32z(s.begin(), s.end()); } +constexpr bool is_base32z(std::string_view s) { return is_base32z<>(s); } + +/// Converts a sequence of base32z digits to bytes. Undefined behaviour if any characters are not +/// valid base32z alphabet characters. It is permitted for the input and output ranges to overlap +/// as long as `out` is no earlier than `begin`. Note that if you pass in a sequence that could not +/// have been created by a base32z encoding of a byte sequence, we treat the excess bits as if they +/// were not provided. +/// +/// For example, "yyy" represents a 15-bit value, but a byte sequence is either 8-bit (requiring 2 +/// characters) or 16-bit (requiring 4). Similarly, "yb" is an impossible encoding because it has +/// its 10th bit set (b = 00001), but a base32z encoded value should have all 0's beyond the 8th (or +/// 16th or 24th or ... bit). We treat any such bits as if they were not specified (even if they +/// are): which means "yy", "yb", "yyy", "yy9", "yd", etc. all decode to the same 1-byte value "\0". +template +void from_base32z(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "from_base32z requires chars/bytes"); + uint_fast16_t curr = 0; + int bits = 0; // number of bits we've loaded into val; we always keep this < 8. + while (begin != end) { + curr = curr << 5 | detail::b32z_lut.from_b32z(static_cast(*begin++)); + if (bits >= 3) { + bits -= 3; // Added 5, removing 8 + *out++ = static_cast>( + static_cast(curr >> bits)); + curr &= (1 << bits) - 1; + } else { + bits += 5; + } + } + + // Ignore any trailing bits. base32z encoding always has at least as many bits as the source + // bytes, which means we should not be able to get here from a properly encoded b32z value with + // anything other than 0s: if we have no extra bits (e.g. 5 bytes == 8 b32z chars) then we have + // a 0-bit value; if we have some extra bits (e.g. 6 bytes requires 10 b32z chars, but that + // contains 50 bits > 48 bits) then those extra bits will be 0s (and this covers the bits -= 3 + // case above: it'll leave us with 0-4 extra bits, but those extra bits would be 0 if produced + // from an actual byte sequence). + // + // The "bits += 5" case, then, means that we could end with 5-7 bits. This, however, cannot be + // produced by a valid encoding: + // - 0 bytes gives us 0 chars with 0 leftover bits + // - 1 byte gives us 2 chars with 2 leftover bits + // - 2 bytes gives us 4 chars with 4 leftover bits + // - 3 bytes gives us 5 chars with 1 leftover bit + // - 4 bytes gives us 7 chars with 3 leftover bits + // - 5 bytes gives us 8 chars with 0 leftover bits (this is where the cycle repeats) + // + // So really the only way we can get 5-7 leftover bits is if you took a 0, 2 or 5 char output (or + // any 8n + {0,2,5} char output) and added a base32z character to the end. If you do that, + // well, too bad: you're giving invalid output and so we're just going to pretend that extra + // character you added isn't there by not doing anything here. +} + +/// Convert a base32z sequence into a std::string of bytes. Undefined behaviour if any characters +/// are not valid (case-insensitive) base32z characters. +template +std::string from_base32z(It begin, It end) { + std::string bytes; + if constexpr (std::is_base_of_v::iterator_category>) + bytes.reserve((std::distance(begin, end)*5 + 7) / 8); // == chars*5/8, rounded up. + from_base32z(begin, end, std::back_inserter(bytes)); + return bytes; +} + +/// Converts base32z digits from a std::string-like object into a std::string of bytes. Undefined +/// behaviour if any characters are not valid (case-insensitive) base32z characters. +template +std::string from_base32z(std::basic_string_view s) { return from_base32z(s.begin(), s.end()); } +inline std::string from_base32z(std::string_view s) { return from_base32z<>(s); } + +} diff --git a/oxenmq/base64.h b/oxenmq/base64.h new file mode 100644 index 0000000..de703ae --- /dev/null +++ b/oxenmq/base64.h @@ -0,0 +1,221 @@ +// Copyright (c) 2019-2021, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#include +#include +#include +#include +#include +#include "byte_type.h" + +namespace oxenmq { + +namespace detail { + +/// Compile-time generated lookup tables for base64 conversion. +struct b64_table { + // Store the 0-63 decoded value of every possible char; all the chars that aren't valid are set + // to 0. (If you don't trust your data, check it with is_base64 first, which uses these 0's + // to detect invalid characters -- which is why we want a full 256 element array). + char from_b64_lut[256]; + // Store the encoded character of every 0-63 (6 bit) value. + char to_b64_lut[64]; + + // constexpr constructor that fills out the above (and should do it at compile time for any half + // decent compiler). + constexpr b64_table() noexcept : from_b64_lut{}, to_b64_lut{} { + for (unsigned char c = 0; c < 26; c++) { + from_b64_lut[(unsigned char)('A' + c)] = 0 + c; + to_b64_lut[ (unsigned char)( 0 + c)] = 'A' + c; + } + for (unsigned char c = 0; c < 26; c++) { + from_b64_lut[(unsigned char)('a' + c)] = 26 + c; + to_b64_lut[ (unsigned char)(26 + c)] = 'a' + c; + } + for (unsigned char c = 0; c < 10; c++) { + from_b64_lut[(unsigned char)('0' + c)] = 52 + c; + to_b64_lut[ (unsigned char)(52 + c)] = '0' + c; + } + to_b64_lut[62] = '+'; from_b64_lut[(unsigned char) '+'] = 62; + to_b64_lut[63] = '/'; from_b64_lut[(unsigned char) '/'] = 63; + } + // Convert a b64 encoded character into a 0-63 value + constexpr char from_b64(unsigned char c) const noexcept { return from_b64_lut[c]; } + // Convert a 0-31 value into a b64 encoded character + constexpr char to_b64(unsigned char b) const noexcept { return to_b64_lut[b]; } +} constexpr b64_lut; + +// This main point of this static assert is to force the compiler to compile-time build the constexpr tables. +static_assert(b64_lut.from_b64('/') == 63 && b64_lut.from_b64('7') == 59 && b64_lut.to_b64(38) == 'm', ""); + +} // namespace detail + +/// Converts bytes into a base64 encoded character sequence. +template +void to_base64(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "to_base64 requires chars/bytes"); + int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in {0, 2, 4} + std::uint_fast16_t r = 0; + while (begin != end) { + r = r << 8 | static_cast(*begin++); + + // we just added 8 bits, so we can *always* consume 6 to produce one character, so (net) we + // are adding 2 bits. + bits += 2; + *out++ = detail::b64_lut.to_b64(r >> bits); // Right-shift off the bits we aren't consuming right now + + // Drop the bits we don't want to keep (because we just consumed them) + r &= (1 << bits) - 1; + + if (bits == 6) { // We have enough bits to produce a second character (which means we had 4 before and added 8) + bits = 0; + *out++ = detail::b64_lut.to_b64(r); + r = 0; + } + } + + // If bits == 0 then we ended our 6-bit outputs coinciding with 8-bit values, i.e. at a multiple + // of 24 bits: this means we don't have anything else to output and don't need any padding. + if (bits == 2) { + // We finished with 2 unconsumed bits, which means we ended 1 byte past a 24-bit group (e.g. + // 1 byte, 4 bytes, 301 bytes, etc.); since we need to always be a multiple of 4 output + // characters that means we've produced 1: so we right-fill 0s to get the next char, then + // add two padding ='s. + *out++ = detail::b64_lut.to_b64(r << 4); + *out++ = '='; + *out++ = '='; + } else if (bits == 4) { + // 4 bits left means we produced 2 6-bit values from the first 2 bytes of a 3-byte group. + // Fill 0s to get the last one, plus one padding output. + *out++ = detail::b64_lut.to_b64(r << 2); + *out++ = '='; + } +} + +/// Creates and returns a base64 string from an iterator pair of a character sequence +template +std::string to_base64(It begin, It end) { + std::string base64; + if constexpr (std::is_base_of_v::iterator_category>) + base64.reserve((std::distance(begin, end) + 2) / 3 * 4); // bytes*4/3, rounded up to the next multiple of 4 + to_base64(begin, end, std::back_inserter(base64)); + return base64; +} + +/// Creates a base64 string from an iterable, std::string-like object +template +std::string to_base64(std::basic_string_view s) { return to_base64(s.begin(), s.end()); } +inline std::string to_base64(std::string_view s) { return to_base64<>(s); } + +/// Returns true if the range is a base64 encoded value; we allow (but do not require) '=' padding, +/// but only at the end, only 1 or 2, and only if it pads out the total to a multiple of 4. +template +constexpr bool is_base64(It begin, It end) { + static_assert(sizeof(decltype(*begin)) == 1, "is_base64 requires chars/bytes"); + using std::distance; + using std::prev; + + // Allow 1 or 2 padding chars *if* they pad it to a multiple of 4. + if (begin != end && distance(begin, end) % 4 == 0) { + auto last = prev(end); + if (static_cast(*last) == '=') + end = last--; + if (static_cast(*last) == '=') + end = last; + } + + for (; begin != end; ++begin) { + auto c = static_cast(*begin); + if (detail::b64_lut.from_b64(c) == 0 && c != 'A') + return false; + } + return true; +} + +/// Returns true if the string-like value is a base64 encoded value +template +constexpr bool is_base64(std::basic_string_view s) { return is_base64(s.begin(), s.end()); } +constexpr bool is_base64(std::string_view s) { return is_base64(s.begin(), s.end()); } + +/// Converts a sequence of base64 digits to bytes. Undefined behaviour if any characters are not +/// valid base64 alphabet characters. It is permitted for the input and output ranges to overlap as +/// long as `out` is no earlier than `begin`. Trailing padding characters are permitted but not +/// required. +/// +/// It is possible to provide "impossible" base64 encoded values; for example "YWJja" which has 30 +/// bits of data even though a base64 encoded byte string should have 24 (4 chars) or 36 (6 chars) +/// bits for a 3- and 4-byte input, respectively. We ignore any such "impossible" bits, and +/// similarly ignore impossible bits in the bit "overhang"; that means "YWJjZA==" (the proper +/// encoding of "abcd") and "YWJjZB", "YWJjZC", ..., "YWJjZP" all decode to the same "abcd" value: +/// the last 4 bits of the last character are essentially considered padding. +template +void from_base64(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "from_base64 requires chars/bytes"); + uint_fast16_t curr = 0; + int bits = 0; // number of bits we've loaded into val; we always keep this < 8. + while (begin != end) { + auto c = static_cast(*begin++); + + // padding; don't bother checking if we're at the end because is_base64 is a precondition + // and we're allowed UB if it isn't satisfied. + if (c == '=') continue; + + curr = curr << 6 | detail::b64_lut.from_b64(c); + if (bits == 0) + bits = 6; + else { + bits -= 2; // Added 6, removing 8 + *out++ = static_cast>( + static_cast(curr >> bits)); + curr &= (1 << bits) - 1; + } + } + // Don't worry about leftover bits because either they have to be 0, or they can't happen at + // all. See base32z.h for why: the reasoning is exactly the same (except using 6 bits per + // character here instead of 5). +} + +/// Converts base64 digits from a iterator pair of characters into a std::string of bytes. +/// Undefined behaviour if any characters are not valid base64 characters. +template +std::string from_base64(It begin, It end) { + std::string bytes; + if constexpr (std::is_base_of_v::iterator_category>) + bytes.reserve(std::distance(begin, end)*6 / 8); // each digit carries 6 bits; this may overallocate by 1-2 bytes due to padding + from_base64(begin, end, std::back_inserter(bytes)); + return bytes; +} + +/// Converts base64 digits from a std::string-like object into a std::string of bytes. Undefined +/// behaviour if any characters are not valid base64 characters. +template +std::string from_base64(std::basic_string_view s) { return from_base64(s.begin(), s.end()); } +inline std::string from_base64(std::string_view s) { return from_base64<>(s); } + +} diff --git a/oxenmq/batch.h b/oxenmq/batch.h new file mode 100644 index 0000000..0db15ea --- /dev/null +++ b/oxenmq/batch.h @@ -0,0 +1,279 @@ +// Copyright (c) 2020-2021, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#include +#include +#include +#include "oxenmq.h" + +namespace oxenmq { + +namespace detail { + +enum class BatchState { + running, // there are still jobs to run (or running) + complete, // the batch is complete but still has a completion job to call + done // the batch is complete and has no completion function +}; + +struct BatchStatus { + BatchState state; + int thread; +}; + +// Virtual base class for Batch +class Batch { +public: + // Returns the number of jobs in this batch and whether any of them are thread-specific + virtual std::pair size() const = 0; + // Returns a vector of exactly the same length of size().first containing the tagged thread ids + // of the batch jobs or 0 for general jobs. + virtual std::vector threads() const = 0; + // Called in a worker thread to run the job + virtual void run_job(int i) = 0; + // Called in the main proxy thread when the worker returns from finishing a job. The return + // value tells us whether the current finishing job finishes off the batch: `running` to tell us + // there are more jobs; `complete` to tell us that the jobs are done but the completion function + // needs to be called; and `done` to signal that the jobs are done and there is no completion + // function. + virtual BatchStatus job_finished() = 0; + // Called by a worker; not scheduled until all jobs are done. + virtual void job_completion() = 0; + + virtual ~Batch() = default; +}; + +} + +/** + * Simple class that can either hold a result or an exception and retrieves the result (or raises + * the exception) via a .get() method. + * + * This is designed to be like a very stripped down version of a std::promise/std::future pair. We + * reimplemented it, however, because by ditching all the thread synchronization that promise/future + * guarantees we can substantially reduce call overhead (by a factor of ~8 according to benchmarking + * code). Since OxenMQ's proxy<->worker communication channel already gives us thread that overhead + * would just be wasted. + * + * @tparam R the value type held by the result; must be default constructible. Note, however, that + * there are specializations provided for lvalue references types and `void` (which obviously don't + * satisfy this). + */ +template +class job_result { + R value; + std::exception_ptr exc; + +public: + /// Sets the value. Should be called only once, or not at all if set_exception was called. + void set_value(R&& v) { value = std::move(v); } + + /// Sets the exception, which will be rethrown when `get()` is called. Should be called + /// only once, or not at all if set_value() was called. + void set_exception(std::exception_ptr e) { exc = std::move(e); } + + /// Retrieves the value. If an exception was set instead of a value then that exception is + /// thrown instead. Note that the interval value is moved out of the held value so you should + /// not call this multiple times. + R get() { + if (exc) std::rethrow_exception(exc); + return std::move(value); + } +}; + +/** job_result specialization for reference types */ +template +class job_result::value>> { + std::remove_reference_t* value_ptr; + std::exception_ptr exc; + +public: + void set_value(R v) { value_ptr = &v; } + void set_exception(std::exception_ptr e) { exc = std::move(e); } + R get() { + if (exc) std::rethrow_exception(exc); + return *value_ptr; + } +}; + +/** job_result specialization for void; there is no value, but exceptions are still captured + * (rethrown when `get()` is called). + */ +template<> +class job_result { + std::exception_ptr exc; + +public: + void set_exception(std::exception_ptr e) { exc = std::move(e); } + // Returns nothing, but rethrows if there is a captured exception. + void get() { if (exc) std::rethrow_exception(exc); } +}; + +/// Helper class used to set up batches of jobs to be scheduled via the oxenmq job handler. +/// +/// @tparam R - the return type of the individual jobs +/// +template +class Batch final : private detail::Batch { + friend class OxenMQ; +public: + /// The completion function type, called after all jobs have finished. + using CompletionFunc = std::function> results)>; + + // Default constructor + Batch() = default; + + // movable + Batch(Batch&&) = default; + Batch &operator=(Batch&&) = default; + + // non-copyable + Batch(const Batch&) = delete; + Batch &operator=(const Batch&) = delete; + +private: + std::vector, int>> jobs; + std::vector> results; + CompletionFunc complete; + std::size_t jobs_outstanding = 0; + int complete_in_thread = 0; + bool started = false; + bool tagged_thread_jobs = false; + + void check_not_started() { + if (started) + throw std::logic_error("Cannot add jobs or completion function after starting a oxenmq::Batch!"); + } + +public: + /// Preallocates space in the internal vector that stores jobs. + void reserve(std::size_t num) { + jobs.reserve(num); + results.reserve(num); + } + + /// Adds a job. This takes any callable object that is invoked with no arguments and returns R + /// (the Batch return type). The tasks will be scheduled and run when the next worker thread is + /// available. The called function may throw exceptions (which will be propagated to the + /// completion function through the job_result values). There is no guarantee on the order of + /// invocation of the jobs. + /// + /// \param job the callback + /// \param thread an optional TaggedThreadID indicating a thread in which this job must run + void add_job(std::function job, std::optional thread = std::nullopt) { + check_not_started(); + if (thread && thread->_id == -1) + // There are some special case internal jobs where we allow this, but they use the + // private method below that doesn't have this check. + throw std::logic_error{"Cannot add a proxy thread batch job -- this makes no sense"}; + add_job(std::move(job), thread ? thread->_id : 0); + } + + /// Sets the completion function to invoke after all jobs have finished. If this is not set + /// then jobs simply run and results are discarded. + /// + /// \param comp - function to call when all jobs have finished + /// \param thread - optional tagged thread in which to schedule the completion job. If not + /// provided then the completion job is scheduled in the pool of batch job threads. + /// + /// `thread` can be provided the value &OxenMQ::run_in_proxy to invoke the completion function + /// *IN THE PROXY THREAD* itself after all jobs have finished. Be very, very careful: this + /// should be a nearly trivial job that does not require any substantial CPU time and does not + /// block for any reason. This is only intended for the case where the completion job is so + /// trivial that it will take less time than simply queuing the job to be executed by another + /// thread. + void completion(CompletionFunc comp, std::optional thread = std::nullopt) { + check_not_started(); + if (complete) + throw std::logic_error("Completion function can only be set once"); + complete = std::move(comp); + complete_in_thread = thread ? thread->_id : 0; + } + +private: + + void add_job(std::function job, int thread_id) { + jobs.emplace_back(std::move(job), thread_id); + results.emplace_back(); + jobs_outstanding++; + if (thread_id != 0) + tagged_thread_jobs = true; + } + + std::pair size() const override { + return {jobs.size(), tagged_thread_jobs}; + } + + std::vector threads() const override { + std::vector t; + t.reserve(jobs.size()); + for (auto& j : jobs) + t.push_back(j.second); + return t; + }; + + template + void set_value(job_result& r, std::function& f) { r.set_value(f()); } + void set_value(job_result&, std::function& f) { f(); } + + void run_job(const int i) override { + // called by worker thread + auto& r = results[i]; + try { + set_value(r, jobs[i].first); + } catch (...) { + r.set_exception(std::current_exception()); + } + } + + detail::BatchStatus job_finished() override { + --jobs_outstanding; + if (jobs_outstanding) + return {detail::BatchState::running, 0}; + if (complete) + return {detail::BatchState::complete, complete_in_thread}; + return {detail::BatchState::done, 0}; + } + + void job_completion() override { + return complete(std::move(results)); + } +}; + + +template +void OxenMQ::batch(Batch&& batch) { + if (batch.size().first == 0) + throw std::logic_error("Cannot batch a a job batch with 0 jobs"); + // Need to send this over to the proxy thread via the base class pointer. It assumes ownership. + auto* baseptr = static_cast(new Batch(std::move(batch))); + detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast(baseptr))); +} + +} diff --git a/lokimq/bt_serialize.cpp b/oxenmq/bt_serialize.cpp similarity index 99% rename from lokimq/bt_serialize.cpp rename to oxenmq/bt_serialize.cpp index 0619401..e67b9bc 100644 --- a/lokimq/bt_serialize.cpp +++ b/oxenmq/bt_serialize.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2020, The Loki Project +// Copyright (c) 2019-2021, The Oxen Project // // All rights reserved. // @@ -29,7 +29,7 @@ #include "bt_serialize.h" #include -namespace lokimq { +namespace oxenmq { namespace detail { /// Reads digits into an unsigned 64-bit int. @@ -228,4 +228,4 @@ std::pair bt_dict_consumer::next_string() { } -} // namespace lokimq +} // namespace oxenmq diff --git a/oxenmq/bt_serialize.h b/oxenmq/bt_serialize.h new file mode 100644 index 0000000..6767e79 --- /dev/null +++ b/oxenmq/bt_serialize.h @@ -0,0 +1,915 @@ +// Copyright (c) 2019-2020, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "variant.h" +#include +#include +#include +#include +#include +#include +#include + +#include "bt_value.h" + +namespace oxenmq { + +using namespace std::literals; + +/** \file + * OxenMQ serialization for internal commands is very simple: we support two primitive types, + * strings and integers, and two container types, lists and dicts with string keys. On the wire + * these go in BitTorrent byte encoding as described in BEP-0003 + * (https://www.bittorrent.org/beps/bep_0003.html#bencoding). + * + * On the C++ side, on input we allow strings, integral types, STL-like containers of these types, + * and STL-like containers of pairs with a string first value and any of these types as second + * value. We also accept std::variants of these. + * + * One minor deviation from BEP-0003 is that we don't support serializing values that don't fit in a + * 64-bit integer (BEP-0003 specifies arbitrary precision integers). + * + * On deserialization we can either deserialize into a special bt_value type supports everything + * (with arbitrary nesting), or we can fill a container of your given type (though this fails if the + * container isn't compatible with the deserialized data). + * + * There is also a stream deserialization that allows you to deserialize without needing heap + * allocations (as long as you know the precise data structure layout). + */ + +/// Exception throw if deserialization fails +class bt_deserialize_invalid : public std::invalid_argument { + using std::invalid_argument::invalid_argument; +}; + +/// A more specific subclass that is thown if the serialization type is an initial mismatch: for +/// example, trying deserializing an int but the next thing in input is a list. This is not, +/// however, thrown if the type initially looks fine but, say, a nested serialization fails. This +/// error will only be thrown when the input stream has not been advanced (and so can be tried for a +/// different type). +class bt_deserialize_invalid_type : public bt_deserialize_invalid { + using bt_deserialize_invalid::bt_deserialize_invalid; +}; + +namespace detail { + +/// Reads digits into an unsigned 64-bit int. +uint64_t extract_unsigned(std::string_view& s); +// (Provide non-constant lvalue and rvalue ref functions so that we only accept explicit +// string_views but not implicitly converted ones) +inline uint64_t extract_unsigned(std::string_view&& s) { return extract_unsigned(s); } + +// Fallback base case; we only get here if none of the partial specializations below work +template +struct bt_serialize { static_assert(!std::is_same_v, "Cannot serialize T: unsupported type for bt serialization"); }; + +template +struct bt_deserialize { static_assert(!std::is_same_v, "Cannot deserialize T: unsupported type for bt deserialization"); }; + +/// Checks that we aren't at the end of a string view and throws if we are. +inline void bt_need_more(const std::string_view &s) { + if (s.empty()) + throw bt_deserialize_invalid{"Unexpected end of string while deserializing"}; +} + +/// Deserializes a signed or unsigned 64-bit integer from a string. Sets the second bool to true +/// iff the value read was negative, false if positive; in either case the unsigned value is return +/// in .first. Throws an exception if the read value doesn't fit in a int64_t (if negative) or a +/// uint64_t (if positive). Removes consumed characters from the string_view. +std::pair bt_deserialize_integer(std::string_view& s); + +/// Integer specializations +template +struct bt_serialize>> { + static_assert(sizeof(T) <= sizeof(uint64_t), "Serialization of integers larger than uint64_t is not supported"); + void operator()(std::ostream &os, const T &val) { + // Cast 1-byte types to a larger type to avoid iostream interpreting them as single characters + using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t, int, unsigned>>; + os << 'i' << static_cast(val) << 'e'; + } +}; + +template +struct bt_deserialize>> { + void operator()(std::string_view& s, T &val) { + constexpr uint64_t umax = static_cast(std::numeric_limits::max()); + constexpr int64_t smin = static_cast(std::numeric_limits::min()); + + auto [magnitude, negative] = bt_deserialize_integer(s); + + if (std::is_signed_v) { + if (!negative) { + if (magnitude > umax) + throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax)); + val = static_cast(magnitude); + } else { + auto sval = -static_cast(magnitude); + if (!std::is_same_v && sval < smin) + throw bt_deserialize_invalid("Integer deserialization failed: found too-low value " + std::to_string(sval) + " < " + std::to_string(smin)); + val = static_cast(sval); + } + } else { + if (negative) + throw bt_deserialize_invalid("Integer deserialization failed: found negative value -" + std::to_string(magnitude) + " but type is unsigned"); + if (!std::is_same_v && magnitude > umax) + throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax)); + val = static_cast(magnitude); + } + } +}; + +extern template struct bt_deserialize; +extern template struct bt_deserialize; + +template <> +struct bt_serialize { + void operator()(std::ostream &os, const std::string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); } +}; +template <> +struct bt_deserialize { + void operator()(std::string_view& s, std::string_view& val); +}; + +/// String specialization +template <> +struct bt_serialize { + void operator()(std::ostream &os, const std::string &val) { bt_serialize{}(os, val); } +}; +template <> +struct bt_deserialize { + void operator()(std::string_view& s, std::string& val) { std::string_view view; bt_deserialize{}(s, view); val = {view.data(), view.size()}; } +}; + +/// char * and string literals -- we allow serialization for convenience, but not deserialization +template <> +struct bt_serialize { + void operator()(std::ostream &os, const char *str) { bt_serialize{}(os, {str, std::strlen(str)}); } +}; +template +struct bt_serialize { + void operator()(std::ostream &os, const char *str) { bt_serialize{}(os, {str, N-1}); } +}; + +/// Partial dict validity; we don't check the second type for serializability, that will be handled +/// via the base case static_assert if invalid. +template struct is_bt_input_dict_container_impl : std::false_type {}; +template +struct is_bt_input_dict_container_impl> || + std::is_same_v>, + std::void_t>> +: std::true_type {}; + +/// Determines whether the type looks like something we can insert into (using `v.insert(v.end(), x)`) +template struct is_bt_insertable_impl : std::false_type {}; +template +struct is_bt_insertable_impl().insert(std::declval().end(), std::declval()))>> +: std::true_type {}; +template +constexpr bool is_bt_insertable = is_bt_insertable_impl::value; + +/// Determines whether the given type looks like a compatible map (i.e. has std::string keys) that +/// we can insert into. +template struct is_bt_output_dict_container_impl : std::false_type {}; +template +struct is_bt_output_dict_container_impl> && is_bt_insertable, + std::void_t>> +: std::true_type {}; + +template +constexpr bool is_bt_output_dict_container = is_bt_output_dict_container_impl::value; +template +constexpr bool is_bt_input_dict_container = is_bt_output_dict_container_impl::value; + +// Sanity checks: +static_assert(is_bt_input_dict_container); +static_assert(is_bt_output_dict_container); + +/// Specialization for a dict-like container (such as an unordered_map). We accept anything for a +/// dict that is const iterable over something that looks like a pair with std::string for first +/// value type. The value (i.e. second element of the pair) also must be serializable. +template +struct bt_serialize>> { + using second_type = typename T::value_type::second_type; + using ref_pair = std::reference_wrapper; + void operator()(std::ostream &os, const T &dict) { + os << 'd'; + std::vector pairs; + pairs.reserve(dict.size()); + for (const auto &pair : dict) + pairs.emplace(pairs.end(), pair); + std::sort(pairs.begin(), pairs.end(), [](ref_pair a, ref_pair b) { return a.get().first < b.get().first; }); + for (auto &ref : pairs) { + bt_serialize{}(os, ref.get().first); + bt_serialize{}(os, ref.get().second); + } + os << 'e'; + } +}; + +template +struct bt_deserialize>> { + using second_type = typename T::value_type::second_type; + void operator()(std::string_view& s, T& dict) { + // Smallest dict is 2 bytes "de", for an empty dict. + if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where dict expected"); + if (s[0] != 'd') throw bt_deserialize_invalid_type("Deserialization failed: expected 'd', found '"s + s[0] + "'"s); + s.remove_prefix(1); + dict.clear(); + bt_deserialize key_deserializer; + bt_deserialize val_deserializer; + + while (!s.empty() && s[0] != 'e') { + std::string key; + second_type val; + key_deserializer(s, key); + val_deserializer(s, val); + dict.insert(dict.end(), typename T::value_type{std::move(key), std::move(val)}); + } + if (s.empty()) + throw bt_deserialize_invalid("Deserialization failed: encountered end of string before dict was finished"); + s.remove_prefix(1); // Consume the 'e' + } +}; + + +/// Accept anything that looks iterable; value serialization validity isn't checked here (it fails +/// via the base case static assert). +template struct is_bt_input_list_container_impl : std::false_type {}; +template +struct is_bt_input_list_container_impl && !std::is_same_v && !is_bt_input_dict_container, + std::void_t>> +: std::true_type {}; + +template struct is_bt_output_list_container_impl : std::false_type {}; +template +struct is_bt_output_list_container_impl && !is_bt_output_dict_container && is_bt_insertable>> +: std::true_type {}; + +template +constexpr bool is_bt_output_list_container = is_bt_output_list_container_impl::value; +template +constexpr bool is_bt_input_list_container = is_bt_input_list_container_impl::value; + +// Sanity checks: +static_assert(is_bt_input_list_container); +static_assert(is_bt_output_list_container); + +/// List specialization +template +struct bt_serialize>> { + void operator()(std::ostream& os, const T& list) { + os << 'l'; + for (const auto &v : list) + bt_serialize>{}(os, v); + os << 'e'; + } +}; +template +struct bt_deserialize>> { + using value_type = typename T::value_type; + void operator()(std::string_view& s, T& list) { + // Smallest list is 2 bytes "le", for an empty list. + if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where list expected"); + if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization failed: expected 'l', found '"s + s[0] + "'"s); + s.remove_prefix(1); + list.clear(); + bt_deserialize deserializer; + while (!s.empty() && s[0] != 'e') { + value_type v; + deserializer(s, v); + list.insert(list.end(), std::move(v)); + } + if (s.empty()) + throw bt_deserialize_invalid("Deserialization failed: encountered end of string before list was finished"); + s.remove_prefix(1); // Consume the 'e' + } +}; + +/// Serializes a tuple or pair of serializable values (as a list on the wire) + +/// Common implementation for both tuple and pair: +template typename Tuple, typename... T> +struct bt_serialize_tuple { +private: + template + void operator()(std::ostream& os, const Tuple& elems, std::index_sequence) { + os << 'l'; + (bt_serialize{}(os, std::get(elems)), ...); + os << 'e'; + } +public: + void operator()(std::ostream& os, const Tuple& elems) { + operator()(os, elems, std::index_sequence_for{}); + } +}; +template typename Tuple, typename... T> +struct bt_deserialize_tuple { +private: + template + void operator()(std::string_view& s, Tuple& elems, std::index_sequence) { + // Smallest list is 2 bytes "le", for an empty list. + if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where tuple expected"); + if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization of tuple failed: expected 'l', found '"s + s[0] + "'"s); + s.remove_prefix(1); + (bt_deserialize{}(s, std::get(elems)), ...); + if (s.empty()) + throw bt_deserialize_invalid("Deserialization failed: encountered end of string before tuple was finished"); + if (s[0] != 'e') + throw bt_deserialize_invalid("Deserialization failed: expected end of tuple but found something else"); + s.remove_prefix(1); // Consume the 'e' + } +public: + void operator()(std::string_view& s, Tuple& elems) { + operator()(s, elems, std::index_sequence_for{}); + } +}; +template +struct bt_serialize> : bt_serialize_tuple {}; +template +struct bt_deserialize> : bt_deserialize_tuple {}; +template +struct bt_serialize> : bt_serialize_tuple {}; +template +struct bt_deserialize> : bt_deserialize_tuple {}; + +template +constexpr bool is_bt_tuple = false; +template +constexpr bool is_bt_tuple> = true; +template +constexpr bool is_bt_tuple> = true; + + +template +constexpr bool is_bt_deserializable = std::is_same_v || std::is_integral_v || + is_bt_output_dict_container || is_bt_output_list_container || is_bt_tuple; + +// General template and base case; this base will only actually be invoked when Ts... is empty, +// which means we reached the end without finding any variant type capable of holding the value. +template +struct bt_deserialize_try_variant_impl { + void operator()(std::string_view&, Variant&) { + throw bt_deserialize_invalid("Deserialization failed: could not deserialize value into any variant type"); + } +}; + +template +void bt_deserialize_try_variant(std::string_view& s, Variant& variant) { + bt_deserialize_try_variant_impl{}(s, variant); +} + + +template +struct bt_deserialize_try_variant_impl>, Variant, T, Ts...> { + void operator()(std::string_view& s, Variant& variant) { + if ( is_bt_output_list_container ? s[0] == 'l' : + is_bt_tuple ? s[0] == 'l' : + is_bt_output_dict_container ? s[0] == 'd' : + std::is_integral_v ? s[0] == 'i' : + std::is_same_v ? s[0] >= '0' && s[0] <= '9' : + false) { + T val; + bt_deserialize{}(s, val); + variant = std::move(val); + } else { + bt_deserialize_try_variant(s, variant); + } + } +}; + +template +struct bt_deserialize_try_variant_impl>, Variant, T, Ts...> { + void operator()(std::string_view& s, Variant& variant) { + // Unsupported deserialization type, skip it + bt_deserialize_try_variant(s, variant); + } +}; + +// Serialization of a variant; all variant types must be bt-serializable. +template +struct bt_serialize, std::void_t...>> { + void operator()(std::ostream& os, const std::variant& val) { + var::visit( + [&os] (const auto& val) { + using T = std::remove_cv_t>; + bt_serialize{}(os, val); + }, + val); + } +}; + +// Deserialization to a variant; at least one variant type must be bt-deserializble. +template +struct bt_deserialize, std::enable_if_t<(is_bt_deserializable || ...)>> { + void operator()(std::string_view& s, std::variant& val) { + bt_deserialize_try_variant(s, val); + } +}; + +template <> +struct bt_serialize : bt_serialize {}; + +template <> +struct bt_deserialize { + void operator()(std::string_view& s, bt_value& val); +}; + +template +struct bt_stream_serializer { + const T &val; + explicit bt_stream_serializer(const T &val) : val{val} {} + operator std::string() const { + std::ostringstream oss; + oss << *this; + return oss.str(); + } +}; +template +std::ostream &operator<<(std::ostream &os, const bt_stream_serializer &s) { + bt_serialize{}(os, s.val); + return os; +} + +} // namespace detail + + +/// Returns a wrapper around a value reference that can serialize the value directly to an output +/// stream. This class is intended to be used inline (i.e. without being stored) as in: +/// +/// std::list my_list{{1,2,3}}; +/// std::cout << bt_serializer(my_list); +/// +/// While it is possible to store the returned object and use it, such as: +/// +/// auto encoded = bt_serializer(42); +/// std::cout << encoded; +/// +/// this approach is not generally recommended: the returned object stores a reference to the +/// passed-in type, which may not survive. If doing this note that it is the caller's +/// responsibility to ensure the serializer is not used past the end of the lifetime of the value +/// being serialized. +/// +/// Also note that serializing directly to an output stream is more efficient as no intermediate +/// string containing the entire serialization has to be constructed. +/// +template +detail::bt_stream_serializer bt_serializer(const T &val) { return detail::bt_stream_serializer{val}; } + +/// Serializes the given value into a std::string. +/// +/// int number = 42; +/// std::string encoded = bt_serialize(number); +/// // Equivalent: +/// //auto encoded = (std::string) bt_serialize(number); +/// +/// This takes any serializable type: integral types, strings, lists of serializable types, and +/// string->value maps of serializable types. +template +std::string bt_serialize(const T &val) { return bt_serializer(val); } + +/// Deserializes the given string view directly into `val`. Usage: +/// +/// std::string encoded = "i42e"; +/// int value; +/// bt_deserialize(encoded, value); // Sets value to 42 +/// +template , int> = 0> +void bt_deserialize(std::string_view s, T& val) { + return detail::bt_deserialize{}(s, val); +} + + +/// Deserializes the given string_view into a `T`, which is returned. +/// +/// std::string encoded = "li1ei2ei3ee"; // bt-encoded list of ints: [1,2,3] +/// auto mylist = bt_deserialize>(encoded); +/// +template +T bt_deserialize(std::string_view s) { + T val; + bt_deserialize(s, val); + return val; +} + +/// Deserializes the given value into a generic `bt_value` type (wrapped std::variant) which is +/// capable of holding all possible BT-encoded values (including recursion). +/// +/// Example: +/// +/// std::string encoded = "i42e"; +/// auto val = bt_get(encoded); +/// int v = get_int(val); // fails unless the encoded value was actually an integer that +/// // fits into an `int` +/// +inline bt_value bt_get(std::string_view s) { + return bt_deserialize(s); +} + +/// Helper functions to extract a value of some integral type from a bt_value which contains either +/// a int64_t or uint64_t. Does range checking, throwing std::overflow_error if the stored value is +/// outside the range of the target type. +/// +/// Example: +/// +/// std::string encoded = "i123456789e"; +/// auto val = bt_get(encoded); +/// auto v = get_int(val); // throws if the decoded value doesn't fit in a uint32_t +template , int> = 0> +IntType get_int(const bt_value &v) { + if (auto* value = std::get_if(&v)) { + if constexpr (!std::is_same_v) + if (*value > static_cast(std::numeric_limits::max())) + throw std::overflow_error("Unable to extract integer value: stored value is too large for the requested type"); + return static_cast(*value); + } + + int64_t value = var::get(v); // throws if no int contained + if constexpr (!std::is_same_v) + if (value > static_cast(std::numeric_limits::max()) + || value < static_cast(std::numeric_limits::min())) + throw std::overflow_error("Unable to extract integer value: stored value is outside the range of the requested type"); + return static_cast(value); +} + +namespace detail { +template +void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence); +} + +/// Converts a bt_list into the given template std::tuple or std::pair. Throws a +/// std::invalid_argument if the list has the wrong size or wrong element types. Supports recursion +/// (i.e. if the tuple itself contains tuples or pairs). The tuple (or nested tuples) may only +/// contain integral types, strings, string_views, bt_list, bt_dict, and tuples/pairs of those. +template +Tuple get_tuple(const bt_list& x) { + Tuple t; + detail::get_tuple_impl(t, x, std::make_index_sequence>{}); + return t; +} +template +Tuple get_tuple(const bt_value& x) { + return get_tuple(var::get(static_cast(x))); +} + +namespace detail { +template +void get_tuple_impl_one(T& t, It& it) { + const bt_variant& v = *it++; + if constexpr (std::is_integral_v) { + t = oxenmq::get_int(v); + } else if constexpr (is_bt_tuple) { + if (std::holds_alternative(v)) + throw std::invalid_argument{"Unable to convert tuple: cannot create sub-tuple from non-bt_list"}; + t = get_tuple(var::get(v)); + } else if constexpr (std::is_same_v || std::is_same_v) { + // If we request a string/string_view, we might have the other one and need to copy/view it. + if (std::holds_alternative(v)) + t = var::get(v); + else + t = var::get(v); + } else { + t = var::get(v); + } +} +template +void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence) { + if (l.size() != sizeof...(Is)) + throw std::invalid_argument{"Unable to convert tuple: bt_list has wrong size"}; + auto it = l.begin(); + (get_tuple_impl_one(std::get(t), it), ...); +} +} // namespace detail + + + + +/// Class that allows you to walk through a bt-encoded list in memory without copying or allocating +/// memory. It accesses existing memory directly and so the caller must ensure that the referenced +/// memory stays valid for the lifetime of the bt_list_consumer object. +class bt_list_consumer { +protected: + std::string_view data; + bt_list_consumer() = default; +public: + bt_list_consumer(std::string_view data_); + + /// Copy constructor. Making a copy copies the current position so can be used for multipass + /// iteration through a list. + bt_list_consumer(const bt_list_consumer&) = default; + bt_list_consumer& operator=(const bt_list_consumer&) = default; + + /// Get a copy of the current buffer + std::string_view current_buffer() const { return data; } + + /// Returns true if the next value indicates the end of the list + bool is_finished() const { return data.front() == 'e'; } + /// Returns true if the next element looks like an encoded string + bool is_string() const { return data.front() >= '0' && data.front() <= '9'; } + /// Returns true if the next element looks like an encoded integer + bool is_integer() const { return data.front() == 'i'; } + /// Returns true if the next element looks like an encoded negative integer + bool is_negative_integer() const { return is_integer() && data.size() >= 2 && data[1] == '-'; } + /// Returns true if the next element looks like an encoded list + bool is_list() const { return data.front() == 'l'; } + /// Returns true if the next element looks like an encoded dict + bool is_dict() const { return data.front() == 'd'; } + + /// Attempt to parse the next value as a string (and advance just past it). Throws if the next + /// value is not a string. + std::string consume_string(); + std::string_view consume_string_view(); + + /// Attempts to parse the next value as an integer (and advance just past it). Throws if the + /// next value is not an integer. + template + IntType consume_integer() { + if (!is_integer()) throw bt_deserialize_invalid_type{"next value is not an integer"}; + std::string_view next{data}; + IntType ret; + detail::bt_deserialize{}(next, ret); + data = next; + return ret; + } + + /// Consumes a list, return it as a list-like type. Can also be used for tuples/pairs. This + /// typically requires dynamic allocation, but only has to parse the data once. Compare with + /// consume_list_data() which allows alloc-free traversal, but requires parsing twice (if the + /// contents are to be used). + template + T consume_list() { + T list; + consume_list(list); + return list; + } + + /// Same as above, but takes a pre-existing list-like data type. + template + void consume_list(T& list) { + if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"}; + std::string_view n{data}; + detail::bt_deserialize{}(n, list); + data = n; + } + + /// Consumes a dict, return it as a dict-like type. This typically requires dynamic allocation, + /// but only has to parse the data once. Compare with consume_dict_data() which allows + /// alloc-free traversal, but requires parsing twice (if the contents are to be used). + template + T consume_dict() { + T dict; + consume_dict(dict); + return dict; + } + + /// Same as above, but takes a pre-existing dict-like data type. + template + void consume_dict(T& dict) { + if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"}; + std::string_view n{data}; + detail::bt_deserialize{}(n, dict); + data = n; + } + + /// Consumes a value without returning it. + void skip_value(); + + /// Attempts to parse the next value as a list and returns the string_view that contains the + /// entire thing. This is recursive into both lists and dicts and likely to be quite + /// inefficient for large, nested structures (unless the values only need to be skipped but + /// aren't separately needed). This, however, does not require dynamic memory allocation. + std::string_view consume_list_data(); + + /// Attempts to parse the next value as a dict and returns the string_view that contains the + /// entire thing. This is recursive into both lists and dicts and likely to be quite + /// inefficient for large, nested structures (unless the values only need to be skipped but + /// aren't separately needed). This, however, does not require dynamic memory allocation. + std::string_view consume_dict_data(); +}; + + +/// Class that allows you to walk through key-value pairs of a bt-encoded dict in memory without +/// copying or allocating memory. It accesses existing memory directly and so the caller must +/// ensure that the referenced memory stays valid for the lifetime of the bt_dict_consumer object. +class bt_dict_consumer : private bt_list_consumer { + std::string_view key_; + + /// Consume the key if not already consumed and there is a key present (rather than 'e'). + /// Throws exception if what should be a key isn't a string, or if the key consumes the entire + /// data (i.e. requires that it be followed by something). Returns true if the key was consumed + /// (either now or previously and cached). + bool consume_key(); + + /// Clears the cached key and returns it. Must have already called consume_key directly or + /// indirectly via one of the `is_{...}` methods. + std::string_view flush_key() { + std::string_view k; + k.swap(key_); + return k; + } + +public: + bt_dict_consumer(std::string_view data_); + + /// Copy constructor. Making a copy copies the current position so can be used for multipass + /// iteration through a list. + bt_dict_consumer(const bt_dict_consumer&) = default; + bt_dict_consumer& operator=(const bt_dict_consumer&) = default; + + /// Returns true if the next value indicates the end of the dict + bool is_finished() { return !consume_key() && data.front() == 'e'; } + /// Operator bool is an alias for `!is_finished()` + operator bool() { return !is_finished(); } + /// Returns true if the next value looks like an encoded string + bool is_string() { return consume_key() && data.front() >= '0' && data.front() <= '9'; } + /// Returns true if the next element looks like an encoded integer + bool is_integer() { return consume_key() && data.front() == 'i'; } + /// Returns true if the next element looks like an encoded negative integer + bool is_negative_integer() { return is_integer() && data.size() >= 2 && data[1] == '-'; } + /// Returns true if the next element looks like an encoded list + bool is_list() { return consume_key() && data.front() == 'l'; } + /// Returns true if the next element looks like an encoded dict + bool is_dict() { return consume_key() && data.front() == 'd'; } + /// Returns the key of the next pair. This does not have to be called; it is also returned by + /// all of the other consume_* methods. The value is cached whether called here or by some + /// other method; accessing it multiple times simple accesses the cache until the next value is + /// consumed. + std::string_view key() { + if (!consume_key()) + throw bt_deserialize_invalid{"Cannot access next key: at the end of the dict"}; + return key_; + } + + /// Attempt to parse the next value as a string->string pair (and advance just past it). Throws + /// if the next value is not a string. + std::pair next_string(); + + /// Attempts to parse the next value as an string->integer pair (and advance just past it). + /// Throws if the next value is not an integer. + template + std::pair next_integer() { + if (!is_integer()) throw bt_deserialize_invalid_type{"next bt dict value is not an integer"}; + std::pair ret; + ret.second = bt_list_consumer::consume_integer(); + ret.first = flush_key(); + return ret; + } + + /// Consumes a string->list pair, return it as a list-like type. This typically requires + /// dynamic allocation, but only has to parse the data once. Compare with consume_list_data() + /// which allows alloc-free traversal, but requires parsing twice (if the contents are to be + /// used). + template + std::pair next_list() { + std::pair pair; + pair.first = next_list(pair.second); + return pair; + } + + /// Same as above, but takes a pre-existing list-like data type. Returns the key. + template + std::string_view next_list(T& list) { + if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"}; + bt_list_consumer::consume_list(list); + return flush_key(); + } + + /// Consumes a string->dict pair, return it as a dict-like type. This typically requires + /// dynamic allocation, but only has to parse the data once. Compare with consume_dict_data() + /// which allows alloc-free traversal, but requires parsing twice (if the contents are to be + /// used). + template + std::pair next_dict() { + std::pair pair; + pair.first = consume_dict(pair.second); + return pair; + } + + /// Same as above, but takes a pre-existing dict-like data type. Returns the key. + template + std::string_view next_dict(T& dict) { + if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"}; + bt_list_consumer::consume_dict(dict); + return flush_key(); + } + + /// Attempts to parse the next value as a string->list pair and returns the string_view that + /// contains the entire thing. This is recursive into both lists and dicts and likely to be + /// quite inefficient for large, nested structures (unless the values only need to be skipped + /// but aren't separately needed). This, however, does not require dynamic memory allocation. + std::pair next_list_data() { + if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt dict value is not a list"}; + return {flush_key(), bt_list_consumer::consume_list_data()}; + } + + /// Same as next_list_data(), but wraps the value in a bt_list_consumer for convenience + std::pair next_list_consumer() { return next_list_data(); } + + /// Attempts to parse the next value as a string->dict pair and returns the string_view that + /// contains the entire thing. This is recursive into both lists and dicts and likely to be + /// quite inefficient for large, nested structures (unless the values only need to be skipped + /// but aren't separately needed). This, however, does not require dynamic memory allocation. + std::pair next_dict_data() { + if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt dict value is not a dict"}; + return {flush_key(), bt_list_consumer::consume_dict_data()}; + } + + /// Same as next_dict_data(), but wraps the value in a bt_dict_consumer for convenience + std::pair next_dict_consumer() { return next_dict_data(); } + + /// Skips ahead until we find the first key >= the given key or reach the end of the dict. + /// Returns true if we found an exact match, false if we reached some greater value or the end. + /// If we didn't hit the end, the next `consumer_*()` call will return the key-value pair we + /// found (either the exact match or the first key greater than the requested key). + /// + /// Two important notes: + /// + /// - properly encoded bt dicts must have lexicographically sorted keys, and this method assumes + /// that the input is correctly sorted (and thus if we find a greater value then your key does + /// not exist). + /// - this is irreversible; you cannot returned to skipped values without reparsing. (You *can* + /// however, make a copy of the bt_dict_consumer before calling and use the copy to return to + /// the pre-skipped position). + bool skip_until(std::string_view find) { + while (consume_key() && key_ < find) { + flush_key(); + skip_value(); + } + return key_ == find; + } + + /// The `consume_*` functions are wrappers around next_whatever that discard the returned key. + /// + /// Intended for use with skip_until such as: + /// + /// std::string value; + /// if (d.skip_until("key")) + /// value = d.consume_string(); + /// + + auto consume_string_view() { return next_string().second; } + auto consume_string() { return std::string{consume_string_view()}; } + + template + auto consume_integer() { return next_integer().second; } + + template + auto consume_list() { return next_list().second; } + + template + void consume_list(T& list) { next_list(list); } + + template + auto consume_dict() { return next_dict().second; } + + template + void consume_dict(T& dict) { next_dict(dict); } + + std::string_view consume_list_data() { return next_list_data().second; } + std::string_view consume_dict_data() { return next_dict_data().second; } + + bt_list_consumer consume_list_consumer() { return consume_list_data(); } + bt_dict_consumer consume_dict_consumer() { return consume_dict_data(); } +}; + + +} // namespace oxenmq diff --git a/oxenmq/bt_value.h b/oxenmq/bt_value.h new file mode 100644 index 0000000..7ad7579 --- /dev/null +++ b/oxenmq/bt_value.h @@ -0,0 +1,112 @@ +// Copyright (c) 2019-2021, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +// This header is here to provide just the basic bt_value/bt_dict/bt_list definitions without +// needing to include the full bt_serialize.h header. + +#include +#include +#include +#include +#include +#include + +namespace oxenmq { + +struct bt_value; + +/// The type used to store dictionaries inside bt_value. +using bt_dict = std::map; // NB: unordered_map doesn't work because it can't be used with a predeclared type +/// The type used to store list items inside bt_value. +using bt_list = std::list; + +/// The basic variant that can hold anything (recursively). +using bt_variant = std::variant< + std::string, + std::string_view, + int64_t, + uint64_t, + bt_list, + bt_dict +>; + +#ifdef __cpp_lib_remove_cvref // C++20 +using std::remove_cvref_t; +#else +template +using remove_cvref_t = std::remove_cv_t>; +#endif + +template +struct has_alternative; +template +struct has_alternative> : std::bool_constant<(std::is_same_v || ...)> {}; +template +constexpr bool has_alternative_v = has_alternative::value; + +namespace detail { + template + bt_list tuple_to_list(const Tuple& tuple, std::index_sequence) { + return {{bt_value{std::get(tuple)}...}}; + } + template constexpr bool is_tuple = false; + template constexpr bool is_tuple> = true; + template constexpr bool is_tuple> = true; +} + +/// Recursive generic type that can fully represent everything valid for a BT serialization. +/// This is basically just an empty wrapper around the std::variant, except we add some extra +/// converting constructors: +/// - integer constructors so that any unsigned value goes to the uint64_t and any signed value goes +/// to the int64_t. +/// - std::tuple and std::pair constructors that build a bt_list out of the tuple/pair elements. +struct bt_value : bt_variant { + using bt_variant::bt_variant; + using bt_variant::operator=; + + template , std::enable_if_t && std::is_unsigned_v, int> = 0> + bt_value(T&& uint) : bt_variant{static_cast(uint)} {} + + template , std::enable_if_t && std::is_signed_v, int> = 0> + bt_value(T&& sint) : bt_variant{static_cast(sint)} {} + + template + bt_value(const std::tuple& tuple) : bt_variant{detail::tuple_to_list(tuple, std::index_sequence_for{})} {} + + template + bt_value(const std::pair& pair) : bt_variant{detail::tuple_to_list(pair, std::index_sequence_for{})} {} + + template , std::enable_if_t && !detail::is_tuple, int> = 0> + bt_value(T&& v) : bt_variant{std::forward(v)} {} + + bt_value(const char* s) : bt_value{std::string_view{s}} {} +}; + +} diff --git a/oxenmq/byte_type.h b/oxenmq/byte_type.h new file mode 100644 index 0000000..f582130 --- /dev/null +++ b/oxenmq/byte_type.h @@ -0,0 +1,28 @@ +#pragma once + +// Specializations for assigning from a char into an output iterator, used by hex/base32z/base64 +// decoding to bytes. + +#include +#include + +namespace oxenmq::detail { + +// Fallback - we just try a char +template +struct byte_type { using type = char; }; + +// Support for things like std::back_inserter: +template +struct byte_type> { + using type = typename OutputIt::container_type::value_type; }; + +// iterator, raw pointers: +template +struct byte_type::reference>>> { + using type = std::remove_reference_t::reference>; }; + +template +using byte_type_t = typename byte_type::type; + +} diff --git a/lokimq/connections.cpp b/oxenmq/connections.cpp similarity index 94% rename from lokimq/connections.cpp rename to oxenmq/connections.cpp index 92e106c..8bd7d7d 100644 --- a/lokimq/connections.cpp +++ b/oxenmq/connections.cpp @@ -1,8 +1,8 @@ -#include "lokimq.h" -#include "lokimq-internal.h" +#include "oxenmq.h" +#include "oxenmq-internal.h" #include "hex.h" -namespace lokimq { +namespace oxenmq { std::ostream& operator<<(std::ostream& o, const ConnectionID& conn) { if (!conn.pk.empty()) @@ -24,7 +24,7 @@ void add_pollitem(std::vector& pollitems, zmq::socket_t& sock) } // anonymous namespace -void LokiMQ::rebuild_pollitems() { +void OxenMQ::rebuild_pollitems() { pollitems.clear(); add_pollitem(pollitems, command); add_pollitem(pollitems, workers_socket); @@ -35,7 +35,7 @@ void LokiMQ::rebuild_pollitems() { pollitems_stale = false; } -void LokiMQ::setup_external_socket(zmq::socket_t& socket) { +void OxenMQ::setup_external_socket(zmq::socket_t& socket) { socket.set(zmq::sockopt::reconnect_ivl, (int) RECONNECT_INTERVAL.count()); socket.set(zmq::sockopt::reconnect_ivl_max, (int) RECONNECT_INTERVAL_MAX.count()); socket.set(zmq::sockopt::handshake_ivl, (int) HANDSHAKE_TIME.count()); @@ -47,7 +47,7 @@ void LokiMQ::setup_external_socket(zmq::socket_t& socket) { } } -void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey) { +void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey) { setup_external_socket(socket); @@ -67,7 +67,7 @@ void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remot // else let ZMQ pick a random one } -ConnectionID LokiMQ::connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive, std::string_view hint) { +ConnectionID OxenMQ::connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive, std::string_view hint) { if (!proxy_thread.joinable()) throw std::logic_error("Cannot call connect_sn() before calling `start()`"); @@ -76,7 +76,7 @@ ConnectionID LokiMQ::connect_sn(std::string_view pubkey, std::chrono::millisecon return pubkey; } -ConnectionID LokiMQ::connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure, +ConnectionID OxenMQ::connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) { if (!proxy_thread.joinable()) throw std::logic_error("Cannot call connect_remote() before calling `start()`"); @@ -96,13 +96,13 @@ ConnectionID LokiMQ::connect_remote(const address& remote, ConnectSuccess on_con return id; } -ConnectionID LokiMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, +ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, std::string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) { return connect_remote(address{remote}.set_pubkey(pubkey), std::move(on_connect), std::move(on_failure), auth_level, timeout); } -void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) { +void OxenMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) { detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize({ {"conn_id", id.id}, {"linger_ms", linger.count()}, @@ -111,7 +111,7 @@ void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) { } std::pair -LokiMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) { +OxenMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) { ConnectionID remote_cid{remote}; auto its = peers.equal_range(remote_cid); peer_info* peer = nullptr; @@ -186,7 +186,7 @@ LokiMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, return {&connections.back(), ""s}; } -std::pair LokiMQ::proxy_connect_sn(bt_dict_consumer data) { +std::pair OxenMQ::proxy_connect_sn(bt_dict_consumer data) { std::string_view hint, remote_pk; std::chrono::milliseconds keep_alive; bool optional = false, incoming_only = false, outgoing_only = false; @@ -226,7 +226,7 @@ void update_connection_indices(Container& c, size_t index, AccessIndex get_index /// Closes outgoing connections and removes all references. Note that this will call `erase()` /// which can invalidate iterators on the various connection containers - if you don't want that, /// delete it first so that the container won't contain the element being deleted. -void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) { +void OxenMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) { connections[index].set(zmq::sockopt::linger, linger > 0ms ? (int) linger.count() : 0); pollitems_stale = true; connections.erase(connections.begin() + index); @@ -244,7 +244,7 @@ void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds ling conn_index_to_id.erase(conn_index_to_id.begin() + index); } -void LokiMQ::proxy_expire_idle_peers() { +void OxenMQ::proxy_expire_idle_peers() { for (auto it = peers.begin(); it != peers.end(); ) { auto &info = it->second; if (info.outgoing()) { @@ -267,7 +267,7 @@ void LokiMQ::proxy_expire_idle_peers() { } } -void LokiMQ::proxy_conn_cleanup() { +void OxenMQ::proxy_conn_cleanup() { LMQ_TRACE("starting proxy connections cleanup"); // Drop idle connections (if we haven't done it in a while) @@ -307,7 +307,7 @@ void LokiMQ::proxy_conn_cleanup() { LMQ_TRACE("done proxy connections cleanup"); }; -void LokiMQ::proxy_connect_remote(bt_dict_consumer data) { +void OxenMQ::proxy_connect_remote(bt_dict_consumer data) { AuthLevel auth_level = AuthLevel::none; long long conn_id = -1; ConnectSuccess on_connect; @@ -372,7 +372,7 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) { peers.emplace(std::move(conn), std::move(peer)); } -void LokiMQ::proxy_disconnect(bt_dict_consumer data) { +void OxenMQ::proxy_disconnect(bt_dict_consumer data) { ConnectionID connid{-1}; std::chrono::milliseconds linger = 1s; @@ -388,7 +388,7 @@ void LokiMQ::proxy_disconnect(bt_dict_consumer data) { proxy_disconnect(std::move(connid), linger); } -void LokiMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) { +void OxenMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) { LMQ_TRACE("Disconnecting outgoing connection to ", conn); auto pr = peers.equal_range(conn); for (auto it = pr.first; it != pr.second; ++it) { diff --git a/oxenmq/connections.h b/oxenmq/connections.h new file mode 100644 index 0000000..9772845 --- /dev/null +++ b/oxenmq/connections.h @@ -0,0 +1,96 @@ +#pragma once +#include "auth.h" +#include "bt_value.h" +#include +#include +#include +#include +#include +#include + +namespace oxenmq { + +struct ConnectionID; + +namespace detail { +template +bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts); +} + +/// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For +/// connections to service node this is the service node pubkey (and you can pass a 32-byte string +/// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy +/// of the ConnectionID returned by connect_remote(). +struct ConnectionID { + // Default construction; creates a ConnectionID with an invalid internal ID that will not match + // an actual connection. + ConnectionID() : ConnectionID(0) {} + // Construction from a service node pubkey + ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} { + if (pk.size() != 32) + throw std::runtime_error{"Invalid pubkey: expected 32 bytes"}; + } + // Construction from a service node pubkey + ConnectionID(std::string_view pubkey_) : ConnectionID(std::string{pubkey_}) {} + + ConnectionID(const ConnectionID&) = default; + ConnectionID(ConnectionID&&) = default; + ConnectionID& operator=(const ConnectionID&) = default; + ConnectionID& operator=(ConnectionID&&) = default; + + // Returns true if this is a ConnectionID (false for a default-constructed, invalid id) + explicit operator bool() const { + return id != 0; + } + + // Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both + // not SNs and have matching internal IDs and routes. (Pubkeys do not have to match for + // non-SNs). + bool operator==(const ConnectionID &o) const { + if (sn() && o.sn()) + return pk == o.pk; + return id == o.id && route == o.route; + } + bool operator!=(const ConnectionID &o) const { return !(*this == o); } + bool operator<(const ConnectionID &o) const { + if (sn() && o.sn()) + return pk < o.pk; + return id < o.id || (id == o.id && route < o.route); + } + + // Returns true if this ConnectionID represents a SN connection + bool sn() const { return id == SN_ID; } + + // Returns this connection's pubkey, if any. (Note that all curve connections have pubkeys, not + // only SNs). + const std::string& pubkey() const { return pk; } + + // Returns a copy of the ConnectionID with the route set to empty. + ConnectionID unrouted() { return ConnectionID{id, pk, ""}; } + +private: + ConnectionID(long long id) : id{id} {} + ConnectionID(long long id, std::string pubkey, std::string route = "") + : id{id}, pk{std::move(pubkey)}, route{std::move(route)} {} + + constexpr static long long SN_ID = -1; + long long id = 0; + std::string pk; + std::string route; + friend class OxenMQ; + friend struct std::hash; + template + friend bt_dict detail::build_send(ConnectionID to, std::string_view cmd, T&&... opts); + friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn); +}; + +} // namespace oxenmq +namespace std { + template <> struct hash { + size_t operator()(const oxenmq::ConnectionID &c) const { + return c.sn() ? oxenmq::already_hashed{}(c.pk) : + std::hash{}(c.id) + std::hash{}(c.route); + } + }; +} // namespace std + diff --git a/oxenmq/hex.h b/oxenmq/hex.h new file mode 100644 index 0000000..fdfad3f --- /dev/null +++ b/oxenmq/hex.h @@ -0,0 +1,165 @@ +// Copyright (c) 2019-2021, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#include +#include +#include +#include +#include +#include "byte_type.h" + +namespace oxenmq { + +namespace detail { + +/// Compile-time generated lookup tables hex conversion +struct hex_table { + char from_hex_lut[256]; + char to_hex_lut[16]; + constexpr hex_table() noexcept : from_hex_lut{}, to_hex_lut{} { + for (unsigned char c = 0; c < 10; c++) { + from_hex_lut[(unsigned char)('0' + c)] = 0 + c; + to_hex_lut[ (unsigned char)( 0 + c)] = '0' + c; + } + for (unsigned char c = 0; c < 6; c++) { + from_hex_lut[(unsigned char)('a' + c)] = 10 + c; + from_hex_lut[(unsigned char)('A' + c)] = 10 + c; + to_hex_lut[ (unsigned char)(10 + c)] = 'a' + c; + } + } + constexpr char from_hex(unsigned char c) const noexcept { return from_hex_lut[c]; } + constexpr char to_hex(unsigned char b) const noexcept { return to_hex_lut[b]; } +} constexpr hex_lut; + +// This main point of this static assert is to force the compiler to compile-time build the constexpr tables. +static_assert(hex_lut.from_hex('a') == 10 && hex_lut.from_hex('F') == 15 && hex_lut.to_hex(13) == 'd', ""); + +} // namespace detail + +/// Creates hex digits from a character sequence. +template +void to_hex(InputIt begin, InputIt end, OutputIt out) { + static_assert(sizeof(decltype(*begin)) == 1, "to_hex requires chars/bytes"); + for (; begin != end; ++begin) { + uint8_t c = static_cast(*begin); + *out++ = detail::hex_lut.to_hex(c >> 4); + *out++ = detail::hex_lut.to_hex(c & 0x0f); + } +} + +/// Creates a string of hex digits from a character sequence iterator pair +template +std::string to_hex(It begin, It end) { + std::string hex; + if constexpr (std::is_base_of_v::iterator_category>) + hex.reserve(2 * std::distance(begin, end)); + to_hex(begin, end, std::back_inserter(hex)); + return hex; +} + +/// Creates a hex string from an iterable, std::string-like object +template +std::string to_hex(std::basic_string_view s) { return to_hex(s.begin(), s.end()); } +inline std::string to_hex(std::string_view s) { return to_hex<>(s); } + +/// Returns true if the given value is a valid hex digit. +template +constexpr bool is_hex_digit(CharT c) { + static_assert(sizeof(CharT) == 1, "is_hex requires chars/bytes"); + return detail::hex_lut.from_hex(static_cast(c)) != 0 || static_cast(c) == '0'; +} + +/// Returns true if all elements in the range are hex characters *and* the string length is a +/// multiple of 2, and thus suitable to pass to from_hex(). +template +constexpr bool is_hex(It begin, It end) { + static_assert(sizeof(decltype(*begin)) == 1, "is_hex requires chars/bytes"); + constexpr bool ra = std::is_base_of_v::iterator_category>; + if constexpr (ra) + if (std::distance(begin, end) % 2 != 0) + return false; + + size_t count = 0; + for (; begin != end; ++begin) { + if constexpr (!ra) ++count; + if (!is_hex_digit(*begin)) + return false; + } + if constexpr (!ra) + return count % 2 == 0; + return true; +} + +/// Returns true if all elements in the string-like value are hex characters +template +constexpr bool is_hex(std::basic_string_view s) { return is_hex(s.begin(), s.end()); } +constexpr bool is_hex(std::string_view s) { return is_hex(s.begin(), s.end()); } + +/// Convert a hex digit into its numeric (0-15) value +constexpr char from_hex_digit(unsigned char x) noexcept { + return detail::hex_lut.from_hex(x); +} + +/// Constructs a byte value from a pair of hex digits +constexpr char from_hex_pair(unsigned char a, unsigned char b) noexcept { return (from_hex_digit(a) << 4) | from_hex_digit(b); } + +/// Converts a sequence of hex digits to bytes. Undefined behaviour if any characters are not in +/// [0-9a-fA-F] or if the input sequence length is not even: call `is_hex` first if you need to +/// check. It is permitted for the input and output ranges to overlap as long as out is no earlier +/// than begin. +template +void from_hex(InputIt begin, InputIt end, OutputIt out) { + using std::distance; + assert(is_hex(begin, end)); + while (begin != end) { + auto a = *begin++; + auto b = *begin++; + *out++ = static_cast>( + from_hex_pair(static_cast(a), static_cast(b))); + } +} + +/// Converts a sequence of hex digits to a string of bytes and returns it. Undefined behaviour if +/// the input sequence is not an even-length sequence of [0-9a-fA-F] characters. +template +std::string from_hex(It begin, It end) { + std::string bytes; + if constexpr (std::is_base_of_v::iterator_category>) + bytes.reserve(std::distance(begin, end) / 2); + from_hex(begin, end, std::back_inserter(bytes)); + return bytes; +} + +/// Converts hex digits from a std::string-like object into a std::string of bytes. Undefined +/// behaviour if any characters are not in [0-9a-fA-F] or if the input sequence length is not even. +template +std::string from_hex(std::basic_string_view s) { return from_hex(s.begin(), s.end()); } +inline std::string from_hex(std::string_view s) { return from_hex<>(s); } + +} diff --git a/lokimq/jobs.cpp b/oxenmq/jobs.cpp similarity index 87% rename from lokimq/jobs.cpp rename to oxenmq/jobs.cpp index 653593a..deae049 100644 --- a/lokimq/jobs.cpp +++ b/oxenmq/jobs.cpp @@ -1,10 +1,10 @@ -#include "lokimq.h" +#include "oxenmq.h" #include "batch.h" -#include "lokimq-internal.h" +#include "oxenmq-internal.h" -namespace lokimq { +namespace oxenmq { -void LokiMQ::proxy_batch(detail::Batch* batch) { +void OxenMQ::proxy_batch(detail::Batch* batch) { batches.insert(batch); const auto [jobs, tagged_threads] = batch->size(); LMQ_TRACE("proxy queuing batch job with ", jobs, " jobs", tagged_threads ? " (job uses tagged thread(s))" : ""); @@ -26,7 +26,7 @@ void LokiMQ::proxy_batch(detail::Batch* batch) { proxy_skip_one_poll = true; } -void LokiMQ::job(std::function f, std::optional thread) { +void OxenMQ::job(std::function f, std::optional thread) { if (thread && thread->_id == -1) throw std::logic_error{"job() cannot be used to queue an in-proxy job"}; auto* b = new Batch; @@ -35,7 +35,7 @@ void LokiMQ::job(std::function f, std::optional thread) detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast(baseptr))); } -void LokiMQ::proxy_schedule_reply_job(std::function f) { +void OxenMQ::proxy_schedule_reply_job(std::function f) { auto* b = new Batch; b->add_job(std::move(f)); batches.insert(b); @@ -43,7 +43,7 @@ void LokiMQ::proxy_schedule_reply_job(std::function f) { proxy_skip_one_poll = true; } -void LokiMQ::proxy_run_batch_jobs(std::queue& jobs, const int reserved, int& active, bool reply) { +void OxenMQ::proxy_run_batch_jobs(std::queue& jobs, const int reserved, int& active, bool reply) { while (!jobs.empty() && active_workers() < max_workers && (active < reserved || active_workers() < general_workers)) { proxy_run_worker(get_idle_worker().load(std::move(jobs.front()), reply)); @@ -54,20 +54,20 @@ void LokiMQ::proxy_run_batch_jobs(std::queue& jobs, const int reserve // Called either within the proxy thread, or before the proxy thread has been created; actually adds // the timer. If the timer object hasn't been set up yet it gets set up here. -void LokiMQ::proxy_timer(std::function job, std::chrono::milliseconds interval, bool squelch, int thread) { +void OxenMQ::proxy_timer(std::function job, std::chrono::milliseconds interval, bool squelch, int thread) { if (!timers) timers.reset(zmq_timers_new()); int timer_id = zmq_timers_add(timers.get(), interval.count(), - [](int timer_id, void* self) { static_cast(self)->_queue_timer_job(timer_id); }, + [](int timer_id, void* self) { static_cast(self)->_queue_timer_job(timer_id); }, this); if (timer_id == -1) throw zmq::error_t{}; timer_jobs[timer_id] = { std::move(job), squelch, false, thread }; } -void LokiMQ::proxy_timer(bt_list_consumer timer_data) { +void OxenMQ::proxy_timer(bt_list_consumer timer_data) { std::unique_ptr> func{reinterpret_cast*>(timer_data.consume_integer())}; auto interval = std::chrono::milliseconds{timer_data.consume_integer()}; auto squelch = timer_data.consume_integer(); @@ -77,7 +77,7 @@ void LokiMQ::proxy_timer(bt_list_consumer timer_data) { proxy_timer(std::move(*func), interval, squelch, thread); } -void LokiMQ::_queue_timer_job(int timer_id) { +void OxenMQ::_queue_timer_job(int timer_id) { auto it = timer_jobs.find(timer_id); if (it == timer_jobs.end()) { LMQ_LOG(warn, "Could not find timer job ", timer_id); @@ -107,7 +107,7 @@ void LokiMQ::_queue_timer_job(int timer_id) { auto it = timer_jobs.find(timer_id); if (it != timer_jobs.end()) it->second.running = false; - }, LokiMQ::run_in_proxy); + }, OxenMQ::run_in_proxy); } batches.insert(b); LMQ_TRACE("b: ", b->size().first, ", ", b->size().second, "; thread = ", thread); @@ -118,7 +118,7 @@ void LokiMQ::_queue_timer_job(int timer_id) { queue.emplace(static_cast(b), 0); } -void LokiMQ::add_timer(std::function job, std::chrono::milliseconds interval, bool squelch, std::optional thread) { +void OxenMQ::add_timer(std::function job, std::chrono::milliseconds interval, bool squelch, std::optional thread) { int th_id = thread ? thread->_id : 0; if (proxy_thread.joinable()) { detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{ @@ -131,9 +131,9 @@ void LokiMQ::add_timer(std::function job, std::chrono::milliseconds inte } } -void LokiMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); } +void OxenMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); } -TaggedThreadID LokiMQ::add_tagged_thread(std::string name, std::function start) { +TaggedThreadID OxenMQ::add_tagged_thread(std::string name, std::function start) { if (proxy_thread.joinable()) throw std::logic_error{"Cannot add tagged threads after calling `start()`"}; @@ -146,7 +146,7 @@ TaggedThreadID LokiMQ::add_tagged_thread(std::string name, std::function run.worker_routing_id = "t" + std::to_string(run.worker_id); LMQ_TRACE("Created new tagged thread ", name, " with routing id ", run.worker_routing_id); - run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id, name, std::move(start)}; + run.worker_thread = std::thread{&OxenMQ::worker_thread, this, run.worker_id, name, std::move(start)}; return TaggedThreadID{static_cast(run.worker_id)}; } diff --git a/oxenmq/message.h b/oxenmq/message.h new file mode 100644 index 0000000..d914561 --- /dev/null +++ b/oxenmq/message.h @@ -0,0 +1,57 @@ +#pragma once +#include +#include "connections.h" + +namespace oxenmq { + +class OxenMQ; + +/// Encapsulates an incoming message from a remote connection with message details plus extra +/// info need to send a reply back through the proxy thread via the `reply()` method. Note that +/// this object gets reused: callbacks should use but not store any reference beyond the callback. +class Message { +public: + OxenMQ& oxenmq; ///< The owning OxenMQ object + std::vector data; ///< The provided command data parts, if any. + ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status. + std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`. + Access access; ///< The access level of the invoker. This can be higher than the access level of the command, for example for an admin invoking a basic command. + std::string remote; ///< Some sort of remote address from which the request came. Often "IP" for TCP connections and "localhost:UID:GID:PID" for UDP connections. + + /// Constructor + Message(OxenMQ& lmq, ConnectionID cid, Access access, std::string remote) + : oxenmq{lmq}, conn{std::move(cid)}, access{std::move(access)}, remote{std::move(remote)} {} + + // Non-copyable + Message(const Message&) = delete; + Message& operator=(const Message&) = delete; + + /// Sends a command back to whomever sent this message. Arguments are forwarded to send() but + /// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e. + /// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to + /// establish a new connection to the SN if no longer connected. For non-SN messages the reply + /// will be attempted using the available routing information, but if the connection has already + /// been closed the reply will be dropped. + /// + /// If you want to send a non-strong reply even when the remote is a service node then add + /// an explicit `send_option::optional()` argument. + template + void send_back(std::string_view, Args&&... args); + + /// Sends a reply to a request. This takes no command: the command is always the built-in + /// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other + /// arguments are as in `send_back()`. You should only send one reply for a command expecting + /// replies, though this is not enforced: attempting to send multiple replies will simply be + /// dropped when received by the remote. (Note, however, that it is possible to send multiple + /// messages -- e.g. you could send a reply and then also call send_back() and/or send_request() + /// to send more requests back to the sender). + template + void send_reply(Args&&... args); + + /// Sends a request back to whomever sent this message. This is effectively a wrapper around + /// lmq.request() that takes care of setting up the recipient arguments. + template + void send_request(std::string_view cmd, ReplyCallback&& callback, Args&&... args); +}; + +} diff --git a/lokimq/lokimq-internal.h b/oxenmq/oxenmq-internal.h similarity index 99% rename from lokimq/lokimq-internal.h rename to oxenmq/oxenmq-internal.h index 005a640..f43a8ac 100644 --- a/lokimq/lokimq-internal.h +++ b/oxenmq/oxenmq-internal.h @@ -1,5 +1,5 @@ #pragma once -#include "lokimq.h" +#include "oxenmq.h" // Inside some method: // LMQ_LOG(warn, "bad ", 42, " stuff"); @@ -13,7 +13,7 @@ # define LMQ_TRACE(...) #endif -namespace lokimq { +namespace oxenmq { constexpr char SN_ADDR_COMMAND[] = "inproc://sn-command"; constexpr char SN_ADDR_WORKERS[] = "inproc://sn-workers"; diff --git a/lokimq/lokimq.cpp b/oxenmq/oxenmq.cpp similarity index 89% rename from lokimq/lokimq.cpp rename to oxenmq/oxenmq.cpp index 4d8278a..7747c40 100644 --- a/lokimq/lokimq.cpp +++ b/oxenmq/oxenmq.cpp @@ -1,5 +1,5 @@ -#include "lokimq.h" -#include "lokimq-internal.h" +#include "oxenmq.h" +#include "oxenmq-internal.h" #include "zmq.hpp" #include #include @@ -13,7 +13,7 @@ extern "C" { } #include "hex.h" -namespace lokimq { +namespace oxenmq { namespace { @@ -76,20 +76,20 @@ std::pair extract_metadata(zmq::message_t& msg) { } // namespace detail -void LokiMQ::set_zmq_context_option(zmq::ctxopt option, int value) { +void OxenMQ::set_zmq_context_option(zmq::ctxopt option, int value) { context.set(option, value); } -void LokiMQ::log_level(LogLevel level) { +void OxenMQ::log_level(LogLevel level) { log_lvl.store(level, std::memory_order_relaxed); } -LogLevel LokiMQ::log_level() const { +LogLevel OxenMQ::log_level() const { return log_lvl.load(std::memory_order_relaxed); } -CatHelper LokiMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) { +CatHelper OxenMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) { check_not_started(proxy_thread, "add a category"); if (name.size() > MAX_CATEGORY_LENGTH) @@ -107,7 +107,7 @@ CatHelper LokiMQ::add_category(std::string name, Access access_level, unsigned i return ret; } -void LokiMQ::add_command(const std::string& category, std::string name, CommandCallback callback) { +void OxenMQ::add_command(const std::string& category, std::string name, CommandCallback callback) { check_not_started(proxy_thread, "add a command"); if (name.size() > MAX_COMMAND_LENGTH) @@ -126,12 +126,12 @@ void LokiMQ::add_command(const std::string& category, std::string name, CommandC throw std::runtime_error("Cannot add command `" + fullname + "': that command already exists"); } -void LokiMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) { +void OxenMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) { add_command(category, name, std::move(callback)); categories.at(category).commands.at(name).second = true; } -void LokiMQ::add_command_alias(std::string from, std::string to) { +void OxenMQ::add_command_alias(std::string from, std::string to) { check_not_started(proxy_thread, "add a command alias"); if (from.empty()) @@ -160,10 +160,10 @@ std::atomic next_id{1}; /// Accesses a thread-local command socket connected to the proxy's command socket used to issue /// commands in a thread-safe manner. A mutex is only required here the first time a thread /// accesses the control socket. -zmq::socket_t& LokiMQ::get_control_socket() { +zmq::socket_t& OxenMQ::get_control_socket() { assert(proxy_thread.joinable()); - // Optimize by caching the last value; LokiMQ is often a singleton and in that case we're + // Optimize by caching the last value; OxenMQ is often a singleton and in that case we're // going to *always* hit this optimization. Even if it isn't, we're probably likely to need the // same control socket from the same thread multiple times sequentially so this may still help. static thread_local int last_id = -1; @@ -174,7 +174,7 @@ zmq::socket_t& LokiMQ::get_control_socket() { std::lock_guard lock{control_sockets_mutex}; if (proxy_shutting_down) - throw std::runtime_error("Unable to obtain LokiMQ control socket: proxy thread is shutting down"); + throw std::runtime_error("Unable to obtain OxenMQ control socket: proxy thread is shutting down"); auto& socket = control_sockets[std::this_thread::get_id()]; if (!socket) { @@ -188,7 +188,7 @@ zmq::socket_t& LokiMQ::get_control_socket() { } -LokiMQ::LokiMQ( +OxenMQ::OxenMQ( std::string pubkey_, std::string privkey_, bool service_node, @@ -199,17 +199,17 @@ LokiMQ::LokiMQ( sn_lookup{std::move(lookup)}, log_lvl{level}, logger{std::move(logger)} { - LMQ_TRACE("Constructing LokiMQ, id=", object_id, ", this=", this); + LMQ_TRACE("Constructing OxenMQ, id=", object_id, ", this=", this); if (sodium_init() == -1) throw std::runtime_error{"libsodium initialization failed"}; if (pubkey.empty() != privkey.empty()) { - throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key."); + throw std::invalid_argument("OxenMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key."); } else if (pubkey.empty()) { if (service_node) - throw std::invalid_argument("Cannot construct a service node mode LokiMQ without a keypair"); - LMQ_LOG(debug, "generating x25519 keypair for remote-only LokiMQ instance"); + throw std::invalid_argument("Cannot construct a service node mode OxenMQ without a keypair"); + LMQ_LOG(debug, "generating x25519 keypair for remote-only OxenMQ instance"); pubkey.resize(crypto_box_PUBLICKEYBYTES); privkey.resize(crypto_box_SECRETKEYBYTES); crypto_box_keypair(reinterpret_cast(&pubkey[0]), reinterpret_cast(&privkey[0])); @@ -224,11 +224,11 @@ LokiMQ::LokiMQ( std::string verify_pubkey(crypto_box_PUBLICKEYBYTES, 0); crypto_scalarmult_base(reinterpret_cast(&verify_pubkey[0]), reinterpret_cast(&privkey[0])); if (verify_pubkey != pubkey) - throw std::invalid_argument("Invalid pubkey/privkey values given to LokiMQ construction: pubkey verification failed"); + throw std::invalid_argument("Invalid pubkey/privkey values given to OxenMQ construction: pubkey verification failed"); } } -void LokiMQ::start() { +void OxenMQ::start() { if (proxy_thread.joinable()) throw std::logic_error("Cannot call start() multiple times!"); @@ -238,19 +238,19 @@ void LokiMQ::start() { if (bind.empty() && local_service_node) throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"}; - LMQ_LOG(info, "Initializing LokiMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey)); + LMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey)); int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit); if (MAX_SOCKETS > 1 && MAX_SOCKETS <= zmq_socket_limit) context.set(zmq::ctxopt::max_sockets, MAX_SOCKETS); else - LMQ_LOG(error, "Not applying LokiMQ::MAX_SOCKETS setting: ", MAX_SOCKETS, " must be in [1, ", zmq_socket_limit, "]"); + LMQ_LOG(error, "Not applying OxenMQ::MAX_SOCKETS setting: ", MAX_SOCKETS, " must be in [1, ", zmq_socket_limit, "]"); // We bind `command` here so that the `get_control_socket()` below is always connecting to a // bound socket, but we do nothing else here: the proxy thread is responsible for everything // except binding it. command.bind(SN_ADDR_COMMAND); - proxy_thread = std::thread{&LokiMQ::proxy_loop, this}; + proxy_thread = std::thread{&OxenMQ::proxy_loop, this}; LMQ_LOG(debug, "Waiting for proxy thread to get ready..."); auto &control = get_control_socket(); @@ -260,14 +260,14 @@ void LokiMQ::start() { zmq::message_t ready_msg; std::vector parts; try { recv_message_parts(control, parts); } - catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from LokiMQ::Proxy thread: "s + e.what()); } + catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from OxenMQ::Proxy thread: "s + e.what()); } if (!(parts.size() == 1 && view(parts.front()) == "READY")) throw std::runtime_error("Invalid startup message from proxy thread (didn't get expected READY message)"); LMQ_LOG(debug, "Proxy thread is ready"); } -void LokiMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) { +void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) { // TODO: there's no particular reason we can't start listening after starting up; just needs to // be implemented. (But if we can start we'll probably also want to be able to stop, so it's // more than just binding that needs implementing). @@ -276,7 +276,7 @@ void LokiMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) { bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)}); } -void LokiMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) { +void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) { // TODO: As above. check_not_started(proxy_thread, "start listening"); @@ -284,7 +284,7 @@ void LokiMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) { } -std::pair*> LokiMQ::get_command(std::string& command) { +std::pair*> OxenMQ::get_command(std::string& command) { if (command.size() > MAX_CATEGORY_LENGTH + 1 + MAX_COMMAND_LENGTH) { LMQ_LOG(warn, "Invalid command '", command, "': command too long"); return {}; @@ -320,7 +320,7 @@ std::pair*> Lo return {&catit->second, &callback_it->second}; } -void LokiMQ::set_batch_threads(int threads) { +void OxenMQ::set_batch_threads(int threads) { if (proxy_thread.joinable()) throw std::logic_error("Cannot change reserved batch threads after calling `start()`"); if (threads < -1) // -1 is the default which is based on general threads @@ -328,7 +328,7 @@ void LokiMQ::set_batch_threads(int threads) { batch_jobs_reserved = threads; } -void LokiMQ::set_reply_threads(int threads) { +void OxenMQ::set_reply_threads(int threads) { if (proxy_thread.joinable()) throw std::logic_error("Cannot change reserved reply threads after calling `start()`"); if (threads < -1) // -1 is the default which is based on general threads @@ -336,7 +336,7 @@ void LokiMQ::set_reply_threads(int threads) { reply_jobs_reserved = threads; } -void LokiMQ::set_general_threads(int threads) { +void OxenMQ::set_general_threads(int threads) { if (proxy_thread.joinable()) throw std::logic_error("Cannot change general thread count after calling `start()`"); if (threads < 1) @@ -344,7 +344,7 @@ void LokiMQ::set_general_threads(int threads) { general_workers = threads; } -LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_, +OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_, std::vector data_parts_, const std::pair* callback_) { reset(); cat = cat_; @@ -357,7 +357,7 @@ LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, C return *this; } -LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function callback) { +OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function callback) { reset(); is_injected = true; cat = cat_; @@ -369,7 +369,7 @@ LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, s return *this; } -LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) { +OxenMQ::run_info& OxenMQ::run_info::load(pending_command&& pending) { if (auto *f = std::get_if>(&pending.callback)) return load(&pending.cat, std::move(pending.command), std::move(pending.remote), std::move(*f)); @@ -378,7 +378,7 @@ LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) { std::move(pending.remote), std::move(pending.data_parts), var::get<0>(pending.callback)); } -LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) { +OxenMQ::run_info& OxenMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) { reset(); is_batch_job = true; is_reply_job = reply_job; @@ -389,7 +389,7 @@ LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job, int tag } -LokiMQ::~LokiMQ() { +OxenMQ::~OxenMQ() { if (!proxy_thread.joinable()) { if (!tagged_workers.empty()) { // This is a bit icky: we have tagged workers that are waiting for a signal on @@ -416,10 +416,10 @@ LokiMQ::~LokiMQ() { return; } - LMQ_LOG(info, "LokiMQ shutting down proxy thread"); + LMQ_LOG(info, "OxenMQ shutting down proxy thread"); detail::send_control(get_control_socket(), "QUIT"); proxy_thread.join(); - LMQ_LOG(info, "LokiMQ proxy thread has stopped"); + LMQ_LOG(info, "OxenMQ proxy thread has stopped"); } std::ostream &operator<<(std::ostream &os, LogLevel lvl) { @@ -443,5 +443,5 @@ std::string make_random_string(size_t size) { return rando; } -} // namespace lokimq +} // namespace oxenmq // vim:sw=4:et diff --git a/oxenmq/oxenmq.h b/oxenmq/oxenmq.h new file mode 100644 index 0000000..84cf037 --- /dev/null +++ b/oxenmq/oxenmq.h @@ -0,0 +1,1528 @@ +// Copyright (c) 2019-2021, The Oxen Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "zmq.hpp" +#include "address.h" +#include "bt_serialize.h" +#include "connections.h" +#include "message.h" +#include "auth.h" + +#if ZMQ_VERSION < ZMQ_MAKE_VERSION (4, 3, 0) +// Timers were not added until 4.3.0 +#error "ZMQ >= 4.3.0 required" +#endif + +namespace oxenmq { + +using namespace std::literals; + +/// Logging levels passed into LogFunc. (Note that trace does nothing more than debug in a release +/// build). +enum class LogLevel { fatal, error, warn, info, debug, trace }; + +// Forward declarations; see batch.h +namespace detail { class Batch; } +template class Batch; + +/** The keep-alive time for a send() that results in a establishing a new outbound connection. To + * use a longer keep-alive to a host call `connect()` first with the desired keep-alive time or pass + * the send_option::keep_alive. + */ +inline constexpr auto DEFAULT_SEND_KEEP_ALIVE = 30s; + +// The default timeout for connect_remote() +inline constexpr auto REMOTE_CONNECT_TIMEOUT = 10s; + +// The amount of time we wait for a reply to a REQUEST before calling the callback with +// `false` to signal a timeout. +inline constexpr auto DEFAULT_REQUEST_TIMEOUT = 15s; + +/// Maximum length of a category +inline constexpr size_t MAX_CATEGORY_LENGTH = 50; + +/// Maximum length of a command +inline constexpr size_t MAX_COMMAND_LENGTH = 200; + +class CatHelper; + +/// Opaque handle for a tagged thread constructed by add_tagged_thread(...). Not directly +/// constructible, but is safe (and cheap) to copy. +struct TaggedThreadID { +private: + int _id; + explicit constexpr TaggedThreadID(int id) : _id{id} {} + friend class OxenMQ; + template friend class Batch; +}; + +/** + * Class that handles OxenMQ listeners, connections, proxying, and workers. An application + * typically has just one instance of this class. + */ +class OxenMQ { + +private: + + /// The global context + zmq::context_t context; + + /// A unique id for this OxenMQ instance, assigned in a thread-safe manner during construction. + const int object_id; + + /// The x25519 keypair of this connection. For service nodes these are the long-run x25519 keys + /// provided at construction, for non-service-node connections these are generated during + /// construction. + std::string pubkey, privkey; + + /// True if *this* node is running in service node mode (whether or not actually active) + bool local_service_node = false; + + /// The thread in which most of the intermediate work happens (handling external connections + /// and proxying requests between them to worker threads) + std::thread proxy_thread; + + /// Will be true (and is guarded by a mutex) if the proxy thread is quitting; guards against new + /// control sockets from threads trying to talk to the proxy thread. + bool proxy_shutting_down = false; + + /// We have one seldom-used mutex here: it is generally locked just once per thread (the first + /// time the thread calls get_control_socket()) and once more by the proxy thread when it shuts + /// down. + std::mutex control_sockets_mutex; + + /// Called to obtain a "command" socket that attaches to `control` to send commands to the + /// proxy thread from other threads. This socket is unique per thread and OxenMQ instance. + zmq::socket_t& get_control_socket(); + + /// Per-thread control sockets used by oxenmq threads to talk to this object's proxy thread. + std::unordered_map> control_sockets; + +public: + + /// Callback type invoked to determine whether the given new incoming connection is allowed to + /// connect to us and to set its authentication level. + /// + /// @param address - the address of the incoming connection. For TCP connections this is an IP + /// address; for UDP connections it's a string such as "localhost:UID:GID:PID". + /// @param pubkey - the x25519 pubkey of the connecting client (32 byte string). Note that this + /// will only be non-empty for incoming connections on `listen_curve` sockets; `listen_plain` + /// sockets do not have a pubkey. + /// @param service_node - will be true if the `pubkey` is in the set of known active service + /// nodes. + /// + /// @returns an `AuthLevel` enum value indicating the default auth level for the incoming + /// connection, or AuthLevel::denied if the connection should be refused. + using AllowFunc = std::function; + + /// Callback that is invoked when we need to send a "strong" message to a SN that we aren't + /// already connected to and need to establish a connection. This callback returns the ZMQ + /// connection string we should use which is typically a string such as `tcp://1.2.3.4:5678`. + using SNRemoteAddress = std::function; + + /// The callback type for registered commands. + using CommandCallback = std::function; + + /// The callback for making requests. This is called with `true` and a (moved) vector of data + /// part strings when we get a reply, or `false` and empty vector on timeout. + using ReplyCallback = std::function data)>; + + /// Called to write a log message. This will only be called if the `level` is >= the current + /// OxenMQ object log level. It must be a raw function pointer (or a capture-less lambda) for + /// performance reasons. Takes four arguments: the log level of the message, the filename and + /// line number where the log message was invoked, and the log message itself. + using Logger = std::function; + + /// Callback for the success case of connect_remote() + using ConnectSuccess = std::function; + /// Callback for the failure case of connect_remote() + using ConnectFailure = std::function; + + /// Explicitly non-copyable, non-movable because most things here aren't copyable, and a few + /// things aren't movable, either. If you need to pass the OxenMQ instance around, wrap it + /// in a unique_ptr or shared_ptr. + OxenMQ(const OxenMQ&) = delete; + OxenMQ& operator=(const OxenMQ&) = delete; + OxenMQ(OxenMQ&&) = delete; + OxenMQ& operator=(OxenMQ&&) = delete; + + /** How long to wait for handshaking to complete on external connections before timing out and + * closing the connection. Setting this only affects new outgoing connections. */ + std::chrono::milliseconds HANDSHAKE_TIME = 10s; + + /** Whether to use a zmq routing ID based on the pubkey for new outgoing connections. This is + * normally desirable as it allows the listener to recognize that the incoming connection is a + * reconnection from the same remote and handover routing to the new socket while closing off + * the (likely dead) old socket. This, however, prevents a single OxenMQ instance from + * establishing multiple connections to the same listening OxenMQ, which is sometimes useful + * (for example when testing), and so this option can be overridden to `false` to use completely + * random zmq routing ids on outgoing connections (which will thus allow multiple connections). + */ + bool PUBKEY_BASED_ROUTING_ID = true; + + /** Maximum incoming message size; if a remote tries sending a message larger than this they get + * disconnected. -1 means no limit. */ + int64_t MAX_MSG_SIZE = 1 * 1024 * 1024; + + /** Maximum open sockets, passed to the ZMQ context during start(). The default here is 10k, + * designed to be enough to be more than enough to allow a full-mesh SN layer connection if + * necessary for the forseeable future. */ + int MAX_SOCKETS = 10000; + + /** Minimum reconnect interval: when a connection fails or dies, wait this long before + * attempting to reconnect. (ZMQ may randomize the value somewhat to avoid reconnection + * storms). See RECONNECT_INTERVAL_MAX as well. The OxenMQ default is 250ms. + */ + std::chrono::milliseconds RECONNECT_INTERVAL = 250ms; + + /** Maximum reconnect interval. When this is set to a value larger than RECONNECT_INTERVAL then + * ZMQ's reconnection logic uses an exponential backoff: each reconnection attempts waits twice + * as long as the previous attempt, up to this maximum. The OxenMQ default is 5 seconds. + */ + std::chrono::milliseconds RECONNECT_INTERVAL_MAX = 5s; + + /** How long (in ms) to linger sockets when closing them; this is the maximum time zmq spends + * trying to sending pending messages before dropping them and closing the underlying socket + * after the high-level zmq socket is closed. */ + std::chrono::milliseconds CLOSE_LINGER = 5s; + + /** How frequently we cleanup connections (closing idle connections, calling connect or request + * failure callbacks). Making this slower results in more "overshoot" before failure callbacks + * are invoked; making it too fast results in more proxy thread overhead. Any change to this + * variable must be set before calling start(). + */ + std::chrono::milliseconds CONN_CHECK_INTERVAL = 250ms; + + /** Whether to enable heartbeats on incoming/outgoing connections. If set to > 0 then we set up + * ZMQ to send a heartbeat ping over the socket this often, which helps keep the connection + * alive and lets failed connections be detected sooner (see the next option). + * + * Only new connections created after changing this are affected, so if changing it is + * recommended to set it before calling `start()`. + */ + std::chrono::milliseconds CONN_HEARTBEAT = 15s; + + /** When CONN_HEARTBEAT is enabled, this sets how long we wait for a reply on a socket before + * considering the socket to have died and closing it. + * + * Only new connections created after changing this are affected, so if changing it is + * recommended to set it before calling `start()`. + */ + std::chrono::milliseconds CONN_HEARTBEAT_TIMEOUT = 30s; + + /// Allows you to set options on the internal zmq context object. For advanced use only. + void set_zmq_context_option(zmq::ctxopt option, int value); + + /** The umask to apply when constructing sockets (which affects any new ipc:// listening sockets + * that get created). Does nothing if set to -1 (the default), and does nothing on Windows. + * Note that the umask is applied temporarily during `start()`, so may affect other threads that + * create files/directories at the same time as the start() call. + */ + int STARTUP_UMASK = -1; + + /** The gid that owns any sockets when constructed (same as umask) + */ + int SOCKET_GID = -1; + /** The uid that owns any sockets when constructed (same as umask but requires root) + */ + int SOCKET_UID = -1; + + /// A special TaggedThreadID value that always refers to the proxy thread; the main use of this is + /// to direct very simple batch completion jobs to be executed directly in the proxy thread. + inline static constexpr TaggedThreadID run_in_proxy{-1}; + + /// Writes a message to the logging system; intended mostly for internal use. + template + void log(LogLevel lvl, const char* filename, int line, const T&... stuff); + +private: + + /// The lookup function that tells us where to connect to a peer, or empty if not found. + SNRemoteAddress sn_lookup; + + /// The log level; this is atomic but we use relaxed order to set and access it (so changing it + /// might not be instantly visible on all threads, but that's okay). + std::atomic log_lvl{LogLevel::warn}; + + /// The callback to call with log messages + Logger logger; + + /////////////////////////////////////////////////////////////////////////////////// + /// NB: The following are all the domain of the proxy thread (once it is started)! + + /// The socket we listen on for handling ZAP authentication requests (the other end is internal + /// to zmq which sends requests to us as needed). + zmq::socket_t zap_auth{context, zmq::socket_type::rep}; + + struct bind_data { + bool curve; + size_t index; + AllowFunc allow; + bind_data(bool curve, AllowFunc allow) + : curve{curve}, index{0}, allow{std::move(allow)} {} + }; + + /// Addresses on which we are listening (or, before start(), on which we will listen). + std::vector> bind; + + /// Info about a peer's established connection with us. Note that "established" means both + /// connected and authenticated. Note that we only store peer info data for SN connections (in + /// or out), and outgoing non-SN connections. Incoming non-SN connections are handled on the + /// fly. + struct peer_info { + /// Pubkey of the remote, if this connection is a curve25519 connection; empty otherwise. + std::string pubkey; + + /// True if we've authenticated this peer as a service node. This gets set on incoming + /// messages when we check the remote's pubkey, and immediately on outgoing connections to + /// SNs (since we know their pubkey -- we'll fail to connect if it doesn't match). + bool service_node = false; + + /// The auth level of this peer, as returned by the AllowFunc for incoming connections or + /// specified during outgoing connections. + AuthLevel auth_level = AuthLevel::none; + + /// The actual internal socket index through which this connection is established + size_t conn_index; + + /// Will be set to a non-empty routing prefix *if* one is necessary on the connection. This + /// is used only for SN peers (non-SN incoming connections don't have a peer_info record, + /// and outgoing connections don't have a route). + std::string route; + + /// Returns true if this is an outgoing connection. (This is simply an alias for + /// route.empty() -- outgoing connections never have a route, incoming connections always + /// do). + bool outgoing() const { return route.empty(); } + + /// The last time we sent or received a message (or had some other relevant activity) on + /// this connection. Used for closing outgoing connections that have reached an inactivity + /// expiry time (closing inactive conns for incoming connections is done by the other end). + std::chrono::steady_clock::time_point last_activity; + + /// Updates last_activity to the current time + void activity() { last_activity = std::chrono::steady_clock::now(); } + + /// After more than this much inactivity we will close an idle (outgoing) connection + std::chrono::milliseconds idle_expiry; + }; + + /// Currently peer connections: id -> peer_info. The ID is as returned by connect_remote or a + /// SN pubkey string. + std::unordered_multimap peers; + + /// Maps connection indices (which can change) to ConnectionID values (which are permanent). + /// This is primarily for outgoing sockets, but incoming sockets are here too (with empty-route + /// (and thus unroutable) ConnectionIDs). + std::vector conn_index_to_id; + + /// Maps listening socket ConnectionIDs to connection index values (these don't have peers + /// entries). The keys here have empty routes (and thus aren't actually routable). + std::unordered_map incoming_conn_index; + + /// The next ConnectionID value we should use (for non-SN connections). + std::atomic next_conn_id{1}; + + /// Remotes we are still trying to connect to (via connect_remote(), not connect_sn()); when + /// we pass handshaking we move them out of here and (if set) trigger the on_connect callback. + /// Unlike regular node-to-node peers, these have an extra "HI"/"HELLO" sequence that we used + /// before we consider ourselves connected to the remote. + std::list> pending_connects; + + /// Pending requests that have been sent out but not yet received a matching "REPLY". The value + /// is the timeout timestamp. + std::unordered_map> + pending_requests; + + /// different polling sockets the proxy handler polls: this always contains some internal + /// sockets for inter-thread communication followed by a pollitem for every connection (both + /// incoming and outgoing) in `connections`. We rebuild this from `connections` whenever + /// `pollitems_stale` is set to true. + std::vector pollitems; + + /// If set then rebuild pollitems before the next poll (set when establishing new connections or + /// closing existing ones). + bool pollitems_stale = true; + + /// Rebuilds pollitems to include the internal sockets + all incoming/outgoing sockets. + void rebuild_pollitems(); + + /// The connections to/from remotes we currently have open, both listening and outgoing. Each + /// element [i] here corresponds to an the pollitem_t at pollitems[i+1+poll_internal_size]. + /// (Ideally we'd use one structure, but zmq requires the pollitems be in contiguous storage). + std::vector connections; + + /// Socket we listen on to receive control messages in the proxy thread. Each thread has its own + /// internal "control" connection (returned by `get_control_socket()`) to this socket used to + /// give instructions to the proxy such as instructing it to initiate a connection to a remote + /// or send a message. + zmq::socket_t command{context, zmq::socket_type::router}; + + /// Timers. TODO: once cppzmq adds an interface around the zmq C timers API then switch to it. + struct TimersDeleter { void operator()(void* timers); }; + struct timer_data { std::function function; bool squelch; bool running; int thread; }; + std::unordered_map timer_jobs; + std::unique_ptr timers; +public: + // This needs to be public because we have to be able to call it from a plain C function. + // Nothing external may call it! + void _queue_timer_job(int); +private: + + /// Router socket to reach internal worker threads from proxy + zmq::socket_t workers_socket{context, zmq::socket_type::router}; + + /// indices of idle, active workers + std::vector idle_workers; + + /// Maximum number of general task workers, specified by g`/during construction + int general_workers = std::max(1, std::thread::hardware_concurrency()); + + /// Maximum number of possible worker threads we can have. This is calculated when starting, + /// and equals general_workers plus the sum of all categories' reserved threads counts plus the + /// reserved batch workers count. This is also used to signal a shutdown; we set it to 0 when + /// quitting. + int max_workers; + + /// Number of active workers + int active_workers() const { return workers.size() - idle_workers.size(); } + + /// Worker thread loop. Tagged and start are provided for a tagged worker thread. + void worker_thread(unsigned int index, std::optional tagged = std::nullopt, std::function start = nullptr); + + /// If set, skip polling for one proxy loop iteration (set when we know we have something + /// processible without having to shove it onto a socket, such as scheduling an internal job). + bool proxy_skip_one_poll = false; + + /// Does the proxying work + void proxy_loop(); + + void proxy_conn_cleanup(); + + void proxy_worker_message(std::vector& parts); + + void proxy_process_queue(); + + void proxy_schedule_reply_job(std::function f); + + /// Looks up a peers element given a connect index (for outgoing connections where we already + /// knew the pubkey and SN status) or an incoming zmq message (which has the pubkey and sn + /// status metadata set during initial connection authentication), creating a new peer element + /// if required. + decltype(peers)::iterator proxy_lookup_peer(int conn_index, zmq::message_t& msg); + + /// Handles built-in primitive commands in the proxy thread for things like "BYE" that have to + /// be done in the proxy thread anyway (if we forwarded to a worker the worker would just have + /// to send an instruction back to the proxy to do it). Returns true if one was handled, false + /// to continue with sending to a worker. + bool proxy_handle_builtin(size_t conn_index, std::vector& parts); + + struct run_info; + /// Gets an idle worker's run_info and removes the worker from the idle worker list. If there + /// is no idle worker this creates a new `workers` element for a new worker (and so you should + /// only call this if new workers are permitted). Note that if this creates a new work info the + /// worker will *not* yet be started, so the caller must create the thread (in `.thread`) after + /// setting up the job if `.thread.joinable()` is false. + run_info& get_idle_worker(); + + /// Runs the worker; called after the `run` object has been set up. If the worker thread hasn't + /// been created then it is spawned; otherwise it is sent a RUN command. + void proxy_run_worker(run_info& run); + + /// Sets up a job for a worker then signals the worker (or starts a worker thread) + void proxy_to_worker(size_t conn_index, std::vector& parts); + + /// proxy thread command handlers for commands sent from the outer object QUIT. This doesn't + /// get called immediately on a QUIT command: the QUIT commands tells workers to quit, then this + /// gets called after all works have done so. + void proxy_quit(); + + // Common setup code for setting up an external (incoming or outgoing) socket. + void setup_external_socket(zmq::socket_t& socket); + + // Sets the various properties on an outgoing socket prior to connection. If remote_pubkey is + // provided then the connection will be curve25519 encrypted and authenticate; otherwise it will + // be unencrypted and unauthenticated. Note that the remote end must be in the same mode (i.e. + // either accepting curve connections, or not accepting curve). + void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey = {}); + + /// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and, + /// if a routing prefix is needed, the required prefix (or an empty string if not needed). For + /// an optional connect that fails (or some other connection failure), returns nullptr for the + /// socket. + /// + /// @param pubkey the pubkey to connect to + /// @param connect_hint if we need a new connection and this is non-empty then we *may* use it + /// instead of doing a call to `sn_lookup()`. + /// @param optional if we don't already have a connection then don't establish a new one + /// @param incoming_only only relay this if we have an established incoming connection from the + /// given SN, otherwise don't connect (like `optional`) + /// @param keep_alive the keep alive for the connection, if we establish a new outgoing + /// connection. If we already have an outgoing connection then its keep-alive gets increased to + /// this if currently less than this. + std::pair proxy_connect_sn(std::string_view pubkey, std::string_view connect_hint, + bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive); + + /// CONNECT_SN command telling us to connect to a new pubkey. Returns the socket (which could + /// be existing or a new one). This basically just unpacks arguments and passes them on to + /// proxy_connect_sn(). + std::pair proxy_connect_sn(bt_dict_consumer data); + + /// Opens a new connection to a remote, with callbacks. This is the proxy-side implementation + /// of the `connect_remote()` call. + void proxy_connect_remote(bt_dict_consumer data); + + /// Called to disconnect our remote connection to the given id (if we have one). + void proxy_disconnect(bt_dict_consumer data); + void proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger); + + /// SEND command. Does a connect first, if necessary. + void proxy_send(bt_dict_consumer data); + + /// REPLY command. Like SEND, but only has a listening socket route to send back to and so is + /// weaker (i.e. it cannot reconnect to the SN if the connection is no longer open). + void proxy_reply(bt_dict_consumer data); + + /// Currently active batch/reply jobs; this is the container that owns the Batch instances + std::unordered_set batches; + /// Individual batch jobs waiting to run; .second is the 0-n batch number or -1 for the + /// completion job + using batch_job = std::pair; + std::queue batch_jobs, reply_jobs; + int batch_jobs_active = 0; + int reply_jobs_active = 0; + int batch_jobs_reserved = -1; + int reply_jobs_reserved = -1; + /// Runs any queued batch jobs + void proxy_run_batch_jobs(std::queue& jobs, int reserved, int& active, bool reply); + + /// BATCH command. Called with a Batch (see oxenmq/batch.h) object pointer for the proxy to + /// take over and queue batch jobs. + void proxy_batch(detail::Batch* batch); + + /// TIMER command. Called with a serialized list containing: function pointer to assume + /// ownership of, an interval count (in ms), and whether or not jobs should be squelched (see + /// `add_timer()`). + void proxy_timer(bt_list_consumer timer_data); + + /// Same, but deserialized + void proxy_timer(std::function job, std::chrono::milliseconds interval, bool squelch, int thread); + + /// ZAP (https://rfc.zeromq.org/spec:27/ZAP/) authentication handler; this does non-blocking + /// processing of any waiting authentication requests for new incoming connections. + void process_zap_requests(); + + /// Handles a control message from some outer thread to the proxy + void proxy_control_message(std::vector& parts); + + /// Closing any idle connections that have outlived their idle time. Note that this only + /// affects outgoing connections; incomings connections are the responsibility of the other end. + void proxy_expire_idle_peers(); + + /// Helper method to actually close a remote connection and update the stuff that needs updating. + void proxy_close_connection(size_t removed, std::chrono::milliseconds linger); + + /// Closes an outgoing connection immediately, updates internal variables appropriately. + /// Returns the next iterator (the original may or may not be removed from peers, depending on + /// whether or not it also has an active incoming connection). + decltype(peers)::iterator proxy_close_outgoing(decltype(peers)::iterator it); + + struct category { + Access access; + std::unordered_map> commands; + unsigned int reserved_threads = 0; + unsigned int active_threads = 0; + int max_queue = 200; + int queued = 0; + + category(Access access, unsigned int reserved_threads, int max_queue) + : access{access}, reserved_threads{reserved_threads}, max_queue{max_queue} {} + }; + + /// Categories, mapped by category name. + std::unordered_map categories; + + /// For enabling backwards compatibility with command renaming: this allows mapping one command + /// to another in a different category (which happens before the category and command lookup is + /// done). + std::unordered_map command_aliases; + + using cat_call_t = std::pair*>; + /// Retrieve category and callback from a command name, including alias mapping. Warns on + /// invalid commands and returns nullptrs. The command name will be updated in place if it is + /// aliased to another command. + cat_call_t get_command(std::string& command); + + /// Checks a peer's authentication level. Returns true if allowed, warns and returns false if + /// not. + bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, + zmq::message_t& command, const cat_call_t& cat_call, std::vector& data); + + struct injected_task { + category& cat; + std::string command; + std::string remote; + std::function callback; + }; + + /// Injects a external callback to be handled by a worker; this is the proxy side of + /// inject_task(). + void proxy_inject_task(injected_task task); + + + /// Set of active service nodes. + pubkey_set active_service_nodes; + + /// Resets or updates the stored set of active SN pubkeys + void proxy_set_active_sns(std::string_view data); + void proxy_set_active_sns(pubkey_set pubkeys); + void proxy_update_active_sns(bt_list_consumer data); + void proxy_update_active_sns(pubkey_set added, pubkey_set removed); + void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed); + + /// Details for a pending command; such a command already has authenticated access and is just + /// waiting for a thread to become available to handle it. This also gets used (via the + /// `callback` variant) for injected external jobs to be able to integrate some external + /// interface with the oxenmq job queue. + struct pending_command { + category& cat; + std::string command; + std::vector data_parts; + std::variant< + const std::pair*, // Normal command callback + std::function // Injected external callback + > callback; + ConnectionID conn; + Access access; + std::string remote; + + // Normal ctor for an actual lmq command being processed + pending_command(category& cat, std::string command, std::vector data_parts, + const std::pair* callback, ConnectionID conn, Access access, std::string remote) + : cat{cat}, command{std::move(command)}, data_parts{std::move(data_parts)}, + callback{callback}, conn{std::move(conn)}, access{std::move(access)}, remote{std::move(remote)} {} + + // Ctor for an injected external command. + pending_command(category& cat, std::string command, std::function callback, std::string remote) + : cat{cat}, command{std::move(command)}, callback{std::move(callback)}, remote{std::move(remote)} {} + }; + std::list pending_commands; + + + /// End of proxy-specific members + /////////////////////////////////////////////////////////////////////////////////// + + + /// Structure that contains the data for a worker thread - both the thread itself, plus any + /// transient data we are passing into the thread. + struct run_info { + bool is_batch_job = false; + bool is_reply_job = false; + bool is_tagged_thread_job = false; + bool is_injected = false; + + // resets the job type bools, above. + void reset() { is_batch_job = is_reply_job = is_tagged_thread_job = is_injected = false; } + + // If is_batch_job is false then these will be set appropriate (if is_batch_job is true then + // these shouldn't be accessed and likely contain stale data). Note that if the command is + // an external, injected command then conn, access, conn_route, and data_parts will be + // empty/default constructed. + category *cat; + std::string command; + ConnectionID conn; // The connection (or SN pubkey) to reply on/to. + Access access; // The access level of the invoker (actual level, can be higher than the command's requirement) + std::string remote; // The remote address from which we received the request. + std::string conn_route; // if non-empty this is the reply routing prefix (for incoming connections) + std::vector data_parts; + + // If is_batch_job true then these are set (if is_batch_job false then don't access these!): + int batch_jobno; // >= 0 for a job, -1 for the completion job + + // The callback or batch job to run. The first of these is for regular tasks, the second + // for batch jobs, the third for injected external tasks. + std::variant< + const std::pair*, + detail::Batch*, + std::function + > to_run; + + // These belong to the proxy thread and must not be accessed by a worker: + std::thread worker_thread; + size_t worker_id; // The index in `workers` (0-n) or index+1 in `tagged_workers` (1-n) + std::string worker_routing_id; // "w123" where 123 == worker_id; "n123" for tagged threads. + + /// Loads the run info with an incoming command + run_info& load(category* cat, std::string command, ConnectionID conn, Access access, std::string remote, + std::vector data_parts, const std::pair* callback); + + /// Loads the run info with an injected external command + run_info& load(category* cat, std::string command, std::string remote, std::function callback); + + /// Loads the run info with a stored pending command + run_info& load(pending_command&& pending); + + /// Loads the run info with a batch job + run_info& load(batch_job&& bj, bool reply_job = false, int tagged_thread = 0); + }; + /// Data passed to workers for the RUN command. The proxy thread sets elements in this before + /// sending RUN to a worker then the worker uses it to get call info, and only allocates it + /// once, before starting any workers. Workers may only access their own index and may not + /// change it. + std::vector workers; + + /// Workers that are reserved for tagged thread tasks (as created with add_tagged_thread). The + /// queue here is similar to worker_jobs, but contains only the tagged thread's jobs. The bool + /// is whether the worker is currently busy (true) or available (false). + std::vector>> tagged_workers; + +public: + /** + * OxenMQ constructor. This constructs the object but does not start it; you will typically + * want to first add categories and commands, then finish startup by invoking `start()`. + * (Categories and commands cannot be added after startup). + * + * @param pubkey the public key (32-byte binary string). For a service node this is the service + * node x25519 keypair. For non-service nodes this (and privkey) can be empty strings to + * automatically generate an ephemeral keypair. + * + * @param privkey the service node's private key (32-byte binary string), or empty to generate + * one. + * + * @param service_node - true if this instance should be considered a service node for the + * purpose of allowing "Access::local_sn" remote calls. (This should be true if we are + * *capable* of being a service node, whether or not we are currently actively). If specified + * as true then the pubkey and privkey values must not be empty. + * + * @param sn_lookup function that takes a pubkey key (32-byte binary string) and returns a + * connection string such as "tcp://1.2.3.4:23456" to which a connection should be established + * to reach that service node. Note that this function is only called if there is no existing + * connection to that service node, and that the function is never called for a connection to + * self (that uses an internal connection instead). Also note that the service node must be + * listening in curve25519 mode (otherwise we couldn't verify its authenticity). Should return + * empty for not found or if SN lookups are not supported. + * + * @param allow_incoming is a callback that OxenMQ can use to determine whether an incoming + * connection should be allowed at all and, if so, whether the connection is from a known + * service node. Called with the connecting IP, the remote's verified x25519 pubkey, and the + * called on incoming connections with the (verified) incoming connection + * pubkey (32-byte binary string) to determine whether the given SN should be allowed to + * connect. + * + * @param log a function or callable object that writes a log message. If omitted then all log + * messages are suppressed. + * + * @param level the initial log level; defaults to warn. The log level can be changed later by + * calling log_level(...). + */ + OxenMQ( std::string pubkey, + std::string privkey, + bool service_node, + SNRemoteAddress sn_lookup, + Logger logger = [](LogLevel, const char*, int, std::string) { }, + LogLevel level = LogLevel::warn); + + /** + * Simplified OxenMQ constructor for a non-listening client or simple listener without any + * outgoing SN connection lookup capabilities. The OxenMQ object will not be able to establish + * new connections (including reconnections) to service nodes by pubkey. + */ + explicit OxenMQ( + Logger logger = [](LogLevel, const char*, int, std::string) { }, + LogLevel level = LogLevel::warn) + : OxenMQ("", "", false, [](auto) { return ""s; /*no peer lookups*/ }, std::move(logger), level) {} + + /** + * Destructor; instructs the proxy to quit. The proxy tells all workers to quit, waits for them + * to quit and rejoins the threads then quits itself. The outer thread (where the destructor is + * running) rejoins the proxy thread. + */ + ~OxenMQ(); + + /// Sets the log level of the OxenMQ object. + void log_level(LogLevel level); + + /// Gets the log level of the OxenMQ object. + LogLevel log_level() const; + + /** + * Add a new command category. This method may not be invoked after `start()` has been called. + * This method is also not thread safe, and is generally intended to be called (along with + * add_command) immediately after construction and immediately before calling start(). + * + * @param name - the category name which must consist of one or more characters and may not + * contain a ".". + * + * @param access_level the access requirements for remote invocation of the commands inside this + * category. + * + * @param reserved_threads if non-zero then the worker thread pool will ensure there are at at + * least this many threads either current processing or available to process commands in this + * category. This is used to ensure that a category's commands can be invoked even if + * long-running commands in some other category are currently using all worker threads. This + * can increase the number of worker threads above the `general_workers` parameter given in the + * constructor, but will only do so if the need arised: that is, if a command request arrives + * for a category when all workers are busy and no worker is currently processing any command in + * that category. + * + * @param max_queue is the maximum number of incoming messages in this category that we will + * queue up when waiting for a worker to become available for this category. Once the queue for + * a category exceeds this many incoming messages then new messages will be dropped until some + * messages are processed off the queue. -1 means unlimited, 0 means we will never queue (which + * means just dropping messages for this category if no workers are available to instantly + * handle the request). + * + * @returns a CatHelper object that makes adding commands slightly less verbose (see the + * CatHelper describe, above). + */ + CatHelper add_category(std::string name, Access access_level, unsigned int reserved_threads = 0, int max_queue = 200); + + /** + * Adds a new command to an existing category. This method may not be invoked after `start()` + * has been called. + * + * @param category - the category name (must already be created by a call to `add_category`) + * + * @param name - the command name, without the `category.` prefix. + * + * @param callback - a callable object which is callable as `callback(zeromq::Message &)` + */ + void add_command(const std::string& category, std::string name, CommandCallback callback); + + /** + * Adds a new "request" command to an existing category. These commands are just like normal + * commands, but are expected to call `msg.send_reply()` with any data parts on every request, + * while normal commands are more general. + * + * Parameters given here are identical to `add_command()`. + */ + void add_request_command(const std::string& category, std::string name, CommandCallback callback); + + /** + * Adds a command alias; this is intended for temporary backwards compatibility: if any aliases + * are defined then every command (not just aliased ones) has to be checked on invocation to see + * if it is defined in the alias list. May not be invoked after `start()`. + * + * Aliases should follow the `category.command` format for both the from and to names, and + * should only be called for `to` categories that are already defined. The category name is not + * currently enforced on the `from` name (for backwards compatility with Oxen's quorumnet code) + * but will be at some point. + * + * Access permissions for an aliased command depend only on the mapped-to value; for example, if + * `cat.meow` is aliased to `dog.bark` then it is the access permissions on `dog` that apply, + * not those of `cat`, even if `cat` is more restrictive than `dog`. + */ + void add_command_alias(std::string from, std::string to); + + /** Creates a "tagged thread" and starts it immediately. A tagged thread is one that batches, + * jobs, and timer jobs can be sent to specifically, typically to perform coordination of some + * thread-unsafe work. + * + * Tagged threads will *only* process jobs sent specifically to them; they do not participate in + * the thread pool used for regular jobs. Each tagged thread also has its own job queue + * completely separate from any other jobs. + * + * Tagged threads must be created *before* `start()` is called. The name will be used to set the + * thread name in the process table (if supported on the OS). + * + * \param name - the name of the thread; will be used in log messages and (if supported by the + * OS) as the system thread name. + * + * \param start - an optional callback to invoke from the thread as soon as OxenMQ itself starts + * up (i.e. after a call to `start()`). + * + * \returns a TaggedThreadID object that can be passed to job(), batch(), or add_timer() to + * direct the task to the tagged thread. + */ + TaggedThreadID add_tagged_thread(std::string name, std::function start = nullptr); + + /** + * Sets the number of worker threads reserved for batch jobs. If not explicitly called then + * this defaults to half the general worker threads configured (rounded up). This works exactly + * like reserved_threads for a category, but allows to batch jobs. See category for details. + * + * Note that some internal jobs are counted as batch jobs: in particular timers added via + * add_timer() are scheduled as batch jobs. + * + * Cannot be called after start()ing the OxenMQ instance. + */ + void set_batch_threads(int threads); + + /** + * Sets the number of worker threads reserved for handling replies from servers; this is + * mostly for responses to `request()` calls, but also gets used for other network-related + * events such as the ConnectSuccess/ConnectFailure callbacks for establishing remote non-SN + * connections. + * + * Defaults to one-eighth of the number of configured general threads, rounded up. + * + * Cannot be changed after start()ing the OxenMQ instance. + */ + void set_reply_threads(int threads); + + /** + * Sets the number of general worker threads. This is the target number of threads to run that + * we generally try not to exceed. These threads can be used for any command, and will be + * created (up to the limit) on demand. Note that individual categories (or batch jobs) with + * reserved threads can create threads in addition to the amount specified here if necessary to + * fulfill the reserved threads count for the category. + * + * Adjusting this also adjusts the default values of batch and reply threads, above. + * + * Defaults to `std::thread::hardware_concurrency()`. + * + * Cannot be called after start()ing the OxenMQ instance. + */ + void set_general_threads(int threads); + + /** + * Finish starting up: binds to the bind locations given in the constructor and launches the + * proxy thread to handle message dispatching between remote nodes and worker threads. + * + * Things you want to do before calling this: + * - Use `add_category`/`add_command` to set up any commands remote connections can invoke. + * - If any commands require SN authentication, specify a list of currently active service node + * pubkeys via `set_active_sns()` (and make sure this gets updated when things change by + * another `set_active_sns()` or a `update_active_sns()` call). It *is* possible to make the + * initial call after calling `start()`, but that creates a window during which incoming + * remote SN connections will be erroneously treated as non-SN connections. + * - If this LMQ instance should accept incoming connections, set up any listening ports via + * `listen_curve()` and/or `listen_plain()`. + */ + void start(); + + /** Start listening on the given bind address using curve authentication/encryption. Incoming + * connections will only be allowed from clients that already have the server's pubkey, and + * will be encrypted. `allow_connection` is invoked for any incoming connections on this + * address to determine the incoming remote's access and authentication level. + * + * @param bind address - can be any string zmq supports; typically a tcp IP/port combination + * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". + * + * @param allow_connection function to call to determine whether to allow the connection and, if + * so, the authentication level it receives. If omitted the default returns AuthLevel::none + * access. + */ + void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); + + /** Start listening on the given bind address in unauthenticated plain text mode. Incoming + * connections can come from anywhere. `allow_connection` is invoked for any incoming + * connections on this address to determine the incoming remote's access and authentication + * level. Note that `allow_connection` here will be called with an empty pubkey. + * + * @param bind address - can be any string zmq supports; typically a tcp IP/port combination + * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". + * + * @param allow_connection function to call to determine whether to allow the connection and, if + * so, the authentication level it receives. If omitted the default returns AuthLevel::none + * access. + */ + void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); + + /** + * Try to initiate a connection to the given SN in anticipation of needing a connection in the + * future. If a connection is already established, the connection's idle timer will be reset + * (so that the connection will not be closed too soon). If the given idle timeout is greater + * than the current idle timeout then the timeout increases to the new value; if less than the + * current timeout it is ignored. (Note that idle timeouts only apply if the existing + * connection is an outgoing connection). + * + * Note that this method (along with send) doesn't block waiting for a connection; it merely + * instructs the proxy thread that it should establish a connection. + * + * @param pubkey - the public key (32-byte binary string) of the service node to connect to + * @param keep_alive - the connection will be kept alive if there was valid activity within + * the past `keep_alive` milliseconds. If an outgoing connection already + * exists, the longer of the existing and the given keep alive is used. + * (Note that the default applied here is much longer than the default for an + * implicit connect() by calling send() directly.) + * @param hint - if non-empty and a new outgoing connection needs to be made this hint value + * may be used instead of calling the lookup function. (Note that there is no + * guarantee that the hint will be used; it is only usefully specified if the + * connection address has already been incidentally determined). + * + * @returns a ConnectionID that identifies an connection with the given SN. Typically you + * *don't* need to worry about this (and can just discard it): you can always simply pass the + * pubkey as a string wherever a ConnectionID is called. + */ + ConnectionID connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive = 5min, std::string_view hint = {}); + + /** + * Establish a connection to the given remote with callbacks invoked on a successful or failed + * connection. Returns a ConnectionID associated with the connection being attempted. It is + * possible to send to the remote before the successful callback is invoked, but there is no + * guarantee that the messages will be delivered (e.g. if the connection ultimately fails). + * + * For connections to a service node you generally want connect_sn() instead (which verifies + * that it is talking to the SN and encrypts the connection). + * + * Unlike `connect_sn`, the connection established here will be kept open indefinitely (until + * you call disconnect). + * + * The `on_connect` and `on_failure` callbacks are invoked when a connection has been + * established or failed to establish. + * + * @param remote the remote connection address either as implicitly from a string or as a full + * oxenmq::address object; see address.h for details. This specifies both the connection + * address and whether curve encryption should be used. + * @param on_connect called with the identifier after the connection has been established. + * @param on_failure called with the identifier and failure message if we fail to connect. + * @param auth_level determines the authentication level of the remote for issuing commands to + * us. The default is `AuthLevel::none`. + * @param timeout how long to try before aborting the connection attempt and calling the + * on_failure callback. Note that the connection can fail for various reasons before the + * timeout. + * + * @param returns ConnectionID that uniquely identifies the connection to this remote node. In + * order to talk to it you will need the returned value (or a copy of it). + */ + ConnectionID connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure, + AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT); + + /// Same as the above, but takes the address as a string_view and constructs an `address` from + /// it. + ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, + AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT) { + return connect_remote(address{remote}, std::move(on_connect), std::move(on_failure), auth_level, timeout); + } + + /// Deprecated version of the above that takes the remote address and remote pubkey for curve + /// encryption as separate arguments. New code should either use a pubkey-embedded address + /// string, or specify remote address and pubkey with an `address` object such as: + /// connect_remote(address{remote, pubkey}, ...) + [[deprecated("use connect_remote() with a oxenmq::address instead")]] + ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, + std::string_view pubkey, + AuthLevel auth_level = AuthLevel::none, + std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT); + + /** + * Disconnects an established outgoing connection established with `connect_remote()` (or, less + * commonly, `connect_sn()`). + * + * @param id the connection id, as returned by `connect_remote()` or the SN pubkey. + * + * @param linger how long to allow the connection to linger while there are still pending + * outbound messages to it before disconnecting and dropping any pending messages. (Note that + * this lingering is internal; the disconnect_remote() call does not block). The default is 1 + * second. + * + * If given a pubkey, we try to close an outgoing connection to the given SN if one exists; note + * however that this is often not particularly useful as messages to that SN can immediately + * reopen the connection. + */ + void disconnect(ConnectionID id, std::chrono::milliseconds linger = 1s); + + /** + * Queue a message to be relayed to the given service node or remote without requiring a reply. + * OxenMQ will attempt to relay the message (first connecting and handshaking to the remote SN + * if not already connected). + * + * If a new connection is established it will have a relatively short (30s) idle timeout. If + * the connection should stay open longer you should either call `connect(pubkey, IDLETIME)` or + * pass a a `send_option::keep_alive{IDLETIME}` in `opts`. + * + * Note that this method (along with connect) doesn't block waiting for a connection or for the + * message to send; it merely instructs the proxy thread that it should send. ZMQ will + * generally try hard to deliver it (reconnecting if the connection fails), but if the + * connection fails persistently the message will eventually be dropped. + * + * @param remote - either a ConnectionID value returned by connect_remote, or a service node + * pubkey string. In the latter case, sending the message may trigger a new + * connection being established to the service node (i.e. you do not have to + * call connect() first). + * @param cmd - the first data frame value which is almost always the remote "category.command" name + * @param opts - any number of std::string (or string_views) and send options. Each send option + * affects how the send works; each string becomes a message part. + * + * Example: + * + * // Send to a SN, connecting to it if we aren't already connected: + * lmq.send(pubkey, "hello.world", "abc", send_option::hint("tcp://localhost:1234"), "def"); + * + * // Start connecting to a remote and immediately queue a message for it + * auto conn = lmq.connect_remote("tcp://127.0.0.1:1234", + * [](ConnectionID) { std::cout << "connected\n"; }, + * [](ConnectionID, string_view why) { std::cout << "connection failed: " << why << \n"; }); + * lmq.send(conn, "hello.world", "abc", "def"); + * + * Both of these send the command `hello.world` to the given pubkey, containing additional + * message parts "abc" and "def". In the first case, if not currently connected, the given + * connection hint may be used rather than performing a connection address lookup on the pubkey. + */ + template + void send(ConnectionID to, std::string_view cmd, const T&... opts); + + /** Send a command configured as a "REQUEST" command to a service node: the data parts will be + * prefixed with a random identifier. The remote is expected to reply with a ["REPLY", + * , ...] message, at which point we invoke the given callback with any [...] parts + * of the reply. + * + * Like `send()`, a new connection to the service node will be established if not already + * connected. + * + * @param to - the pubkey string or ConnectionID to send this request to + * @param cmd - the command name + * @param callback - the callback to invoke when we get a reply. Called with a true value and + * the data strings when a reply is received, or false with error string(s) indicating the + * failure reason upon failure or timeout. + * @param opts - anything else (i.e. strings, send_options) is forwarded to send(). + * + * Possible error data values: + * - ["TIMEOUT"] - we got no reply within the timeout window + * - ["UNKNOWNCOMMAND"] - the remote did not recognize the given request command + * - ["NO_REPLY_TAG"] - the invoked command is a request command but no reply tag was included + * - ["FORBIDDEN"] - the command requires an authorization level (e.g. Basic or Admin) that we + * do not have. + * - ["FORBIDDEN_SN"] - the command requires service node authentication, but the remote did not + * recognize us as a service node. You *may* want to retry the request a limited number of + * times (but do not retry indefinitely as that can be an infinite loop!) because this is + * typically also followed by a disconnection; a retried message would reconnect and + * reauthenticate which *may* result in picking up the SN authentication. + * - ["NOT_A_SERVICE_NODE"] - this command is only invokable on service nodes, and the remote is + * not running as a service node. + */ + template + void request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T&... opts); + + /** Injects an external task into the oxenmq command queue. This is used to allow connecting + * non-OxenMQ requests into the OxenMQ thread pool as if they were ordinary requests, to be + * scheduled as commands of an individual category. For example, you might support rpc requests + * via OxenMQ as `rpc.some_command` and *also* accept them over HTTP. Using `inject_task()` + * allows you to handle processing the request in the same thread pool with the same priority as + * `rpc.*` commands. + * + * @param category - the category name that should handle the request for the purposes of + * scheduling the job. The category must have been added using add_category(). The category + * can be an actual category with added commands, in which case the injected tasks are queued + * along with LMQ requests for that category, or can have no commands to set up a distinct + * category for the injected jobs. + * + * @param command - a command name; this is mainly used for debugging and does not need to + * actually exist (and, in fact, is often less confusing if it does not). It is recommended for + * clarity purposes to use something that doesn't look like a typical command, for example + * "(http)". + * + * @param remote - some free-form identifier of the remote connection. For example, this could + * be a remote IP address. Can be blank if there is nothing suitable. + * + * @param callback - the function to call from a worker thread when the injected task is + * processed. Takes no arguments. + */ + void inject_task(const std::string& category, std::string command, std::string remote, std::function callback); + + /// The key pair this OxenMQ was created with; if empty keys were given during construction then + /// this returns the generated keys. + const std::string& get_pubkey() const { return pubkey; } + const std::string& get_privkey() const { return privkey; } + + /** Updates (or initially sets) OxenMQ's list of service node pubkeys with the given list. + * + * This has two main effects: + * + * - All commands processed after the update will have SN status determined by the new list. + * - All outgoing connections to service nodes that are no longer on the list will be closed. + * This includes both explicit connections (established by `connect_sn()`) and implicit ones + * (established by sending to a SN that wasn't connected). + * + * As this update is potentially quite heavy it is recommended that this be called only when + * necessary--i.e. when the list has changed (or potentially changed), but *not* on a short + * periodic timer. + * + * This method may (and should!) be called before start() to load an initial set of SNs. + * + * Once a full list has been set, updates on changes can either call this again with the new + * list, or use the more efficient update_active_sns() call if incremental results are + * available. + */ + void set_active_sns(pubkey_set pubkeys); + + /** Updates the list of active pubkeys by adding or removing the given pubkeys from the existing + * list. This is more efficient when the incremental information is already available; if it + * isn't, simply call set_active_sns with a new list to have OxenMQ figure out what was added or + * removed. + * + * \param added new pubkeys that were added since the last set_active_sns or update_active_sns + * call. + * + * \param removed pubkeys that were removed from active SN status since the last call. If a + * pubkey is in both `added` and `removed` for some reason then its presence in `removed` will + * be ignored. + */ + void update_active_sns(pubkey_set added, pubkey_set removed); + + /** + * Batches a set of jobs to be executed by workers, optionally followed by a completion function. + * + * Must include oxenmq/batch.h to use. + */ + template + void batch(Batch&& batch); + + /** + * Queues a single job to be executed with no return value. This is a shortcut for creating and + * submitting a single-job, no-completion-function batch job. + * + * \param f the callback to invoke + * \param thread an optional tagged thread in which this job should run. You may *not* pass the + * proxy thread here. + */ + void job(std::function f, std::optional = std::nullopt); + + /** + * Adds a timer that gets scheduled periodically in the job queue. Normally jobs are not + * double-booked: that is, a new timed job will not be scheduled if the timer fires before a + * previously scheduled callback of the job has not yet completed. If you want to override this + * (so that, under heavy load or long jobs, there can be more than one of the same job scheduled + * or running at a time) then specify `squelch` as `false`. + * + * \param thread specifies a thread (added with add_tagged_thread()) on which this timer must run. + */ + void add_timer(std::function job, std::chrono::milliseconds interval, bool squelch = true, std::optional = std::nullopt); +}; + +/// Helper class that slightly simplifies adding commands to a category. +/// +/// This allows simplifying: +/// +/// lmq.add_category("foo", ...); +/// lmq.add_command("foo", "a", ...); +/// lmq.add_command("foo", "b", ...); +/// lmq.add_request_command("foo", "c", ...); +/// +/// to: +/// +/// lmq.add_category("foo", ...) +/// .add_command("a", ...) +/// .add_command("b", ...) +/// .add_request_command("b", ...) +/// ; +class CatHelper { + OxenMQ& lmq; + std::string cat; + +public: + CatHelper(OxenMQ& lmq, std::string cat) : lmq{lmq}, cat{std::move(cat)} {} + + CatHelper& add_command(std::string name, OxenMQ::CommandCallback callback) { + lmq.add_command(cat, std::move(name), std::move(callback)); + return *this; + } + + CatHelper& add_request_command(std::string name, OxenMQ::CommandCallback callback) { + lmq.add_request_command(cat, std::move(name), std::move(callback)); + return *this; + } +}; + + +/// Namespace for options to the send() method +namespace send_option { + +template +struct data_parts_impl { + InputIt begin, end; + data_parts_impl(InputIt begin, InputIt end) : begin{std::move(begin)}, end{std::move(end)} {} +}; + +/// Specifies an iterator pair of data parts to send, for when the number of arguments to send() +/// cannot be determined at compile time. The iterator pair must be over strings or string_view (or +/// something convertible to a string_view). +template ()), std::string_view>>> +data_parts_impl data_parts(InputIt begin, InputIt end) { return {std::move(begin), std::move(end)}; } + +/// Specifies a connection hint when passed in to send(). If there is no current connection to the +/// peer then the hint is used to save a call to the SNRemoteAddress to get the connection location. +/// (Note that there is no guarantee that the given hint will be used or that a SNRemoteAddress call +/// will not also be done.) +struct hint { + std::string connect_hint; + // Constructor taking a hint. If the hint is an empty string then no hint will be used. + explicit hint(std::string connect_hint) : connect_hint{std::move(connect_hint)} {} +}; + +/// Does a send() if we already have a connection (incoming or outgoing) with the given peer, +/// otherwise drops the message. +struct optional { + bool is_optional = true; + // Constructor; default construction gives you an optional, but the bool parameter can be + // specified as false to explicitly make a connection non-optional instead. + explicit optional(bool opt = true) : is_optional{opt} {} +}; + +/// Specifies that the message should be sent only if it can be sent on an existing incoming socket, +/// and dropped otherwise. +struct incoming { + bool is_incoming = true; + // Constructor; default construction gives you an incoming-only, but the bool parameter can be + // specified as false to explicitly disable incoming-only behaviour. + explicit incoming(bool inc = true) : is_incoming{inc} {} +}; + +/// Specifies that the message must use an outgoing connection; for messages to a service node the +/// message will be delivered over an existing outgoing connection, if one is established, and a new +/// outgoing connection opened to deliver the message if none is currently established. For non-SN +/// messages, the message will simply be dropped if it is attempting to be sent on an incoming +/// socket, and send otherwise on an outgoing socket (this option is primarily aimed at SN +/// messages). +struct outgoing { + bool is_outgoing = true; + // Constructor; default construction gives you an outgoing-only, but the bool parameter can be + // specified as false to explicitly disable the outgoing-only flag. + explicit outgoing(bool out = true) : is_outgoing{out} {} +}; + +/// Specifies the idle timeout for the connection - if a new or existing outgoing connection is used +/// for the send and its current idle timeout setting is less than this value then it is updated. +struct keep_alive { + std::chrono::milliseconds time; + explicit keep_alive(std::chrono::milliseconds time) : time{std::move(time)} {} +}; + +/// Specifies the amount of time to wait before triggering a failure callback for a request. If a +/// request reply arrives *after* the failure timeout has been triggered then it will be dropped. +/// (This has no effect if specified on a non-request() call). Note that requests failures are only +/// processed in the CONN_CHECK_INTERVAL timer, so it can be up to that much longer than the time +/// specified here before a failure callback is invoked. +struct request_timeout { + std::chrono::milliseconds time; + explicit request_timeout(std::chrono::milliseconds time) : time{std::move(time)} {} +}; + +/// Specifies a callback to invoke if the message couldn't be queued for delivery. There are +/// generally two failure modes here: a full queue, and a send exception. This callback is invoked +/// for both; to only catch full queues see `queue_full` instead. +/// +/// A full queue means there are too many messages queued for delivery already that haven't been +/// delivered yet (i.e. because the remote is slow); this error is potentially recoverable if the +/// remote end wakes up and receives/acknoledges its messages. +/// +/// A send exception is not recoverable: it indicates some failure such as the remote having +/// disconnected or an internal send error. +/// +/// This callback can be used by a caller to log, attempt to resend, or take other appropriate +/// action. +/// +/// Note that this callback is *not* exhaustive for all possible send failures: there are failure +/// cases (such as when a message is queued but the connection fails before delivery) that do not +/// trigger this failure at all; rather this callback only signals an immediate queuing failure. +struct queue_failure { + using callback_t = std::function; + /// Callback; invoked with nullptr for a queue full failure, otherwise will be set to a copy of + /// the raised exception. + callback_t callback; +}; + +/// This is similar to queue_failure_callback, but is only invoked on a (potentially recoverable) +/// full queue failure. Send failures are simply dropped. +struct queue_full { + using callback_t = std::function; + callback_t callback; +}; + +} + +namespace detail { + +/// Takes an rvalue reference, moves it into a new instance then returns a uintptr_t value +/// containing the pointer to be serialized to pass (via oxenmq queues) from one thread to another. +/// Must be matched with a deserializer_pointer on the other side to reconstitute the object and +/// destroy the intermediate pointer. +template +uintptr_t serialize_object(T&& obj) { + static_assert(std::is_rvalue_reference::value, "serialize_object must be given an rvalue reference"); + auto* ptr = new T{std::forward(obj)}; + return reinterpret_cast(ptr); +} + +/// Takes a uintptr_t as produced by serialize_pointer and the type, converts the serialized value +/// back into a pointer, moves it into a new instance (to be returned) and destroys the +/// intermediate. +template T deserialize_object(uintptr_t ptrval) { + auto* ptr = reinterpret_cast(ptrval); + T ret{std::move(*ptr)}; + delete ptr; + return ret; +} + +// Sends a control message to the given socket consisting of the command plus optional dict +// data (only sent if the data is non-empty). +void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data = {}); + +/// Base case: takes a string-like value and appends it to the message parts +inline void apply_send_option(bt_list& parts, bt_dict&, std::string_view arg) { + parts.emplace_back(arg); +} + +/// `data_parts` specialization: appends a range of serialized data parts to the parts to send +template +void apply_send_option(bt_list& parts, bt_dict&, const send_option::data_parts_impl data) { + for (auto it = data.begin; it != data.end; ++it) + parts.emplace_back(*it); +} + +/// `hint` specialization: sets the hint in the control data +inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::hint& hint) { + control_data["hint"] = hint.connect_hint; +} + +/// `optional` specialization: sets the optional flag in the control data +inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::optional& o) { + control_data["optional"] = o.is_optional; +} + +/// `incoming` specialization: sets the incoming-only flag in the control data +inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::incoming& i) { + control_data["incoming"] = i.is_incoming; +} + +/// `outgoing` specialization: sets the outgoing-only flag in the control data +inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::outgoing& o) { + control_data["outgoing"] = o.is_outgoing; +} + +/// `keep_alive` specialization: increases the outgoing socket idle timeout (if shorter) +inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::keep_alive& timeout) { + control_data["keep_alive"] = timeout.time.count(); +} + +/// `request_timeout` specialization: set the timeout time for a request +inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option::request_timeout& timeout) { + control_data["request_timeout"] = timeout.time.count(); +} + +/// `queue_failure` specialization +inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queue_failure f) { + control_data["send_fail"] = serialize_object(std::move(f.callback)); +} + +/// `queue_full` specialization +inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queue_full f) { + control_data["send_full_q"] = serialize_object(std::move(f.callback)); +} + +/// Extracts a pubkey and auth level from a zmq message received on a *listening* socket. +std::pair extract_metadata(zmq::message_t& msg); + +template +bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts) { + bt_dict control_data; + bt_list parts{{cmd}}; + (detail::apply_send_option(parts, control_data, std::forward(opts)),...); + + if (to.sn()) + control_data["conn_pubkey"] = std::move(to.pk); + else { + control_data["conn_id"] = to.id; + control_data["conn_route"] = std::move(to.route); + } + control_data["send"] = std::move(parts); + return control_data; + +} + +} // namespace detail + + +template +void OxenMQ::send(ConnectionID to, std::string_view cmd, const T&... opts) { + detail::send_control(get_control_socket(), "SEND", + bt_serialize(detail::build_send(std::move(to), cmd, opts...))); +} + +std::string make_random_string(size_t size); + +template +void OxenMQ::request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T &...opts) { + const auto reply_tag = make_random_string(15); // 15 random bytes is lots and should keep us in most stl implementations' small string optimization + bt_dict control_data = detail::build_send(std::move(to), cmd, reply_tag, opts...); + control_data["request"] = true; + control_data["request_callback"] = detail::serialize_object(std::move(callback)); + control_data["request_tag"] = std::string_view{reply_tag}; + detail::send_control(get_control_socket(), "SEND", bt_serialize(std::move(control_data))); +} + +template +void Message::send_back(std::string_view command, Args&&... args) { + oxenmq.send(conn, command, send_option::optional{!conn.sn()}, std::forward(args)...); +} + +template +void Message::send_reply(Args&&... args) { + assert(!reply_tag.empty()); + oxenmq.send(conn, "REPLY", reply_tag, send_option::optional{!conn.sn()}, std::forward(args)...); +} + +template +void Message::send_request(std::string_view cmd, Callback&& callback, Args&&... args) { + oxenmq.request(conn, cmd, std::forward(callback), + send_option::optional{!conn.sn()}, std::forward(args)...); +} + +// When log messages are invoked we strip out anything before this in the filename: +constexpr std::string_view LOG_PREFIX{"oxenmq/", 7}; +inline std::string_view trim_log_filename(std::string_view local_file) { + auto chop = local_file.rfind(LOG_PREFIX); + if (chop != local_file.npos) + local_file.remove_prefix(chop); + return local_file; +} + +template +void OxenMQ::log(LogLevel lvl, const char* file, int line, const T&... stuff) { + if (log_level() < lvl) + return; + + std::ostringstream os; + (os << ... << stuff); + logger(lvl, trim_log_filename(file).data(), line, os.str()); +} + +std::ostream &operator<<(std::ostream &os, LogLevel lvl); + +} // namespace oxenmq + +// vim:sw=4:et diff --git a/lokimq/proxy.cpp b/oxenmq/proxy.cpp similarity index 97% rename from lokimq/proxy.cpp rename to oxenmq/proxy.cpp index bd7978c..f1a7598 100644 --- a/lokimq/proxy.cpp +++ b/oxenmq/proxy.cpp @@ -1,5 +1,5 @@ -#include "lokimq.h" -#include "lokimq-internal.h" +#include "oxenmq.h" +#include "oxenmq-internal.h" #include "hex.h" #if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) @@ -17,9 +17,9 @@ extern "C" { } #endif -namespace lokimq { +namespace oxenmq { -void LokiMQ::proxy_quit() { +void OxenMQ::proxy_quit() { LMQ_LOG(debug, "Received quit command, shutting down proxy thread"); assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); })); @@ -41,7 +41,7 @@ void LokiMQ::proxy_quit() { LMQ_LOG(debug, "Proxy thread teardown complete"); } -void LokiMQ::proxy_send(bt_dict_consumer data) { +void OxenMQ::proxy_send(bt_dict_consumer data) { // NB: bt_dict_consumer goes in alphabetical order std::string_view hint; std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE}; @@ -205,7 +205,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { } } -void LokiMQ::proxy_reply(bt_dict_consumer data) { +void OxenMQ::proxy_reply(bt_dict_consumer data) { bool have_conn_id = false; ConnectionID conn_id{0}; if (data.skip_until("conn_id")) { @@ -250,11 +250,11 @@ void LokiMQ::proxy_reply(bt_dict_consumer data) { } } -void LokiMQ::proxy_control_message(std::vector& parts) { +void OxenMQ::proxy_control_message(std::vector& parts) { // We throw an uncaught exception here because we only generate control messages internally in - // lokimq code: if one of these condition fail it's a lokimq bug. + // oxenmq code: if one of these condition fail it's a oxenmq bug. if (parts.size() < 2) - throw std::logic_error("LokiMQ bug: Expected 2-3 message parts for a proxy control message"); + throw std::logic_error("OxenMQ bug: Expected 2-3 message parts for a proxy control message"); auto route = view(parts[0]), cmd = view(parts[1]); LMQ_TRACE("control message: ", cmd); if (parts.size() == 3) { @@ -306,11 +306,11 @@ void LokiMQ::proxy_control_message(std::vector& parts) { return; } } - throw std::runtime_error("LokiMQ bug: Proxy received invalid control command: " + + throw std::runtime_error("OxenMQ bug: Proxy received invalid control command: " + std::string{cmd} + " (" + std::to_string(parts.size()) + ")"); } -void LokiMQ::proxy_loop() { +void OxenMQ::proxy_loop() { #if defined(__linux__) || defined(__sun) || defined(__MINGW32__) pthread_setname_np(pthread_self(), "lmq-proxy"); @@ -371,7 +371,7 @@ void LokiMQ::proxy_loop() { listener.set(zmq::sockopt::router_mandatory, true); listener.bind(bind[i].first); - LMQ_LOG(info, "LokiMQ listening on ", bind[i].first); + LMQ_LOG(info, "OxenMQ listening on ", bind[i].first); connections.push_back(std::move(listener)); auto conn_id = next_conn_id++; @@ -547,7 +547,7 @@ static bool is_error_response(std::string_view cmd) { // Return true if we recognized/handled the builtin command (even if we reject it for whatever // reason) -bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector& parts) { +bool OxenMQ::proxy_handle_builtin(size_t conn_index, std::vector& parts) { // Doubling as a bool and an offset: size_t incoming = connections[conn_index].get(zmq::sockopt::type) == ZMQ_ROUTER; @@ -644,7 +644,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector // pre-1.1.0 sent just a plain UNKNOWNCOMMAND (without the actual command); this was not // useful, but also this response is *expected* for things 1.0.5 didn't understand, like // FORBIDDEN_SN: so log it only at debug level and move on. - LMQ_LOG(debug, "Received plain UNKNOWNCOMMAND; remote is probably an older lokimq. Ignoring."); + LMQ_LOG(debug, "Received plain UNKNOWNCOMMAND; remote is probably an older oxenmq. Ignoring."); return true; } @@ -669,7 +669,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector return false; } -void LokiMQ::proxy_process_queue() { +void OxenMQ::proxy_process_queue() { if (max_workers == 0) // shutting down return; diff --git a/oxenmq/variant.h b/oxenmq/variant.h new file mode 100644 index 0000000..fb4c9fe --- /dev/null +++ b/oxenmq/variant.h @@ -0,0 +1,103 @@ +#pragma once +// Workarounds for macos compatibility. On macOS we aren't allowed to touch anything in +// std::variant that could throw if compiling with a target <10.14 because Apple fails hard at +// properly updating their STL. Thus, if compiling in such a mode, we have to introduce +// workarounds. +// +// This header defines a `var` namespace with `var::get` and `var::visit` implementations. On +// everything except broken backwards macos, this is just an alias to `std`. On broken backwards +// macos, we provide implementations that throw std::runtime_error in failure cases since the +// std::bad_variant_access exception can't be touched. +// +// You also get a BROKEN_APPLE_VARIANT macro defined if targetting a problematic mac architecture. + +#include + +#ifdef __APPLE__ +# include +# if defined(__APPLE__) && MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_10_14 +# define BROKEN_APPLE_VARIANT +# endif +#endif + +#ifndef BROKEN_APPLE_VARIANT + +namespace var = std; // Oh look, actual C++17 support + +#else + +// Oh look, apple. + +namespace var { + +// Apple won't let us use std::visit or std::get if targetting some version of macos earlier than +// 10.14 because Apple is awful about not updating their STL. So we have to provide our own, and +// then call these without `std::` -- on crappy macos we'll come here, on everything else we'll ADL +// to the std:: implementation. +template +constexpr T& get(std::variant& var) { + if (auto* v = std::get_if(&var)) return *v; + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} +template +constexpr const T& get(const std::variant& var) { + if (auto* v = std::get_if(&var)) return *v; + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} +template +constexpr const T&& get(const std::variant&& var) { + if (auto* v = std::get_if(&var)) return std::move(*v); + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} +template +constexpr T&& get(std::variant&& var) { + if (auto* v = std::get_if(&var)) return std::move(*v); + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} +template +constexpr auto& get(std::variant& var) { + if (auto* v = std::get_if(&var)) return *v; + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} +template +constexpr const auto& get(const std::variant& var) { + if (auto* v = std::get_if(&var)) return *v; + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} +template +constexpr const auto&& get(const std::variant&& var) { + if (auto* v = std::get_if(&var)) return std::move(*v); + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} +template +constexpr auto&& get(std::variant&& var) { + if (auto* v = std::get_if(&var)) return std::move(*v); + throw std::runtime_error{"Bad variant access -- variant does not contain the requested type"}; +} + +template +constexpr auto visit_helper(Visitor&& vis, Variant&& var) { + if (var.index() == I) + return vis(var::get(std::forward(var))); + else if constexpr (sizeof...(More) > 0) + return visit_helper(std::forward(vis), std::forward(var)); + else + throw std::runtime_error{"Bad visit -- variant is valueless"}; +} + +template +constexpr auto visit_helper(Visitor&& vis, Variant&& var, std::index_sequence) { + return visit_helper(std::forward(vis), std::forward(var)); +} + +// Only handle a single variant here because multi-variant invocation is notably harder (and we +// don't need it). +template +constexpr auto visit(Visitor&& vis, Variant&& var) { + return visit_helper(std::forward(vis), std::forward(var), + std::make_index_sequence>>{}); +} + +} // namespace var + +#endif diff --git a/oxenmq/version.h.in b/oxenmq/version.h.in new file mode 100644 index 0000000..a5d6aec --- /dev/null +++ b/oxenmq/version.h.in @@ -0,0 +1,5 @@ +namespace oxenmq { +constexpr int VERSION_MAJOR = @OXENMQ_VERSION_MAJOR@; +constexpr int VERSION_MINOR = @OXENMQ_VERSION_MINOR@; +constexpr int VERSION_PATCH = @OXENMQ_VERSION_PATCH@; +} diff --git a/lokimq/worker.cpp b/oxenmq/worker.cpp similarity index 96% rename from lokimq/worker.cpp rename to oxenmq/worker.cpp index 548fb4d..09685ea 100644 --- a/lokimq/worker.cpp +++ b/oxenmq/worker.cpp @@ -1,6 +1,6 @@ -#include "lokimq.h" +#include "oxenmq.h" #include "batch.h" -#include "lokimq-internal.h" +#include "oxenmq-internal.h" #if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) extern "C" { @@ -9,7 +9,7 @@ extern "C" { } #endif -namespace lokimq { +namespace oxenmq { namespace { @@ -17,7 +17,7 @@ namespace { // received. If "QUIT" was received, replies with "QUITTING" on the socket and closes it, then // returns false. [[gnu::always_inline]] inline -bool worker_wait_for(LokiMQ& lmq, zmq::socket_t& sock, std::vector& parts, const std::string_view worker_id, const std::string_view expect) { +bool worker_wait_for(OxenMQ& lmq, zmq::socket_t& sock, std::vector& parts, const std::string_view worker_id, const std::string_view expect) { while (true) { lmq.log(LogLevel::debug, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect); parts.clear(); @@ -46,7 +46,7 @@ bool worker_wait_for(LokiMQ& lmq, zmq::socket_t& sock, std::vector tagged, std::function start) { +void OxenMQ::worker_thread(unsigned int index, std::optional tagged, std::function start) { std::string routing_id = (tagged ? "t" : "w") + std::to_string(index); // for routing std::string_view worker_id{tagged ? *tagged : routing_id}; // for debug @@ -72,8 +72,8 @@ void LokiMQ::worker_thread(unsigned int index, std::optional tagged bool waiting_for_command; if (tagged) { - // If we're a tagged worker then we got started up before LokiMQ started, so we need to wait - // for an all-clear signal from LokiMQ first, then we fire our `start` callback, then we can + // If we're a tagged worker then we got started up before OxenMQ started, so we need to wait + // for an all-clear signal from OxenMQ first, then we fire our `start` callback, then we can // start waiting for commands in the main loop further down. (We also can't get the // reference to our `tagged_workers` element until the main proxy threads is running). @@ -159,7 +159,7 @@ void LokiMQ::worker_thread(unsigned int index, std::optional tagged } -LokiMQ::run_info& LokiMQ::get_idle_worker() { +OxenMQ::run_info& OxenMQ::get_idle_worker() { if (idle_workers.empty()) { size_t id = workers.size(); assert(workers.capacity() > id); @@ -174,7 +174,7 @@ LokiMQ::run_info& LokiMQ::get_idle_worker() { return workers[id]; } -void LokiMQ::proxy_worker_message(std::vector& parts) { +void OxenMQ::proxy_worker_message(std::vector& parts) { // Process messages sent by workers if (parts.size() != 2) { LMQ_LOG(error, "Received send invalid ", parts.size(), "-part message"); @@ -268,14 +268,14 @@ void LokiMQ::proxy_worker_message(std::vector& parts) { } } -void LokiMQ::proxy_run_worker(run_info& run) { +void OxenMQ::proxy_run_worker(run_info& run) { if (!run.worker_thread.joinable()) run.worker_thread = std::thread{[this, id=run.worker_id] { worker_thread(id); }}; else send_routed_message(workers_socket, run.worker_routing_id, "RUN"); } -void LokiMQ::proxy_to_worker(size_t conn_index, std::vector& parts) { +void OxenMQ::proxy_to_worker(size_t conn_index, std::vector& parts) { bool outgoing = connections[conn_index].get(zmq::sockopt::type) == ZMQ_DEALER; peer_info tmp_peer; @@ -377,7 +377,7 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector& par category.active_threads++; } -void LokiMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function callback) { +void OxenMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function callback) { if (!callback) return; auto it = categories.find(category); if (it == categories.end()) @@ -386,7 +386,7 @@ void LokiMQ::inject_task(const std::string& category, std::string command, std:: injected_task{it->second, std::move(command), std::move(remote), std::move(callback)}))); } -void LokiMQ::proxy_inject_task(injected_task task) { +void OxenMQ::proxy_inject_task(injected_task task) { auto& category = task.cat; if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) { // No free worker slot, queue for later diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2b377ce..96d5710 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -19,7 +19,7 @@ add_executable(tests ${LMQ_TEST_SRC}) find_package(Threads) -target_link_libraries(tests Catch2::Catch2 lokimq Threads::Threads) +target_link_libraries(tests Catch2::Catch2 oxenmq Threads::Threads) set_target_properties(tests PROPERTIES CXX_STANDARD 17 diff --git a/tests/common.h b/tests/common.h index 9a3ce2b..6dda8f6 100644 --- a/tests/common.h +++ b/tests/common.h @@ -1,8 +1,8 @@ #pragma once -#include "lokimq/lokimq.h" +#include "oxenmq/oxenmq.h" #include -using namespace lokimq; +using namespace oxenmq; static auto startup = std::chrono::steady_clock::now(); @@ -41,7 +41,7 @@ inline std::unique_lock catch_lock() { return std::unique_lock{mutex}; } -inline LokiMQ::Logger get_logger(std::string prefix = "") { +inline OxenMQ::Logger get_logger(std::string prefix = "") { std::string me = "tests/common.h"; std::string strip = __FILE__; if (strip.substr(strip.size() - me.size()) == me) diff --git a/tests/test_address.cpp b/tests/test_address.cpp index 6b7cb3f..4e9bf52 100644 --- a/tests/test_address.cpp +++ b/tests/test_address.cpp @@ -1,4 +1,4 @@ -#include "lokimq/address.h" +#include "oxenmq/address.h" #include "common.h" const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88"; diff --git a/tests/test_batch.cpp b/tests/test_batch.cpp index f850e6b..30ed270 100644 --- a/tests/test_batch.cpp +++ b/tests/test_batch.cpp @@ -1,4 +1,4 @@ -#include "lokimq/batch.h" +#include "oxenmq/batch.h" #include "common.h" #include @@ -12,7 +12,7 @@ double do_my_task(int input) { std::promise> done; -void continue_big_task(std::vector> results) { +void continue_big_task(std::vector> results) { double sum = 0; int exc_count = 0; for (auto& r : results) { @@ -25,10 +25,10 @@ void continue_big_task(std::vector> results) { done.set_value({sum, exc_count}); } -void start_big_task(lokimq::LokiMQ& lmq) { +void start_big_task(oxenmq::OxenMQ& lmq) { size_t num_jobs = 32; - lokimq::Batch batch; + oxenmq::Batch batch; batch.reserve(num_jobs); for (size_t i = 0; i < num_jobs; i++) @@ -41,7 +41,7 @@ void start_big_task(lokimq::LokiMQ& lmq) { TEST_CASE("batching many small jobs", "[batch-many]") { - lokimq::LokiMQ lmq{ + oxenmq::OxenMQ lmq{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -58,7 +58,7 @@ TEST_CASE("batching many small jobs", "[batch-many]") { } TEST_CASE("batch exception propagation", "[batch-exceptions]") { - lokimq::LokiMQ lmq{ + oxenmq::OxenMQ lmq{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -73,7 +73,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") { using Catch::Matchers::Message; SECTION( "value return" ) { - lokimq::Batch batch; + oxenmq::Batch batch; for (int i : {1, 2}) batch.add_job([i]() { if (i == 1) return 42; throw std::domain_error("bad value " + std::to_string(i)); }); batch.completion([&done_promise](auto results) { @@ -88,7 +88,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") { } SECTION( "lvalue return" ) { - lokimq::Batch batch; + oxenmq::Batch batch; int forty_two = 42; for (int i : {1, 2}) batch.add_job([i,&forty_two]() -> int& { @@ -110,7 +110,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") { } SECTION( "void return" ) { - lokimq::Batch batch; + oxenmq::Batch batch; for (int i : {1, 2}) batch.add_job([i]() { if (i != 1) throw std::domain_error("bad value " + std::to_string(i)); }); batch.completion([&done_promise](auto results) { diff --git a/tests/test_bt.cpp b/tests/test_bt.cpp index fa907a9..b3f8b0b 100644 --- a/tests/test_bt.cpp +++ b/tests/test_bt.cpp @@ -1,4 +1,4 @@ -#include "lokimq/bt_serialize.h" +#include "oxenmq/bt_serialize.h" #include "common.h" #include #include @@ -129,10 +129,10 @@ TEST_CASE("bt_value deserialization", "[bt][deserialization][bt_value]") { REQUIRE( var::get(dna2) == -42 ); REQUIRE_THROWS( var::get(dna1) ); REQUIRE_THROWS( var::get(dna2) ); - REQUIRE( lokimq::get_int(dna1) == 42 ); - REQUIRE( lokimq::get_int(dna2) == -42 ); - REQUIRE( lokimq::get_int(dna1) == 42 ); - REQUIRE_THROWS( lokimq::get_int(dna2) ); + REQUIRE( oxenmq::get_int(dna1) == 42 ); + REQUIRE( oxenmq::get_int(dna2) == -42 ); + REQUIRE( oxenmq::get_int(dna1) == 42 ); + REQUIRE_THROWS( oxenmq::get_int(dna2) ); bt_value x = bt_deserialize("d3:barle3:foold1:ali1ei2ei3ee1:bleed1:cli-5ei4eeeee"); REQUIRE( std::holds_alternative(x) ); @@ -150,9 +150,9 @@ TEST_CASE("bt_value deserialization", "[bt][deserialization][bt_value]") { bt_list& foo1b = var::get(foo1.at("b")); bt_list& foo2c = var::get(foo2.at("c")); std::list foo1a_vals, foo1b_vals, foo2c_vals; - for (auto& v : foo1a) foo1a_vals.push_back(lokimq::get_int(v)); - for (auto& v : foo1b) foo1b_vals.push_back(lokimq::get_int(v)); - for (auto& v : foo2c) foo2c_vals.push_back(lokimq::get_int(v)); + for (auto& v : foo1a) foo1a_vals.push_back(oxenmq::get_int(v)); + for (auto& v : foo1b) foo1b_vals.push_back(oxenmq::get_int(v)); + for (auto& v : foo2c) foo2c_vals.push_back(oxenmq::get_int(v)); REQUIRE( foo1a_vals == std::list{{1,2,3}} ); REQUIRE( foo1b_vals == std::list{} ); REQUIRE( foo2c_vals == std::list{{-5, 4}} ); diff --git a/tests/test_commands.cpp b/tests/test_commands.cpp index 70b7f36..fe5c598 100644 --- a/tests/test_commands.cpp +++ b/tests/test_commands.cpp @@ -1,13 +1,13 @@ #include "common.h" -#include +#include #include #include -using namespace lokimq; +using namespace oxenmq; TEST_CASE("basic commands", "[commands]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -31,7 +31,7 @@ TEST_CASE("basic commands", "[commands]") { server.start(); - LokiMQ client{get_logger("C» "), LogLevel::trace}; + OxenMQ client{get_logger("C» "), LogLevel::trace}; client.add_category("public", Access{AuthLevel::none}); client.add_command("public", "hi", [&](auto&) { his++; }); @@ -77,7 +77,7 @@ TEST_CASE("basic commands", "[commands]") { TEST_CASE("outgoing auth level", "[commands][auth]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -93,7 +93,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") { server.start(); - LokiMQ client{get_logger("C» "), LogLevel::trace}; + OxenMQ client{get_logger("C» "), LogLevel::trace}; std::atomic public_hi{0}, basic_hi{0}, admin_hi{0}; client.add_category("public", Access{AuthLevel::none}); @@ -159,7 +159,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") // original node. std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -178,7 +178,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") m.send_reply("Okay, I'll remember that."); if (backdoor) - m.lokimq.send(backdoor, "backdoor.data", m.data[0]); + m.oxenmq.send(backdoor, "backdoor.data", m.data[0]); }); server.add_command("hey google", "recall", [&](Message& m) { auto l = catch_lock(); @@ -199,7 +199,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") std::set backdoor_details; - LokiMQ nsa{get_logger("NSA» ")}; + OxenMQ nsa{get_logger("NSA» ")}; nsa.add_category("backdoor", Access{AuthLevel::admin}); nsa.add_command("backdoor", "data", [&](Message& m) { auto l = catch_lock(); @@ -215,7 +215,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") REQUIRE( backdoor ); } - std::vector> clients; + std::vector> clients; std::vector conns; std::map> personal_details{ {0, {"Loretta"s, "photos"s}}, @@ -231,7 +231,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") std::map> google_knows; int things_remembered{0}; for (int i = 0; i < 5; i++) { - clients.push_back(std::make_unique( + clients.push_back(std::make_unique( get_logger("C" + std::to_string(i) + "» "), LogLevel::trace )); auto& c = clients.back(); @@ -271,7 +271,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") TEST_CASE("send failure callbacks", "[commands][queue_full]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -298,7 +298,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") { server.start(); // Use a raw socket here because I want to stall it by not reading from it at all, and that is - // hard with LokiMQ. + // hard with OxenMQ. zmq::context_t client_ctx; zmq::socket_t client{client_ctx, zmq::socket_type::dealer}; client.connect(listen); @@ -365,7 +365,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") { TEST_CASE("data parts", "[send][data_parts]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -385,7 +385,7 @@ TEST_CASE("data parts", "[send][data_parts]") { }); server.start(); - LokiMQ client{get_logger("C» "), LogLevel::trace}; + OxenMQ client{get_logger("C» "), LogLevel::trace}; client.start(); std::atomic got{false}; @@ -406,7 +406,7 @@ TEST_CASE("data parts", "[send][data_parts]") { } std::vector some_data{{"abc"s, "def"s, "omg123\0zzz"s}}; - client.send(c, "public.hello", lokimq::send_option::data_parts(some_data.begin(), some_data.end())); + client.send(c, "public.hello", oxenmq::send_option::data_parts(some_data.begin(), some_data.end())); reply_sleep(); { auto lock = catch_lock(); @@ -418,10 +418,10 @@ TEST_CASE("data parts", "[send][data_parts]") { std::vector some_data2{{"a"sv, "b"sv, "\0"sv}}; client.send(c, "public.hello", "hi", - lokimq::send_option::data_parts(some_data2.begin(), some_data2.end()), + oxenmq::send_option::data_parts(some_data2.begin(), some_data2.end()), "another", "string"sv, - lokimq::send_option::data_parts(some_data.begin(), some_data.end())); + oxenmq::send_option::data_parts(some_data.begin(), some_data.end())); std::vector expected; expected.push_back("hi"); diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index b9f30d8..ca66fe1 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -1,5 +1,5 @@ #include "common.h" -#include +#include extern "C" { #include } @@ -7,7 +7,7 @@ extern "C" { TEST_CASE("connections with curve authentication", "[curve][connect]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -20,7 +20,7 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") { server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); }); server.start(); - LokiMQ client{get_logger("C» "), LogLevel::trace}; + OxenMQ client{get_logger("C» "), LogLevel::trace}; client.start(); @@ -55,7 +55,7 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") { REQUIRE(sodium_init() != -1); auto listen_addr = random_localhost(); crypto_box_keypair(reinterpret_cast(&pubkey[0]), reinterpret_cast(&privkey[0])); - LokiMQ sn{ + OxenMQ sn{ pubkey, privkey, true, [&](auto pk) { if (pk == pubkey) return listen_addr; else return ""s; }, @@ -92,7 +92,7 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") { TEST_CASE("plain-text connections", "[plaintext][connect]") { std::string listen = random_localhost(); - LokiMQ server{get_logger("S» "), LogLevel::trace}; + OxenMQ server{get_logger("S» "), LogLevel::trace}; server.add_category("public", Access{AuthLevel::none}); server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); }); @@ -101,7 +101,7 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") { server.start(); - LokiMQ client{get_logger("C» "), LogLevel::trace}; + OxenMQ client{get_logger("C» "), LogLevel::trace}; client.start(); @@ -131,7 +131,7 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") { TEST_CASE("unique connection IDs", "[connect][id]") { std::string listen = random_localhost(); - LokiMQ server{get_logger("S» "), LogLevel::trace}; + OxenMQ server{get_logger("S» "), LogLevel::trace}; ConnectionID first, second; server.add_category("x", Access{AuthLevel::none}) @@ -143,8 +143,8 @@ TEST_CASE("unique connection IDs", "[connect][id]") { server.start(); - LokiMQ client1{get_logger("C1» "), LogLevel::trace}; - LokiMQ client2{get_logger("C2» "), LogLevel::trace}; + OxenMQ client1{get_logger("C1» "), LogLevel::trace}; + OxenMQ client2{get_logger("C2» "), LogLevel::trace}; client1.start(); client2.start(); @@ -186,7 +186,7 @@ TEST_CASE("unique connection IDs", "[connect][id]") { TEST_CASE("SN disconnections", "[connect][disconnect]") { - std::vector> lmq; + std::vector> lmq; std::vector pubkey, privkey; std::unordered_map conn; REQUIRE(sodium_init() != -1); @@ -200,7 +200,7 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") { } std::atomic his{0}; for (int i = 0; i < pubkey.size(); i++) { - lmq.push_back(std::make_unique( + lmq.push_back(std::make_unique( pubkey[i], privkey[i], true, [conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; }, get_logger("S" + std::to_string(i) + "» "), @@ -238,7 +238,7 @@ TEST_CASE("SN auth checks", "[sandwich][auth]") { privkey.resize(crypto_box_SECRETKEYBYTES); REQUIRE(sodium_init() != -1); crypto_box_keypair(reinterpret_cast(&pubkey[0]), reinterpret_cast(&privkey[0])); - LokiMQ server{ + OxenMQ server{ pubkey, privkey, true, // service node [](auto) { return ""; }, @@ -265,7 +265,7 @@ TEST_CASE("SN auth checks", "[sandwich][auth]") { .add_request_command("make", [&](Message& m) { m.send_reply("okay"); }); server.start(); - LokiMQ client{ + OxenMQ client{ "", "", false, [&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; }, get_logger("B» "), LogLevel::trace}; @@ -352,7 +352,7 @@ TEST_CASE("SN single worker test", "[connect][worker]") { // Tests a failure case that could trigger when all workers are allocated (here we make that // simpler by just having one worker). std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", false, // service node [](auto) { return ""; }, @@ -368,7 +368,7 @@ TEST_CASE("SN single worker test", "[connect][worker]") { ; server.start(); - LokiMQ client{get_logger("B» "), LogLevel::trace}; + OxenMQ client{get_logger("B» "), LogLevel::trace}; client.start(); auto conn = client.connect_remote(listen, [](auto) {}, [](auto, auto) {}); diff --git a/tests/test_encoding.cpp b/tests/test_encoding.cpp index c2b3b7a..f506d8c 100644 --- a/tests/test_encoding.cpp +++ b/tests/test_encoding.cpp @@ -1,7 +1,8 @@ -#include "lokimq/hex.h" -#include "lokimq/base32z.h" -#include "lokimq/base64.h" +#include "oxenmq/hex.h" +#include "oxenmq/base32z.h" +#include "oxenmq/base64.h" #include "common.h" +#include using namespace std::literals; @@ -11,116 +12,123 @@ const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqr const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g="; TEST_CASE("hex encoding/decoding", "[encoding][decoding][hex]") { - REQUIRE( lokimq::to_hex("\xff\x42\x12\x34") == "ff421234"s ); + REQUIRE( oxenmq::to_hex("\xff\x42\x12\x34") == "ff421234"s ); std::vector chars{{1, 10, 100, 254}}; std::array out; std::array expected{{'0', '1', '0', 'a', '6', '4', 'f', 'e'}}; - lokimq::to_hex(chars.begin(), chars.end(), out.begin()); + oxenmq::to_hex(chars.begin(), chars.end(), out.begin()); REQUIRE( out == expected ); - REQUIRE( lokimq::to_hex(chars.begin(), chars.end()) == "010a64fe" ); + REQUIRE( oxenmq::to_hex(chars.begin(), chars.end()) == "010a64fe" ); - REQUIRE( lokimq::from_hex("12345678ffEDbca9") == "\x12\x34\x56\x78\xff\xed\xbc\xa9"s ); + REQUIRE( oxenmq::from_hex("12345678ffEDbca9") == "\x12\x34\x56\x78\xff\xed\xbc\xa9"s ); - REQUIRE( lokimq::is_hex("1234567890abcdefABCDEF1234567890abcdefABCDEF") ); - REQUIRE_FALSE( lokimq::is_hex("1234567890abcdefABCDEF1234567890aGcdefABCDEF") ); - REQUIRE_FALSE( lokimq::is_hex("1234567890abcdefABCDEF1234567890agcdefABCDEF") ); - REQUIRE_FALSE( lokimq::is_hex("\x11\xff") ); + REQUIRE( oxenmq::is_hex("1234567890abcdefABCDEF1234567890abcdefABCDEF") ); + REQUIRE_FALSE( oxenmq::is_hex("1234567890abcdefABCDEF1234567890aGcdefABCDEF") ); + // ^ + REQUIRE_FALSE( oxenmq::is_hex("1234567890abcdefABCDEF1234567890agcdefABCDEF") ); + // ^ + REQUIRE_FALSE( oxenmq::is_hex("\x11\xff") ); + constexpr auto odd_hex = "1234567890abcdefABCDEF1234567890abcdefABCDE"sv; + REQUIRE_FALSE( oxenmq::is_hex(odd_hex) ); + REQUIRE_FALSE( oxenmq::is_hex("0") ); - REQUIRE( lokimq::from_hex(pk_hex) == pk ); - REQUIRE( lokimq::to_hex(pk) == pk_hex ); + REQUIRE( std::all_of(odd_hex.begin(), odd_hex.end(), oxenmq::is_hex_digit) ); - REQUIRE( lokimq::from_hex(pk_hex.begin(), pk_hex.end()) == pk ); + REQUIRE( oxenmq::from_hex(pk_hex) == pk ); + REQUIRE( oxenmq::to_hex(pk) == pk_hex ); + + REQUIRE( oxenmq::from_hex(pk_hex.begin(), pk_hex.end()) == pk ); std::vector bytes{{std::byte{0xff}, std::byte{0x42}, std::byte{0x12}, std::byte{0x34}}}; std::basic_string_view b{bytes.data(), bytes.size()}; - REQUIRE( lokimq::to_hex(b) == "ff421234"s ); + REQUIRE( oxenmq::to_hex(b) == "ff421234"s ); bytes.resize(8); bytes[0] = std::byte{'f'}; bytes[1] = std::byte{'f'}; bytes[2] = std::byte{'4'}; bytes[3] = std::byte{'2'}; bytes[4] = std::byte{'1'}; bytes[5] = std::byte{'2'}; bytes[6] = std::byte{'3'}; bytes[7] = std::byte{'4'}; std::basic_string_view hex_bytes{bytes.data(), bytes.size()}; - REQUIRE( lokimq::is_hex(hex_bytes) ); - REQUIRE( lokimq::from_hex(hex_bytes) == "\xff\x42\x12\x34" ); + REQUIRE( oxenmq::is_hex(hex_bytes) ); + REQUIRE( oxenmq::from_hex(hex_bytes) == "\xff\x42\x12\x34" ); } TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") { - REQUIRE( lokimq::to_base32z("\0\0\0\0\0"s) == "yyyyyyyy" ); - REQUIRE( lokimq::to_base32z("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv) + REQUIRE( oxenmq::to_base32z("\0\0\0\0\0"s) == "yyyyyyyy" ); + REQUIRE( oxenmq::to_base32z("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv) == "yrtwk3hjixg66yjdeiuauk6p7hy1gtm8tgih55abrpnsxnpm3zzo"); - REQUIRE( lokimq::from_base32z("yrtwk3hjixg66yjdeiuauk6p7hy1gtm8tgih55abrpnsxnpm3zzo") + REQUIRE( oxenmq::from_base32z("yrtwk3hjixg66yjdeiuauk6p7hy1gtm8tgih55abrpnsxnpm3zzo") == "\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv); - REQUIRE( lokimq::from_base32z("YRTWK3HJIXG66YJDEIUAUK6P7HY1GTM8TGIH55ABRPNSXNPM3ZZO") + REQUIRE( oxenmq::from_base32z("YRTWK3HJIXG66YJDEIUAUK6P7HY1GTM8TGIH55ABRPNSXNPM3ZZO") == "\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv); - auto five_nulls = lokimq::from_base32z("yyyyyyyy"); + auto five_nulls = oxenmq::from_base32z("yyyyyyyy"); REQUIRE( five_nulls.size() == 5 ); REQUIRE( five_nulls == "\0\0\0\0\0"s ); // 00000 00001 00010 00011 00100 00101 00110 00111 // == // 00000000 01000100 00110010 00010100 11000111 - REQUIRE( lokimq::from_base32z("ybndrfg8") == "\x00\x44\x32\x14\xc7"s ); + REQUIRE( oxenmq::from_base32z("ybndrfg8") == "\x00\x44\x32\x14\xc7"s ); // Special case 1: 7 base32z digits with 3 trailing 0 bits -> 4 bytes (the trailing 0s are dropped) // 00000 00001 00010 00011 00100 00101 11000 // == // 00000000 01000100 00110010 00010111 - REQUIRE( lokimq::from_base32z("ybndrfa") == "\x00\x44\x32\x17"s ); + REQUIRE( oxenmq::from_base32z("ybndrfa") == "\x00\x44\x32\x17"s ); // Round-trip it: - REQUIRE( lokimq::from_base32z(lokimq::to_base32z("\x00\x44\x32\x17"sv)) == "\x00\x44\x32\x17"sv ); - REQUIRE( lokimq::to_base32z(lokimq::from_base32z("ybndrfa")) == "ybndrfa" ); + REQUIRE( oxenmq::from_base32z(oxenmq::to_base32z("\x00\x44\x32\x17"sv)) == "\x00\x44\x32\x17"sv ); + REQUIRE( oxenmq::to_base32z(oxenmq::from_base32z("ybndrfa")) == "ybndrfa" ); // Special case 2: 7 base32z digits with 3 trailing bits 010; we just ignore the trailing stuff, // as if it was specified as 0. (The last digit here is 11010 instead of 11000). - REQUIRE( lokimq::from_base32z("ybndrf4") == "\x00\x44\x32\x17"s ); + REQUIRE( oxenmq::from_base32z("ybndrf4") == "\x00\x44\x32\x17"s ); // This one won't round-trip to the same value since it has ignored garbage bytes at the end - REQUIRE( lokimq::to_base32z(lokimq::from_base32z("ybndrf4"s)) == "ybndrfa" ); + REQUIRE( oxenmq::to_base32z(oxenmq::from_base32z("ybndrf4"s)) == "ybndrfa" ); - REQUIRE( lokimq::to_base32z(pk) == pk_b32z ); - REQUIRE( lokimq::to_base32z(pk.begin(), pk.end()) == pk_b32z ); - REQUIRE( lokimq::from_base32z(pk_b32z) == pk ); - REQUIRE( lokimq::from_base32z(pk_b32z.begin(), pk_b32z.end()) == pk ); + REQUIRE( oxenmq::to_base32z(pk) == pk_b32z ); + REQUIRE( oxenmq::to_base32z(pk.begin(), pk.end()) == pk_b32z ); + REQUIRE( oxenmq::from_base32z(pk_b32z) == pk ); + REQUIRE( oxenmq::from_base32z(pk_b32z.begin(), pk_b32z.end()) == pk ); std::string pk_b32z_again, pk_again; - lokimq::to_base32z(pk.begin(), pk.end(), std::back_inserter(pk_b32z_again)); - lokimq::from_base32z(pk_b32z.begin(), pk_b32z.end(), std::back_inserter(pk_again)); + oxenmq::to_base32z(pk.begin(), pk.end(), std::back_inserter(pk_b32z_again)); + oxenmq::from_base32z(pk_b32z.begin(), pk_b32z.end(), std::back_inserter(pk_again)); REQUIRE( pk_b32z_again == pk_b32z ); REQUIRE( pk_again == pk ); std::vector bytes{{std::byte{0}, std::byte{255}}}; std::basic_string_view b{bytes.data(), bytes.size()}; - REQUIRE( lokimq::to_base32z(b) == "yd9o" ); + REQUIRE( oxenmq::to_base32z(b) == "yd9o" ); bytes.resize(4); bytes[0] = std::byte{'y'}; bytes[1] = std::byte{'d'}; bytes[2] = std::byte{'9'}; bytes[3] = std::byte{'o'}; std::basic_string_view b32_bytes{bytes.data(), bytes.size()}; - REQUIRE( lokimq::is_base32z(b32_bytes) ); - REQUIRE( lokimq::from_base32z(b32_bytes) == "\x00\xff"sv ); + REQUIRE( oxenmq::is_base32z(b32_bytes) ); + REQUIRE( oxenmq::from_base32z(b32_bytes) == "\x00\xff"sv ); } TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { // 00000000 00000000 00000000 -> 000000 000000 000000 000000 - REQUIRE( lokimq::to_base64("\0\0\0"s) == "AAAA" ); + REQUIRE( oxenmq::to_base64("\0\0\0"s) == "AAAA" ); // 00000001 00000002 00000003 -> 000000 010000 000200 000003 - REQUIRE( lokimq::to_base64("\x01\x02\x03"s) == "AQID" ); - REQUIRE( lokimq::to_base64("\0\0\0\0"s) == "AAAAAA==" ); + REQUIRE( oxenmq::to_base64("\x01\x02\x03"s) == "AQID" ); + REQUIRE( oxenmq::to_base64("\0\0\0\0"s) == "AAAAAA==" ); // 00000000 00000000 00000000 11111111 -> // 000000 000000 000000 000000 111111 110000 (pad) (pad) - REQUIRE( lokimq::to_base64("a") == "YQ==" ); - REQUIRE( lokimq::to_base64("ab") == "YWI=" ); - REQUIRE( lokimq::to_base64("abc") == "YWJj" ); - REQUIRE( lokimq::to_base64("abcd") == "YWJjZA==" ); - REQUIRE( lokimq::to_base64("abcde") == "YWJjZGU=" ); - REQUIRE( lokimq::to_base64("abcdef") == "YWJjZGVm" ); + REQUIRE( oxenmq::to_base64("a") == "YQ==" ); + REQUIRE( oxenmq::to_base64("ab") == "YWI=" ); + REQUIRE( oxenmq::to_base64("abc") == "YWJj" ); + REQUIRE( oxenmq::to_base64("abcd") == "YWJjZA==" ); + REQUIRE( oxenmq::to_base64("abcde") == "YWJjZGU=" ); + REQUIRE( oxenmq::to_base64("abcdef") == "YWJjZGVm" ); - REQUIRE( lokimq::to_base64("\0\0\0\xff"s) == "AAAA/w==" ); - REQUIRE( lokimq::to_base64("\0\0\0\xff\xff"s) == "AAAA//8=" ); - REQUIRE( lokimq::to_base64("\0\0\0\xff\xff\xff"s) == "AAAA////" ); - REQUIRE( lokimq::to_base64( + REQUIRE( oxenmq::to_base64("\0\0\0\xff"s) == "AAAA/w==" ); + REQUIRE( oxenmq::to_base64("\0\0\0\xff\xff"s) == "AAAA//8=" ); + REQUIRE( oxenmq::to_base64("\0\0\0\xff\xff\xff"s) == "AAAA////" ); + REQUIRE( oxenmq::to_base64( "Man is distinguished, not only by his reason, but by this singular passion from other " "animals, which is a lust of the mind, that by a perseverance of delight in the " "continued and indefatigable generation of knowledge, exceeds the short vehemence of " @@ -132,33 +140,33 @@ TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { "dWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRo" "ZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=" ); - REQUIRE( lokimq::from_base64("A+/A") == "\x03\xef\xc0" ); - REQUIRE( lokimq::from_base64("YWJj") == "abc" ); - REQUIRE( lokimq::from_base64("YWJjZA==") == "abcd" ); - REQUIRE( lokimq::from_base64("YWJjZA") == "abcd" ); - REQUIRE( lokimq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits - REQUIRE( lokimq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits - REQUIRE( lokimq::from_base64("YWJj+") == "abc" ); // ignore superfluous bits - REQUIRE( lokimq::from_base64("YWJjZGU=") == "abcde" ); - REQUIRE( lokimq::from_base64("YWJjZGU") == "abcde" ); - REQUIRE( lokimq::from_base64("YWJjZGVm") == "abcdef" ); + REQUIRE( oxenmq::from_base64("A+/A") == "\x03\xef\xc0" ); + REQUIRE( oxenmq::from_base64("YWJj") == "abc" ); + REQUIRE( oxenmq::from_base64("YWJjZA==") == "abcd" ); + REQUIRE( oxenmq::from_base64("YWJjZA") == "abcd" ); + REQUIRE( oxenmq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits + REQUIRE( oxenmq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits + REQUIRE( oxenmq::from_base64("YWJj+") == "abc" ); // ignore superfluous bits + REQUIRE( oxenmq::from_base64("YWJjZGU=") == "abcde" ); + REQUIRE( oxenmq::from_base64("YWJjZGU") == "abcde" ); + REQUIRE( oxenmq::from_base64("YWJjZGVm") == "abcdef" ); - REQUIRE( lokimq::is_base64("YWJjZGVm") ); - REQUIRE( lokimq::is_base64("YWJjZGU") ); - REQUIRE( lokimq::is_base64("YWJjZGU=") ); - REQUIRE( lokimq::is_base64("YWJjZA==") ); - REQUIRE( lokimq::is_base64("YWJjZA") ); - REQUIRE( lokimq::is_base64("YWJjZB") ); // not really valid, but we explicitly accept it + REQUIRE( oxenmq::is_base64("YWJjZGVm") ); + REQUIRE( oxenmq::is_base64("YWJjZGU") ); + REQUIRE( oxenmq::is_base64("YWJjZGU=") ); + REQUIRE( oxenmq::is_base64("YWJjZA==") ); + REQUIRE( oxenmq::is_base64("YWJjZA") ); + REQUIRE( oxenmq::is_base64("YWJjZB") ); // not really valid, but we explicitly accept it - REQUIRE_FALSE( lokimq::is_base64("YWJjZ=") ); // invalid padding (padding can only be 4th or 3rd+4th of a 4-char block) - REQUIRE_FALSE( lokimq::is_base64("YWJj=") ); - REQUIRE_FALSE( lokimq::is_base64("YWJj=A") ); - REQUIRE_FALSE( lokimq::is_base64("YWJjA===") ); - REQUIRE_FALSE( lokimq::is_base64("YWJ[") ); - REQUIRE_FALSE( lokimq::is_base64("YWJ.") ); - REQUIRE_FALSE( lokimq::is_base64("_YWJ") ); + REQUIRE_FALSE( oxenmq::is_base64("YWJjZ=") ); // invalid padding (padding can only be 4th or 3rd+4th of a 4-char block) + REQUIRE_FALSE( oxenmq::is_base64("YWJj=") ); + REQUIRE_FALSE( oxenmq::is_base64("YWJj=A") ); + REQUIRE_FALSE( oxenmq::is_base64("YWJjA===") ); + REQUIRE_FALSE( oxenmq::is_base64("YWJ[") ); + REQUIRE_FALSE( oxenmq::is_base64("YWJ.") ); + REQUIRE_FALSE( oxenmq::is_base64("_YWJ") ); - REQUIRE( lokimq::from_base64( + REQUIRE( oxenmq::from_base64( "TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlz" "IHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2Yg" "dGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0aGUgY29udGlu" @@ -170,24 +178,63 @@ TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") { "continued and indefatigable generation of knowledge, exceeds the short vehemence of " "any carnal pleasure."); - REQUIRE( lokimq::to_base64(pk) == pk_b64 ); - REQUIRE( lokimq::to_base64(pk.begin(), pk.end()) == pk_b64 ); - REQUIRE( lokimq::from_base64(pk_b64) == pk ); - REQUIRE( lokimq::from_base64(pk_b64.begin(), pk_b64.end()) == pk ); + REQUIRE( oxenmq::to_base64(pk) == pk_b64 ); + REQUIRE( oxenmq::to_base64(pk.begin(), pk.end()) == pk_b64 ); + REQUIRE( oxenmq::from_base64(pk_b64) == pk ); + REQUIRE( oxenmq::from_base64(pk_b64.begin(), pk_b64.end()) == pk ); std::string pk_b64_again, pk_again; - lokimq::to_base64(pk.begin(), pk.end(), std::back_inserter(pk_b64_again)); - lokimq::from_base64(pk_b64.begin(), pk_b64.end(), std::back_inserter(pk_again)); + oxenmq::to_base64(pk.begin(), pk.end(), std::back_inserter(pk_b64_again)); + oxenmq::from_base64(pk_b64.begin(), pk_b64.end(), std::back_inserter(pk_again)); REQUIRE( pk_b64_again == pk_b64 ); REQUIRE( pk_again == pk ); std::vector bytes{{std::byte{0}, std::byte{255}}}; std::basic_string_view b{bytes.data(), bytes.size()}; - REQUIRE( lokimq::to_base64(b) == "AP8=" ); + REQUIRE( oxenmq::to_base64(b) == "AP8=" ); bytes.resize(4); bytes[0] = std::byte{'/'}; bytes[1] = std::byte{'w'}; bytes[2] = std::byte{'A'}; bytes[3] = std::byte{'='}; std::basic_string_view b64_bytes{bytes.data(), bytes.size()}; - REQUIRE( lokimq::is_base64(b64_bytes) ); - REQUIRE( lokimq::from_base64(b64_bytes) == "\xff\x00"sv ); + REQUIRE( oxenmq::is_base64(b64_bytes) ); + REQUIRE( oxenmq::from_base64(b64_bytes) == "\xff\x00"sv ); +} + +TEST_CASE("std::byte decoding", "[decoding][hex][base32z][base64]") { + // Decoding to std::byte is a little trickier because you can't assign to a byte without an + // explicit cast, which means we have to properly detect that output is going to a std::byte + // output. + + // hex + auto b_in = "ff42"s; + std::vector b_out; + oxenmq::from_hex(b_in.begin(), b_in.end(), std::back_inserter(b_out)); + REQUIRE( b_out == std::vector{std::byte{0xff}, std::byte{0x42}} ); + b_out.emplace_back(); + oxenmq::from_hex(b_in.begin(), b_in.end(), b_out.begin() + 1); + REQUIRE( b_out == std::vector{std::byte{0xff}, std::byte{0xff}, std::byte{0x42}} ); + oxenmq::from_hex(b_in.begin(), b_in.end(), b_out.data()); + REQUIRE( b_out == std::vector{std::byte{0xff}, std::byte{0x42}, std::byte{0x42}} ); + + // base32z + b_in = "yojky"s; + b_out.clear(); + oxenmq::from_base32z(b_in.begin(), b_in.end(), std::back_inserter(b_out)); + REQUIRE( b_out == std::vector{std::byte{0x04}, std::byte{0x12}, std::byte{0xa0}} ); + b_out.emplace_back(); + oxenmq::from_base32z(b_in.begin(), b_in.end(), b_out.begin() + 1); + REQUIRE( b_out == std::vector{std::byte{0x04}, std::byte{0x04}, std::byte{0x12}, std::byte{0xa0}} ); + oxenmq::from_base32z(b_in.begin(), b_in.end(), b_out.data()); + REQUIRE( b_out == std::vector{std::byte{0x04}, std::byte{0x12}, std::byte{0xa0}, std::byte{0xa0}} ); + + // base64 + b_in = "yojk"s; + b_out.clear(); + oxenmq::from_base64(b_in.begin(), b_in.end(), std::back_inserter(b_out)); + REQUIRE( b_out == std::vector{std::byte{0xca}, std::byte{0x88}, std::byte{0xe4}} ); + b_out.emplace_back(); + oxenmq::from_base64(b_in.begin(), b_in.end(), b_out.begin() + 1); + REQUIRE( b_out == std::vector{std::byte{0xca}, std::byte{0xca}, std::byte{0x88}, std::byte{0xe4}} ); + oxenmq::from_base64(b_in.begin(), b_in.end(), b_out.data()); + REQUIRE( b_out == std::vector{std::byte{0xca}, std::byte{0x88}, std::byte{0xe4}, std::byte{0xe4}} ); } diff --git a/tests/test_failures.cpp b/tests/test_failures.cpp index 1701b28..8861ca6 100644 --- a/tests/test_failures.cpp +++ b/tests/test_failures.cpp @@ -1,13 +1,13 @@ #include "common.h" -#include +#include #include #include -using namespace lokimq; +using namespace oxenmq; TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -48,7 +48,7 @@ TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") { TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -109,7 +109,7 @@ TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") { TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -192,7 +192,7 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") { TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -259,7 +259,7 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, diff --git a/tests/test_inject.cpp b/tests/test_inject.cpp index cffd20a..caa2956 100644 --- a/tests/test_inject.cpp +++ b/tests/test_inject.cpp @@ -1,10 +1,10 @@ #include "common.h" -using namespace lokimq; +using namespace oxenmq; TEST_CASE("injected external commands", "[injected]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -24,17 +24,26 @@ TEST_CASE("injected external commands", "[injected]") { server.start(); - LokiMQ client{get_logger("C» "), LogLevel::trace}; + OxenMQ client{get_logger("C» "), LogLevel::trace}; client.start(); std::atomic got{false}; bool success = false; +// Deliberately using a deprecated command here, disable -Wdeprecated-declarations +#ifdef __GNUG__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif auto c = client.connect_remote(listen, [&](auto conn) { success = true; got = true; }, [&](auto conn, std::string_view) { got = true; }, server.get_pubkey()); +#ifdef __GNUG__ +#pragma GCC diagnostic pop +#endif + wait_for_conn(got); { auto lock = catch_lock(); diff --git a/tests/test_requests.cpp b/tests/test_requests.cpp index af31e9e..06aa28c 100644 --- a/tests/test_requests.cpp +++ b/tests/test_requests.cpp @@ -1,11 +1,11 @@ #include "common.h" -#include +#include -using namespace lokimq; +using namespace oxenmq; TEST_CASE("basic requests", "[requests]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -20,7 +20,7 @@ TEST_CASE("basic requests", "[requests]") { }); server.start(); - LokiMQ client( + OxenMQ client( [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } ); //client.log_level(LogLevel::trace); @@ -62,7 +62,7 @@ TEST_CASE("basic requests", "[requests]") { TEST_CASE("request from server to client", "[requests]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -77,7 +77,7 @@ TEST_CASE("request from server to client", "[requests]") { }); server.start(); - LokiMQ client( + OxenMQ client( [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } ); //client.log_level(LogLevel::trace); @@ -125,7 +125,7 @@ TEST_CASE("request from server to client", "[requests]") { TEST_CASE("request timeouts", "[requests][timeout]") { std::string listen = random_localhost(); - LokiMQ server{ + OxenMQ server{ "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, @@ -138,7 +138,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") { server.add_request_command("public", "blackhole", [&](Message& m) { /* doesn't reply */ }); server.start(); - LokiMQ client( + OxenMQ client( [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } ); //client.log_level(LogLevel::trace); @@ -167,7 +167,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") { success = ok; data = std::move(data_); }, - lokimq::send_option::request_timeout{10ms} + oxenmq::send_option::request_timeout{10ms} ); std::atomic got_triggered2{false}; @@ -176,7 +176,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") { success = ok; data = std::move(data_); }, - lokimq::send_option::request_timeout{200ms} + oxenmq::send_option::request_timeout{200ms} ); std::this_thread::sleep_for(100ms); diff --git a/tests/test_tagged_threads.cpp b/tests/test_tagged_threads.cpp index 914a58c..7ad0196 100644 --- a/tests/test_tagged_threads.cpp +++ b/tests/test_tagged_threads.cpp @@ -1,9 +1,9 @@ -#include "lokimq/batch.h" +#include "oxenmq/batch.h" #include "common.h" #include TEST_CASE("tagged thread start functions", "[tagged][start]") { - lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace}; + oxenmq::OxenMQ lmq{get_logger(""), LogLevel::trace}; lmq.set_general_threads(2); lmq.set_batch_threads(2); @@ -26,13 +26,13 @@ TEST_CASE("tagged thread start functions", "[tagged][start]") { } TEST_CASE("tagged threads quit-before-start", "[tagged][quit]") { - auto lmq = std::make_unique(get_logger(""), LogLevel::trace); + auto lmq = std::make_unique(get_logger(""), LogLevel::trace); auto t_abc = lmq->add_tagged_thread("abc"); REQUIRE_NOTHROW(lmq.reset()); } TEST_CASE("batch jobs to tagged threads", "[tagged][batch]") { - lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace}; + oxenmq::OxenMQ lmq{get_logger(""), LogLevel::trace}; lmq.set_general_threads(2); lmq.set_batch_threads(2); @@ -111,7 +111,7 @@ TEST_CASE("batch jobs to tagged threads", "[tagged][batch]") { } TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]") { - lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace}; + oxenmq::OxenMQ lmq{get_logger(""), LogLevel::trace}; lmq.set_general_threads(4); lmq.set_batch_threads(4); @@ -119,7 +119,7 @@ TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]" auto t_abc = lmq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); }); lmq.start(); - lokimq::Batch batch; + oxenmq::Batch batch; for (int i = 1; i < 10; i++) batch.add_job([i, &id_abc]() { if (std::this_thread::get_id() == id_abc) return 0; return i; }); @@ -140,7 +140,7 @@ TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]" TEST_CASE("timer job completion on tagged threads", "[tagged][timer]") { - lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace}; + oxenmq::OxenMQ lmq{get_logger(""), LogLevel::trace}; lmq.set_general_threads(4); lmq.set_batch_threads(4);