mirror of https://github.com/oxen-io/oxen-mq.git
Merge remote-tracking branch 'origin/dev' into master
This commit is contained in:
commit
53481cdfa9
|
@ -0,0 +1,63 @@
|
|||
local debian_pipeline(name, image, arch='amd64', deps='g++ libsodium-dev libzmq3-dev', cmake_extra='', build_type='Release', extra_cmds=[], allow_fail=false) = {
|
||||
kind: 'pipeline',
|
||||
type: 'docker',
|
||||
name: name,
|
||||
platform: { arch: arch },
|
||||
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
|
||||
steps: [
|
||||
{
|
||||
name: 'build',
|
||||
image: image,
|
||||
[if allow_fail then "failure"]: "ignore",
|
||||
commands: [
|
||||
'apt-get update',
|
||||
'apt-get install -y eatmydata',
|
||||
'eatmydata apt-get dist-upgrade -y',
|
||||
'eatmydata apt-get install -y cmake git ninja-build pkg-config ccache ' + deps,
|
||||
'git submodule update --init --recursive',
|
||||
'mkdir build',
|
||||
'cd build',
|
||||
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fdiagnostics-color=always -DCMAKE_BUILD_TYPE='+build_type+' -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ' + cmake_extra,
|
||||
'ninja -v',
|
||||
'./tests/tests --use-colour yes'
|
||||
] + extra_cmds,
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
[
|
||||
debian_pipeline("Ubuntu focal (amd64)", "ubuntu:focal"),
|
||||
debian_pipeline("Ubuntu bionic (amd64)", "ubuntu:bionic", deps='libsodium-dev g++-8',
|
||||
cmake_extra='-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8'),
|
||||
debian_pipeline("Debian sid (amd64)", "debian:sid"),
|
||||
debian_pipeline("Debian sid/Debug (amd64)", "debian:sid", build_type='Debug'),
|
||||
debian_pipeline("Debian sid/clang-10 (amd64)", "debian:sid", deps='clang-10 lld-10 libsodium-dev libzmq3-dev',
|
||||
cmake_extra='-DCMAKE_C_COMPILER=clang-10 -DCMAKE_CXX_COMPILER=clang++-10 ' + std.join(' ', [
|
||||
'-DCMAKE_'+type+'_LINKER_FLAGS=-fuse-ld=lld-10' for type in ['EXE','MODULE','SHARED','STATIC']])),
|
||||
debian_pipeline("Debian buster (amd64)", "debian:buster"),
|
||||
debian_pipeline("Debian buster (i386)", "i386/debian:buster"),
|
||||
debian_pipeline("Ubuntu bionic (ARM64)", "ubuntu:bionic", arch="arm64", deps='libsodium-dev g++-8',
|
||||
cmake_extra='-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8'),
|
||||
debian_pipeline("Debian sid (ARM64)", "debian:sid", arch="arm64"),
|
||||
debian_pipeline("Debian buster (armhf)", "arm32v7/debian:buster", arch="arm64"),
|
||||
{
|
||||
kind: 'pipeline',
|
||||
type: 'exec',
|
||||
name: 'macOS (Catalina w/macports)',
|
||||
platform: { os: 'darwin', arch: 'amd64' },
|
||||
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
|
||||
steps: [
|
||||
{
|
||||
name: 'build',
|
||||
commands: [
|
||||
'git submodule update --init --recursive',
|
||||
'mkdir build',
|
||||
'cd build',
|
||||
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fcolor-diagnostics -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
|
||||
'ninja -v',
|
||||
'./tests/tests --use-colour yes'
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
|
@ -1,6 +1,3 @@
|
|||
[submodule "mapbox-variant"]
|
||||
path = mapbox-variant
|
||||
url = https://github.com/mapbox/variant.git
|
||||
[submodule "cppzmq"]
|
||||
path = cppzmq
|
||||
url = https://github.com/zeromq/cppzmq.git
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
cmake_minimum_required(VERSION 3.7)
|
||||
|
||||
project(liblokimq CXX)
|
||||
project(liblokimq CXX C)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
set(LOKIMQ_VERSION_MAJOR 1)
|
||||
set(LOKIMQ_VERSION_MINOR 1)
|
||||
set(LOKIMQ_VERSION_PATCH 4)
|
||||
set(LOKIMQ_VERSION_MINOR 2)
|
||||
set(LOKIMQ_VERSION_PATCH 0)
|
||||
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
|
||||
message(STATUS "lokimq v${LOKIMQ_VERSION}")
|
||||
|
||||
|
@ -22,6 +22,7 @@ configure_file(lokimq/version.h.in lokimq/version.h @ONLY)
|
|||
configure_file(liblokimq.pc.in liblokimq.pc @ONLY)
|
||||
|
||||
add_library(lokimq
|
||||
lokimq/address.cpp
|
||||
lokimq/auth.cpp
|
||||
lokimq/bt_serialize.cpp
|
||||
lokimq/connections.cpp
|
||||
|
@ -46,17 +47,18 @@ if(TARGET libzmq)
|
|||
elseif(BUILD_SHARED_LIBS)
|
||||
include(FindPkgConfig)
|
||||
pkg_check_modules(libzmq libzmq>=4.3 IMPORTED_TARGET)
|
||||
# Debian sid includes a -isystem in the mit-krb package that, starting with pkg-config 0.29.2,
|
||||
# breaks cmake's pkgconfig module because it stupidly thinks "-isystem" is a path, so if we find
|
||||
# -isystem in the include dirs then hack it out.
|
||||
get_property(zmq_inc TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
|
||||
list(FIND zmq_inc "-isystem" broken_isystem)
|
||||
if(NOT broken_isystem EQUAL -1)
|
||||
list(REMOVE_AT zmq_inc ${broken_isystem})
|
||||
set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc})
|
||||
endif()
|
||||
|
||||
if(libzmq_FOUND)
|
||||
# Debian sid includes a -isystem in the mit-krb package that, starting with pkg-config 0.29.2,
|
||||
# breaks cmake's pkgconfig module because it stupidly thinks "-isystem" is a path, so if we find
|
||||
# -isystem in the include dirs then hack it out.
|
||||
get_property(zmq_inc TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
|
||||
list(FIND zmq_inc "-isystem" broken_isystem)
|
||||
if(NOT broken_isystem EQUAL -1)
|
||||
list(REMOVE_AT zmq_inc ${broken_isystem})
|
||||
set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc})
|
||||
endif()
|
||||
|
||||
target_link_libraries(lokimq PUBLIC PkgConfig::libzmq)
|
||||
else()
|
||||
set(lokimq_build_static_libzmq ON)
|
||||
|
@ -66,7 +68,7 @@ else()
|
|||
endif()
|
||||
|
||||
if(lokimq_build_static_libzmq)
|
||||
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled 4.3.2")
|
||||
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled version")
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/local-libzmq")
|
||||
include(LocalLibzmq)
|
||||
target_link_libraries(lokimq PUBLIC libzmq_vendor)
|
||||
|
@ -77,12 +79,11 @@ target_include_directories(lokimq
|
|||
$<INSTALL_INTERFACE:>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/cppzmq>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/mapbox-variant/include>
|
||||
)
|
||||
|
||||
target_compile_options(lokimq PRIVATE -Wall -Wextra -Werror)
|
||||
set_target_properties(lokimq PROPERTIES
|
||||
CXX_STANDARD 14
|
||||
CXX_STANDARD 17
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CXX_EXTENSIONS OFF
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
|
@ -101,7 +102,7 @@ endfunction()
|
|||
# If the caller has already set up a sodium target then we will just link to it, otherwise we go
|
||||
# looking for it.
|
||||
if(TARGET sodium)
|
||||
target_link_libraries(lokimq PRIVATE sodium)
|
||||
target_link_libraries(lokimq PUBLIC sodium)
|
||||
if(lokimq_build_static_libzmq)
|
||||
target_link_libraries(libzmq_vendor INTERFACE sodium)
|
||||
endif()
|
||||
|
@ -109,13 +110,13 @@ else()
|
|||
pkg_check_modules(sodium REQUIRED libsodium IMPORTED_TARGET)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_libraries(lokimq PRIVATE PkgConfig::sodium)
|
||||
target_link_libraries(lokimq PUBLIC PkgConfig::sodium)
|
||||
if(lokimq_build_static_libzmq)
|
||||
target_link_libraries(libzmq_vendor INTERFACE PkgConfig::sodium)
|
||||
endif()
|
||||
else()
|
||||
link_dep_libs(lokimq PRIVATE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_include_directories(lokimq PRIVATE ${sodium_STATIC_INCLUDE_DIRS})
|
||||
link_dep_libs(lokimq PUBLIC "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_include_directories(lokimq PUBLIC ${sodium_STATIC_INCLUDE_DIRS})
|
||||
if(lokimq_build_static_libzmq)
|
||||
link_dep_libs(libzmq_vendor INTERFACE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_link_libraries(libzmq_vendor INTERFACE ${sodium_STATIC_INCLUDE_DIRS})
|
||||
|
@ -137,9 +138,13 @@ install(
|
|||
)
|
||||
|
||||
install(
|
||||
FILES lokimq/auth.h
|
||||
FILES lokimq/address.h
|
||||
lokimq/auth.h
|
||||
lokimq/base32z.h
|
||||
lokimq/base64.h
|
||||
lokimq/batch.h
|
||||
lokimq/bt_serialize.h
|
||||
lokimq/bt_value.h
|
||||
lokimq/connections.h
|
||||
lokimq/hex.h
|
||||
lokimq/lokimq.h
|
||||
|
@ -148,17 +153,6 @@ install(
|
|||
${CMAKE_CURRENT_BINARY_DIR}/lokimq/version.h
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
|
||||
)
|
||||
option(LOKIMQ_INSTALL_MAPBOX_VARIANT "Install mapbox-variant headers with lokimq/ headers" ON)
|
||||
if(LOKIMQ_INSTALL_MAPBOX_VARIANT)
|
||||
install(
|
||||
FILES mapbox-variant/include/mapbox/variant.hpp
|
||||
mapbox-variant/include/mapbox/variant_cast.hpp
|
||||
mapbox-variant/include/mapbox/variant_io.hpp
|
||||
mapbox-variant/include/mapbox/variant_visitor.hpp
|
||||
mapbox-variant/include/mapbox/recursive_wrapper.hpp
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq/mapbox
|
||||
)
|
||||
endif()
|
||||
|
||||
option(LOKIMQ_INSTALL_CPPZMQ "Install cppzmq header with lokimq/ headers" ON)
|
||||
if(LOKIMQ_INSTALL_CPPZMQ)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# LokiMQ - zeromq-based message passing for Loki projects
|
||||
|
||||
This C++14 library contains an abstraction layer around ZeroMQ to support integration with Loki
|
||||
This C++17 library contains an abstraction layer around ZeroMQ to support integration with Loki
|
||||
authentication, RPC, and message passing. It is designed to be usable as the underlying
|
||||
communication mechanism of SN-to-SN communication ("quorumnet"), the RPC interface used by wallets
|
||||
and local daemon commands, communication channels between lokid and auxiliary services (storage
|
||||
|
@ -123,6 +123,7 @@ The connection ID generally has two possible values:
|
|||
places to get one, such as from the `Message` object passed to a command: see the following
|
||||
section).
|
||||
|
||||
```C++
|
||||
// Send to a service node, establishing a connection if necessary:
|
||||
std::string my_sn = ...; // 32-byte pubkey of a known SN
|
||||
lmq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
|
||||
|
@ -137,6 +138,7 @@ The connection ID generally has two possible values:
|
|||
else
|
||||
std::cout << "Timeout fetching height!";
|
||||
});
|
||||
```
|
||||
|
||||
## Command invocation
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
set(LIBZMQ_PREFIX ${CMAKE_BINARY_DIR}/libzmq)
|
||||
set(ZeroMQ_VERSION 4.3.2)
|
||||
set(ZeroMQ_VERSION 4.3.3)
|
||||
set(LIBZMQ_URL https://github.com/zeromq/libzmq/releases/download/v${ZeroMQ_VERSION}/zeromq-${ZeroMQ_VERSION}.tar.gz)
|
||||
set(LIBZMQ_HASH SHA512=b6251641e884181db9e6b0b705cced7ea4038d404bdae812ff47bdd0eed12510b6af6846b85cb96898e253ccbac71eca7fe588673300ddb9c3109c973250c8e4)
|
||||
set(LIBZMQ_HASH SHA512=4c18d784085179c5b1fcb753a93813095a12c8d34970f2e1bfca6499be6c9d67769c71c68b7ca54ff181b20390043170e89733c22f76ff1ea46494814f7095b1)
|
||||
|
||||
message(${LIBZMQ_URL})
|
||||
|
||||
|
|
|
@ -0,0 +1,351 @@
|
|||
#include "address.h"
|
||||
#include <tuple>
|
||||
#include <limits>
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <ostream>
|
||||
#include "hex.h"
|
||||
#include "base32z.h"
|
||||
#include "base64.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
constexpr size_t enc_length(address::encoding enc) {
|
||||
return enc == address::encoding::hex ? 64 :
|
||||
enc == address::encoding::base64 ? 43 : // this can be 44 with a padding byte, but we don't need it
|
||||
52 /*base32z*/;
|
||||
};
|
||||
|
||||
// Parses an encoding pubkey from the given string_view. Advanced the string_view beyond the
|
||||
// consumed pubkey data, and returns the pubkey (as a 32-byte string). Throws if no valid pubkey
|
||||
// was found at the beginning of addr. We look for hex, base32z, or base64 pubkeys *unless* qr is
|
||||
// given: for QR-friendly we only accept hex or base32z (since QR cannot handle base64's alphabet).
|
||||
std::string decode_pubkey(std::string_view& in, bool qr) {
|
||||
std::string pubkey;
|
||||
if (in.size() >= 64 && lokimq::is_hex(in.substr(0, 64))) {
|
||||
pubkey = from_hex(in.substr(0, 64));
|
||||
in.remove_prefix(64);
|
||||
} else if (in.size() >= 52 && lokimq::is_base32z(in.substr(0, 52))) {
|
||||
pubkey = from_base32z(in.substr(0, 52));
|
||||
in.remove_prefix(52);
|
||||
} else if (!qr && in.size() >= 43 && lokimq::is_base64(in.substr(0, 43))) {
|
||||
pubkey = from_base64(in.substr(0, 43));
|
||||
in.remove_prefix(43);
|
||||
if (!in.empty() && in.front() == '=')
|
||||
in.remove_prefix(1); // allow (and eat) a padding byte at the end
|
||||
} else {
|
||||
throw std::invalid_argument{"No pubkey found"};
|
||||
}
|
||||
return pubkey;
|
||||
}
|
||||
|
||||
// Parse the host, port, and optionally pubkey from a string view, mutating it to remove the parsed
|
||||
// sections. qr should be true if we should accept $IPv6$ as a QR-encoding-friendly alternative to
|
||||
// [IPv6] (the returned host will have the $ replaced, i.e. [IPv6]).
|
||||
std::tuple<std::string, uint16_t, std::string> parse_tcp(std::string_view& addr, bool qr, bool expect_pubkey) {
|
||||
std::tuple<std::string, uint16_t, std::string> result;
|
||||
auto& host = std::get<0>(result);
|
||||
if (addr.front() == '[' || (qr && addr.front() == '$')) { // IPv6 addr (though this is far from complete validation)
|
||||
auto pos = addr.find_first_not_of(":.1234567890abcdefABCDEF", 1);
|
||||
if (pos == std::string_view::npos)
|
||||
throw std::invalid_argument("Could not find terminating ] while parsing an IPv6 address");
|
||||
if (!(addr[pos] == ']' || (qr && addr[pos] == '$')))
|
||||
throw std::invalid_argument{"Expected " + (qr ? "$"s : "]"s) + " to close IPv6 address but found " + std::string(1, addr[pos])};
|
||||
host = std::string{addr.substr(0, pos+1)};
|
||||
if (qr) {
|
||||
if (host.front() == '$')
|
||||
host.front() = '[';
|
||||
if (host.back() == '$')
|
||||
host.back() = ']';
|
||||
}
|
||||
addr.remove_prefix(pos+1);
|
||||
} else {
|
||||
auto port_pos = addr.find(':');
|
||||
if (port_pos == std::string_view::npos)
|
||||
throw std::invalid_argument{"Could not determine host (no following ':port' found)"};
|
||||
if (port_pos == 0)
|
||||
throw std::invalid_argument{"Host cannot be empty"};
|
||||
host = std::string{addr.substr(0, port_pos)};
|
||||
addr.remove_prefix(port_pos);
|
||||
}
|
||||
|
||||
if (qr)
|
||||
// Lower-case the host because upper case hostnames are ugly
|
||||
for (char& c : host)
|
||||
if (c >= 'A' && c <= 'Z')
|
||||
c = c - 'A' + 'a';
|
||||
|
||||
if (addr.size() < 2 || addr[0] != ':')
|
||||
throw std::invalid_argument{"Could not find :port in address string"};
|
||||
addr.remove_prefix(1);
|
||||
auto pos = addr.find_first_not_of("1234567890");
|
||||
if (pos == 0)
|
||||
throw std::invalid_argument{"Could not find numeric port in address string"};
|
||||
if (pos == std::string_view::npos)
|
||||
pos = addr.size();
|
||||
size_t processed;
|
||||
int port_int = std::stoi(std::string{addr.substr(0, pos)}, &processed);
|
||||
if (port_int == 0 || processed != pos)
|
||||
throw std::invalid_argument{"Could not parse numeric port in address string"};
|
||||
if (port_int < 0 || port_int > std::numeric_limits<uint16_t>::max())
|
||||
throw std::invalid_argument{"Invalid port: port must be in range 1-65535"};
|
||||
std::get<1>(result) = static_cast<uint16_t>(port_int);
|
||||
addr.remove_prefix(pos);
|
||||
|
||||
if (expect_pubkey) {
|
||||
if (addr.size() < 1 + enc_length(qr ? address::encoding::base32z : address::encoding::base64)
|
||||
|| addr.front() != '/')
|
||||
throw std::invalid_argument{"Invalid address: expected /PUBKEY after port"};
|
||||
addr.remove_prefix(1);
|
||||
|
||||
std::get<2>(result) = decode_pubkey(addr, qr);
|
||||
if (!addr.empty())
|
||||
throw std::invalid_argument{"Invalid address: found unexpected trailing data after pubkey"};
|
||||
} else if (!addr.empty()) {
|
||||
throw std::invalid_argument{"Invalid address: found unexpected trailing data after port"};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Parse the socket path and (possibly) pubkey, mutating it to remove the parsed sections.
|
||||
// Currently the /pubkey *must* be at the end of the string, but this might not always be the case
|
||||
// (e.g. we could in the future support query string-like arguments).
|
||||
std::pair<std::string, std::string> parse_unix(std::string_view& addr, bool expect_pubkey) {
|
||||
std::pair<std::string, std::string> result;
|
||||
if (expect_pubkey) {
|
||||
size_t b64_len = addr.size() > 0 && addr.back() == '=' ? 44 : 43;
|
||||
if (addr.size() > 64 && addr[addr.size() - 65] == '/' && is_hex(addr.substr(addr.size() - 64))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - 65)};
|
||||
result.second = from_hex(addr.substr(addr.size() - 64));
|
||||
} else if (addr.size() > 52 && addr[addr.size() - 53] == '/' && is_base32z(addr.substr(addr.size() - 52))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - 53)};
|
||||
result.second = from_base32z(addr.substr(addr.size() - 52));
|
||||
} else if (addr.size() > b64_len && addr[addr.size() - b64_len - 1] == '/' && is_base64(addr.substr(addr.size() - b64_len))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - b64_len - 1)};
|
||||
result.second = from_base64(addr.substr(addr.size() - b64_len));
|
||||
} else {
|
||||
throw std::invalid_argument{"icp+curve:// requires a trailing /PUBKEY value, got: " + std::string{addr}};
|
||||
}
|
||||
} else {
|
||||
// Anything goes
|
||||
result.first = std::string{addr};
|
||||
}
|
||||
|
||||
// Any path above consumes everything:
|
||||
addr.remove_prefix(addr.size());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
address::address(std::string_view addr) {
|
||||
auto protoend = addr.find("://"sv);
|
||||
if (protoend == std::string_view::npos || protoend == 0)
|
||||
throw std::invalid_argument("Invalid address: no protocol found");
|
||||
auto pro = addr.substr(0, protoend);
|
||||
addr.remove_prefix(protoend + 3);
|
||||
if (addr.empty())
|
||||
throw std::invalid_argument("Invalid address: no value specified after protocol");
|
||||
bool qr = false;
|
||||
if (pro == "tcp") protocol = proto::tcp;
|
||||
else if (pro == "tcp+curve" || pro == "curve") protocol = proto::tcp_curve;
|
||||
else if (pro == "ipc") protocol = proto::ipc;
|
||||
else if (pro == "ipc+curve") protocol = proto::ipc_curve;
|
||||
else if (pro == "TCP") {
|
||||
protocol = proto::tcp;
|
||||
qr = true;
|
||||
} else if (pro == "CURVE") {
|
||||
protocol = proto::tcp_curve;
|
||||
qr = true;
|
||||
} else {
|
||||
throw std::invalid_argument("Invalid protocol '" + std::string{pro} + "'");
|
||||
}
|
||||
|
||||
if (qr) {
|
||||
// The QR variations only allow QR-alphanumeric characters (upper-case letters, numbers, and
|
||||
// a few symbols):
|
||||
for (char c : addr) {
|
||||
// QR alphanumeric also allows space, %, *, +, but we don't need or allow any of those here.
|
||||
if (!((c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '$' || c == ':' || c == '/' || c == '.' || c == '-'))
|
||||
throw std::invalid_argument("Found non-QR-alphanumeric value in QR TCP:// or CURVE:// address");
|
||||
}
|
||||
}
|
||||
|
||||
if (tcp())
|
||||
std::tie(host, port, pubkey) = parse_tcp(addr, qr, curve());
|
||||
else
|
||||
std::tie(socket, pubkey) = parse_unix(addr, curve());
|
||||
|
||||
if (!addr.empty())
|
||||
throw std::invalid_argument{"Invalid trailing garbage '" + std::string{addr} + "' in address"};
|
||||
}
|
||||
|
||||
address& address::set_pubkey(std::string_view pk) {
|
||||
if (pk.size() == 0) {
|
||||
if (protocol == proto::tcp_curve) protocol = proto::tcp;
|
||||
else if (protocol == proto::ipc_curve) protocol = proto::ipc;
|
||||
} else if (pk.size() == 32) {
|
||||
if (protocol == proto::tcp) protocol = proto::tcp_curve;
|
||||
else if (protocol == proto::ipc) protocol = proto::ipc_curve;
|
||||
} else {
|
||||
throw std::invalid_argument{"Invalid pubkey passed to set_pubkey(): require 0- or 32-byte pubkey"};
|
||||
}
|
||||
pubkey = pk;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string address::encode_pubkey(encoding enc) const {
|
||||
std::string pk;
|
||||
if (enc == encoding::hex)
|
||||
pk = to_hex(pubkey);
|
||||
else if (enc == encoding::base32z)
|
||||
pk = to_base32z(pubkey);
|
||||
else if (enc == encoding::BASE32Z) {
|
||||
pk = to_base32z(pubkey);
|
||||
for (char& c : pk)
|
||||
if (c >= 'a' && c <= 'z')
|
||||
c = c - 'a' + 'A';
|
||||
} else if (enc == encoding::base64) {
|
||||
pk = to_base64(pubkey);
|
||||
if (pk.size() == 44 && pk.back() == '=')
|
||||
pk.resize(43);
|
||||
} else {
|
||||
throw std::logic_error{"Invalid encoding"};
|
||||
}
|
||||
return pk;
|
||||
}
|
||||
|
||||
std::string address::full_address(encoding enc) const {
|
||||
std::string result;
|
||||
std::string pk;
|
||||
if (curve())
|
||||
pk = encode_pubkey(enc);
|
||||
|
||||
if (protocol == proto::tcp) {
|
||||
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
|
||||
result += "tcp://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
} else if (protocol == proto::tcp_curve) {
|
||||
result.reserve(8 /*curve:// */ + host.size() + 6 /*:port*/ + 1 /* / */ + pk.size());
|
||||
result += "curve://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
result += '/';
|
||||
result += pk;
|
||||
} else if (protocol == proto::ipc) {
|
||||
result.reserve(6 /*ipc:// */ + socket.size());
|
||||
result += "ipc://";
|
||||
result += socket;
|
||||
} else if (protocol == proto::ipc_curve) {
|
||||
result.reserve(12 /*ipc+curve:// */ + socket.size() + 1 /* / */ + pk.size());
|
||||
result += "ipc+curve://";
|
||||
result += socket;
|
||||
result += '/';
|
||||
result += pk;
|
||||
} else {
|
||||
throw std::logic_error{"Invalid protocol"};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string address::zmq_address() const {
|
||||
std::string result;
|
||||
if (tcp()) {
|
||||
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
|
||||
result += "tcp://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
} else {
|
||||
result.reserve(6 /*ipc:// */ + socket.size());
|
||||
result += "ipc://";
|
||||
result += socket;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string address::qr_address() const {
|
||||
if (protocol != proto::tcp && protocol != proto::tcp_curve)
|
||||
throw std::logic_error("Cannot construct a QR-friendly address for a non-TCP address");
|
||||
if (host.empty())
|
||||
throw std::logic_error("Cannot construct a QR-friendly address with an empty TCP host");
|
||||
std::string result;
|
||||
result.reserve((curve() ? 8 /*CURVE:// */ : 6 /*TCP:// */) + host.size() + 6 /*:port*/ +
|
||||
(curve() ? 1 + enc_length(encoding::BASE32Z) : 0));
|
||||
result += curve() ? "CURVE://" : "TCP://";
|
||||
std::string uc_host = host;
|
||||
for (auto& c : uc_host)
|
||||
if (c >= 'a' && c <= 'z')
|
||||
c = c - 'a' + 'A';
|
||||
|
||||
if (uc_host.front() == '[' && uc_host.back() == ']') {
|
||||
uc_host.front() = '$';
|
||||
uc_host.back() = '$';
|
||||
}
|
||||
result += uc_host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
|
||||
if (curve()) {
|
||||
result += '/';
|
||||
result += encode_pubkey(encoding::BASE32Z);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool address::operator==(const address& other) const {
|
||||
if (protocol != other.protocol)
|
||||
return false;
|
||||
if (tcp())
|
||||
if (host != other.host || port != other.port)
|
||||
return false;
|
||||
if (ipc())
|
||||
if (socket != other.socket)
|
||||
return false;
|
||||
if (curve())
|
||||
if (pubkey != other.pubkey)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
address address::tcp(std::string host, uint16_t port) {
|
||||
address a;
|
||||
a.protocol = proto::tcp;
|
||||
a.host = std::move(host);
|
||||
a.port = port;
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::tcp_curve(std::string host, uint16_t port, std::string pubkey) {
|
||||
address a;
|
||||
a.protocol = proto::tcp_curve;
|
||||
a.host = std::move(host);
|
||||
a.port = port;
|
||||
a.pubkey = std::move(pubkey);
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::ipc(std::string path) {
|
||||
address a;
|
||||
a.protocol = proto::ipc;
|
||||
a.socket = std::move(path);
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::ipc_curve(std::string path, std::string pubkey) {
|
||||
address a;
|
||||
a.protocol = proto::ipc_curve;
|
||||
a.socket = std::move(path);
|
||||
a.pubkey = std::move(pubkey);
|
||||
return a;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, const address& a) { return o << a.full_address(); }
|
||||
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
// Copyright (c) 2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <cstdint>
|
||||
#include <iosfwd>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
/** LokiMQ address abstraction class. This class uses and extends standard ZMQ addresses allowing
|
||||
* extra parameters to be passed in in a relative standard way.
|
||||
*
|
||||
* External ZMQ addresses generally have two forms that we are concerned with: one for TCP and one
|
||||
* for Unix sockets:
|
||||
*
|
||||
* tcp://HOST:PORT -- HOST can be a hostname, IPv4 address, or IPv6 address in [...]
|
||||
* ipc://PATH -- PATH can be absolute (ipc:///path/to/some.sock) or relative (ipc://some.sock)
|
||||
*
|
||||
* but this doesn't carry enough info: in particular, we can connect with two very different
|
||||
* protocols: curve25519-encrypted, or plaintext, but for curve25519-encrypted we require the
|
||||
* remote's public key as well to verify the connection.
|
||||
*
|
||||
* This class, then, handles this by allowing addresses of:
|
||||
*
|
||||
* Standard ZMQ address: these carry no pubkey and so the connection will be unencrypted:
|
||||
*
|
||||
* tcp://HOSTNAME:PORT
|
||||
* ipc://PATH
|
||||
*
|
||||
* Non-ZMQ address formats that specify that the connection shall be x25519 encrypted:
|
||||
*
|
||||
* curve://HOSTNAME:PORT/PUBKEY -- PUBKEY must be specified in hex (64 characters), base32z (52)
|
||||
* or base64 (43 or 44 with one '=' trailing padding)
|
||||
* ipc+curve:///path/to/my.sock/PUBKEY -- same requirements on PUBKEY as above.
|
||||
* tcp+curve://(whatever) -- alias for curve://(whatever)
|
||||
*
|
||||
* We also accept special upper-case TCP-only variants which *only* accept uppercase characters and
|
||||
* a few required symbols (:, /, $, ., and -) in the string:
|
||||
*
|
||||
* TCP://HOSTNAME:PORT
|
||||
* CURVE://HOSTNAME:PORT/B32ZPUBKEY
|
||||
*
|
||||
* These versions are explicitly meant to be used with QR codes; the upper-case-only requirement
|
||||
* allows a smaller QR code by allowing QR's alphanumeric mode (which allows only [A-Z0-9 $%*+./:-])
|
||||
* to be used. Such a QR-friendly address can be created from the qr_address() method. To support
|
||||
* literal IPv6 addresses we surround the address with $...$ instead of the usual [...].
|
||||
*
|
||||
* Note that this class does very little validate the host argument at all, and no socket path
|
||||
* validation whatsoever. The only constraint on host is when parsing an encoded address: we check
|
||||
* that it contains no : at all, or must be a [bracketed] expression that contains only hex
|
||||
* characters, :'s, or .'s. Otherwise, if you pass broken crap into the hostname, expect broken
|
||||
* crap out.
|
||||
*/
|
||||
struct address {
|
||||
/// Supported address protocols: TCP connections (tcp), or unix sockets (ipc).
|
||||
enum class proto {
|
||||
tcp,
|
||||
tcp_curve,
|
||||
ipc,
|
||||
ipc_curve
|
||||
};
|
||||
/// Supported public key encodings (used when regenerating an augmented address).
|
||||
enum class encoding {
|
||||
hex, ///< hexadecimal encoded
|
||||
base32z, ///< base32z encoded
|
||||
base64, ///< base64 encoded (*without* trailing = padding)
|
||||
BASE32Z ///< upper-case base32z encoding, meant for QR encoding
|
||||
};
|
||||
|
||||
/// The protocol: one of the `protocol` enum values for tcp or ipc (unix sockets), with or
|
||||
/// without _curve encryption.
|
||||
proto protocol = proto::tcp;
|
||||
/// The host for tcp connections; can be a hostname or IP address. If this is an IPv6 it must be surrounded with [ ].
|
||||
std::string host;
|
||||
/// The port (for tcp connections)
|
||||
uint16_t port = 0;
|
||||
/// The socket path (for unix socket connections)
|
||||
std::string socket;
|
||||
/// If a curve connection, this is the required remote public key (in bytes)
|
||||
std::string pubkey;
|
||||
|
||||
/// Default constructor; this gives you an unusable address.
|
||||
address() = default;
|
||||
|
||||
/**
|
||||
* Constructs an address by parsing a string_view containing one of the formats listed in the
|
||||
* class description. This is intentionally implicitly constructible so that you can pass a
|
||||
* string_view into anything expecting an `address`.
|
||||
*
|
||||
* Throw std::invalid_argument if the given address is not parseable.
|
||||
*/
|
||||
address(std::string_view addr);
|
||||
|
||||
/** Constructs an address from a remote string and a separate pubkey. Typically `remote` is a
|
||||
* basic ZMQ connect string, though this is not enforced. Any pubkey information embedded in
|
||||
* the remote string will be discarded and replaced with the given pubkey string. The result
|
||||
* will be curve encrypted if `pubkey` is non-empty, plaintext if `pubkey` is empty.
|
||||
*
|
||||
* Throws an exception if either addr or pubkey is invalid.
|
||||
*
|
||||
* Exactly equivalent to `address a{remote}; a.set_pubkey(pubkey);`
|
||||
*/
|
||||
address(std::string_view addr, std::string_view pubkey) : address(addr) { set_pubkey(pubkey); }
|
||||
|
||||
/// Replaces the address's pubkey (if any) with the given pubkey (or no pubkey if empty). If
|
||||
/// changing from pubkey to no-pubkey or no-pubkey to pubkey then the protocol is update to
|
||||
/// switch to or from curve encryption.
|
||||
///
|
||||
/// pubkey should be the 32-byte binary pubkey, or an empty string to remove an existing pubkey.
|
||||
///
|
||||
/// Returns the object itself, so that you can chain it.
|
||||
address& set_pubkey(std::string_view pubkey);
|
||||
|
||||
/// Constructs and builds the ZMQ connection address from the stored connection details. This
|
||||
/// does not contain any of the curve-related details; those must be specified separately when
|
||||
/// interfacing with ZMQ.
|
||||
std::string zmq_address() const;
|
||||
|
||||
/// Returns true if the connection was specified as a curve-encryption-enabled connection, false
|
||||
/// otherwise.
|
||||
bool curve() const { return protocol == proto::tcp_curve || protocol == proto::ipc_curve; }
|
||||
|
||||
/// True if the protocol is TCP (either with or without curve)
|
||||
bool tcp() const { return protocol == proto::tcp || protocol == proto::tcp_curve; }
|
||||
|
||||
/// True if the protocol is unix socket (either with or without curve)
|
||||
bool ipc() const { return !tcp(); }
|
||||
|
||||
/// Returns the full "augmented" address string (i.e. that could be passed in to the
|
||||
/// constructor). This will be equivalent (but not necessarily identical) to an augmented
|
||||
/// string passed into the constructor. Takes an optional encoding format for the pubkey (if
|
||||
/// any), which defaults to base32z.
|
||||
std::string full_address(encoding enc = encoding::base32z) const;
|
||||
|
||||
/// Returns a QR-code friendly address string. This returns an all-uppercase version of the
|
||||
/// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z
|
||||
/// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we replace the
|
||||
/// surround the
|
||||
/// address with $ instead of $
|
||||
///
|
||||
/// \throws std::logic_error if called on a unix socket address.
|
||||
std::string qr_address() const;
|
||||
|
||||
/// Returns `.pubkey` but encoded in the given format
|
||||
std::string encode_pubkey(encoding enc) const;
|
||||
|
||||
/// Returns true if two addresses are identical (i.e. same protocol and relevant protocol
|
||||
/// arguments).
|
||||
///
|
||||
/// Note that it is possible for addresses to connect to the same socket without being
|
||||
/// identical: for example, using "foo.sock" and "./foo.sock", or writing IPv6 addresses (or
|
||||
/// even IPv4 addresses) in slightly different ways). Such equivalent but non-equal values will
|
||||
/// result in a false return here.
|
||||
///
|
||||
/// Note also that we ignore irrelevant arguments: for example, we don't care whether pubkeys
|
||||
/// match when comparing two non-curve TCP addresses.
|
||||
bool operator==(const address& other) const;
|
||||
/// Negation of ==
|
||||
bool operator!=(const address& other) const { return !operator==(other); }
|
||||
|
||||
/// Factory function that constructs a TCP address from a host and port. The connection will be
|
||||
/// plaintext. If the host is an IPv6 address it *must* be surrounded with [ and ].
|
||||
static address tcp(std::string host, uint16_t port);
|
||||
|
||||
/// Factory function that constructs a curve-encrypted TCP address from a host, port, and remote
|
||||
/// pubkey. The pubkey must be 32 bytes. As above, IPv6 addresses must be specified as [addr].
|
||||
static address tcp_curve(std::string host, uint16_t, std::string pubkey);
|
||||
|
||||
/// Factory function that constructs a unix socket address from a path. The connection will be
|
||||
/// plaintext (which is usually fine for a socket since unix sockets are local machine).
|
||||
static address ipc(std::string path);
|
||||
|
||||
/// Factory function that constructs a unix socket address from a path and remote pubkey. The
|
||||
/// connection will be curve25519 encrypted; the remote pubkey must be 32 bytes.
|
||||
static address ipc_curve(std::string path, std::string pubkey);
|
||||
};
|
||||
|
||||
// Outputs address.full_address() when sent to an ostream.
|
||||
std::ostream& operator<<(std::ostream& o, const address& a);
|
||||
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
#include "lokimq.h"
|
||||
#include "hex.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
|
@ -12,7 +14,7 @@ namespace {
|
|||
|
||||
// Builds a ZMTP metadata key-value pair. These will be available on every message from that peer.
|
||||
// Keys must start with X- and be <= 255 characters.
|
||||
std::string zmtp_metadata(string_view key, string_view value) {
|
||||
std::string zmtp_metadata(std::string_view key, std::string_view value) {
|
||||
assert(key.size() > 2 && key.size() <= 255 && key[0] == 'X' && key[1] == '-');
|
||||
|
||||
std::string result;
|
||||
|
@ -63,7 +65,7 @@ bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info&
|
|||
msgs.push_back(create_message(peer.route));
|
||||
msgs.push_back(create_message(reply));
|
||||
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
|
||||
msgs.push_back(create_message("REPLY"_sv));
|
||||
msgs.push_back(create_message("REPLY"sv));
|
||||
msgs.push_back(create_message(view(data.front()))); // reply tag
|
||||
} else {
|
||||
msgs.push_back(create_message(view(cmd)));
|
||||
|
@ -87,7 +89,7 @@ void LokiMQ::set_active_sns(pubkey_set pubkeys) {
|
|||
proxy_set_active_sns(std::move(pubkeys));
|
||||
}
|
||||
}
|
||||
void LokiMQ::proxy_set_active_sns(string_view data) {
|
||||
void LokiMQ::proxy_set_active_sns(std::string_view data) {
|
||||
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(bt_deserialize<uintptr_t>(data)));
|
||||
}
|
||||
void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
|
||||
|
@ -204,7 +206,7 @@ void LokiMQ::process_zap_requests() {
|
|||
else
|
||||
o << v;
|
||||
}
|
||||
log_(LogLevel::trace, __FILE__, __LINE__, o.str());
|
||||
log(LogLevel::trace, __FILE__, __LINE__, o.str());
|
||||
} else
|
||||
#endif
|
||||
LMQ_LOG(debug, "Processing ZAP authentication request");
|
||||
|
@ -267,7 +269,7 @@ void LokiMQ::process_zap_requests() {
|
|||
status_text = "Invalid public key size for CURVE authentication";
|
||||
} else {
|
||||
auto ip = view(frames[3]);
|
||||
string_view pubkey;
|
||||
std::string_view pubkey;
|
||||
bool sn = false;
|
||||
if (bind[bind_id].second.curve) {
|
||||
pubkey = view(frames[6]);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#pragma once
|
||||
#include <iostream>
|
||||
#include <iosfwd>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
#include <cassert>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Compile-time generated lookup tables for base32z conversion. This is case insensitive (though
|
||||
/// for byte -> b32z conversion we always produce lower case).
|
||||
struct b32z_table {
|
||||
// Store the 0-31 decoded value of every possible char; all the chars that aren't valid are set
|
||||
// to 0. (If you don't trust your data, check it with is_base32z first, which uses these 0's
|
||||
// to detect invalid characters -- which is why we want a full 256 element array).
|
||||
char from_b32z_lut[256];
|
||||
// Store the encoded character of every 0-31 (5 bit) value.
|
||||
char to_b32z_lut[32];
|
||||
|
||||
// constexpr constructor that fills out the above (and should do it at compile time for any half
|
||||
// decent compiler).
|
||||
constexpr b32z_table() noexcept : from_b32z_lut{},
|
||||
to_b32z_lut{
|
||||
'y', 'b', 'n', 'd', 'r', 'f', 'g', '8', 'e', 'j', 'k', 'm', 'c', 'p', 'q', 'x',
|
||||
'o', 't', '1', 'u', 'w', 'i', 's', 'z', 'a', '3', '4', '5', 'h', '7', '6', '9'
|
||||
}
|
||||
{
|
||||
for (unsigned char c = 0; c < 32; c++) {
|
||||
unsigned char x = to_b32z_lut[c];
|
||||
from_b32z_lut[x] = c;
|
||||
if (x >= 'a' && x <= 'z')
|
||||
from_b32z_lut[x - 'a' + 'A'] = c;
|
||||
}
|
||||
}
|
||||
// Convert a b32z encoded character into a 0-31 value
|
||||
constexpr char from_b32z(unsigned char c) const noexcept { return from_b32z_lut[c]; }
|
||||
// Convert a 0-31 value into a b32z encoded character
|
||||
constexpr char to_b32z(unsigned char b) const noexcept { return to_b32z_lut[b]; }
|
||||
} constexpr b32z_lut;
|
||||
|
||||
// This main point of this static assert is to force the compiler to compile-time build the constexpr tables.
|
||||
static_assert(b32z_lut.from_b32z('w') == 20 && b32z_lut.from_b32z('T') == 17 && b32z_lut.to_b32z(5) == 'f', "");
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Converts bytes into a base32z encoded character sequence.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void to_base32z(InputIt begin, InputIt end, OutputIt out) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "to_base32z requires chars/bytes");
|
||||
int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in [0, 4]
|
||||
std::uint_fast16_t r = 0;
|
||||
while (begin != end) {
|
||||
r = r << 8 | static_cast<unsigned char>(*begin++);
|
||||
|
||||
// we just added 8 bits, so we can *always* consume 5 to produce one character, so (net) we
|
||||
// are adding 3 bits.
|
||||
bits += 3;
|
||||
*out++ = detail::b32z_lut.to_b32z(r >> bits); // Right-shift off the bits we aren't consuming right now
|
||||
|
||||
// Drop the bits we don't want to keep (because we just consumed them)
|
||||
r &= (1 << bits) - 1;
|
||||
|
||||
if (bits >= 5) { // We have enough bits to produce a second character; essentially the same as above
|
||||
bits -= 5; // Except now we are just consuming 5 without having added any more
|
||||
*out++ = detail::b32z_lut.to_b32z(r >> bits);
|
||||
r &= (1 << bits) - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (bits > 0) // We hit the end, but still have some unconsumed bits so need one final character to append
|
||||
*out++ = detail::b32z_lut.to_b32z(r << (5 - bits));
|
||||
}
|
||||
|
||||
/// Creates a base32z string from an iterator pair of a byte sequence.
|
||||
template <typename It>
|
||||
std::string to_base32z(It begin, It end) {
|
||||
std::string base32z;
|
||||
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
|
||||
base32z.reserve((std::distance(begin, end)*8 + 4) / 5); // == bytes*8/5, rounded up.
|
||||
to_base32z(begin, end, std::back_inserter(base32z));
|
||||
return base32z;
|
||||
}
|
||||
|
||||
/// Creates a base32z string from an iterable, std::string-like object
|
||||
template <typename CharT>
|
||||
std::string to_base32z(std::basic_string_view<CharT> s) { return to_base32z(s.begin(), s.end()); }
|
||||
inline std::string to_base32z(std::string_view s) { return to_base32z<>(s); }
|
||||
|
||||
/// Returns true if all elements in the range are base32z characters
|
||||
template <typename It>
|
||||
constexpr bool is_base32z(It begin, It end) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "is_base32z requires chars/bytes");
|
||||
for (; begin != end; ++begin) {
|
||||
auto c = static_cast<unsigned char>(*begin);
|
||||
if (detail::b32z_lut.from_b32z(c) == 0 && !(c == 'y' || c == 'Y'))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if all elements in the string-like value are base32z characters
|
||||
template <typename CharT>
|
||||
constexpr bool is_base32z(std::basic_string_view<CharT> s) { return is_base32z(s.begin(), s.end()); }
|
||||
constexpr bool is_base32z(std::string_view s) { return is_base32z<>(s); }
|
||||
|
||||
/// Converts a sequence of base32z digits to bytes. Undefined behaviour if any characters are not
|
||||
/// valid base32z alphabet characters. It is permitted for the input and output ranges to overlap
|
||||
/// as long as `out` is no earlier than `begin`. Note that if you pass in a sequence that could not
|
||||
/// have been created by a base32z encoding of a byte sequence, we treat the excess bits as if they
|
||||
/// were not provided.
|
||||
///
|
||||
/// For example, "yyy" represents a 15-bit value, but a byte sequence is either 8-bit (requiring 2
|
||||
/// characters) or 16-bit (requiring 4). Similarly, "yb" is an impossible encoding because it has
|
||||
/// its 10th bit set (b = 00001), but a base32z encoded value should have all 0's beyond the 8th (or
|
||||
/// 16th or 24th or ... bit). We treat any such bits as if they were not specified (even if they
|
||||
/// are): which means "yy", "yb", "yyy", "yy9", "yd", etc. all decode to the same 1-byte value "\0".
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void from_base32z(InputIt begin, InputIt end, OutputIt out) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "from_base32z requires chars/bytes");
|
||||
uint_fast16_t curr = 0;
|
||||
int bits = 0; // number of bits we've loaded into val; we always keep this < 8.
|
||||
while (begin != end) {
|
||||
curr = curr << 5 | detail::b32z_lut.from_b32z(static_cast<unsigned char>(*begin++));
|
||||
if (bits >= 3) {
|
||||
bits -= 3; // Added 5, removing 8
|
||||
*out++ = static_cast<uint8_t>(curr >> bits);
|
||||
curr &= (1 << bits) - 1;
|
||||
} else {
|
||||
bits += 5;
|
||||
}
|
||||
}
|
||||
|
||||
// Ignore any trailing bits. base32z encoding always has at least as many bits as the source
|
||||
// bytes, which means we should not be able to get here from a properly encoded b32z value with
|
||||
// anything other than 0s: if we have no extra bits (e.g. 5 bytes == 8 b32z chars) then we have
|
||||
// a 0-bit value; if we have some extra bits (e.g. 6 bytes requires 10 b32z chars, but that
|
||||
// contains 50 bits > 48 bits) then those extra bits will be 0s (and this covers the bits -= 3
|
||||
// case above: it'll leave us with 0-4 extra bits, but those extra bits would be 0 if produced
|
||||
// from an actual byte sequence).
|
||||
//
|
||||
// The "bits += 5" case, then, means that we could end with 5-7 bits. This, however, cannot be
|
||||
// produced by a valid encoding:
|
||||
// - 0 bytes gives us 0 chars with 0 leftover bits
|
||||
// - 1 byte gives us 2 chars with 2 leftover bits
|
||||
// - 2 bytes gives us 4 chars with 4 leftover bits
|
||||
// - 3 bytes gives us 5 chars with 1 leftover bit
|
||||
// - 4 bytes gives us 7 chars with 3 leftover bits
|
||||
// - 5 bytes gives us 8 chars with 0 leftover bits (this is where the cycle repeats)
|
||||
//
|
||||
// So really the only way we can get 5-7 leftover bits is if you took a 0, 2 or 5 char output (or
|
||||
// any 8n + {0,2,5} char output) and added a base32z character to the end. If you do that,
|
||||
// well, too bad: you're giving invalid output and so we're just going to pretend that extra
|
||||
// character you added isn't there by not doing anything here.
|
||||
}
|
||||
|
||||
/// Convert a base32z sequence into a std::string of bytes. Undefined behaviour if any characters
|
||||
/// are not valid (case-insensitive) base32z characters.
|
||||
template <typename It>
|
||||
std::string from_base32z(It begin, It end) {
|
||||
std::string bytes;
|
||||
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
|
||||
bytes.reserve((std::distance(begin, end)*5 + 7) / 8); // == chars*5/8, rounded up.
|
||||
from_base32z(begin, end, std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Converts base32z digits from a std::string-like object into a std::string of bytes. Undefined
|
||||
/// behaviour if any characters are not valid (case-insensitive) base32z characters.
|
||||
template <typename CharT>
|
||||
std::string from_base32z(std::basic_string_view<CharT> s) { return from_base32z(s.begin(), s.end()); }
|
||||
inline std::string from_base32z(std::string_view s) { return from_base32z<>(s); }
|
||||
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
#include <cassert>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Compile-time generated lookup tables for base64 conversion.
|
||||
struct b64_table {
|
||||
// Store the 0-63 decoded value of every possible char; all the chars that aren't valid are set
|
||||
// to 0. (If you don't trust your data, check it with is_base64 first, which uses these 0's
|
||||
// to detect invalid characters -- which is why we want a full 256 element array).
|
||||
char from_b64_lut[256];
|
||||
// Store the encoded character of every 0-63 (6 bit) value.
|
||||
char to_b64_lut[64];
|
||||
|
||||
// constexpr constructor that fills out the above (and should do it at compile time for any half
|
||||
// decent compiler).
|
||||
constexpr b64_table() noexcept : from_b64_lut{}, to_b64_lut{} {
|
||||
for (unsigned char c = 0; c < 26; c++) {
|
||||
from_b64_lut[(unsigned char)('A' + c)] = 0 + c;
|
||||
to_b64_lut[ (unsigned char)( 0 + c)] = 'A' + c;
|
||||
}
|
||||
for (unsigned char c = 0; c < 26; c++) {
|
||||
from_b64_lut[(unsigned char)('a' + c)] = 26 + c;
|
||||
to_b64_lut[ (unsigned char)(26 + c)] = 'a' + c;
|
||||
}
|
||||
for (unsigned char c = 0; c < 10; c++) {
|
||||
from_b64_lut[(unsigned char)('0' + c)] = 52 + c;
|
||||
to_b64_lut[ (unsigned char)(52 + c)] = '0' + c;
|
||||
}
|
||||
to_b64_lut[62] = '+'; from_b64_lut[(unsigned char) '+'] = 62;
|
||||
to_b64_lut[63] = '/'; from_b64_lut[(unsigned char) '/'] = 63;
|
||||
}
|
||||
// Convert a b64 encoded character into a 0-63 value
|
||||
constexpr char from_b64(unsigned char c) const noexcept { return from_b64_lut[c]; }
|
||||
// Convert a 0-31 value into a b64 encoded character
|
||||
constexpr char to_b64(unsigned char b) const noexcept { return to_b64_lut[b]; }
|
||||
} constexpr b64_lut;
|
||||
|
||||
// This main point of this static assert is to force the compiler to compile-time build the constexpr tables.
|
||||
static_assert(b64_lut.from_b64('/') == 63 && b64_lut.from_b64('7') == 59 && b64_lut.to_b64(38) == 'm', "");
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Converts bytes into a base64 encoded character sequence.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void to_base64(InputIt begin, InputIt end, OutputIt out) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "to_base64 requires chars/bytes");
|
||||
int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in {0, 2, 4}
|
||||
std::uint_fast16_t r = 0;
|
||||
while (begin != end) {
|
||||
r = r << 8 | static_cast<unsigned char>(*begin++);
|
||||
|
||||
// we just added 8 bits, so we can *always* consume 6 to produce one character, so (net) we
|
||||
// are adding 2 bits.
|
||||
bits += 2;
|
||||
*out++ = detail::b64_lut.to_b64(r >> bits); // Right-shift off the bits we aren't consuming right now
|
||||
|
||||
// Drop the bits we don't want to keep (because we just consumed them)
|
||||
r &= (1 << bits) - 1;
|
||||
|
||||
if (bits == 6) { // We have enough bits to produce a second character (which means we had 4 before and added 8)
|
||||
bits = 0;
|
||||
*out++ = detail::b64_lut.to_b64(r);
|
||||
r = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// If bits == 0 then we ended our 6-bit outputs coinciding with 8-bit values, i.e. at a multiple
|
||||
// of 24 bits: this means we don't have anything else to output and don't need any padding.
|
||||
if (bits == 2) {
|
||||
// We finished with 2 unconsumed bits, which means we ended 1 byte past a 24-bit group (e.g.
|
||||
// 1 byte, 4 bytes, 301 bytes, etc.); since we need to always be a multiple of 4 output
|
||||
// characters that means we've produced 1: so we right-fill 0s to get the next char, then
|
||||
// add two padding ='s.
|
||||
*out++ = detail::b64_lut.to_b64(r << 4);
|
||||
*out++ = '=';
|
||||
*out++ = '=';
|
||||
} else if (bits == 4) {
|
||||
// 4 bits left means we produced 2 6-bit values from the first 2 bytes of a 3-byte group.
|
||||
// Fill 0s to get the last one, plus one padding output.
|
||||
*out++ = detail::b64_lut.to_b64(r << 2);
|
||||
*out++ = '=';
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates and returns a base64 string from an iterator pair of a character sequence
|
||||
template <typename It>
|
||||
std::string to_base64(It begin, It end) {
|
||||
std::string base64;
|
||||
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
|
||||
base64.reserve((std::distance(begin, end) + 2) / 3 * 4); // bytes*4/3, rounded up to the next multiple of 4
|
||||
to_base64(begin, end, std::back_inserter(base64));
|
||||
return base64;
|
||||
}
|
||||
|
||||
/// Creates a base64 string from an iterable, std::string-like object
|
||||
template <typename CharT>
|
||||
std::string to_base64(std::basic_string_view<CharT> s) { return to_base64(s.begin(), s.end()); }
|
||||
inline std::string to_base64(std::string_view s) { return to_base64<>(s); }
|
||||
|
||||
/// Returns true if the range is a base64 encoded value; we allow (but do not require) '=' padding,
|
||||
/// but only at the end, only 1 or 2, and only if it pads out the total to a multiple of 4.
|
||||
template <typename It>
|
||||
constexpr bool is_base64(It begin, It end) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "is_base64 requires chars/bytes");
|
||||
using std::distance;
|
||||
using std::prev;
|
||||
|
||||
// Allow 1 or 2 padding chars *if* they pad it to a multiple of 4.
|
||||
if (begin != end && distance(begin, end) % 4 == 0) {
|
||||
auto last = prev(end);
|
||||
if (static_cast<unsigned char>(*last) == '=')
|
||||
end = last--;
|
||||
if (static_cast<unsigned char>(*last) == '=')
|
||||
end = last;
|
||||
}
|
||||
|
||||
for (; begin != end; ++begin) {
|
||||
auto c = static_cast<unsigned char>(*begin);
|
||||
if (detail::b64_lut.from_b64(c) == 0 && c != 'A')
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if the string-like value is a base64 encoded value
|
||||
template <typename CharT>
|
||||
constexpr bool is_base64(std::basic_string_view<CharT> s) { return is_base64(s.begin(), s.end()); }
|
||||
constexpr bool is_base64(std::string_view s) { return is_base64(s.begin(), s.end()); }
|
||||
|
||||
/// Converts a sequence of base64 digits to bytes. Undefined behaviour if any characters are not
|
||||
/// valid base64 alphabet characters. It is permitted for the input and output ranges to overlap as
|
||||
/// long as `out` is no earlier than `begin`. Trailing padding characters are permitted but not
|
||||
/// required.
|
||||
///
|
||||
/// It is possible to provide "impossible" base64 encoded values; for example "YWJja" which has 30
|
||||
/// bits of data even though a base64 encoded byte string should have 24 (4 chars) or 36 (6 chars)
|
||||
/// bits for a 3- and 4-byte input, respectively. We ignore any such "impossible" bits, and
|
||||
/// similarly ignore impossible bits in the bit "overhang"; that means "YWJjZA==" (the proper
|
||||
/// encoding of "abcd") and "YWJjZB", "YWJjZC", ..., "YWJjZP" all decode to the same "abcd" value:
|
||||
/// the last 4 bits of the last character are essentially considered padding.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void from_base64(InputIt begin, InputIt end, OutputIt out) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "from_base64 requires chars/bytes");
|
||||
uint_fast16_t curr = 0;
|
||||
int bits = 0; // number of bits we've loaded into val; we always keep this < 8.
|
||||
while (begin != end) {
|
||||
auto c = static_cast<unsigned char>(*begin++);
|
||||
|
||||
// padding; don't bother checking if we're at the end because is_base64 is a precondition
|
||||
// and we're allowed UB if it isn't satisfied.
|
||||
if (c == '=') continue;
|
||||
|
||||
curr = curr << 6 | detail::b64_lut.from_b64(c);
|
||||
if (bits == 0)
|
||||
bits = 6;
|
||||
else {
|
||||
bits -= 2; // Added 6, removing 8
|
||||
*out++ = static_cast<uint8_t>(curr >> bits);
|
||||
curr &= (1 << bits) - 1;
|
||||
}
|
||||
}
|
||||
// Don't worry about leftover bits because either they have to be 0, or they can't happen at
|
||||
// all. See base32z.h for why: the reasoning is exactly the same (except using 6 bits per
|
||||
// character here instead of 5).
|
||||
}
|
||||
|
||||
/// Converts base64 digits from a iterator pair of characters into a std::string of bytes.
|
||||
/// Undefined behaviour if any characters are not valid base64 characters.
|
||||
template <typename It>
|
||||
std::string from_base64(It begin, It end) {
|
||||
std::string bytes;
|
||||
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
|
||||
bytes.reserve(std::distance(begin, end)*6 / 8); // each digit carries 6 bits; this may overallocate by 1-2 bytes due to padding
|
||||
from_base64(begin, end, std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Converts base64 digits from a std::string-like object into a std::string of bytes. Undefined
|
||||
/// behaviour if any characters are not valid base64 characters.
|
||||
template <typename CharT>
|
||||
std::string from_base64(std::basic_string_view<CharT> s) { return from_base64(s.begin(), s.end()); }
|
||||
inline std::string from_base64(std::string_view s) { return from_base64<>(s); }
|
||||
|
||||
}
|
|
@ -36,18 +36,25 @@ namespace lokimq {
|
|||
|
||||
namespace detail {
|
||||
|
||||
enum class BatchStatus {
|
||||
enum class BatchState {
|
||||
running, // there are still jobs to run (or running)
|
||||
complete, // the batch is complete but still has a completion job to call
|
||||
complete_proxy, // same as `complete`, but the completion job should be invoked immediately in the proxy thread (be very careful)
|
||||
done // the batch is complete and has no completion function
|
||||
};
|
||||
|
||||
struct BatchStatus {
|
||||
BatchState state;
|
||||
int thread;
|
||||
};
|
||||
|
||||
// Virtual base class for Batch<R>
|
||||
class Batch {
|
||||
public:
|
||||
// Returns the number of jobs in this batch
|
||||
virtual size_t size() const = 0;
|
||||
// Returns the number of jobs in this batch and whether any of them are thread-specific
|
||||
virtual std::pair<size_t, bool> size() const = 0;
|
||||
// Returns a vector of exactly the same length of size().first containing the tagged thread ids
|
||||
// of the batch jobs or 0 for general jobs.
|
||||
virtual std::vector<int> threads() const = 0;
|
||||
// Called in a worker thread to run the job
|
||||
virtual void run_job(int i) = 0;
|
||||
// Called in the main proxy thread when the worker returns from finishing a job. The return
|
||||
|
@ -151,12 +158,13 @@ public:
|
|||
Batch &operator=(const Batch&) = delete;
|
||||
|
||||
private:
|
||||
std::vector<std::function<R()>> jobs;
|
||||
std::vector<std::pair<std::function<R()>, int>> jobs;
|
||||
std::vector<job_result<R>> results;
|
||||
CompletionFunc complete;
|
||||
std::size_t jobs_outstanding = 0;
|
||||
bool complete_in_proxy = false;
|
||||
int complete_in_thread = 0;
|
||||
bool started = false;
|
||||
bool tagged_thread_jobs = false;
|
||||
|
||||
void check_not_started() {
|
||||
if (started)
|
||||
|
@ -175,39 +183,61 @@ public:
|
|||
/// available. The called function may throw exceptions (which will be propagated to the
|
||||
/// completion function through the job_result values). There is no guarantee on the order of
|
||||
/// invocation of the jobs.
|
||||
void add_job(std::function<R()> job) {
|
||||
///
|
||||
/// \param job the callback
|
||||
/// \param thread an optional TaggedThreadID indicating a thread in which this job must run
|
||||
void add_job(std::function<R()> job, std::optional<TaggedThreadID> thread = std::nullopt) {
|
||||
check_not_started();
|
||||
jobs.emplace_back(std::move(job));
|
||||
results.emplace_back();
|
||||
jobs_outstanding++;
|
||||
if (thread && thread->_id == -1)
|
||||
// There are some special case internal jobs where we allow this, but they use the
|
||||
// private method below that doesn't have this check.
|
||||
throw std::logic_error{"Cannot add a proxy thread batch job -- this makes no sense"};
|
||||
add_job(std::move(job), thread ? thread->_id : 0);
|
||||
}
|
||||
|
||||
/// Sets the completion function to invoke after all jobs have finished. If this is not set
|
||||
/// then jobs simply run and results are discarded.
|
||||
void completion(CompletionFunc comp) {
|
||||
///
|
||||
/// \param comp - function to call when all jobs have finished
|
||||
/// \param thread - optional tagged thread in which to schedule the completion job. If not
|
||||
/// provided then the completion job is scheduled in the pool of batch job threads.
|
||||
///
|
||||
/// `thread` can be provided the value &LokiMQ::run_in_proxy to invoke the completion function
|
||||
/// *IN THE PROXY THREAD* itself after all jobs have finished. Be very, very careful: this
|
||||
/// should be a nearly trivial job that does not require any substantial CPU time and does not
|
||||
/// block for any reason. This is only intended for the case where the completion job is so
|
||||
/// trivial that it will take less time than simply queuing the job to be executed by another
|
||||
/// thread.
|
||||
void completion(CompletionFunc comp, std::optional<TaggedThreadID> thread = std::nullopt) {
|
||||
check_not_started();
|
||||
if (complete)
|
||||
throw std::logic_error("Completion function can only be set once");
|
||||
complete = std::move(comp);
|
||||
}
|
||||
|
||||
/// Sets a completion function to invoke *IN THE PROXY THREAD* after all jobs have finished. Be
|
||||
/// very, very careful: this should not be a job that takes any significant amount of CPU time
|
||||
/// or can block for any reason (NO MUTEXES).
|
||||
void completion_proxy(CompletionFunc comp) {
|
||||
check_not_started();
|
||||
if (complete)
|
||||
throw std::logic_error("Completion function can only be set once");
|
||||
complete = std::move(comp);
|
||||
complete_in_proxy = true;
|
||||
complete_in_thread = thread ? thread->_id : 0;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
std::size_t size() const override {
|
||||
return jobs.size();
|
||||
void add_job(std::function<R()> job, int thread_id) {
|
||||
jobs.emplace_back(std::move(job), thread_id);
|
||||
results.emplace_back();
|
||||
jobs_outstanding++;
|
||||
if (thread_id != 0)
|
||||
tagged_thread_jobs = true;
|
||||
}
|
||||
|
||||
std::pair<std::size_t, bool> size() const override {
|
||||
return {jobs.size(), tagged_thread_jobs};
|
||||
}
|
||||
|
||||
std::vector<int> threads() const override {
|
||||
std::vector<int> t;
|
||||
t.reserve(jobs.size());
|
||||
for (auto& j : jobs)
|
||||
t.push_back(j.second);
|
||||
return t;
|
||||
};
|
||||
|
||||
template <typename S = R>
|
||||
void set_value(job_result<S>& r, std::function<S()>& f) { r.set_value(f()); }
|
||||
void set_value(job_result<void>&, std::function<void()>& f) { f(); }
|
||||
|
@ -216,7 +246,7 @@ private:
|
|||
// called by worker thread
|
||||
auto& r = results[i];
|
||||
try {
|
||||
set_value(r, jobs[i]);
|
||||
set_value(r, jobs[i].first);
|
||||
} catch (...) {
|
||||
r.set_exception(std::current_exception());
|
||||
}
|
||||
|
@ -225,12 +255,10 @@ private:
|
|||
detail::BatchStatus job_finished() override {
|
||||
--jobs_outstanding;
|
||||
if (jobs_outstanding)
|
||||
return detail::BatchStatus::running;
|
||||
return {detail::BatchState::running, 0};
|
||||
if (complete)
|
||||
return complete_in_proxy
|
||||
? detail::BatchStatus::complete_proxy
|
||||
: detail::BatchStatus::complete;
|
||||
return detail::BatchStatus::done;
|
||||
return {detail::BatchState::complete, complete_in_thread};
|
||||
return {detail::BatchState::done, 0};
|
||||
}
|
||||
|
||||
void job_completion() override {
|
||||
|
@ -241,7 +269,7 @@ private:
|
|||
|
||||
template <typename R>
|
||||
void LokiMQ::batch(Batch<R>&& batch) {
|
||||
if (batch.size() == 0)
|
||||
if (batch.size().first == 0)
|
||||
throw std::logic_error("Cannot batch a a job batch with 0 jobs");
|
||||
// Need to send this over to the proxy thread via the base class pointer. It assumes ownership.
|
||||
auto* baseptr = static_cast<detail::Batch*>(new Batch<R>(std::move(batch)));
|
||||
|
|
|
@ -27,12 +27,13 @@
|
|||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include "bt_serialize.h"
|
||||
#include <iterator>
|
||||
|
||||
namespace lokimq {
|
||||
namespace detail {
|
||||
|
||||
/// Reads digits into an unsigned 64-bit int.
|
||||
uint64_t extract_unsigned(string_view& s) {
|
||||
uint64_t extract_unsigned(std::string_view& s) {
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid{"Expected 0-9 but found end of string"};
|
||||
if (s[0] < '0' || s[0] > '9')
|
||||
|
@ -48,7 +49,7 @@ uint64_t extract_unsigned(string_view& s) {
|
|||
return uval;
|
||||
}
|
||||
|
||||
void bt_deserialize<string_view>::operator()(string_view& s, string_view& val) {
|
||||
void bt_deserialize<std::string_view>::operator()(std::string_view& s, std::string_view& val) {
|
||||
if (s.size() < 2) throw bt_deserialize_invalid{"Deserialize failed: given data is not an bt-encoded string"};
|
||||
if (s[0] < '0' || s[0] > '9')
|
||||
throw bt_deserialize_invalid_type{"Expected 0-9 but found '"s + s[0] + "'"};
|
||||
|
@ -72,37 +73,33 @@ static_assert(std::numeric_limits<int64_t>::min() + std::numeric_limits<int64_t>
|
|||
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()) + uint64_t{1} == (uint64_t{1} << 63),
|
||||
"Non 2s-complement architecture not supported!");
|
||||
|
||||
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s) {
|
||||
std::pair<uint64_t, bool> bt_deserialize_integer(std::string_view& s) {
|
||||
// Smallest possible encoded integer is 3 chars: "i0e"
|
||||
if (s.size() < 3) throw bt_deserialize_invalid("Deserialization failed: end of string found where integer expected");
|
||||
if (s[0] != 'i') throw bt_deserialize_invalid_type("Deserialization failed: expected 'i', found '"s + s[0] + '\'');
|
||||
s.remove_prefix(1);
|
||||
std::pair<maybe_signed_int64_t, bool> result;
|
||||
std::pair<uint64_t, bool> result;
|
||||
if (s[0] == '-') {
|
||||
result.second = true;
|
||||
s.remove_prefix(1);
|
||||
}
|
||||
|
||||
uint64_t uval = extract_unsigned(s);
|
||||
result.first = extract_unsigned(s);
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: encountered end of string before integer was finished");
|
||||
if (s[0] != 'e')
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: expected digit or 'e', found '"s + s[0] + '\'');
|
||||
s.remove_prefix(1);
|
||||
if (result.second) { // negative
|
||||
if (uval > (uint64_t{1} << 63))
|
||||
throw bt_deserialize_invalid("Deserialization of integer failed: negative integer value is too large for a 64-bit signed int");
|
||||
result.first.i64 = -uval;
|
||||
} else {
|
||||
result.first.u64 = uval;
|
||||
}
|
||||
if (result.second /*negative*/ && result.first > (uint64_t{1} << 63))
|
||||
throw bt_deserialize_invalid("Deserialization of integer failed: negative integer value is too large for a 64-bit signed int");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template struct bt_deserialize<int64_t>;
|
||||
template struct bt_deserialize<uint64_t>;
|
||||
|
||||
void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
|
||||
void bt_deserialize<bt_value, void>::operator()(std::string_view& s, bt_value& val) {
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where bt-encoded value expected");
|
||||
|
||||
switch (s[0]) {
|
||||
|
@ -119,8 +116,9 @@ void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
|
|||
break;
|
||||
}
|
||||
case 'i': {
|
||||
auto read = bt_deserialize_integer(s);
|
||||
val = read.first.i64; // We only store an i64, but can get a u64 out of it via get<uint64_t>(val)
|
||||
auto [magnitude, negative] = bt_deserialize_integer(s);
|
||||
if (negative) val = -static_cast<int64_t>(magnitude);
|
||||
else val = magnitude;
|
||||
break;
|
||||
}
|
||||
case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': {
|
||||
|
@ -137,7 +135,7 @@ void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
|
|||
} // namespace detail
|
||||
|
||||
|
||||
bt_list_consumer::bt_list_consumer(string_view data_) : data{std::move(data_)} {
|
||||
bt_list_consumer::bt_list_consumer(std::string_view data_) : data{std::move(data_)} {
|
||||
if (data.empty()) throw std::runtime_error{"Cannot create a bt_list_consumer with an empty string_view"};
|
||||
if (data[0] != 'l') throw std::runtime_error{"Cannot create a bt_list_consumer with non-list data"};
|
||||
data.remove_prefix(1);
|
||||
|
@ -145,13 +143,13 @@ bt_list_consumer::bt_list_consumer(string_view data_) : data{std::move(data_)} {
|
|||
|
||||
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
|
||||
/// value is not a string.
|
||||
string_view bt_list_consumer::consume_string_view() {
|
||||
std::string_view bt_list_consumer::consume_string_view() {
|
||||
if (data.empty())
|
||||
throw bt_deserialize_invalid{"expected a string, but reached end of data"};
|
||||
else if (!is_string())
|
||||
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
|
||||
string_view next{data}, result;
|
||||
detail::bt_deserialize<string_view>{}(next, result);
|
||||
std::string_view next{data}, result;
|
||||
detail::bt_deserialize<std::string_view>{}(next, result);
|
||||
data = next;
|
||||
return result;
|
||||
}
|
||||
|
@ -174,7 +172,7 @@ void bt_list_consumer::skip_value() {
|
|||
throw bt_deserialize_invalid_type{"next bt value has unknown type"};
|
||||
}
|
||||
|
||||
string_view bt_list_consumer::consume_list_data() {
|
||||
std::string_view bt_list_consumer::consume_list_data() {
|
||||
auto start = data.begin();
|
||||
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
data.remove_prefix(1); // Descend into the sublist, consume the "l"
|
||||
|
@ -187,7 +185,7 @@ string_view bt_list_consumer::consume_list_data() {
|
|||
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
|
||||
}
|
||||
|
||||
string_view bt_list_consumer::consume_dict_data() {
|
||||
std::string_view bt_list_consumer::consume_dict_data() {
|
||||
auto start = data.begin();
|
||||
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
data.remove_prefix(1); // Descent into the dict, consumer the "d"
|
||||
|
@ -202,7 +200,7 @@ string_view bt_list_consumer::consume_dict_data() {
|
|||
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
|
||||
}
|
||||
|
||||
bt_dict_consumer::bt_dict_consumer(string_view data_) {
|
||||
bt_dict_consumer::bt_dict_consumer(std::string_view data_) {
|
||||
data = std::move(data_);
|
||||
if (data.empty()) throw std::runtime_error{"Cannot create a bt_dict_consumer with an empty string_view"};
|
||||
if (data.size() < 2 || data[0] != 'd') throw std::runtime_error{"Cannot create a bt_dict_consumer with non-dict data"};
|
||||
|
@ -220,10 +218,10 @@ bool bt_dict_consumer::consume_key() {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::pair<string_view, string_view> bt_dict_consumer::next_string() {
|
||||
std::pair<std::string_view, std::string_view> bt_dict_consumer::next_string() {
|
||||
if (!is_string())
|
||||
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
|
||||
std::pair<string_view, string_view> ret;
|
||||
std::pair<std::string_view, std::string_view> ret;
|
||||
ret.second = bt_list_consumer::consume_string_view();
|
||||
ret.first = flush_key();
|
||||
return ret;
|
||||
|
|
|
@ -26,21 +26,25 @@
|
|||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <cstring>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include "string_view.h"
|
||||
#include "mapbox/variant.hpp"
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <algorithm>
|
||||
|
||||
#include "bt_value.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
|
@ -54,16 +58,17 @@ using namespace std::literals;
|
|||
*
|
||||
* On the C++ side, on input we allow strings, integral types, STL-like containers of these types,
|
||||
* and STL-like containers of pairs with a string first value and any of these types as second
|
||||
* value. We also accept std::variants (if compiled with std::variant support, i.e. in C++17 mode)
|
||||
* that contain any of these, and mapbox::util::variants (the internal type used for its recursive
|
||||
* support).
|
||||
* value. We also accept std::variants of these.
|
||||
*
|
||||
* One minor deviation from BEP-0003 is that we don't support serializing values that don't fit in a
|
||||
* 64-bit integer (BEP-0003 specifies arbitrary precision integers).
|
||||
*
|
||||
* On deserialization we can either deserialize into a mapbox::util::variant that supports everything, or
|
||||
* we can fill a container of your given type (though this fails if the container isn't compatible
|
||||
* with the deserialized data).
|
||||
* On deserialization we can either deserialize into a special bt_value type supports everything
|
||||
* (with arbitrary nesting), or we can fill a container of your given type (though this fails if the
|
||||
* container isn't compatible with the deserialized data).
|
||||
*
|
||||
* There is also a stream deserialization that allows you to deserialize without needing heap
|
||||
* allocations (as long as you know the precise data structure layout).
|
||||
*/
|
||||
|
||||
/// Exception throw if deserialization fails
|
||||
|
@ -80,107 +85,69 @@ class bt_deserialize_invalid_type : public bt_deserialize_invalid {
|
|||
using bt_deserialize_invalid::bt_deserialize_invalid;
|
||||
};
|
||||
|
||||
class bt_list;
|
||||
class bt_dict;
|
||||
|
||||
/// Special type wrapper for storing a uint64_t value that may need to be larger than an int64_t.
|
||||
/// You *can* shove a uint64_t directly into a bt_value, but it will end up on the wire as its
|
||||
/// 2s-complement int64_t value; using this wrapper instead allows you to force a 64-bit positive
|
||||
/// integer onto the wire.
|
||||
struct bt_u64 { uint64_t val; explicit bt_u64(uint64_t val) : val{val} {} };
|
||||
|
||||
/// Recursive generic type that can fully represent everything valid for a BT serialization.
|
||||
using bt_value = mapbox::util::variant<
|
||||
std::string,
|
||||
string_view,
|
||||
int64_t,
|
||||
bt_u64,
|
||||
mapbox::util::recursive_wrapper<bt_list>,
|
||||
mapbox::util::recursive_wrapper<bt_dict>
|
||||
>;
|
||||
|
||||
/// Very thin wrapper around a std::list<bt_value> that holds a list of generic values (though *any*
|
||||
/// compatible data type can be used).
|
||||
class bt_list : public std::list<bt_value> {
|
||||
using std::list<bt_value>::list;
|
||||
};
|
||||
/// Very thin wrapper around a std::unordered_map<bt_value> that holds a list of string -> generic
|
||||
/// value pairs (though *any* compatible data type can be used).
|
||||
class bt_dict : public std::unordered_map<std::string, bt_value> {
|
||||
using std::unordered_map<std::string, bt_value>::unordered_map;
|
||||
};
|
||||
|
||||
#ifdef __cpp_lib_void_t
|
||||
using std::void_t;
|
||||
#else
|
||||
/// C++17 void_t backport
|
||||
template <typename... Ts> struct void_t_impl { using type = void; };
|
||||
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Reads digits into an unsigned 64-bit int.
|
||||
uint64_t extract_unsigned(string_view& s);
|
||||
inline uint64_t extract_unsigned(string_view&& s) { return extract_unsigned(s); }
|
||||
uint64_t extract_unsigned(std::string_view& s);
|
||||
// (Provide non-constant lvalue and rvalue ref functions so that we only accept explicit
|
||||
// string_views but not implicitly converted ones)
|
||||
inline uint64_t extract_unsigned(std::string_view&& s) { return extract_unsigned(s); }
|
||||
|
||||
// Fallback base case; we only get here if none of the partial specializations below work
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct bt_serialize { static_assert(!std::is_same<T, T>::value, "Cannot serialize T: unsupported type for bt serialization"); };
|
||||
struct bt_serialize { static_assert(!std::is_same_v<T, T>, "Cannot serialize T: unsupported type for bt serialization"); };
|
||||
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct bt_deserialize { static_assert(!std::is_same<T, T>::value, "Cannot deserialize T: unsupported type for bt deserialization"); };
|
||||
struct bt_deserialize { static_assert(!std::is_same_v<T, T>, "Cannot deserialize T: unsupported type for bt deserialization"); };
|
||||
|
||||
/// Checks that we aren't at the end of a string view and throws if we are.
|
||||
inline void bt_need_more(const string_view &s) {
|
||||
inline void bt_need_more(const std::string_view &s) {
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid{"Unexpected end of string while deserializing"};
|
||||
}
|
||||
|
||||
union maybe_signed_int64_t { int64_t i64; uint64_t u64; };
|
||||
|
||||
/// Deserializes a signed or unsigned 64-bit integer from a string. Sets the second bool to true
|
||||
/// iff the value is int64_t because a negative value was read. Throws an exception if the read
|
||||
/// value doesn't fit in a int64_t (if negative) or a uint64_t (if positive). Removes consumed
|
||||
/// characters from the string_view.
|
||||
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s);
|
||||
/// iff the value read was negative, false if positive; in either case the unsigned value is return
|
||||
/// in .first. Throws an exception if the read value doesn't fit in a int64_t (if negative) or a
|
||||
/// uint64_t (if positive). Removes consumed characters from the string_view.
|
||||
std::pair<uint64_t, bool> bt_deserialize_integer(std::string_view& s);
|
||||
|
||||
/// Integer specializations
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<std::is_integral<T>::value>> {
|
||||
struct bt_serialize<T, std::enable_if_t<std::is_integral_v<T>>> {
|
||||
static_assert(sizeof(T) <= sizeof(uint64_t), "Serialization of integers larger than uint64_t is not supported");
|
||||
void operator()(std::ostream &os, const T &val) {
|
||||
// Cast 1-byte types to a larger type to avoid iostream interpreting them as single characters
|
||||
using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t<std::is_signed<T>::value, int, unsigned>>;
|
||||
using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t<std::is_signed_v<T>, int, unsigned>>;
|
||||
os << 'i' << static_cast<output_type>(val) << 'e';
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<std::is_integral<T>::value>> {
|
||||
void operator()(string_view& s, T &val) {
|
||||
struct bt_deserialize<T, std::enable_if_t<std::is_integral_v<T>>> {
|
||||
void operator()(std::string_view& s, T &val) {
|
||||
constexpr uint64_t umax = static_cast<uint64_t>(std::numeric_limits<T>::max());
|
||||
constexpr int64_t smin = static_cast<int64_t>(std::numeric_limits<T>::min()),
|
||||
smax = static_cast<int64_t>(std::numeric_limits<T>::max());
|
||||
constexpr int64_t smin = static_cast<int64_t>(std::numeric_limits<T>::min());
|
||||
|
||||
auto read = bt_deserialize_integer(s);
|
||||
if (std::is_signed<T>::value) {
|
||||
if (!read.second) { // read a positive value
|
||||
if (read.first.u64 > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(read.first.u64);
|
||||
auto [magnitude, negative] = bt_deserialize_integer(s);
|
||||
|
||||
if (std::is_signed_v<T>) {
|
||||
if (!negative) {
|
||||
if (magnitude > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(magnitude);
|
||||
} else {
|
||||
bool oob = read.first.i64 < smin || read.first.i64 > smax;
|
||||
if (sizeof(T) < sizeof(int64_t) && oob)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found out-of-range value " + std::to_string(read.first.i64) + " not in [" + std::to_string(smin) + "," + std::to_string(smax) + "]");
|
||||
val = static_cast<T>(read.first.i64);
|
||||
auto sval = -static_cast<int64_t>(magnitude);
|
||||
if (!std::is_same_v<T, int64_t> && sval < smin)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-low value " + std::to_string(sval) + " < " + std::to_string(smin));
|
||||
val = static_cast<T>(sval);
|
||||
}
|
||||
} else {
|
||||
if (read.second)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found negative value " + std::to_string(read.first.i64) + " but type is unsigned");
|
||||
if (sizeof(T) < sizeof(uint64_t) && read.first.u64 > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(read.first.u64);
|
||||
if (negative)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found negative value -" + std::to_string(magnitude) + " but type is unsigned");
|
||||
if (!std::is_same_v<T, uint64_t> && magnitude > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(magnitude);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -188,73 +155,78 @@ struct bt_deserialize<T, std::enable_if_t<std::is_integral<T>::value>> {
|
|||
extern template struct bt_deserialize<int64_t>;
|
||||
extern template struct bt_deserialize<uint64_t>;
|
||||
|
||||
template<>
|
||||
struct bt_serialize<bt_u64> { void operator()(std::ostream& os, bt_u64 val) { bt_serialize<uint64_t>{}(os, val.val); } };
|
||||
template<>
|
||||
struct bt_deserialize<bt_u64> { void operator()(string_view& s, bt_u64& val) { bt_deserialize<uint64_t>{}(s, val.val); } };
|
||||
|
||||
template <>
|
||||
struct bt_serialize<string_view> {
|
||||
void operator()(std::ostream &os, const string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); }
|
||||
struct bt_serialize<std::string_view> {
|
||||
void operator()(std::ostream &os, const std::string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); }
|
||||
};
|
||||
template <>
|
||||
struct bt_deserialize<string_view> {
|
||||
void operator()(string_view& s, string_view& val);
|
||||
struct bt_deserialize<std::string_view> {
|
||||
void operator()(std::string_view& s, std::string_view& val);
|
||||
};
|
||||
|
||||
/// String specialization
|
||||
template <>
|
||||
struct bt_serialize<std::string> {
|
||||
void operator()(std::ostream &os, const std::string &val) { bt_serialize<string_view>{}(os, val); }
|
||||
void operator()(std::ostream &os, const std::string &val) { bt_serialize<std::string_view>{}(os, val); }
|
||||
};
|
||||
template <>
|
||||
struct bt_deserialize<std::string> {
|
||||
void operator()(string_view& s, std::string& val) { string_view view; bt_deserialize<string_view>{}(s, view); val = {view.data(), view.size()}; }
|
||||
void operator()(std::string_view& s, std::string& val) { std::string_view view; bt_deserialize<std::string_view>{}(s, view); val = {view.data(), view.size()}; }
|
||||
};
|
||||
|
||||
/// char * and string literals -- we allow serialization for convenience, but not deserialization
|
||||
template <>
|
||||
struct bt_serialize<char *> {
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, std::strlen(str)}); }
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<std::string_view>{}(os, {str, std::strlen(str)}); }
|
||||
};
|
||||
template <size_t N>
|
||||
struct bt_serialize<char[N]> {
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, N-1}); }
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<std::string_view>{}(os, {str, N-1}); }
|
||||
};
|
||||
|
||||
/// Partial dict validity; we don't check the second type for serializability, that will be handled
|
||||
/// via the base case static_assert if invalid.
|
||||
template <typename T, typename = void> struct is_bt_input_dict_container : std::false_type {};
|
||||
template <typename T, typename = void> struct is_bt_input_dict_container_impl : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_input_dict_container<T, std::enable_if_t<
|
||||
std::is_same<std::string, std::remove_cv_t<typename T::value_type::first_type>>::value,
|
||||
void_t<typename T::const_iterator /* is const iterable */,
|
||||
struct is_bt_input_dict_container_impl<T, std::enable_if_t<
|
||||
std::is_same_v<std::string, std::remove_cv_t<typename T::value_type::first_type>> ||
|
||||
std::is_same_v<std::string_view, std::remove_cv_t<typename T::value_type::first_type>>,
|
||||
std::void_t<typename T::const_iterator /* is const iterable */,
|
||||
typename T::value_type::second_type /* has a second type */>>>
|
||||
: std::true_type {};
|
||||
|
||||
/// Determines whether the type looks like something we can insert into (using `v.insert(v.end(), x)`)
|
||||
template <typename T, typename = void> struct is_bt_insertable : std::false_type {};
|
||||
template <typename T, typename = void> struct is_bt_insertable_impl : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_insertable<T,
|
||||
void_t<decltype(std::declval<T>().insert(std::declval<T>().end(), std::declval<typename T::value_type>()))>>
|
||||
struct is_bt_insertable_impl<T,
|
||||
std::void_t<decltype(std::declval<T>().insert(std::declval<T>().end(), std::declval<typename T::value_type>()))>>
|
||||
: std::true_type {};
|
||||
template <typename T>
|
||||
constexpr bool is_bt_insertable = is_bt_insertable_impl<T>::value;
|
||||
|
||||
/// Determines whether the given type looks like a compatible map (i.e. has std::string keys) that
|
||||
/// we can insert into.
|
||||
template <typename T, typename = void> struct is_bt_output_dict_container : std::false_type {};
|
||||
template <typename T, typename = void> struct is_bt_output_dict_container_impl : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_output_dict_container<T, std::enable_if_t<
|
||||
std::is_same<std::string, std::remove_cv_t<typename T::key_type>>::value &&
|
||||
is_bt_insertable<T>::value,
|
||||
void_t<typename T::value_type::second_type /* has a second type */>>>
|
||||
struct is_bt_output_dict_container_impl<T, std::enable_if_t<
|
||||
std::is_same_v<std::string, std::remove_cv_t<typename T::value_type::first_type>> && is_bt_insertable<T>,
|
||||
std::void_t<typename T::value_type::second_type /* has a second type */>>>
|
||||
: std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_bt_output_dict_container = is_bt_output_dict_container_impl<T>::value;
|
||||
template <typename T>
|
||||
constexpr bool is_bt_input_dict_container = is_bt_output_dict_container_impl<T>::value;
|
||||
|
||||
// Sanity checks:
|
||||
static_assert(is_bt_input_dict_container<bt_dict>);
|
||||
static_assert(is_bt_output_dict_container<bt_dict>);
|
||||
|
||||
/// Specialization for a dict-like container (such as an unordered_map). We accept anything for a
|
||||
/// dict that is const iterable over something that looks like a pair with std::string for first
|
||||
/// value type. The value (i.e. second element of the pair) also must be serializable.
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>::value>> {
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>>> {
|
||||
using second_type = typename T::value_type::second_type;
|
||||
using ref_pair = std::reference_wrapper<const typename T::value_type>;
|
||||
void operator()(std::ostream &os, const T &dict) {
|
||||
|
@ -273,9 +245,9 @@ struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>::value>> {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>::value>> {
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>>> {
|
||||
using second_type = typename T::value_type::second_type;
|
||||
void operator()(string_view& s, T& dict) {
|
||||
void operator()(std::string_view& s, T& dict) {
|
||||
// Smallest dict is 2 bytes "de", for an empty dict.
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where dict expected");
|
||||
if (s[0] != 'd') throw bt_deserialize_invalid_type("Deserialization failed: expected 'd', found '"s + s[0] + "'"s);
|
||||
|
@ -300,26 +272,31 @@ struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>::value>
|
|||
|
||||
/// Accept anything that looks iterable; value serialization validity isn't checked here (it fails
|
||||
/// via the base case static assert).
|
||||
template <typename T, typename = void> struct is_bt_input_list_container : std::false_type {};
|
||||
template <typename T, typename = void> struct is_bt_input_list_container_impl : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_input_list_container<T, std::enable_if_t<
|
||||
!std::is_same<T, std::string>::value &&
|
||||
!is_bt_input_dict_container<T>::value,
|
||||
void_t<typename T::const_iterator, typename T::value_type>>>
|
||||
struct is_bt_input_list_container_impl<T, std::enable_if_t<
|
||||
!std::is_same_v<T, std::string> && !std::is_same_v<T, std::string_view> && !is_bt_input_dict_container<T>,
|
||||
std::void_t<typename T::const_iterator, typename T::value_type>>>
|
||||
: std::true_type {};
|
||||
|
||||
template <typename T, typename = void> struct is_bt_output_list_container : std::false_type {};
|
||||
template <typename T, typename = void> struct is_bt_output_list_container_impl : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_output_list_container<T, std::enable_if_t<
|
||||
!std::is_same<T, std::string>::value &&
|
||||
!is_bt_output_dict_container<T>::value &&
|
||||
is_bt_insertable<T>::value>>
|
||||
struct is_bt_output_list_container_impl<T, std::enable_if_t<
|
||||
!std::is_same_v<T, std::string> && !is_bt_output_dict_container<T> && is_bt_insertable<T>>>
|
||||
: std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_bt_output_list_container = is_bt_output_list_container_impl<T>::value;
|
||||
template <typename T>
|
||||
constexpr bool is_bt_input_list_container = is_bt_input_list_container_impl<T>::value;
|
||||
|
||||
// Sanity checks:
|
||||
static_assert(is_bt_input_list_container<bt_list>);
|
||||
static_assert(is_bt_output_list_container<bt_list>);
|
||||
|
||||
/// List specialization
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>::value>> {
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>>> {
|
||||
void operator()(std::ostream& os, const T& list) {
|
||||
os << 'l';
|
||||
for (const auto &v : list)
|
||||
|
@ -328,9 +305,9 @@ struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>::value>> {
|
|||
}
|
||||
};
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>::value>> {
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>>> {
|
||||
using value_type = typename T::value_type;
|
||||
void operator()(string_view& s, T& list) {
|
||||
void operator()(std::string_view& s, T& list) {
|
||||
// Smallest list is 2 bytes "le", for an empty list.
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where list expected");
|
||||
if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization failed: expected 'l', found '"s + s[0] + "'"s);
|
||||
|
@ -348,44 +325,88 @@ struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>::value>
|
|||
}
|
||||
};
|
||||
|
||||
/// variant visitor; serializes whatever is contained
|
||||
class bt_serialize_visitor {
|
||||
std::ostream &os;
|
||||
/// Serializes a tuple or pair of serializable values (as a list on the wire)
|
||||
|
||||
/// Common implementation for both tuple and pair:
|
||||
template <template<typename...> typename Tuple, typename... T>
|
||||
struct bt_serialize_tuple {
|
||||
private:
|
||||
template <size_t... Is>
|
||||
void operator()(std::ostream& os, const Tuple<T...>& elems, std::index_sequence<Is...>) {
|
||||
os << 'l';
|
||||
(bt_serialize<T>{}(os, std::get<Is>(elems)), ...);
|
||||
os << 'e';
|
||||
}
|
||||
public:
|
||||
using result_type = void;
|
||||
bt_serialize_visitor(std::ostream &os) : os{os} {}
|
||||
template <typename T> void operator()(const T &val) const {
|
||||
bt_serialize<T>{}(os, val);
|
||||
void operator()(std::ostream& os, const Tuple<T...>& elems) {
|
||||
operator()(os, elems, std::index_sequence_for<T...>{});
|
||||
}
|
||||
};
|
||||
template <template<typename...> typename Tuple, typename... T>
|
||||
struct bt_deserialize_tuple {
|
||||
private:
|
||||
template <size_t... Is>
|
||||
void operator()(std::string_view& s, Tuple<T...>& elems, std::index_sequence<Is...>) {
|
||||
// Smallest list is 2 bytes "le", for an empty list.
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where tuple expected");
|
||||
if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization of tuple failed: expected 'l', found '"s + s[0] + "'"s);
|
||||
s.remove_prefix(1);
|
||||
(bt_deserialize<T>{}(s, std::get<Is>(elems)), ...);
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before tuple was finished");
|
||||
if (s[0] != 'e')
|
||||
throw bt_deserialize_invalid("Deserialization failed: expected end of tuple but found something else");
|
||||
s.remove_prefix(1); // Consume the 'e'
|
||||
}
|
||||
public:
|
||||
void operator()(std::string_view& s, Tuple<T...>& elems) {
|
||||
operator()(s, elems, std::index_sequence_for<T...>{});
|
||||
}
|
||||
};
|
||||
template <typename... T>
|
||||
struct bt_serialize<std::tuple<T...>> : bt_serialize_tuple<std::tuple, T...> {};
|
||||
template <typename... T>
|
||||
struct bt_deserialize<std::tuple<T...>> : bt_deserialize_tuple<std::tuple, T...> {};
|
||||
template <typename S, typename T>
|
||||
struct bt_serialize<std::pair<S, T>> : bt_serialize_tuple<std::pair, S, T> {};
|
||||
template <typename S, typename T>
|
||||
struct bt_deserialize<std::pair<S, T>> : bt_deserialize_tuple<std::pair, S, T> {};
|
||||
|
||||
template <typename T>
|
||||
using is_bt_deserializable = std::integral_constant<bool,
|
||||
std::is_same<T, std::string>::value || std::is_integral<T>::value ||
|
||||
is_bt_output_dict_container<T>::value || is_bt_output_list_container<T>::value>;
|
||||
constexpr bool is_bt_tuple = false;
|
||||
template <typename... T>
|
||||
constexpr bool is_bt_tuple<std::tuple<T...>> = true;
|
||||
template <typename S, typename T>
|
||||
constexpr bool is_bt_tuple<std::pair<S, T>> = true;
|
||||
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_bt_deserializable = std::is_same_v<T, std::string> || std::is_integral_v<T> ||
|
||||
is_bt_output_dict_container<T> || is_bt_output_list_container<T> || is_bt_tuple<T>;
|
||||
|
||||
// General template and base case; this base will only actually be invoked when Ts... is empty,
|
||||
// which means we reached the end without finding any variant type capable of holding the value.
|
||||
template <typename SFINAE, typename Variant, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl {
|
||||
void operator()(string_view&, Variant&) {
|
||||
void operator()(std::string_view&, Variant&) {
|
||||
throw bt_deserialize_invalid("Deserialization failed: could not deserialize value into any variant type");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts, typename Variant>
|
||||
void bt_deserialize_try_variant(string_view& s, Variant& variant) {
|
||||
void bt_deserialize_try_variant(std::string_view& s, Variant& variant) {
|
||||
bt_deserialize_try_variant_impl<void, Variant, Ts...>{}(s, variant);
|
||||
}
|
||||
|
||||
|
||||
template <typename Variant, typename T, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>::value>, Variant, T, Ts...> {
|
||||
void operator()(string_view& s, Variant& variant) {
|
||||
if ( is_bt_output_list_container<T>::value ? s[0] == 'l' :
|
||||
is_bt_output_dict_container<T>::value ? s[0] == 'd' :
|
||||
std::is_integral<T>::value ? s[0] == 'i' :
|
||||
std::is_same<T, std::string>::value ? s[0] >= '0' && s[0] <= '9' :
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>>, Variant, T, Ts...> {
|
||||
void operator()(std::string_view& s, Variant& variant) {
|
||||
if ( is_bt_output_list_container<T> ? s[0] == 'l' :
|
||||
is_bt_tuple<T> ? s[0] == 'l' :
|
||||
is_bt_output_dict_container<T> ? s[0] == 'd' :
|
||||
std::is_integral_v<T> ? s[0] == 'i' :
|
||||
std::is_same_v<T, std::string> ? s[0] >= '0' && s[0] <= '9' :
|
||||
false) {
|
||||
T val;
|
||||
bt_deserialize<T>{}(s, val);
|
||||
|
@ -397,49 +418,41 @@ struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>:
|
|||
};
|
||||
|
||||
template <typename Variant, typename T, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<!is_bt_deserializable<T>::value>, Variant, T, Ts...> {
|
||||
void operator()(string_view& s, Variant& variant) {
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<!is_bt_deserializable<T>>, Variant, T, Ts...> {
|
||||
void operator()(std::string_view& s, Variant& variant) {
|
||||
// Unsupported deserialization type, skip it
|
||||
bt_deserialize_try_variant<Ts...>(s, variant);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct bt_deserialize<bt_value, void> {
|
||||
void operator()(string_view& s, bt_value& val);
|
||||
};
|
||||
|
||||
// Serialization of a variant; all variant types must be bt-serializable.
|
||||
template <typename... Ts>
|
||||
struct bt_serialize<mapbox::util::variant<Ts...>> {
|
||||
void operator()(std::ostream& os, const mapbox::util::variant<Ts...>& val) {
|
||||
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct bt_deserialize<mapbox::util::variant<Ts...>> {
|
||||
void operator()(string_view& s, mapbox::util::variant<Ts...>& val) {
|
||||
bt_deserialize_try_variant<Ts...>(s, val);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#ifdef __cpp_lib_variant
|
||||
/// C++17 std::variant support
|
||||
template <typename... Ts>
|
||||
struct bt_serialize<std::variant<Ts...>> {
|
||||
struct bt_serialize<std::variant<Ts...>, std::void_t<bt_serialize<Ts>...>> {
|
||||
void operator()(std::ostream &os, const std::variant<Ts...>& val) {
|
||||
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
|
||||
std::visit(
|
||||
[&os] (const auto& val) {
|
||||
using T = std::remove_cv_t<std::remove_reference_t<decltype(val)>>;
|
||||
bt_serialize<T>{}(os, val);
|
||||
},
|
||||
val);
|
||||
}
|
||||
};
|
||||
|
||||
// Deserialization to a variant; at least one variant type must be bt-deserializble.
|
||||
template <typename... Ts>
|
||||
struct bt_deserialize<std::variant<Ts...>> {
|
||||
void operator()(string_view& s, std::variant<Ts...>& val) {
|
||||
struct bt_deserialize<std::variant<Ts...>, std::enable_if_t<(is_bt_deserializable<Ts> || ...)>> {
|
||||
void operator()(std::string_view& s, std::variant<Ts...>& val) {
|
||||
bt_deserialize_try_variant<Ts...>(s, val);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct bt_serialize<bt_value> : bt_serialize<bt_variant> {};
|
||||
|
||||
template <>
|
||||
struct bt_deserialize<bt_value> {
|
||||
void operator()(std::string_view& s, bt_value& val);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct bt_stream_serializer {
|
||||
|
@ -500,8 +513,8 @@ std::string bt_serialize(const T &val) { return bt_serializer(val); }
|
|||
/// int value;
|
||||
/// bt_deserialize(encoded, value); // Sets value to 42
|
||||
///
|
||||
template <typename T, std::enable_if_t<!std::is_const<T>::value, int> = 0>
|
||||
void bt_deserialize(string_view s, T& val) {
|
||||
template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
|
||||
void bt_deserialize(std::string_view s, T& val) {
|
||||
return detail::bt_deserialize<T>{}(s, val);
|
||||
}
|
||||
|
||||
|
@ -512,14 +525,14 @@ void bt_deserialize(string_view s, T& val) {
|
|||
/// auto mylist = bt_deserialize<std::list<int>>(encoded);
|
||||
///
|
||||
template <typename T>
|
||||
T bt_deserialize(string_view s) {
|
||||
T bt_deserialize(std::string_view s) {
|
||||
T val;
|
||||
bt_deserialize(s, val);
|
||||
return val;
|
||||
}
|
||||
|
||||
/// Deserializes the given value into a generic `bt_value` type (mapbox::util::variant) which is capable
|
||||
/// of holding all possible BT-encoded values (including recursion).
|
||||
/// Deserializes the given value into a generic `bt_value` type (wrapped std::variant) which is
|
||||
/// capable of holding all possible BT-encoded values (including recursion).
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
|
@ -528,50 +541,107 @@ T bt_deserialize(string_view s) {
|
|||
/// int v = get_int<int>(val); // fails unless the encoded value was actually an integer that
|
||||
/// // fits into an `int`
|
||||
///
|
||||
inline bt_value bt_get(string_view s) {
|
||||
inline bt_value bt_get(std::string_view s) {
|
||||
return bt_deserialize<bt_value>(s);
|
||||
}
|
||||
|
||||
/// Helper functions to extract a value of some integral type from a bt_value which contains an
|
||||
/// integer. Does range checking, throwing std::overflow_error if the stored value is outside the
|
||||
/// range of the target type.
|
||||
/// Helper functions to extract a value of some integral type from a bt_value which contains either
|
||||
/// a int64_t or uint64_t. Does range checking, throwing std::overflow_error if the stored value is
|
||||
/// outside the range of the target type.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// std::string encoded = "i123456789e";
|
||||
/// auto val = bt_get(encoded);
|
||||
/// auto v = get_int<uint32_t>(val); // throws if the decoded value doesn't fit in a uint32_t
|
||||
template <typename IntType, std::enable_if_t<std::is_integral<IntType>::value, int> = 0>
|
||||
template <typename IntType, std::enable_if_t<std::is_integral_v<IntType>, int> = 0>
|
||||
IntType get_int(const bt_value &v) {
|
||||
// It's highly unlikely that this code ever runs on a non-2s-complement architecture, but check
|
||||
// at compile time if converting to a uint64_t (because while int64_t -> uint64_t is
|
||||
// well-defined, uint64_t -> int64_t only does the right thing under 2's complement).
|
||||
static_assert(!std::is_unsigned<IntType>::value || sizeof(IntType) != sizeof(int64_t) || -1 == ~0,
|
||||
"Non 2s-complement architecture not supported!");
|
||||
int64_t value = mapbox::util::get<int64_t>(v);
|
||||
if (sizeof(IntType) < sizeof(int64_t)) {
|
||||
if (std::holds_alternative<uint64_t>(v)) {
|
||||
uint64_t value = std::get<uint64_t>(v);
|
||||
if constexpr (!std::is_same_v<IntType, uint64_t>)
|
||||
if (value > static_cast<uint64_t>(std::numeric_limits<IntType>::max()))
|
||||
throw std::overflow_error("Unable to extract integer value: stored value is too large for the requested type");
|
||||
return static_cast<IntType>(value);
|
||||
}
|
||||
|
||||
int64_t value = std::get<int64_t>(v);
|
||||
if constexpr (!std::is_same_v<IntType, int64_t>)
|
||||
if (value > static_cast<int64_t>(std::numeric_limits<IntType>::max())
|
||||
|| value < static_cast<int64_t>(std::numeric_limits<IntType>::min()))
|
||||
throw std::overflow_error("Unable to extract integer value: stored value is outside the range of the requested type");
|
||||
}
|
||||
return static_cast<IntType>(value);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <typename Tuple, size_t... Is>
|
||||
void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence<Is...>);
|
||||
}
|
||||
|
||||
/// Converts a bt_list into the given template std::tuple or std::pair. Throws a
|
||||
/// std::invalid_argument if the list has the wrong size or wrong element types. Supports recursion
|
||||
/// (i.e. if the tuple itself contains tuples or pairs). The tuple (or nested tuples) may only
|
||||
/// contain integral types, strings, string_views, bt_list, bt_dict, and tuples/pairs of those.
|
||||
template <typename Tuple>
|
||||
Tuple get_tuple(const bt_list& x) {
|
||||
Tuple t;
|
||||
detail::get_tuple_impl(t, x, std::make_index_sequence<std::tuple_size_v<Tuple>>{});
|
||||
return t;
|
||||
}
|
||||
template <typename Tuple>
|
||||
Tuple get_tuple(const bt_value& x) {
|
||||
return get_tuple<Tuple>(std::get<bt_list>(static_cast<const bt_variant&>(x)));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <typename T, typename It>
|
||||
void get_tuple_impl_one(T& t, It& it) {
|
||||
const bt_variant& v = *it++;
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
t = lokimq::get_int<T>(v);
|
||||
} else if constexpr (is_bt_tuple<T>) {
|
||||
if (std::holds_alternative<bt_list>(v))
|
||||
throw std::invalid_argument{"Unable to convert tuple: cannot create sub-tuple from non-bt_list"};
|
||||
t = get_tuple<T>(std::get<bt_list>(v));
|
||||
} else if constexpr (std::is_same_v<std::string, T> || std::is_same_v<std::string_view, T>) {
|
||||
// If we request a string/string_view, we might have the other one and need to copy/view it.
|
||||
if (std::holds_alternative<std::string_view>(v))
|
||||
t = std::get<std::string_view>(v);
|
||||
else
|
||||
t = std::get<std::string>(v);
|
||||
} else {
|
||||
t = std::get<T>(v);
|
||||
}
|
||||
}
|
||||
template <typename Tuple, size_t... Is>
|
||||
void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence<Is...>) {
|
||||
if (l.size() != sizeof...(Is))
|
||||
throw std::invalid_argument{"Unable to convert tuple: bt_list has wrong size"};
|
||||
auto it = l.begin();
|
||||
(get_tuple_impl_one(std::get<Is>(t), it), ...);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
|
||||
|
||||
|
||||
/// Class that allows you to walk through a bt-encoded list in memory without copying or allocating
|
||||
/// memory. It accesses existing memory directly and so the caller must ensure that the referenced
|
||||
/// memory stays valid for the lifetime of the bt_list_consumer object.
|
||||
class bt_list_consumer {
|
||||
protected:
|
||||
string_view data;
|
||||
std::string_view data;
|
||||
bt_list_consumer() = default;
|
||||
public:
|
||||
bt_list_consumer(string_view data_);
|
||||
bt_list_consumer(std::string_view data_);
|
||||
|
||||
/// Copy constructor. Making a copy copies the current position so can be used for multipass
|
||||
/// iteration through a list.
|
||||
bt_list_consumer(const bt_list_consumer&) = default;
|
||||
bt_list_consumer& operator=(const bt_list_consumer&) = default;
|
||||
|
||||
/// Get a copy of the current buffer
|
||||
std::string_view current_buffer() const { return data; }
|
||||
|
||||
/// Returns true if the next value indicates the end of the list
|
||||
bool is_finished() const { return data.front() == 'e'; }
|
||||
/// Returns true if the next element looks like an encoded string
|
||||
|
@ -588,23 +658,24 @@ public:
|
|||
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
|
||||
/// value is not a string.
|
||||
std::string consume_string();
|
||||
string_view consume_string_view();
|
||||
std::string_view consume_string_view();
|
||||
|
||||
/// Attempts to parse the next value as an integer (and advance just past it). Throws if the
|
||||
/// next value is not an integer.
|
||||
template <typename IntType>
|
||||
IntType consume_integer() {
|
||||
if (!is_integer()) throw bt_deserialize_invalid_type{"next value is not an integer"};
|
||||
string_view next{data};
|
||||
std::string_view next{data};
|
||||
IntType ret;
|
||||
detail::bt_deserialize<IntType>{}(next, ret);
|
||||
data = next;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Consumes a list, return it as a list-like type. This typically requires dynamic allocation,
|
||||
/// but only has to parse the data once. Compare with consume_list_data() which allows
|
||||
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
|
||||
/// Consumes a list, return it as a list-like type. Can also be used for tuples/pairs. This
|
||||
/// typically requires dynamic allocation, but only has to parse the data once. Compare with
|
||||
/// consume_list_data() which allows alloc-free traversal, but requires parsing twice (if the
|
||||
/// contents are to be used).
|
||||
template <typename T = bt_list>
|
||||
T consume_list() {
|
||||
T list;
|
||||
|
@ -616,7 +687,7 @@ public:
|
|||
template <typename T>
|
||||
void consume_list(T& list) {
|
||||
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
string_view n{data};
|
||||
std::string_view n{data};
|
||||
detail::bt_deserialize<T>{}(n, list);
|
||||
data = n;
|
||||
}
|
||||
|
@ -635,7 +706,7 @@ public:
|
|||
template <typename T>
|
||||
void consume_dict(T& dict) {
|
||||
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
string_view n{data};
|
||||
std::string_view n{data};
|
||||
detail::bt_deserialize<T>{}(n, dict);
|
||||
data = n;
|
||||
}
|
||||
|
@ -647,13 +718,13 @@ public:
|
|||
/// entire thing. This is recursive into both lists and dicts and likely to be quite
|
||||
/// inefficient for large, nested structures (unless the values only need to be skipped but
|
||||
/// aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
string_view consume_list_data();
|
||||
std::string_view consume_list_data();
|
||||
|
||||
/// Attempts to parse the next value as a dict and returns the string_view that contains the
|
||||
/// entire thing. This is recursive into both lists and dicts and likely to be quite
|
||||
/// inefficient for large, nested structures (unless the values only need to be skipped but
|
||||
/// aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
string_view consume_dict_data();
|
||||
std::string_view consume_dict_data();
|
||||
};
|
||||
|
||||
|
||||
|
@ -661,7 +732,7 @@ public:
|
|||
/// copying or allocating memory. It accesses existing memory directly and so the caller must
|
||||
/// ensure that the referenced memory stays valid for the lifetime of the bt_dict_consumer object.
|
||||
class bt_dict_consumer : private bt_list_consumer {
|
||||
string_view key_;
|
||||
std::string_view key_;
|
||||
|
||||
/// Consume the key if not already consumed and there is a key present (rather than 'e').
|
||||
/// Throws exception if what should be a key isn't a string, or if the key consumes the entire
|
||||
|
@ -671,14 +742,14 @@ class bt_dict_consumer : private bt_list_consumer {
|
|||
|
||||
/// Clears the cached key and returns it. Must have already called consume_key directly or
|
||||
/// indirectly via one of the `is_{...}` methods.
|
||||
string_view flush_key() {
|
||||
string_view k;
|
||||
std::string_view flush_key() {
|
||||
std::string_view k;
|
||||
k.swap(key_);
|
||||
return k;
|
||||
}
|
||||
|
||||
public:
|
||||
bt_dict_consumer(string_view data_);
|
||||
bt_dict_consumer(std::string_view data_);
|
||||
|
||||
/// Copy constructor. Making a copy copies the current position so can be used for multipass
|
||||
/// iteration through a list.
|
||||
|
@ -703,7 +774,7 @@ public:
|
|||
/// all of the other consume_* methods. The value is cached whether called here or by some
|
||||
/// other method; accessing it multiple times simple accesses the cache until the next value is
|
||||
/// consumed.
|
||||
string_view key() {
|
||||
std::string_view key() {
|
||||
if (!consume_key())
|
||||
throw bt_deserialize_invalid{"Cannot access next key: at the end of the dict"};
|
||||
return key_;
|
||||
|
@ -711,14 +782,14 @@ public:
|
|||
|
||||
/// Attempt to parse the next value as a string->string pair (and advance just past it). Throws
|
||||
/// if the next value is not a string.
|
||||
std::pair<string_view, string_view> next_string();
|
||||
std::pair<std::string_view, std::string_view> next_string();
|
||||
|
||||
/// Attempts to parse the next value as an string->integer pair (and advance just past it).
|
||||
/// Throws if the next value is not an integer.
|
||||
template <typename IntType>
|
||||
std::pair<string_view, IntType> next_integer() {
|
||||
std::pair<std::string_view, IntType> next_integer() {
|
||||
if (!is_integer()) throw bt_deserialize_invalid_type{"next bt dict value is not an integer"};
|
||||
std::pair<string_view, IntType> ret;
|
||||
std::pair<std::string_view, IntType> ret;
|
||||
ret.second = bt_list_consumer::consume_integer<IntType>();
|
||||
ret.first = flush_key();
|
||||
return ret;
|
||||
|
@ -729,15 +800,15 @@ public:
|
|||
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
|
||||
/// used).
|
||||
template <typename T = bt_list>
|
||||
std::pair<string_view, T> next_list() {
|
||||
std::pair<string_view, T> pair;
|
||||
pair.first = consume_list(pair.second);
|
||||
std::pair<std::string_view, T> next_list() {
|
||||
std::pair<std::string_view, T> pair;
|
||||
pair.first = next_list(pair.second);
|
||||
return pair;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing list-like data type. Returns the key.
|
||||
template <typename T>
|
||||
string_view next_list(T& list) {
|
||||
std::string_view next_list(T& list) {
|
||||
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
bt_list_consumer::consume_list(list);
|
||||
return flush_key();
|
||||
|
@ -748,15 +819,15 @@ public:
|
|||
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
|
||||
/// used).
|
||||
template <typename T = bt_dict>
|
||||
std::pair<string_view, T> next_dict() {
|
||||
std::pair<string_view, T> pair;
|
||||
std::pair<std::string_view, T> next_dict() {
|
||||
std::pair<std::string_view, T> pair;
|
||||
pair.first = consume_dict(pair.second);
|
||||
return pair;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing dict-like data type. Returns the key.
|
||||
template <typename T>
|
||||
string_view next_dict(T& dict) {
|
||||
std::string_view next_dict(T& dict) {
|
||||
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
bt_list_consumer::consume_dict(dict);
|
||||
return flush_key();
|
||||
|
@ -766,25 +837,25 @@ public:
|
|||
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
|
||||
/// quite inefficient for large, nested structures (unless the values only need to be skipped
|
||||
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
std::pair<string_view, string_view> next_list_data() {
|
||||
std::pair<std::string_view, std::string_view> next_list_data() {
|
||||
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt dict value is not a list"};
|
||||
return {flush_key(), bt_list_consumer::consume_list_data()};
|
||||
}
|
||||
|
||||
/// Same as next_list_data(), but wraps the value in a bt_list_consumer for convenience
|
||||
std::pair<string_view, bt_list_consumer> next_list_consumer() { return next_list_data(); }
|
||||
std::pair<std::string_view, bt_list_consumer> next_list_consumer() { return next_list_data(); }
|
||||
|
||||
/// Attempts to parse the next value as a string->dict pair and returns the string_view that
|
||||
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
|
||||
/// quite inefficient for large, nested structures (unless the values only need to be skipped
|
||||
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
std::pair<string_view, string_view> next_dict_data() {
|
||||
std::pair<std::string_view, std::string_view> next_dict_data() {
|
||||
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt dict value is not a dict"};
|
||||
return {flush_key(), bt_list_consumer::consume_dict_data()};
|
||||
}
|
||||
|
||||
/// Same as next_dict_data(), but wraps the value in a bt_dict_consumer for convenience
|
||||
std::pair<string_view, bt_dict_consumer> next_dict_consumer() { return next_dict_data(); }
|
||||
std::pair<std::string_view, bt_dict_consumer> next_dict_consumer() { return next_dict_data(); }
|
||||
|
||||
/// Skips ahead until we find the first key >= the given key or reach the end of the dict.
|
||||
/// Returns true if we found an exact match, false if we reached some greater value or the end.
|
||||
|
@ -799,7 +870,7 @@ public:
|
|||
/// - this is irreversible; you cannot returned to skipped values without reparsing. (You *can*
|
||||
/// however, make a copy of the bt_dict_consumer before calling and use the copy to return to
|
||||
/// the pre-skipped position).
|
||||
bool skip_until(string_view find) {
|
||||
bool skip_until(std::string_view find) {
|
||||
while (consume_key() && key_ < find) {
|
||||
flush_key();
|
||||
skip_value();
|
||||
|
@ -834,8 +905,8 @@ public:
|
|||
template <typename T>
|
||||
void consume_dict(T& dict) { next_dict(dict); }
|
||||
|
||||
string_view consume_list_data() { return next_list_data().second; }
|
||||
string_view consume_dict_data() { return next_dict_data().second; }
|
||||
std::string_view consume_list_data() { return next_list_data().second; }
|
||||
std::string_view consume_dict_data() { return next_dict_data().second; }
|
||||
|
||||
bt_list_consumer consume_list_consumer() { return consume_list_data(); }
|
||||
bt_dict_consumer consume_dict_consumer() { return consume_dict_data(); }
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
|
||||
// This header is here to provide just the basic bt_value/bt_dict/bt_list definitions without
|
||||
// needing to include the full bt_serialize.h header.
|
||||
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <cstdint>
|
||||
#include <variant>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
struct bt_value;
|
||||
|
||||
/// The type used to store dictionaries inside bt_value.
|
||||
using bt_dict = std::map<std::string, bt_value>; // NB: unordered_map doesn't work because it can't be used with a predeclared type
|
||||
/// The type used to store list items inside bt_value.
|
||||
using bt_list = std::list<bt_value>;
|
||||
|
||||
/// The basic variant that can hold anything (recursively).
|
||||
using bt_variant = std::variant<
|
||||
std::string,
|
||||
std::string_view,
|
||||
int64_t,
|
||||
uint64_t,
|
||||
bt_list,
|
||||
bt_dict
|
||||
>;
|
||||
|
||||
#ifdef __cpp_lib_remove_cvref // C++20
|
||||
using std::remove_cvref_t;
|
||||
#else
|
||||
template <typename T>
|
||||
using remove_cvref_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
#endif
|
||||
|
||||
template <typename T, typename Variant>
|
||||
struct has_alternative;
|
||||
template <typename T, typename... V>
|
||||
struct has_alternative<T, std::variant<V...>> : std::bool_constant<(std::is_same_v<T, V> || ...)> {};
|
||||
template <typename T, typename Variant>
|
||||
constexpr bool has_alternative_v = has_alternative<T, Variant>::value;
|
||||
|
||||
namespace detail {
|
||||
template <typename Tuple, size_t... Is>
|
||||
bt_list tuple_to_list(const Tuple& tuple, std::index_sequence<Is...>) {
|
||||
return {{bt_value{std::get<Is>(tuple)}...}};
|
||||
}
|
||||
template <typename T> constexpr bool is_tuple = false;
|
||||
template <typename... T> constexpr bool is_tuple<std::tuple<T...>> = true;
|
||||
template <typename S, typename T> constexpr bool is_tuple<std::pair<S, T>> = true;
|
||||
}
|
||||
|
||||
/// Recursive generic type that can fully represent everything valid for a BT serialization.
|
||||
/// This is basically just an empty wrapper around the std::variant, except we add some extra
|
||||
/// converting constructors:
|
||||
/// - integer constructors so that any unsigned value goes to the uint64_t and any signed value goes
|
||||
/// to the int64_t.
|
||||
/// - std::tuple and std::pair constructors that build a bt_list out of the tuple/pair elements.
|
||||
struct bt_value : bt_variant {
|
||||
using bt_variant::bt_variant;
|
||||
using bt_variant::operator=;
|
||||
|
||||
template <typename T, typename U = std::remove_reference_t<T>, std::enable_if_t<std::is_integral_v<U> && std::is_unsigned_v<U>, int> = 0>
|
||||
bt_value(T&& uint) : bt_variant{static_cast<uint64_t>(uint)} {}
|
||||
|
||||
template <typename T, typename U = std::remove_reference_t<T>, std::enable_if_t<std::is_integral_v<U> && std::is_signed_v<U>, int> = 0>
|
||||
bt_value(T&& sint) : bt_variant{static_cast<int64_t>(sint)} {}
|
||||
|
||||
template <typename... T>
|
||||
bt_value(const std::tuple<T...>& tuple) : bt_variant{detail::tuple_to_list(tuple, std::index_sequence_for<T...>{})} {}
|
||||
|
||||
template <typename S, typename T>
|
||||
bt_value(const std::pair<S, T>& pair) : bt_variant{detail::tuple_to_list(pair, std::index_sequence_for<S, T>{})} {}
|
||||
|
||||
template <typename T, typename U = std::remove_reference_t<T>, std::enable_if_t<!std::is_integral_v<U> && !detail::is_tuple<U>, int> = 0>
|
||||
bt_value(T&& v) : bt_variant{std::forward<T>(v)} {}
|
||||
|
||||
bt_value(const char* s) : bt_value{std::string_view{s}} {}
|
||||
};
|
||||
|
||||
}
|
|
@ -47,7 +47,7 @@ void LokiMQ::setup_external_socket(zmq::socket_t& socket) {
|
|||
}
|
||||
}
|
||||
|
||||
void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey) {
|
||||
void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey) {
|
||||
|
||||
setup_external_socket(socket);
|
||||
|
||||
|
@ -67,7 +67,7 @@ void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pub
|
|||
// else let ZMQ pick a random one
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
|
||||
ConnectionID LokiMQ::connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive, std::string_view hint) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call connect_sn() before calling `start()`");
|
||||
|
||||
|
@ -76,30 +76,32 @@ ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds ke
|
|||
return pubkey;
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
ConnectionID LokiMQ::connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call connect_remote() before calling `start()`");
|
||||
|
||||
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
|
||||
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
|
||||
|
||||
auto id = next_conn_id++;
|
||||
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
|
||||
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
|
||||
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id);
|
||||
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
|
||||
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
|
||||
{"conn_id", id},
|
||||
{"connect", detail::serialize_object(std::move(on_connect))},
|
||||
{"failure", detail::serialize_object(std::move(on_failure))},
|
||||
{"pubkey", pubkey},
|
||||
{"remote", remote},
|
||||
{"pubkey", remote.curve() ? remote.pubkey : ""},
|
||||
{"remote", remote.zmq_address()},
|
||||
{"timeout", timeout.count()},
|
||||
}));
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
std::string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
return connect_remote(address{remote}.set_pubkey(pubkey),
|
||||
std::move(on_connect), std::move(on_failure), auth_level, timeout);
|
||||
}
|
||||
|
||||
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
||||
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
|
||||
{"conn_id", id.id},
|
||||
|
@ -109,7 +111,7 @@ void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
|||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string>
|
||||
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
|
||||
LokiMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
|
||||
ConnectionID remote_cid{remote};
|
||||
auto its = peers.equal_range(remote_cid);
|
||||
peer_info* peer = nullptr;
|
||||
|
@ -185,7 +187,7 @@ LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool opti
|
|||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string> LokiMQ::proxy_connect_sn(bt_dict_consumer data) {
|
||||
string_view hint, remote_pk;
|
||||
std::string_view hint, remote_pk;
|
||||
std::chrono::milliseconds keep_alive;
|
||||
bool optional = false, incoming_only = false, outgoing_only = false;
|
||||
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
#pragma once
|
||||
#include "auth.h"
|
||||
#include "string_view.h"
|
||||
#include "bt_value.h"
|
||||
#include <string_view>
|
||||
#include <iosfwd>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
class bt_dict;
|
||||
struct ConnectionID;
|
||||
|
||||
namespace detail {
|
||||
template <typename... T>
|
||||
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts);
|
||||
bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts);
|
||||
}
|
||||
|
||||
/// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For
|
||||
|
@ -26,7 +31,7 @@ struct ConnectionID {
|
|||
throw std::runtime_error{"Invalid pubkey: expected 32 bytes"};
|
||||
}
|
||||
// Construction from a service node pubkey
|
||||
ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
||||
ConnectionID(std::string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
||||
|
||||
ConnectionID(const ConnectionID&) = default;
|
||||
ConnectionID(ConnectionID&&) = default;
|
||||
|
@ -75,7 +80,7 @@ private:
|
|||
friend class LokiMQ;
|
||||
friend struct std::hash<ConnectionID>;
|
||||
template <typename... T>
|
||||
friend bt_dict detail::build_send(ConnectionID to, string_view cmd, T&&... opts);
|
||||
friend bt_dict detail::build_send(ConnectionID to, std::string_view cmd, T&&... opts);
|
||||
friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn);
|
||||
};
|
||||
|
||||
|
|
68
lokimq/hex.h
68
lokimq/hex.h
|
@ -27,7 +27,8 @@
|
|||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include "string_view.h"
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
#include <cassert>
|
||||
|
@ -55,46 +56,52 @@ struct hex_table {
|
|||
constexpr char to_hex(unsigned char b) const noexcept { return to_hex_lut[b]; }
|
||||
} constexpr hex_lut;
|
||||
|
||||
// This main point of this static assert is to force the compiler to compile-time build the constexpr tables.
|
||||
static_assert(hex_lut.from_hex('a') == 10 && hex_lut.from_hex('F') == 15 && hex_lut.to_hex(13) == 'd', "");
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Creates hex digits from a character sequence.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void to_hex(InputIt begin, InputIt end, OutputIt out) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "to_hex requires chars/bytes");
|
||||
for (; begin != end; ++begin) {
|
||||
auto c = *begin;
|
||||
*out++ = detail::hex_lut.to_hex((c & 0xf0) >> 4);
|
||||
uint8_t c = static_cast<uint8_t>(*begin);
|
||||
*out++ = detail::hex_lut.to_hex(c >> 4);
|
||||
*out++ = detail::hex_lut.to_hex(c & 0x0f);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a hex string from an iterable, std::string-like object
|
||||
inline std::string to_hex(string_view s) {
|
||||
/// Creates a string of hex digits from a character sequence iterator pair
|
||||
template <typename It>
|
||||
std::string to_hex(It begin, It end) {
|
||||
std::string hex;
|
||||
hex.reserve(s.size() * 2);
|
||||
to_hex(s.begin(), s.end(), std::back_inserter(hex));
|
||||
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
|
||||
hex.reserve(2 * std::distance(begin, end));
|
||||
to_hex(begin, end, std::back_inserter(hex));
|
||||
return hex;
|
||||
}
|
||||
|
||||
inline std::string to_hex(ustring_view s) {
|
||||
std::string hex;
|
||||
hex.reserve(s.size() * 2);
|
||||
to_hex(s.begin(), s.end(), std::back_inserter(hex));
|
||||
return hex;
|
||||
}
|
||||
/// Creates a hex string from an iterable, std::string-like object
|
||||
template <typename CharT>
|
||||
std::string to_hex(std::basic_string_view<CharT> s) { return to_hex(s.begin(), s.end()); }
|
||||
inline std::string to_hex(std::string_view s) { return to_hex<>(s); }
|
||||
|
||||
/// Returns true if all elements in the range are hex characters
|
||||
template <typename It>
|
||||
constexpr bool is_hex(It begin, It end) {
|
||||
static_assert(sizeof(decltype(*begin)) == 1, "is_hex requires chars/bytes");
|
||||
for (; begin != end; ++begin) {
|
||||
if (detail::hex_lut.from_hex(*begin) == 0 && *begin != '0')
|
||||
if (detail::hex_lut.from_hex(static_cast<unsigned char>(*begin)) == 0 && static_cast<unsigned char>(*begin) != '0')
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if all elements in the string-like value are hex characters
|
||||
constexpr bool is_hex(string_view s) { return is_hex(s.begin(), s.end()); }
|
||||
constexpr bool is_hex(ustring_view s) { return is_hex(s.begin(), s.end()); }
|
||||
template <typename CharT>
|
||||
constexpr bool is_hex(std::basic_string_view<CharT> s) { return is_hex(s.begin(), s.end()); }
|
||||
constexpr bool is_hex(std::string_view s) { return is_hex(s.begin(), s.end()); }
|
||||
|
||||
/// Convert a hex digit into its numeric (0-15) value
|
||||
constexpr char from_hex_digit(unsigned char x) noexcept {
|
||||
|
@ -114,24 +121,25 @@ void from_hex(InputIt begin, InputIt end, OutputIt out) {
|
|||
while (begin != end) {
|
||||
auto a = *begin++;
|
||||
auto b = *begin++;
|
||||
*out++ = from_hex_pair(a, b);
|
||||
*out++ = from_hex_pair(static_cast<unsigned char>(a), static_cast<unsigned char>(b));
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a sequence of hex digits to a string of bytes and returns it. Undefined behaviour if
|
||||
/// the input sequence is not an even-length sequence of [0-9a-fA-F] characters.
|
||||
template <typename It>
|
||||
std::string from_hex(It begin, It end) {
|
||||
std::string bytes;
|
||||
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
|
||||
bytes.reserve(std::distance(begin, end) / 2);
|
||||
from_hex(begin, end, std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Converts hex digits from a std::string-like object into a std::string of bytes. Undefined
|
||||
/// behaviour if any characters are not in [0-9a-fA-F] or if the input sequence length is not even.
|
||||
inline std::string from_hex(string_view s) {
|
||||
std::string bytes;
|
||||
bytes.reserve(s.size() / 2);
|
||||
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
|
||||
inline std::string from_hex(ustring_view s) {
|
||||
std::string bytes;
|
||||
bytes.reserve(s.size() / 2);
|
||||
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
template <typename CharT>
|
||||
std::string from_hex(std::basic_string_view<CharT> s) { return from_hex(s.begin(), s.end()); }
|
||||
inline std::string from_hex(std::string_view s) { return from_hex<>(s); }
|
||||
|
||||
}
|
||||
|
|
|
@ -6,15 +6,31 @@ namespace lokimq {
|
|||
|
||||
void LokiMQ::proxy_batch(detail::Batch* batch) {
|
||||
batches.insert(batch);
|
||||
const int jobs = batch->size();
|
||||
for (int i = 0; i < jobs; i++)
|
||||
batch_jobs.emplace(batch, i);
|
||||
const auto [jobs, tagged_threads] = batch->size();
|
||||
LMQ_TRACE("proxy queuing batch job with ", jobs, " jobs", tagged_threads ? " (job uses tagged thread(s))" : "");
|
||||
if (!tagged_threads) {
|
||||
for (size_t i = 0; i < jobs; i++)
|
||||
batch_jobs.emplace(batch, i);
|
||||
} else {
|
||||
// Some (or all) jobs have a specific thread target so queue any such jobs in the tagged
|
||||
// worker queue.
|
||||
auto threads = batch->threads();
|
||||
for (size_t i = 0; i < jobs; i++) {
|
||||
auto& jobs = threads[i] > 0
|
||||
? std::get<std::queue<batch_job>>(tagged_workers[threads[i] - 1])
|
||||
: batch_jobs;
|
||||
jobs.emplace(batch, i);
|
||||
}
|
||||
}
|
||||
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void LokiMQ::job(std::function<void()> f) {
|
||||
void LokiMQ::job(std::function<void()> f, std::optional<TaggedThreadID> thread) {
|
||||
if (thread && thread->_id == -1)
|
||||
throw std::logic_error{"job() cannot be used to queue an in-proxy job"};
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::move(f));
|
||||
b->add_job(std::move(f), thread);
|
||||
auto* baseptr = static_cast<detail::Batch*>(b);
|
||||
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
}
|
||||
|
@ -38,7 +54,7 @@ void LokiMQ::proxy_run_batch_jobs(std::queue<batch_job>& jobs, const int reserve
|
|||
|
||||
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
|
||||
// the timer. If the timer object hasn't been set up yet it gets set up here.
|
||||
void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
|
||||
void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, int thread) {
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
|
@ -48,16 +64,17 @@ void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds in
|
|||
this);
|
||||
if (timer_id == -1)
|
||||
throw zmq::error_t{};
|
||||
timer_jobs[timer_id] = std::make_tuple(std::move(job), squelch, false);
|
||||
timer_jobs[timer_id] = { std::move(job), squelch, false, thread };
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_timer(bt_list_consumer timer_data) {
|
||||
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
|
||||
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
|
||||
auto squelch = timer_data.consume_integer<bool>();
|
||||
auto thread = timer_data.consume_integer<int>();
|
||||
if (!timer_data.is_finished())
|
||||
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
|
||||
proxy_timer(std::move(*func), interval, squelch);
|
||||
proxy_timer(std::move(*func), interval, squelch, thread);
|
||||
}
|
||||
|
||||
void LokiMQ::_queue_timer_job(int timer_id) {
|
||||
|
@ -66,43 +83,72 @@ void LokiMQ::_queue_timer_job(int timer_id) {
|
|||
LMQ_LOG(warn, "Could not find timer job ", timer_id);
|
||||
return;
|
||||
}
|
||||
auto& timer = it->second;
|
||||
auto& squelch = std::get<1>(timer);
|
||||
auto& running = std::get<2>(timer);
|
||||
auto& [func, squelch, running, thread] = it->second;
|
||||
if (squelch && running) {
|
||||
LMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
|
||||
return;
|
||||
}
|
||||
|
||||
if (thread == -1) { // Run directly in proxy thread
|
||||
try { func(); }
|
||||
catch (const std::exception &e) { LMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { LMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
return;
|
||||
}
|
||||
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::get<0>(timer));
|
||||
b->add_job(func, thread);
|
||||
if (squelch) {
|
||||
running = true;
|
||||
b->completion_proxy([this,timer_id](auto results) {
|
||||
b->completion([this,timer_id](auto results) {
|
||||
try { results[0].get(); }
|
||||
catch (const std::exception &e) { LMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { LMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it != timer_jobs.end())
|
||||
std::get<2>(it->second)/*running*/ = false;
|
||||
});
|
||||
it->second.running = false;
|
||||
}, LokiMQ::run_in_proxy);
|
||||
}
|
||||
batches.insert(b);
|
||||
batch_jobs.emplace(static_cast<detail::Batch*>(b), 0);
|
||||
assert(b->size() == 1);
|
||||
LMQ_TRACE("b: ", b->size().first, ", ", b->size().second, "; thread = ", thread);
|
||||
assert(b->size() == std::make_pair(size_t{1}, thread > 0));
|
||||
auto& queue = thread > 0
|
||||
? std::get<std::queue<batch_job>>(tagged_workers[thread - 1])
|
||||
: batch_jobs;
|
||||
queue.emplace(static_cast<detail::Batch*>(b), 0);
|
||||
}
|
||||
|
||||
void LokiMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
|
||||
void LokiMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
|
||||
int th_id = thread ? thread->_id : 0;
|
||||
if (proxy_thread.joinable()) {
|
||||
detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{
|
||||
detail::serialize_object(std::move(job)),
|
||||
interval.count(),
|
||||
squelch}}));
|
||||
squelch,
|
||||
th_id}}));
|
||||
} else {
|
||||
proxy_timer(std::move(job), interval, squelch);
|
||||
proxy_timer(std::move(job), interval, squelch, th_id);
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
|
||||
|
||||
TaggedThreadID LokiMQ::add_tagged_thread(std::string name, std::function<void()> start) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error{"Cannot add tagged threads after calling `start()`"};
|
||||
|
||||
if (name == "_proxy"sv || name.empty() || name.find('\0') != std::string::npos)
|
||||
throw std::logic_error{"Invalid tagged thread name `" + name + "'"};
|
||||
|
||||
auto& [run, busy, queue] = tagged_workers.emplace_back();
|
||||
busy = false;
|
||||
run.worker_id = tagged_workers.size(); // We want index + 1 (b/c 0 is used for non-tagged jobs)
|
||||
run.worker_routing_id = "t" + std::to_string(run.worker_id);
|
||||
LMQ_TRACE("Created new tagged thread ", name, " with routing id ", run.worker_routing_id);
|
||||
|
||||
run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id, name, std::move(start)};
|
||||
|
||||
return TaggedThreadID{static_cast<int>(run.worker_id)};
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
// Inside some method:
|
||||
// LMQ_LOG(warn, "bad ", 42, " stuff");
|
||||
//
|
||||
// (The "this->" is here to work around gcc 5 bugginess when called in a `this`-capturing lambda.)
|
||||
#define LMQ_LOG(level, ...) this->log_(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LMQ_LOG(level, ...) log(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Same as LMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
|
||||
# define LMQ_TRACE(...) this->log_(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
|
||||
# define LMQ_TRACE(...) log(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
# define LMQ_TRACE(...)
|
||||
#endif
|
||||
|
@ -33,7 +32,7 @@ inline zmq::message_t create_message(std::string&& data) {
|
|||
};
|
||||
|
||||
/// Create a message copying from a string_view
|
||||
inline zmq::message_t create_message(string_view data) {
|
||||
inline zmq::message_t create_message(std::string_view data) {
|
||||
return zmq::message_t{data.begin(), data.end()};
|
||||
}
|
||||
|
||||
|
@ -94,7 +93,7 @@ inline const char* peer_address(zmq::message_t& msg) {
|
|||
|
||||
// Returns a string view of the given message data. It's the caller's responsibility to keep the
|
||||
// referenced message alive. If you want a std::string instead just call `m.to_string()`
|
||||
inline string_view view(const zmq::message_t& m) {
|
||||
inline std::string_view view(const zmq::message_t& m) {
|
||||
return {m.data<char>(), m.size()};
|
||||
}
|
||||
|
||||
|
@ -108,7 +107,7 @@ inline std::string to_string(AuthLevel a) {
|
|||
}
|
||||
}
|
||||
|
||||
inline AuthLevel auth_from_string(string_view a) {
|
||||
inline AuthLevel auth_from_string(std::string_view a) {
|
||||
if (a == "none") return AuthLevel::none;
|
||||
if (a == "basic") return AuthLevel::basic;
|
||||
if (a == "admin") return AuthLevel::admin;
|
||||
|
@ -116,7 +115,7 @@ inline AuthLevel auth_from_string(string_view a) {
|
|||
}
|
||||
|
||||
// Extracts and builds the "send" part of a message for proxy_send/proxy_reply
|
||||
inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_view route) {
|
||||
inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, std::string_view route) {
|
||||
std::list<zmq::message_t> parts;
|
||||
if (!route.empty())
|
||||
parts.push_back(create_message(route));
|
||||
|
@ -128,7 +127,7 @@ inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_
|
|||
/// Sends a control message to a specific destination by prefixing the worker name (or identity)
|
||||
/// then appending the command and optional data (if non-empty). (This is needed when sending the control message
|
||||
/// to a router socket, i.e. inside the proxy thread).
|
||||
inline void route_control(zmq::socket_t& sock, string_view identity, string_view cmd, const std::string& data = {}) {
|
||||
inline void route_control(zmq::socket_t& sock, std::string_view identity, std::string_view cmd, const std::string& data = {}) {
|
||||
sock.send(create_message(identity), zmq::send_flags::sndmore);
|
||||
detail::send_control(sock, cmd, data);
|
||||
}
|
||||
|
|
|
@ -2,9 +2,12 @@
|
|||
#include "lokimq-internal.h"
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <ostream>
|
||||
|
||||
extern "C" {
|
||||
#include <sodium.h>
|
||||
#include <sodium/core.h>
|
||||
#include <sodium/crypto_box.h>
|
||||
#include <sodium/crypto_scalarmult.h>
|
||||
}
|
||||
#include "hex.h"
|
||||
|
||||
|
@ -38,7 +41,7 @@ namespace detail {
|
|||
|
||||
// Sends a control messages between proxy and threads or between proxy and workers consisting of a
|
||||
// single command codes with an optional data part (the data frame is omitted if empty).
|
||||
void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
|
||||
void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data) {
|
||||
auto c = create_message(std::move(cmd));
|
||||
if (data.empty()) {
|
||||
sock.send(c, zmq::send_flags::none);
|
||||
|
@ -53,7 +56,7 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
|
|||
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
||||
auto result = std::make_pair(""s, AuthLevel::none);
|
||||
try {
|
||||
string_view pubkey_hex{msg.gets("User-Id")};
|
||||
std::string_view pubkey_hex{msg.gets("User-Id")};
|
||||
if (pubkey_hex.size() != 64)
|
||||
throw std::logic_error("bad user-id");
|
||||
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
|
||||
|
@ -174,7 +177,7 @@ zmq::socket_t& LokiMQ::get_control_socket() {
|
|||
return *last.second;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock{control_sockets_mutex};
|
||||
std::lock_guard lock{control_sockets_mutex};
|
||||
if (proxy_shutting_down)
|
||||
throw std::runtime_error("Unable to obtain LokiMQ control socket: proxy thread is shutting down");
|
||||
auto control = std::make_shared<zmq::socket_t>(context, zmq::socket_type::dealer);
|
||||
|
@ -201,6 +204,9 @@ LokiMQ::LokiMQ(
|
|||
|
||||
LMQ_TRACE("Constructing LokiMQ, id=", object_id, ", this=", this);
|
||||
|
||||
if (sodium_init() == -1)
|
||||
throw std::runtime_error{"libsodium initialization failed"};
|
||||
|
||||
if (pubkey.empty() != privkey.empty()) {
|
||||
throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
|
||||
} else if (pubkey.empty()) {
|
||||
|
@ -343,35 +349,75 @@ void LokiMQ::set_general_threads(int threads) {
|
|||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_,
|
||||
std::vector<zmq::message_t> data_parts_, const std::pair<CommandCallback, bool>* callback_) {
|
||||
is_batch_job = false;
|
||||
is_reply_job = false;
|
||||
reset();
|
||||
cat = cat_;
|
||||
command = std::move(command_);
|
||||
conn = std::move(conn_);
|
||||
access = std::move(access_);
|
||||
remote = std::move(remote_);
|
||||
data_parts = std::move(data_parts_);
|
||||
callback = callback_;
|
||||
to_run = callback_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function<void()> callback) {
|
||||
reset();
|
||||
is_injected = true;
|
||||
cat = cat_;
|
||||
command = std::move(command_);
|
||||
conn = {};
|
||||
access = {};
|
||||
remote = std::move(remote_);
|
||||
to_run = std::move(callback);
|
||||
return *this;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) {
|
||||
if (auto *f = std::get_if<std::function<void()>>(&pending.callback))
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.remote), std::move(*f));
|
||||
|
||||
assert(pending.callback.index() == 0);
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.conn), std::move(pending.access),
|
||||
std::move(pending.remote), std::move(pending.data_parts), pending.callback);
|
||||
std::move(pending.remote), std::move(pending.data_parts), std::get<0>(pending.callback));
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job) {
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) {
|
||||
reset();
|
||||
is_batch_job = true;
|
||||
is_reply_job = reply_job;
|
||||
is_tagged_thread_job = tagged_thread > 0;
|
||||
batch_jobno = bj.second;
|
||||
batch = bj.first;
|
||||
to_run = bj.first;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
LokiMQ::~LokiMQ() {
|
||||
if (!proxy_thread.joinable())
|
||||
if (!proxy_thread.joinable()) {
|
||||
if (!tagged_workers.empty()) {
|
||||
// This is a bit icky: we have tagged workers that are waiting for a signal on
|
||||
// workers_socket, but the listening end of workers_socket doesn't get set up until the
|
||||
// proxy thread starts (and we're getting destructed here without a proxy thread). So
|
||||
// we need to start listening on it here in the destructor so that we establish a
|
||||
// connection and send the QUITs to the tagged worker threads.
|
||||
workers_socket.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
|
||||
workers_socket.bind(SN_ADDR_WORKERS);
|
||||
for (auto& [run, busy, queue] : tagged_workers) {
|
||||
while (true) {
|
||||
try {
|
||||
route_control(workers_socket, run.worker_routing_id, "QUIT");
|
||||
break;
|
||||
} catch (const zmq::error_t&) {
|
||||
std::this_thread::sleep_for(5ms);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& [run, busy, queue] : tagged_workers)
|
||||
run.worker_thread.join();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
LMQ_LOG(info, "LokiMQ shutting down proxy thread");
|
||||
detail::send_control(get_control_socket(), "QUIT");
|
||||
|
|
265
lokimq/lokimq.h
265
lokimq/lokimq.h
|
@ -29,8 +29,10 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <memory>
|
||||
|
@ -41,9 +43,10 @@
|
|||
#include <chrono>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include "zmq.hpp"
|
||||
#include "address.h"
|
||||
#include "bt_serialize.h"
|
||||
#include "string_view.h"
|
||||
#include "connections.h"
|
||||
#include "message.h"
|
||||
#include "auth.h"
|
||||
|
@ -69,23 +72,33 @@ template <typename R> class Batch;
|
|||
* use a longer keep-alive to a host call `connect()` first with the desired keep-alive time or pass
|
||||
* the send_option::keep_alive.
|
||||
*/
|
||||
static constexpr auto DEFAULT_SEND_KEEP_ALIVE = 30s;
|
||||
inline constexpr auto DEFAULT_SEND_KEEP_ALIVE = 30s;
|
||||
|
||||
// The default timeout for connect_remote()
|
||||
static constexpr auto REMOTE_CONNECT_TIMEOUT = 10s;
|
||||
inline constexpr auto REMOTE_CONNECT_TIMEOUT = 10s;
|
||||
|
||||
// The amount of time we wait for a reply to a REQUEST before calling the callback with
|
||||
// `false` to signal a timeout.
|
||||
static constexpr auto DEFAULT_REQUEST_TIMEOUT = 15s;
|
||||
inline constexpr auto DEFAULT_REQUEST_TIMEOUT = 15s;
|
||||
|
||||
/// Maximum length of a category
|
||||
static constexpr size_t MAX_CATEGORY_LENGTH = 50;
|
||||
inline constexpr size_t MAX_CATEGORY_LENGTH = 50;
|
||||
|
||||
/// Maximum length of a command
|
||||
static constexpr size_t MAX_COMMAND_LENGTH = 200;
|
||||
inline constexpr size_t MAX_COMMAND_LENGTH = 200;
|
||||
|
||||
class CatHelper;
|
||||
|
||||
/// Opaque handle for a tagged thread constructed by add_tagged_thread(...). Not directly
|
||||
/// constructible, but is safe (and cheap) to copy.
|
||||
struct TaggedThreadID {
|
||||
private:
|
||||
int _id;
|
||||
explicit constexpr TaggedThreadID(int id) : _id{id} {}
|
||||
friend class LokiMQ;
|
||||
template <typename R> friend class Batch;
|
||||
};
|
||||
|
||||
/**
|
||||
* Class that handles LokiMQ listeners, connections, proxying, and workers. An application
|
||||
* typically has just one instance of this class.
|
||||
|
@ -146,12 +159,12 @@ public:
|
|||
///
|
||||
/// @returns an `AuthLevel` enum value indicating the default auth level for the incoming
|
||||
/// connection, or AuthLevel::denied if the connection should be refused.
|
||||
using AllowFunc = std::function<AuthLevel(string_view address, string_view pubkey, bool service_node)>;
|
||||
using AllowFunc = std::function<AuthLevel(std::string_view address, std::string_view pubkey, bool service_node)>;
|
||||
|
||||
/// Callback that is invoked when we need to send a "strong" message to a SN that we aren't
|
||||
/// already connected to and need to establish a connection. This callback returns the ZMQ
|
||||
/// connection string we should use which is typically a string such as `tcp://1.2.3.4:5678`.
|
||||
using SNRemoteAddress = std::function<std::string(string_view pubkey)>;
|
||||
using SNRemoteAddress = std::function<std::string(std::string_view pubkey)>;
|
||||
|
||||
/// The callback type for registered commands.
|
||||
using CommandCallback = std::function<void(Message& message)>;
|
||||
|
@ -169,7 +182,7 @@ public:
|
|||
/// Callback for the success case of connect_remote()
|
||||
using ConnectSuccess = std::function<void(ConnectionID)>;
|
||||
/// Callback for the failure case of connect_remote()
|
||||
using ConnectFailure = std::function<void(ConnectionID, string_view)>;
|
||||
using ConnectFailure = std::function<void(ConnectionID, std::string_view)>;
|
||||
|
||||
/// Explicitly non-copyable, non-movable because most things here aren't copyable, and a few
|
||||
/// things aren't movable, either. If you need to pass the LokiMQ instance around, wrap it
|
||||
|
@ -246,6 +259,28 @@ public:
|
|||
/// Allows you to set options on the internal zmq context object. For advanced use only.
|
||||
int set_zmq_context_option(int option, int value);
|
||||
|
||||
/** The umask to apply when constructing sockets (which affects any new ipc:// listening sockets
|
||||
* that get created). Does nothing if set to -1 (the default), and does nothing on Windows.
|
||||
* Note that the umask is applied temporarily during `start()`, so may affect other threads that
|
||||
* create files/directories at the same time as the start() call.
|
||||
*/
|
||||
int STARTUP_UMASK = -1;
|
||||
|
||||
/** The gid that owns any sockets when constructed (same as umask)
|
||||
*/
|
||||
int SOCKET_GID = -1;
|
||||
/** The uid that owns any sockets when constructed (same as umask but requires root)
|
||||
*/
|
||||
int SOCKET_UID = -1;
|
||||
|
||||
/// A special TaggedThreadID value that always refers to the proxy thread; the main use of this is
|
||||
/// to direct very simple batch completion jobs to be executed directly in the proxy thread.
|
||||
inline static constexpr TaggedThreadID run_in_proxy{-1};
|
||||
|
||||
/// Writes a message to the logging system; intended mostly for internal use.
|
||||
template <typename... T>
|
||||
void log(LogLevel lvl, const char* filename, int line, const T&... stuff);
|
||||
|
||||
private:
|
||||
|
||||
/// The lookup function that tells us where to connect to a peer, or empty if not found.
|
||||
|
@ -258,10 +293,6 @@ private:
|
|||
/// The callback to call with log messages
|
||||
Logger logger;
|
||||
|
||||
/// Logging implementation
|
||||
template <typename... T>
|
||||
void log_(LogLevel lvl, const char* filename, int line, const T&... stuff);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
/// NB: The following are all the domain of the proxy thread (once it is started)!
|
||||
|
||||
|
@ -375,7 +406,8 @@ private:
|
|||
|
||||
/// Timers. TODO: once cppzmq adds an interface around the zmq C timers API then switch to it.
|
||||
struct TimersDeleter { void operator()(void* timers); };
|
||||
std::unordered_map<int, std::tuple<std::function<void()>, bool, bool>> timer_jobs; // id => {func, squelch, running}
|
||||
struct timer_data { std::function<void()> function; bool squelch; bool running; int thread; };
|
||||
std::unordered_map<int, timer_data> timer_jobs;
|
||||
std::unique_ptr<void, TimersDeleter> timers;
|
||||
public:
|
||||
// This needs to be public because we have to be able to call it from a plain C function.
|
||||
|
@ -401,8 +433,8 @@ private:
|
|||
/// Number of active workers
|
||||
int active_workers() const { return workers.size() - idle_workers.size(); }
|
||||
|
||||
/// Worker thread loop
|
||||
void worker_thread(unsigned int index);
|
||||
/// Worker thread loop. Tagged and start are provided for a tagged worker thread.
|
||||
void worker_thread(unsigned int index, std::optional<std::string> tagged = std::nullopt, std::function<void()> start = nullptr);
|
||||
|
||||
/// If set, skip polling for one proxy loop iteration (set when we know we have something
|
||||
/// processible without having to shove it onto a socket, such as scheduling an internal job).
|
||||
|
@ -458,7 +490,7 @@ private:
|
|||
// provided then the connection will be curve25519 encrypted and authenticate; otherwise it will
|
||||
// be unencrypted and unauthenticated. Note that the remote end must be in the same mode (i.e.
|
||||
// either accepting curve connections, or not accepting curve).
|
||||
void setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey = {});
|
||||
void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey = {});
|
||||
|
||||
/// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and,
|
||||
/// if a routing prefix is needed, the required prefix (or an empty string if not needed). For
|
||||
|
@ -474,7 +506,7 @@ private:
|
|||
/// @param keep_alive the keep alive for the connection, if we establish a new outgoing
|
||||
/// connection. If we already have an outgoing connection then its keep-alive gets increased to
|
||||
/// this if currently less than this.
|
||||
std::pair<zmq::socket_t*, std::string> proxy_connect_sn(string_view pubkey, string_view connect_hint,
|
||||
std::pair<zmq::socket_t*, std::string> proxy_connect_sn(std::string_view pubkey, std::string_view connect_hint,
|
||||
bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive);
|
||||
|
||||
/// CONNECT_SN command telling us to connect to a new pubkey. Returns the socket (which could
|
||||
|
@ -499,7 +531,8 @@ private:
|
|||
|
||||
/// Currently active batch/reply jobs; this is the container that owns the Batch instances
|
||||
std::unordered_set<detail::Batch*> batches;
|
||||
/// Individual batch jobs waiting to run
|
||||
/// Individual batch jobs waiting to run; .second is the 0-n batch number or -1 for the
|
||||
/// completion job
|
||||
using batch_job = std::pair<detail::Batch*, int>;
|
||||
std::queue<batch_job> batch_jobs, reply_jobs;
|
||||
int batch_jobs_active = 0;
|
||||
|
@ -519,7 +552,7 @@ private:
|
|||
void proxy_timer(bt_list_consumer timer_data);
|
||||
|
||||
/// Same, but deserialized
|
||||
void proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch);
|
||||
void proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, int thread);
|
||||
|
||||
/// ZAP (https://rfc.zeromq.org/spec:27/ZAP/) authentication handler; this does non-blocking
|
||||
/// processing of any waiting authentication requests for new incoming connections.
|
||||
|
@ -571,31 +604,53 @@ private:
|
|||
bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
|
||||
zmq::message_t& command, const cat_call_t& cat_call, std::vector<zmq::message_t>& data);
|
||||
|
||||
struct injected_task {
|
||||
category& cat;
|
||||
std::string command;
|
||||
std::string remote;
|
||||
std::function<void()> callback;
|
||||
};
|
||||
|
||||
/// Injects a external callback to be handled by a worker; this is the proxy side of
|
||||
/// inject_task().
|
||||
void proxy_inject_task(injected_task task);
|
||||
|
||||
|
||||
/// Set of active service nodes.
|
||||
pubkey_set active_service_nodes;
|
||||
|
||||
/// Resets or updates the stored set of active SN pubkeys
|
||||
void proxy_set_active_sns(string_view data);
|
||||
void proxy_set_active_sns(std::string_view data);
|
||||
void proxy_set_active_sns(pubkey_set pubkeys);
|
||||
void proxy_update_active_sns(bt_list_consumer data);
|
||||
void proxy_update_active_sns(pubkey_set added, pubkey_set removed);
|
||||
void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed);
|
||||
|
||||
/// Details for a pending command; such a command already has authenticated access and is just
|
||||
/// waiting for a thread to become available to handle it.
|
||||
/// waiting for a thread to become available to handle it. This also gets used (via the
|
||||
/// `callback` variant) for injected external jobs to be able to integrate some external
|
||||
/// interface with the lokimq job queue.
|
||||
struct pending_command {
|
||||
category& cat;
|
||||
std::string command;
|
||||
std::vector<zmq::message_t> data_parts;
|
||||
const std::pair<CommandCallback, bool>* callback;
|
||||
std::variant<
|
||||
const std::pair<CommandCallback, bool>*, // Normal command callback
|
||||
std::function<void()> // Injected external callback
|
||||
> callback;
|
||||
ConnectionID conn;
|
||||
Access access;
|
||||
std::string remote;
|
||||
|
||||
// Normal ctor for an actual lmq command being processed
|
||||
pending_command(category& cat, std::string command, std::vector<zmq::message_t> data_parts,
|
||||
const std::pair<CommandCallback, bool>* callback, ConnectionID conn, Access access, std::string remote)
|
||||
: cat{cat}, command{std::move(command)}, data_parts{std::move(data_parts)},
|
||||
callback{callback}, conn{std::move(conn)}, access{std::move(access)}, remote{std::move(remote)} {}
|
||||
|
||||
// Ctor for an injected external command.
|
||||
pending_command(category& cat, std::string command, std::function<void()> callback, std::string remote)
|
||||
: cat{cat}, command{std::move(command)}, callback{std::move(callback)}, remote{std::move(remote)} {}
|
||||
};
|
||||
std::list<pending_command> pending_commands;
|
||||
|
||||
|
@ -609,9 +664,16 @@ private:
|
|||
struct run_info {
|
||||
bool is_batch_job = false;
|
||||
bool is_reply_job = false;
|
||||
bool is_tagged_thread_job = false;
|
||||
bool is_injected = false;
|
||||
|
||||
// resets the job type bools, above.
|
||||
void reset() { is_batch_job = is_reply_job = is_tagged_thread_job = is_injected = false; }
|
||||
|
||||
// If is_batch_job is false then these will be set appropriate (if is_batch_job is true then
|
||||
// these shouldn't be accessed and likely contain stale data).
|
||||
// these shouldn't be accessed and likely contain stale data). Note that if the command is
|
||||
// an external, injected command then conn, access, conn_route, and data_parts will be
|
||||
// empty/default constructed.
|
||||
category *cat;
|
||||
std::string command;
|
||||
ConnectionID conn; // The connection (or SN pubkey) to reply on/to.
|
||||
|
@ -623,24 +685,31 @@ private:
|
|||
// If is_batch_job true then these are set (if is_batch_job false then don't access these!):
|
||||
int batch_jobno; // >= 0 for a job, -1 for the completion job
|
||||
|
||||
union {
|
||||
const std::pair<CommandCallback, bool>* callback; // set if !is_batch_job
|
||||
detail::Batch* batch; // set if is_batch_job
|
||||
};
|
||||
// The callback or batch job to run. The first of these is for regular tasks, the second
|
||||
// for batch jobs, the third for injected external tasks.
|
||||
std::variant<
|
||||
const std::pair<CommandCallback, bool>*,
|
||||
detail::Batch*,
|
||||
std::function<void()>
|
||||
> to_run;
|
||||
|
||||
// These belong to the proxy thread and must not be accessed by a worker:
|
||||
std::thread worker_thread;
|
||||
size_t worker_id; // The index in `workers`
|
||||
std::string worker_routing_id; // "w123" where 123 == worker_id
|
||||
size_t worker_id; // The index in `workers` (0-n) or index+1 in `tagged_workers` (1-n)
|
||||
std::string worker_routing_id; // "w123" where 123 == worker_id; "n123" for tagged threads.
|
||||
|
||||
/// Loads the run info with an incoming command
|
||||
run_info& load(category* cat, std::string command, ConnectionID conn, Access access, std::string remote,
|
||||
std::vector<zmq::message_t> data_parts, const std::pair<CommandCallback, bool>* callback);
|
||||
|
||||
/// Loads the run info with an injected external command
|
||||
run_info& load(category* cat, std::string command, std::string remote, std::function<void()> callback);
|
||||
|
||||
/// Loads the run info with a stored pending command
|
||||
run_info& load(pending_command&& pending);
|
||||
|
||||
/// Loads the run info with a batch job
|
||||
run_info& load(batch_job&& bj, bool reply_job = false);
|
||||
run_info& load(batch_job&& bj, bool reply_job = false, int tagged_thread = 0);
|
||||
};
|
||||
/// Data passed to workers for the RUN command. The proxy thread sets elements in this before
|
||||
/// sending RUN to a worker then the worker uses it to get call info, and only allocates it
|
||||
|
@ -648,6 +717,11 @@ private:
|
|||
/// change it.
|
||||
std::vector<run_info> workers;
|
||||
|
||||
/// Workers that are reserved for tagged thread tasks (as created with add_tagged_thread). The
|
||||
/// queue here is similar to worker_jobs, but contains only the tagged thread's jobs. The bool
|
||||
/// is whether the worker is currently busy (true) or available (false).
|
||||
std::vector<std::tuple<run_info, bool, std::queue<batch_job>>> tagged_workers;
|
||||
|
||||
public:
|
||||
/**
|
||||
* LokiMQ constructor. This constructs the object but does not start it; you will typically
|
||||
|
@ -786,6 +860,28 @@ public:
|
|||
*/
|
||||
void add_command_alias(std::string from, std::string to);
|
||||
|
||||
/** Creates a "tagged thread" and starts it immediately. A tagged thread is one that batches,
|
||||
* jobs, and timer jobs can be sent to specifically, typically to perform coordination of some
|
||||
* thread-unsafe work.
|
||||
*
|
||||
* Tagged threads will *only* process jobs sent specifically to them; they do not participate in
|
||||
* the thread pool used for regular jobs. Each tagged thread also has its own job queue
|
||||
* completely separate from any other jobs.
|
||||
*
|
||||
* Tagged threads must be created *before* `start()` is called. The name will be used to set the
|
||||
* thread name in the process table (if supported on the OS).
|
||||
*
|
||||
* \param name - the name of the thread; will be used in log messages and (if supported by the
|
||||
* OS) as the system thread name.
|
||||
*
|
||||
* \param start - an optional callback to invoke from the thread as soon as LokiMQ itself starts
|
||||
* up (i.e. after a call to `start()`).
|
||||
*
|
||||
* \returns a TaggedThreadID object that can be passed to job(), batch(), or add_timer() to
|
||||
* direct the task to the tagged thread.
|
||||
*/
|
||||
TaggedThreadID add_tagged_thread(std::string name, std::function<void()> start = nullptr);
|
||||
|
||||
/**
|
||||
* Sets the number of worker threads reserved for batch jobs. If not explicitly called then
|
||||
* this defaults to half the general worker threads configured (rounded up). This works exactly
|
||||
|
@ -895,7 +991,7 @@ public:
|
|||
* *don't* need to worry about this (and can just discard it): you can always simply pass the
|
||||
* pubkey as a string wherever a ConnectionID is called.
|
||||
*/
|
||||
ConnectionID connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive = 5min, string_view hint = {});
|
||||
ConnectionID connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive = 5min, std::string_view hint = {});
|
||||
|
||||
/**
|
||||
* Establish a connection to the given remote with callbacks invoked on a successful or failed
|
||||
|
@ -912,12 +1008,11 @@ public:
|
|||
* The `on_connect` and `on_failure` callbacks are invoked when a connection has been
|
||||
* established or failed to establish.
|
||||
*
|
||||
* @param remote the remote connection address, such as `tcp://localhost:1234`.
|
||||
* @param remote the remote connection address either as implicitly from a string or as a full
|
||||
* lokimq::address object; see address.h for details. This specifies both the connection
|
||||
* address and whether curve encryption should be used.
|
||||
* @param on_connect called with the identifier after the connection has been established.
|
||||
* @param on_failure called with the identifier and failure message if we fail to connect.
|
||||
* @param pubkey if non-empty then connect securely (using curve encryption) and verify that the
|
||||
* remote's pubkey equals the given value. Specifying this is similar to using connect_sn()
|
||||
* except that we do not treat the remote as a SN for command authorization purposes.
|
||||
* @param auth_level determines the authentication level of the remote for issuing commands to
|
||||
* us. The default is `AuthLevel::none`.
|
||||
* @param timeout how long to try before aborting the connection attempt and calling the
|
||||
|
@ -927,8 +1022,23 @@ public:
|
|||
* @param returns ConnectionID that uniquely identifies the connection to this remote node. In
|
||||
* order to talk to it you will need the returned value (or a copy of it).
|
||||
*/
|
||||
ConnectionID connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
string_view pubkey = {},
|
||||
ConnectionID connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT);
|
||||
|
||||
/// Same as the above, but takes the address as a string_view and constructs an `address` from
|
||||
/// it.
|
||||
ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT) {
|
||||
return connect_remote(address{remote}, std::move(on_connect), std::move(on_failure), auth_level, timeout);
|
||||
}
|
||||
|
||||
/// Deprecated version of the above that takes the remote address and remote pubkey for curve
|
||||
/// encryption as separate arguments. New code should either use a pubkey-embedded address
|
||||
/// string, or specify remote address and pubkey with an `address` object such as:
|
||||
/// connect_remote(address{remote, pubkey}, ...)
|
||||
[[deprecated("use connect_remote() with a lokimq::address instead")]]
|
||||
ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
std::string_view pubkey,
|
||||
AuthLevel auth_level = AuthLevel::none,
|
||||
std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT);
|
||||
|
||||
|
@ -987,7 +1097,7 @@ public:
|
|||
* connection hint may be used rather than performing a connection address lookup on the pubkey.
|
||||
*/
|
||||
template <typename... T>
|
||||
void send(ConnectionID to, string_view cmd, const T&... opts);
|
||||
void send(ConnectionID to, std::string_view cmd, const T&... opts);
|
||||
|
||||
/** Send a command configured as a "REQUEST" command to a service node: the data parts will be
|
||||
* prefixed with a random identifier. The remote is expected to reply with a ["REPLY",
|
||||
|
@ -1019,7 +1129,33 @@ public:
|
|||
* not running as a service node.
|
||||
*/
|
||||
template <typename... T>
|
||||
void request(ConnectionID to, string_view cmd, ReplyCallback callback, const T&... opts);
|
||||
void request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T&... opts);
|
||||
|
||||
/** Injects an external task into the lokimq command queue. This is used to allow connecting
|
||||
* non-LokiMQ requests into the LokiMQ thread pool as if they were ordinary requests, to be
|
||||
* scheduled as commands of an individual category. For example, you might support rpc requests
|
||||
* via LokiMQ as `rpc.some_command` and *also* accept them over HTTP. Using `inject_task()`
|
||||
* allows you to handle processing the request in the same thread pool with the same priority as
|
||||
* `rpc.*` commands.
|
||||
*
|
||||
* @param category - the category name that should handle the request for the purposes of
|
||||
* scheduling the job. The category must have been added using add_category(). The category
|
||||
* can be an actual category with added commands, in which case the injected tasks are queued
|
||||
* along with LMQ requests for that category, or can have no commands to set up a distinct
|
||||
* category for the injected jobs.
|
||||
*
|
||||
* @param command - a command name; this is mainly used for debugging and does not need to
|
||||
* actually exist (and, in fact, is often less confusing if it does not). It is recommended for
|
||||
* clarity purposes to use something that doesn't look like a typical command, for example
|
||||
* "(http)".
|
||||
*
|
||||
* @param remote - some free-form identifier of the remote connection. For example, this could
|
||||
* be a remote IP address. Can be blank if there is nothing suitable.
|
||||
*
|
||||
* @param callback - the function to call from a worker thread when the injected task is
|
||||
* processed. Takes no arguments.
|
||||
*/
|
||||
void inject_task(const std::string& category, std::string command, std::string remote, std::function<void()> callback);
|
||||
|
||||
/// The key pair this LokiMQ was created with; if empty keys were given during construction then
|
||||
/// this returns the generated keys.
|
||||
|
@ -1072,8 +1208,12 @@ public:
|
|||
/**
|
||||
* Queues a single job to be executed with no return value. This is a shortcut for creating and
|
||||
* submitting a single-job, no-completion-function batch job.
|
||||
*
|
||||
* \param f the callback to invoke
|
||||
* \param thread an optional tagged thread in which this job should run. You may *not* pass the
|
||||
* proxy thread here.
|
||||
*/
|
||||
void job(std::function<void()> f);
|
||||
void job(std::function<void()> f, std::optional<TaggedThreadID> = std::nullopt);
|
||||
|
||||
/**
|
||||
* Adds a timer that gets scheduled periodically in the job queue. Normally jobs are not
|
||||
|
@ -1081,8 +1221,10 @@ public:
|
|||
* previously scheduled callback of the job has not yet completed. If you want to override this
|
||||
* (so that, under heavy load or long jobs, there can be more than one of the same job scheduled
|
||||
* or running at a time) then specify `squelch` as `false`.
|
||||
*
|
||||
* \param thread specifies a thread (added with add_tagged_thread()) on which this timer must run.
|
||||
*/
|
||||
void add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch = true);
|
||||
void add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch = true, std::optional<TaggedThreadID> = std::nullopt);
|
||||
};
|
||||
|
||||
/// Helper class that slightly simplifies adding commands to a category.
|
||||
|
@ -1129,9 +1271,10 @@ struct data_parts_impl {
|
|||
data_parts_impl(InputIt begin, InputIt end) : begin{std::move(begin)}, end{std::move(end)} {}
|
||||
};
|
||||
|
||||
/// Specifies an iterator pair of data options to send, for when the number of arguments to send()
|
||||
/// cannot be determined at compile time.
|
||||
template <typename InputIt>
|
||||
/// Specifies an iterator pair of data parts to send, for when the number of arguments to send()
|
||||
/// cannot be determined at compile time. The iterator pair must be over strings or string_view (or
|
||||
/// something convertible to a string_view).
|
||||
template <typename InputIt, typename = std::enable_if_t<std::is_convertible_v<decltype(*std::declval<InputIt>()), std::string_view>>>
|
||||
data_parts_impl<InputIt> data_parts(InputIt begin, InputIt end) { return {std::move(begin), std::move(end)}; }
|
||||
|
||||
/// Specifies a connection hint when passed in to send(). If there is no current connection to the
|
||||
|
@ -1250,10 +1393,10 @@ template <typename T> T deserialize_object(uintptr_t ptrval) {
|
|||
|
||||
// Sends a control message to the given socket consisting of the command plus optional dict
|
||||
// data (only sent if the data is non-empty).
|
||||
void send_control(zmq::socket_t& sock, string_view cmd, std::string data = {});
|
||||
void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data = {});
|
||||
|
||||
/// Base case: takes a string-like value and appends it to the message parts
|
||||
inline void apply_send_option(bt_list& parts, bt_dict&, string_view arg) {
|
||||
inline void apply_send_option(bt_list& parts, bt_dict&, std::string_view arg) {
|
||||
parts.emplace_back(arg);
|
||||
}
|
||||
|
||||
|
@ -1261,7 +1404,7 @@ inline void apply_send_option(bt_list& parts, bt_dict&, string_view arg) {
|
|||
template <typename InputIt>
|
||||
void apply_send_option(bt_list& parts, bt_dict&, const send_option::data_parts_impl<InputIt> data) {
|
||||
for (auto it = data.begin; it != data.end; ++it)
|
||||
parts.push_back(lokimq::bt_deserialize(*it));
|
||||
parts.emplace_back(*it);
|
||||
}
|
||||
|
||||
/// `hint` specialization: sets the hint in the control data
|
||||
|
@ -1308,14 +1451,10 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queu
|
|||
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg);
|
||||
|
||||
template <typename... T>
|
||||
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) {
|
||||
bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts) {
|
||||
bt_dict control_data;
|
||||
bt_list parts{{cmd}};
|
||||
#ifdef __cpp_fold_expressions
|
||||
(detail::apply_send_option(parts, control_data, std::forward<T>(opts)),...);
|
||||
#else
|
||||
(void) std::initializer_list<int>{(detail::apply_send_option(parts, control_data, std::forward<T>(opts)), 0)...};
|
||||
#endif
|
||||
|
||||
if (to.sn())
|
||||
control_data["conn_pubkey"] = std::move(to.pk);
|
||||
|
@ -1332,7 +1471,7 @@ bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) {
|
|||
|
||||
|
||||
template <typename... T>
|
||||
void LokiMQ::send(ConnectionID to, string_view cmd, const T&... opts) {
|
||||
void LokiMQ::send(ConnectionID to, std::string_view cmd, const T&... opts) {
|
||||
detail::send_control(get_control_socket(), "SEND",
|
||||
bt_serialize(detail::build_send(std::move(to), cmd, opts...)));
|
||||
}
|
||||
|
@ -1340,17 +1479,17 @@ void LokiMQ::send(ConnectionID to, string_view cmd, const T&... opts) {
|
|||
std::string make_random_string(size_t size);
|
||||
|
||||
template <typename... T>
|
||||
void LokiMQ::request(ConnectionID to, string_view cmd, ReplyCallback callback, const T &...opts) {
|
||||
void LokiMQ::request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T &...opts) {
|
||||
const auto reply_tag = make_random_string(15); // 15 random bytes is lots and should keep us in most stl implementations' small string optimization
|
||||
bt_dict control_data = detail::build_send(std::move(to), cmd, reply_tag, opts...);
|
||||
control_data["request"] = true;
|
||||
control_data["request_callback"] = detail::serialize_object(std::move(callback));
|
||||
control_data["request_tag"] = string_view{reply_tag};
|
||||
control_data["request_tag"] = std::string_view{reply_tag};
|
||||
detail::send_control(get_control_socket(), "SEND", bt_serialize(std::move(control_data)));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void Message::send_back(string_view command, Args&&... args) {
|
||||
void Message::send_back(std::string_view command, Args&&... args) {
|
||||
lokimq.send(conn, command, send_option::optional{!conn.sn()}, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
@ -1361,14 +1500,14 @@ void Message::send_reply(Args&&... args) {
|
|||
}
|
||||
|
||||
template <typename Callback, typename... Args>
|
||||
void Message::send_request(string_view cmd, Callback&& callback, Args&&... args) {
|
||||
void Message::send_request(std::string_view cmd, Callback&& callback, Args&&... args) {
|
||||
lokimq.request(conn, cmd, std::forward<Callback>(callback),
|
||||
send_option::optional{!conn.sn()}, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// When log messages are invoked we strip out anything before this in the filename:
|
||||
constexpr string_view LOG_PREFIX{"lokimq/", 7};
|
||||
inline string_view trim_log_filename(string_view local_file) {
|
||||
constexpr std::string_view LOG_PREFIX{"lokimq/", 7};
|
||||
inline std::string_view trim_log_filename(std::string_view local_file) {
|
||||
auto chop = local_file.rfind(LOG_PREFIX);
|
||||
if (chop != local_file.npos)
|
||||
local_file.remove_prefix(chop);
|
||||
|
@ -1376,16 +1515,12 @@ inline string_view trim_log_filename(string_view local_file) {
|
|||
}
|
||||
|
||||
template <typename... T>
|
||||
void LokiMQ::log_(LogLevel lvl, const char* file, int line, const T&... stuff) {
|
||||
void LokiMQ::log(LogLevel lvl, const char* file, int line, const T&... stuff) {
|
||||
if (log_level() < lvl)
|
||||
return;
|
||||
|
||||
std::ostringstream os;
|
||||
#ifdef __cpp_fold_expressions
|
||||
(os << ... << stuff);
|
||||
#else
|
||||
(void) std::initializer_list<int>{(os << stuff, 0)...};
|
||||
#endif
|
||||
logger(lvl, trim_log_filename(file).data(), line, os.str());
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ class LokiMQ;
|
|||
class Message {
|
||||
public:
|
||||
LokiMQ& lokimq; ///< The owning LokiMQ object
|
||||
std::vector<string_view> data; ///< The provided command data parts, if any.
|
||||
std::vector<std::string_view> data; ///< The provided command data parts, if any.
|
||||
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
|
||||
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
|
||||
Access access; ///< The access level of the invoker. This can be higher than the access level of the command, for example for an admin invoking a basic command.
|
||||
|
@ -36,7 +36,7 @@ public:
|
|||
/// If you want to send a non-strong reply even when the remote is a service node then add
|
||||
/// an explicit `send_option::optional()` argument.
|
||||
template <typename... Args>
|
||||
void send_back(string_view, Args&&... args);
|
||||
void send_back(std::string_view, Args&&... args);
|
||||
|
||||
/// Sends a reply to a request. This takes no command: the command is always the built-in
|
||||
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
|
||||
|
@ -51,7 +51,7 @@ public:
|
|||
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
|
||||
/// lmq.request() that takes care of setting up the recipient arguments.
|
||||
template <typename ReplyCallback, typename... Args>
|
||||
void send_request(string_view cmd, ReplyCallback&& callback, Args&&... args);
|
||||
void send_request(std::string_view cmd, ReplyCallback&& callback, Args&&... args);
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -9,6 +9,13 @@ extern "C" {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifndef _WIN32
|
||||
extern "C" {
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
|
@ -16,11 +23,12 @@ void LokiMQ::proxy_quit() {
|
|||
LMQ_LOG(debug, "Received quit command, shutting down proxy thread");
|
||||
|
||||
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
|
||||
assert(std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto& worker) { return std::get<0>(worker).worker_thread.joinable(); }));
|
||||
|
||||
command.setsockopt<int>(ZMQ_LINGER, 0);
|
||||
command.close();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{control_sockets_mutex};
|
||||
std::lock_guard lock{control_sockets_mutex};
|
||||
for (auto &control : thread_control_sockets)
|
||||
control->close();
|
||||
proxy_shutting_down = true; // To prevent threads from opening new control sockets
|
||||
|
@ -37,7 +45,7 @@ void LokiMQ::proxy_quit() {
|
|||
|
||||
void LokiMQ::proxy_send(bt_dict_consumer data) {
|
||||
// NB: bt_dict_consumer goes in alphabetical order
|
||||
string_view hint;
|
||||
std::string_view hint;
|
||||
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
|
||||
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
|
||||
bool optional = false;
|
||||
|
@ -106,7 +114,6 @@ void LokiMQ::proxy_send(bt_dict_consumer data) {
|
|||
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
|
||||
// connection and try the next one.
|
||||
bool retry = true, sent = false, nowarn = false;
|
||||
std::unique_ptr<zmq::error_t> send_error;
|
||||
while (retry) {
|
||||
retry = false;
|
||||
zmq::socket_t *send_to;
|
||||
|
@ -265,6 +272,9 @@ void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
|
|||
LMQ_TRACE("proxy batch jobs");
|
||||
auto ptrval = bt_deserialize<uintptr_t>(data);
|
||||
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
|
||||
} else if (cmd == "INJECT") {
|
||||
LMQ_TRACE("proxy inject");
|
||||
return proxy_inject_task(detail::deserialize_object<injected_task>(bt_deserialize<uintptr_t>(data)));
|
||||
} else if (cmd == "SET_SNS") {
|
||||
return proxy_set_active_sns(data);
|
||||
} else if (cmd == "UPDATE_SNS") {
|
||||
|
@ -292,6 +302,9 @@ void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
|
|||
for (const auto &route : idle_workers)
|
||||
route_control(workers_socket, workers[route].worker_routing_id, "QUIT");
|
||||
idle_workers.clear();
|
||||
for (auto& [run, busy, queue] : tagged_workers)
|
||||
if (!busy)
|
||||
route_control(workers_socket, run.worker_routing_id, "QUIT");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -332,12 +345,19 @@ void LokiMQ::proxy_loop() {
|
|||
LMQ_LOG(debug, " - ", cat.first, ": ", cat.second.reserved_threads);
|
||||
LMQ_LOG(debug, " - (batch jobs): ", batch_jobs_reserved);
|
||||
LMQ_LOG(debug, " - (reply jobs): ", reply_jobs_reserved);
|
||||
LMQ_LOG(debug, "Plus ", tagged_workers.size(), " tagged worker threads");
|
||||
}
|
||||
|
||||
workers.reserve(max_workers);
|
||||
if (!workers.empty())
|
||||
throw std::logic_error("Internal error: proxy thread started with active worker threads");
|
||||
|
||||
#ifndef _WIN32
|
||||
int saved_umask = -1;
|
||||
if (STARTUP_UMASK >= 0)
|
||||
saved_umask = umask(STARTUP_UMASK);
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < bind.size(); i++) {
|
||||
auto& b = bind[i].second;
|
||||
zmq::socket_t listener{context, zmq::socket_type::router};
|
||||
|
@ -362,6 +382,24 @@ void LokiMQ::proxy_loop() {
|
|||
incoming_conn_index[conn_id] = connections.size() - 1;
|
||||
b.index = connections.size() - 1;
|
||||
}
|
||||
|
||||
#ifndef _WIN32
|
||||
if (saved_umask != -1)
|
||||
umask(saved_umask);
|
||||
|
||||
// set socket gid / uid if it is provided
|
||||
if (SOCKET_GID != -1 or SOCKET_UID != -1) {
|
||||
for(size_t i = 0; i < bind.size(); i++) {
|
||||
const address addr(bind[i].first);
|
||||
if(addr.ipc()) {
|
||||
if(chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) {
|
||||
throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
pollitems_stale = true;
|
||||
|
||||
// Also add an internal connection to self so that calling code can avoid needing to
|
||||
|
@ -385,10 +423,38 @@ void LokiMQ::proxy_loop() {
|
|||
|
||||
std::vector<zmq::message_t> parts;
|
||||
|
||||
// Wait for tagged worker threads to get ready and connect to us (we get a "STARTING" message)
|
||||
// and send them back a "START" to let them know to go ahead with startup. We need this
|
||||
// synchronization dance to guarantee that the workers are routable before we can proceed.
|
||||
if (!tagged_workers.empty()) {
|
||||
LMQ_LOG(debug, "Waiting for tagged workers");
|
||||
std::unordered_set<std::string_view> waiting_on;
|
||||
for (auto& w : tagged_workers)
|
||||
waiting_on.emplace(std::get<run_info>(w).worker_routing_id);
|
||||
for (; !waiting_on.empty(); parts.clear()) {
|
||||
recv_message_parts(workers_socket, parts);
|
||||
if (parts.size() != 2 || view(parts[1]) != "STARTING"sv) {
|
||||
LMQ_LOG(error, "Received invalid message on worker socket while waiting for tagged thread startup");
|
||||
continue;
|
||||
}
|
||||
LMQ_LOG(debug, "Received STARTING message from ", view(parts[0]));
|
||||
if (auto it = waiting_on.find(view(parts[0])); it != waiting_on.end())
|
||||
waiting_on.erase(it);
|
||||
else
|
||||
LMQ_LOG(error, "Received STARTING message from unknown worker ", view(parts[0]));
|
||||
}
|
||||
|
||||
for (auto&w : tagged_workers) {
|
||||
LMQ_LOG(debug, "Telling tagged thread worker ", std::get<run_info>(w).worker_routing_id, " to finish startup");
|
||||
route_control(workers_socket, std::get<run_info>(w).worker_routing_id, "START");
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
std::chrono::milliseconds poll_timeout;
|
||||
if (max_workers == 0) { // Will be 0 only if we are quitting
|
||||
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); })) {
|
||||
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); }) &&
|
||||
std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto &w) { return std::get<0>(w).worker_thread.joinable(); })) {
|
||||
// All the workers have finished, so we can finish shutting down
|
||||
return proxy_quit();
|
||||
}
|
||||
|
@ -478,7 +544,7 @@ void LokiMQ::proxy_loop() {
|
|||
}
|
||||
}
|
||||
|
||||
static bool is_error_response(string_view cmd) {
|
||||
static bool is_error_response(std::string_view cmd) {
|
||||
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
|
||||
}
|
||||
|
||||
|
@ -488,7 +554,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
|
|||
// Doubling as a bool and an offset:
|
||||
size_t incoming = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_ROUTER;
|
||||
|
||||
string_view route, cmd;
|
||||
std::string_view route, cmd;
|
||||
if (parts.size() < 1 + incoming) {
|
||||
LMQ_LOG(warn, "Received empty message; ignoring");
|
||||
return true;
|
||||
|
@ -598,7 +664,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
|
|||
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
} else {
|
||||
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"_sv),
|
||||
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"sv),
|
||||
" from ", peer_address(parts.back()));
|
||||
}
|
||||
return true;
|
||||
|
@ -607,7 +673,19 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
|
|||
}
|
||||
|
||||
void LokiMQ::proxy_process_queue() {
|
||||
// First up: process any batch jobs; since these are internal they are given higher priority.
|
||||
if (max_workers == 0) // shutting down
|
||||
return;
|
||||
|
||||
// First: send any tagged thread tasks to the tagged threads, if idle
|
||||
for (auto& [run, busy, queue] : tagged_workers) {
|
||||
if (!busy && !queue.empty()) {
|
||||
busy = true;
|
||||
proxy_run_worker(run.load(std::move(queue.front()), false, run.worker_id));
|
||||
queue.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Second: process any batch jobs; since these are internal they are given higher priority.
|
||||
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
|
||||
|
||||
// Next any reply batch jobs (which are a bit different from the above, since they are
|
||||
|
@ -630,5 +708,4 @@ void LokiMQ::proxy_process_queue() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -1,310 +1,15 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#ifdef __cpp_lib_string_view
|
||||
|
||||
#include <string_view>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
// Deprecated type alias for std::string_view
|
||||
using string_view = std::string_view;
|
||||
using ustring_view = std::basic_string_view<unsigned char>;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#include <ostream>
|
||||
#include <limits>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
/// Basic implementation of std::string_view (except for std::hash support).
|
||||
template <typename CharT>
|
||||
class simple_string_view {
|
||||
const CharT *data_;
|
||||
size_t size_;
|
||||
public:
|
||||
using traits_type = std::char_traits<CharT>;
|
||||
using value_type = CharT;
|
||||
using pointer = CharT*;
|
||||
using const_pointer = const CharT*;
|
||||
using reference = CharT&;
|
||||
using const_reference = const CharT&;
|
||||
using const_iterator = const_pointer;
|
||||
using iterator = const_iterator;
|
||||
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
|
||||
using reverse_iterator = const_reverse_iterator;
|
||||
using size_type = std::size_t;
|
||||
using different_type = std::ptrdiff_t;
|
||||
|
||||
static constexpr auto& npos = std::string::npos;
|
||||
|
||||
constexpr simple_string_view() noexcept : data_{nullptr}, size_{0} {}
|
||||
constexpr simple_string_view(const simple_string_view&) noexcept = default;
|
||||
simple_string_view(const std::basic_string<CharT>& str) : data_{str.data()}, size_{str.size()} {}
|
||||
constexpr simple_string_view(const CharT* data, size_t size) noexcept : data_{data}, size_{size} {}
|
||||
simple_string_view(const CharT* data) : data_{data}, size_{traits_type::length(data)} {}
|
||||
simple_string_view& operator=(const simple_string_view&) = default;
|
||||
constexpr const CharT* data() const noexcept { return data_; }
|
||||
constexpr size_t size() const noexcept { return size_; }
|
||||
constexpr size_t length() const noexcept { return size_; }
|
||||
constexpr size_t max_size() const noexcept { return std::numeric_limits<size_t>::max(); }
|
||||
constexpr bool empty() const noexcept { return size_ == 0; }
|
||||
explicit operator std::basic_string<CharT>() const { return {data_, size_}; }
|
||||
constexpr const CharT* begin() const noexcept { return data_; }
|
||||
constexpr const CharT* cbegin() const noexcept { return data_; }
|
||||
constexpr const CharT* end() const noexcept { return data_ + size_; }
|
||||
constexpr const CharT* cend() const noexcept { return data_ + size_; }
|
||||
reverse_iterator rbegin() const { return reverse_iterator{end()}; }
|
||||
reverse_iterator crbegin() const { return reverse_iterator{end()}; }
|
||||
reverse_iterator rend() const { return reverse_iterator{begin()}; }
|
||||
reverse_iterator crend() const { return reverse_iterator{begin()}; }
|
||||
constexpr const CharT& operator[](size_t pos) const { return data_[pos]; }
|
||||
constexpr const CharT& front() const { return *data_; }
|
||||
constexpr const CharT& back() const { return data_[size_ - 1]; }
|
||||
int compare(simple_string_view s) const;
|
||||
constexpr void remove_prefix(size_t n) { data_ += n; size_ -= n; }
|
||||
constexpr void remove_suffix(size_t n) { size_ -= n; }
|
||||
void swap(simple_string_view &s) noexcept { std::swap(data_, s.data_); std::swap(size_, s.size_); }
|
||||
|
||||
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
|
||||
constexpr // GCC 5.x is buggy wrt constexpr throwing
|
||||
#endif
|
||||
const CharT& at(size_t pos) const {
|
||||
if (pos >= size())
|
||||
throw std::out_of_range{"invalid string_view index"};
|
||||
return data_[pos];
|
||||
};
|
||||
|
||||
size_t copy(CharT* dest, size_t count, size_t pos = 0) const {
|
||||
if (pos > size()) throw std::out_of_range{"invalid copy pos"};
|
||||
size_t rcount = std::min(count, size_ - pos);
|
||||
traits_type::copy(dest, data_ + pos, rcount);
|
||||
return rcount;
|
||||
}
|
||||
|
||||
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
|
||||
constexpr // GCC 5.x is buggy wrt constexpr throwing
|
||||
#endif
|
||||
simple_string_view substr(size_t pos = 0, size_t count = npos) const {
|
||||
if (pos > size()) throw std::out_of_range{"invalid substr range"};
|
||||
simple_string_view result = *this;
|
||||
if (pos > 0) result.remove_prefix(pos);
|
||||
if (count < result.size()) result.remove_suffix(result.size() - count);
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t find(simple_string_view v, size_t pos = 0) const {
|
||||
if (pos > size_ || v.size_ > size_) return npos;
|
||||
for (const size_t max_pos = size_ - v.size_; pos <= max_pos; ++pos) {
|
||||
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
|
||||
return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
size_t find(CharT c, size_t pos = 0) const { return find({&c, 1}, pos); }
|
||||
size_t find(const CharT* c, size_t pos, size_t count) const { return find({c, count}, pos); }
|
||||
size_t find(const CharT* c, size_t pos = 0) const { return find(simple_string_view(c), pos); }
|
||||
|
||||
size_t rfind(simple_string_view v, size_t pos = npos) const {
|
||||
if (v.size_ > size_) return npos;
|
||||
const size_t max_pos = size_ - v.size_;
|
||||
for (pos = std::min(pos, max_pos); pos <= max_pos; --pos) {
|
||||
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
|
||||
return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
size_t rfind(CharT c, size_t pos = npos) const { return rfind({&c, 1}, pos); }
|
||||
size_t rfind(const CharT* c, size_t pos, size_t count) const { return rfind({c, count}, pos); }
|
||||
size_t rfind(const CharT* c, size_t pos = npos) const { return rfind(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_first_of(simple_string_view v, size_t pos = 0) const noexcept {
|
||||
for (; pos < size_; ++pos)
|
||||
for (CharT c : v)
|
||||
if (data_[pos] == c)
|
||||
return pos;
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_first_of(CharT c, size_t pos = 0) const noexcept { return find_first_of({&c, 1}, pos); }
|
||||
constexpr size_t find_first_of(const CharT* c, size_t pos, size_t count) const { return find_first_of({c, count}, pos); }
|
||||
size_t find_first_of(const CharT* c, size_t pos = 0) const { return find_first_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_last_of(simple_string_view v, const size_t pos = npos) const noexcept {
|
||||
if (size_ == 0) return npos;
|
||||
const size_t last_pos = std::min(pos, size_-1);
|
||||
for (size_t i = last_pos; i <= last_pos; --i)
|
||||
for (CharT c : v)
|
||||
if (data_[i] == c)
|
||||
return i;
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_last_of(CharT c, size_t pos = npos) const noexcept { return find_last_of({&c, 1}, pos); }
|
||||
constexpr size_t find_last_of(const CharT* c, size_t pos, size_t count) const { return find_last_of({c, count}, pos); }
|
||||
size_t find_last_of(const CharT* c, size_t pos = npos) const { return find_last_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_first_not_of(simple_string_view v, size_t pos = 0) const noexcept {
|
||||
for (; pos < size_; ++pos) {
|
||||
bool none = true;
|
||||
for (CharT c : v) {
|
||||
if (data_[pos] == c) {
|
||||
none = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (none) return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_first_not_of(CharT c, size_t pos = 0) const noexcept { return find_first_not_of({&c, 1}, pos); }
|
||||
constexpr size_t find_first_not_of(const CharT* c, size_t pos, size_t count) const { return find_first_not_of({c, count}, pos); }
|
||||
size_t find_first_not_of(const CharT* c, size_t pos = 0) const { return find_first_not_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_last_not_of(simple_string_view v, const size_t pos = npos) const noexcept {
|
||||
if (size_ == 0) return npos;
|
||||
const size_t last_pos = std::min(pos, size_-1);
|
||||
for (size_t i = last_pos; i <= last_pos; --i) {
|
||||
bool none = true;
|
||||
for (CharT c : v) {
|
||||
if (data_[i] == c) {
|
||||
none = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (none) return i;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_last_not_of(CharT c, size_t pos = npos) const noexcept { return find_last_not_of({&c, 1}, pos); }
|
||||
constexpr size_t find_last_not_of(const CharT* c, size_t pos, size_t count) const { return find_last_not_of({c, count}, pos); }
|
||||
size_t find_last_not_of(const CharT* c, size_t pos = npos) const { return find_last_not_of(simple_string_view(c), pos); }
|
||||
};
|
||||
/// We have three of each of these: one with two string views, one with RHS argument deduction, and
|
||||
/// one with LHS argument deduction, so that you can do (sv == sv), (sv == "foo"), and ("foo" == sv)
|
||||
template <typename CharT>
|
||||
inline bool operator==(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator==(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator==(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator!=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename CharT>
|
||||
inline bool operator!=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename CharT>
|
||||
inline bool operator!=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename CharT>
|
||||
inline int simple_string_view<CharT>::compare(simple_string_view s) const {
|
||||
int cmp = std::char_traits<CharT>::compare(data_, s.data(), std::min(size_, s.size()));
|
||||
if (cmp) return cmp;
|
||||
if (size_ < s.size()) return -1;
|
||||
else if (size_ > s.size()) return 1;
|
||||
return 0;
|
||||
}
|
||||
template <typename CharT>
|
||||
inline bool operator<(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) < 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) < 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) < 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) <= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) <= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) <= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) > 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) > 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) > 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) >= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) >= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) >= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline std::basic_ostream<CharT>& operator<<(std::basic_ostream<CharT>& os, const simple_string_view<CharT>& s) {
|
||||
os.write(s.data(), s.size());
|
||||
return os;
|
||||
}
|
||||
|
||||
using string_view = simple_string_view<char>;
|
||||
using ustring_view = simple_string_view<unsigned char>;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Add a "foo"_sv literal that works exactly like the C++17 "foo"sv literal, but works with our
|
||||
// implementation in pre-C++17.
|
||||
namespace lokimq {
|
||||
// Deprecated "foo"_sv literal; you should use "foo"sv (from <string_view>) instead.
|
||||
inline namespace literals {
|
||||
inline constexpr string_view operator""_sv(const char* str, size_t len) { return {str, len}; }
|
||||
inline constexpr std::string_view operator""_sv(const char* str, size_t len) { return {str, len}; }
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
#include "lokimq.h"
|
||||
#include "batch.h"
|
||||
#include "hex.h"
|
||||
#include "lokimq-internal.h"
|
||||
|
||||
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
|
@ -12,36 +11,107 @@ extern "C" {
|
|||
|
||||
namespace lokimq {
|
||||
|
||||
void LokiMQ::worker_thread(unsigned int index) {
|
||||
std::string worker_id = "w" + std::to_string(index);
|
||||
namespace {
|
||||
|
||||
// Waits for a specific command or "QUIT" on the given socket. Returns true if the command was
|
||||
// received. If "QUIT" was received, replies with "QUITTING" on the socket and closes it, then
|
||||
// returns false.
|
||||
[[gnu::always_inline]] inline
|
||||
bool worker_wait_for(LokiMQ& lmq, zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const std::string_view worker_id, const std::string_view expect) {
|
||||
while (true) {
|
||||
lmq.log(LogLevel::debug, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect);
|
||||
parts.clear();
|
||||
recv_message_parts(sock, parts);
|
||||
if (parts.size() != 1) {
|
||||
lmq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part control msg");
|
||||
continue;
|
||||
}
|
||||
auto command = view(parts[0]);
|
||||
if (command == expect) {
|
||||
#ifndef NDEBUG
|
||||
lmq.log(LogLevel::trace, __FILE__, __LINE__, "Worker ", worker_id, " received waited-for ", expect, " command");
|
||||
#endif
|
||||
return true;
|
||||
} else if (command == "QUIT"sv) {
|
||||
lmq.log(LogLevel::debug, __FILE__, __LINE__, "Worker ", worker_id, " received QUIT command, shutting down");
|
||||
detail::send_control(sock, "QUITTING");
|
||||
sock.setsockopt<int>(ZMQ_LINGER, 1000);
|
||||
sock.close();
|
||||
return false;
|
||||
} else {
|
||||
lmq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void LokiMQ::worker_thread(unsigned int index, std::optional<std::string> tagged, std::function<void()> start) {
|
||||
std::string routing_id = (tagged ? "t" : "w") + std::to_string(index); // for routing
|
||||
std::string_view worker_id{tagged ? *tagged : routing_id}; // for debug
|
||||
|
||||
[[maybe_unused]] std::string thread_name = tagged.value_or("lmq-" + routing_id);
|
||||
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
|
||||
pthread_setname_np(pthread_self(), ("lmq-" + worker_id).c_str());
|
||||
if (thread_name.size() > 15) thread_name.resize(15);
|
||||
pthread_setname_np(pthread_self(), thread_name.c_str());
|
||||
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
pthread_set_name_np(pthread_self(), ("lmq-" + worker_id).c_str());
|
||||
pthread_set_name_np(pthread_self(), thread_name.c_str());
|
||||
#elif defined(__MACH__)
|
||||
pthread_setname_np(("lmq-" + worker_id).c_str());
|
||||
pthread_setname_np(thread_name.c_str());
|
||||
#endif
|
||||
|
||||
zmq::socket_t sock{context, zmq::socket_type::dealer};
|
||||
sock.setsockopt(ZMQ_ROUTING_ID, worker_id.data(), worker_id.size());
|
||||
LMQ_LOG(debug, "New worker thread ", worker_id, " started");
|
||||
sock.setsockopt(ZMQ_ROUTING_ID, routing_id.data(), routing_id.size());
|
||||
LMQ_LOG(debug, "New worker thread ", worker_id, " (", routing_id, ") started");
|
||||
sock.connect(SN_ADDR_WORKERS);
|
||||
if (tagged)
|
||||
detail::send_control(sock, "STARTING");
|
||||
|
||||
Message message{*this, 0, AuthLevel::none, ""s};
|
||||
std::vector<zmq::message_t> parts;
|
||||
run_info& run = workers[index]; // This contains our first job, and will be updated later with subsequent jobs
|
||||
|
||||
bool waiting_for_command;
|
||||
if (tagged) {
|
||||
// If we're a tagged worker then we got started up before LokiMQ started, so we need to wait
|
||||
// for an all-clear signal from LokiMQ first, then we fire our `start` callback, then we can
|
||||
// start waiting for commands in the main loop further down. (We also can't get the
|
||||
// reference to our `tagged_workers` element until the main proxy threads is running).
|
||||
|
||||
waiting_for_command = true;
|
||||
|
||||
if (!worker_wait_for(*this, sock, parts, worker_id, "START"sv))
|
||||
return;
|
||||
if (start) start();
|
||||
} else {
|
||||
// Otherwise for a regular worker we can only be started by an active main proxy thread
|
||||
// which will have preloaded our first job so we can start off right away.
|
||||
waiting_for_command = false;
|
||||
}
|
||||
|
||||
// This will always contains the current job, and is guaranteed to never be invalidated.
|
||||
run_info& run = tagged ? std::get<run_info>(tagged_workers[index - 1]) : workers[index];
|
||||
|
||||
while (true) {
|
||||
if (waiting_for_command) {
|
||||
if (!worker_wait_for(*this, sock, parts, worker_id, "RUN"sv))
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
if (run.is_batch_job) {
|
||||
auto* batch = std::get<detail::Batch*>(run.to_run);
|
||||
if (run.batch_jobno >= 0) {
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, "#", run.batch_jobno);
|
||||
run.batch->run_job(run.batch_jobno);
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", batch, "#", run.batch_jobno);
|
||||
batch->run_job(run.batch_jobno);
|
||||
} else if (run.batch_jobno == -1) {
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, " completion");
|
||||
run.batch->job_completion();
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", batch, " completion");
|
||||
batch->job_completion();
|
||||
}
|
||||
} else if (run.is_injected) {
|
||||
auto& func = std::get<std::function<void()>>(run.to_run);
|
||||
LMQ_TRACE("worker thread ", worker_id, " invoking injected command ", run.command);
|
||||
func();
|
||||
func = nullptr;
|
||||
} else {
|
||||
message.conn = run.conn;
|
||||
message.access = run.access;
|
||||
|
@ -50,7 +120,8 @@ void LokiMQ::worker_thread(unsigned int index) {
|
|||
|
||||
LMQ_TRACE("Got incoming command from ", message.remote, "/", message.conn, message.conn.route.empty() ? " (outgoing)" : " (incoming)");
|
||||
|
||||
if (run.callback->second /*is_request*/) {
|
||||
auto& [callback, is_request] = *std::get<const std::pair<CommandCallback, bool>*>(run.to_run);
|
||||
if (is_request) {
|
||||
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
|
||||
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
|
||||
message.data.emplace_back(it->data<char>(), it->size());
|
||||
|
@ -60,13 +131,13 @@ void LokiMQ::worker_thread(unsigned int index) {
|
|||
}
|
||||
|
||||
LMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
|
||||
run.callback->first(message);
|
||||
callback(message);
|
||||
}
|
||||
}
|
||||
catch (const bt_deserialize_invalid& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
|
||||
}
|
||||
catch (const mapbox::util::bad_variant_access& e) {
|
||||
catch (const std::bad_variant_access& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
|
||||
}
|
||||
catch (const std::out_of_range& e) {
|
||||
|
@ -79,32 +150,9 @@ void LokiMQ::worker_thread(unsigned int index) {
|
|||
LMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
// Signal that we are ready for another job and wait for it. (We do this down here
|
||||
// because our first job gets set up when the thread is started).
|
||||
detail::send_control(sock, "RAN");
|
||||
LMQ_TRACE("worker ", worker_id, " waiting for requests");
|
||||
parts.clear();
|
||||
recv_message_parts(sock, parts);
|
||||
|
||||
if (parts.size() != 1) {
|
||||
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part worker instruction");
|
||||
continue;
|
||||
}
|
||||
auto command = view(parts[0]);
|
||||
if (command == "RUN") {
|
||||
LMQ_LOG(debug, "worker ", worker_id, " running command ", run.command);
|
||||
break; // proxy has set up a command for us, go back and run it.
|
||||
} else if (command == "QUIT") {
|
||||
LMQ_LOG(debug, "worker ", worker_id, " shutting down");
|
||||
detail::send_control(sock, "QUITTING");
|
||||
sock.setsockopt<int>(ZMQ_LINGER, 1000);
|
||||
sock.close();
|
||||
return;
|
||||
} else {
|
||||
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
|
||||
}
|
||||
}
|
||||
// Tell the proxy thread that we are ready for another job
|
||||
detail::send_control(sock, "RAN");
|
||||
waiting_for_command = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,52 +180,73 @@ void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
|
|||
}
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
LMQ_TRACE("worker message from ", route);
|
||||
assert(route.size() >= 2 && route[0] == 'w' && route[1] >= '0' && route[1] <= '9');
|
||||
string_view worker_id_str{&route[1], route.size()-1}; // Chop off the leading "w"
|
||||
assert(route.size() >= 2 && (route[0] == 'w' || route[0] == 't') && route[1] >= '0' && route[1] <= '9');
|
||||
bool tagged_worker = route[0] == 't';
|
||||
std::string_view worker_id_str{&route[1], route.size()-1}; // Chop off the leading "w" (or "t")
|
||||
unsigned int worker_id = detail::extract_unsigned(worker_id_str);
|
||||
if (!worker_id_str.empty() /* didn't consume everything */ || worker_id >= workers.size()) {
|
||||
if (!worker_id_str.empty() /* didn't consume everything */ ||
|
||||
(tagged_worker
|
||||
? 0 == worker_id || worker_id > tagged_workers.size() // tagged worker ids are indexed from 1 to N (0 means untagged)
|
||||
: worker_id >= workers.size() // regular worker ids are indexed from 0 to N-1
|
||||
)
|
||||
) {
|
||||
LMQ_LOG(error, "Worker id '", route, "' is invalid, unable to process worker command");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = workers[worker_id];
|
||||
auto& run = tagged_worker ? std::get<run_info>(tagged_workers[worker_id - 1]) : workers[worker_id];
|
||||
|
||||
LMQ_TRACE("received ", cmd, " command from ", route);
|
||||
if (cmd == "RAN") {
|
||||
LMQ_LOG(debug, "Worker ", route, " finished ", run.command);
|
||||
if (cmd == "RAN"sv) {
|
||||
LMQ_TRACE("Worker ", route, " finished ", run.is_batch_job ? "batch job" : run.command);
|
||||
if (run.is_batch_job) {
|
||||
auto& jobs = run.is_reply_job ? reply_jobs : batch_jobs;
|
||||
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
|
||||
assert(active > 0);
|
||||
active--;
|
||||
if (tagged_worker) {
|
||||
std::get<bool>(tagged_workers[worker_id - 1]) = false;
|
||||
} else {
|
||||
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
|
||||
assert(active > 0);
|
||||
active--;
|
||||
}
|
||||
bool clear_job = false;
|
||||
auto* batch = std::get<detail::Batch*>(run.to_run);
|
||||
if (run.batch_jobno == -1) {
|
||||
// Returned from the completion function
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto status = run.batch->job_finished();
|
||||
if (status == detail::BatchStatus::complete) {
|
||||
jobs.emplace(run.batch, -1);
|
||||
} else if (status == detail::BatchStatus::complete_proxy) {
|
||||
try {
|
||||
run.batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
|
||||
} catch (const std::exception &e) {
|
||||
// Raise these to error levels: the caller really shouldn't be doing
|
||||
// anything non-trivial in an in-proxy completion function!
|
||||
LMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
|
||||
} catch (...) {
|
||||
LMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
|
||||
auto [state, thread] = batch->job_finished();
|
||||
if (state == detail::BatchState::complete) {
|
||||
if (thread == -1) { // run directly in proxy
|
||||
LMQ_TRACE("Completion job running directly in proxy");
|
||||
try {
|
||||
batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
|
||||
} catch (const std::exception &e) {
|
||||
// Raise these to error levels: the caller really shouldn't be doing
|
||||
// anything non-trivial in an in-proxy completion function!
|
||||
LMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
|
||||
} catch (...) {
|
||||
LMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
|
||||
}
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto& jobs =
|
||||
thread > 0
|
||||
? std::get<std::queue<batch_job>>(tagged_workers[thread - 1]) // run in tagged thread
|
||||
: run.is_reply_job
|
||||
? reply_jobs
|
||||
: batch_jobs;
|
||||
jobs.emplace(batch, -1);
|
||||
}
|
||||
clear_job = true;
|
||||
} else if (status == detail::BatchStatus::done) {
|
||||
} else if (state == detail::BatchState::done) {
|
||||
// No completion job
|
||||
clear_job = true;
|
||||
}
|
||||
// else the job is still running
|
||||
}
|
||||
|
||||
if (clear_job) {
|
||||
batches.erase(run.batch);
|
||||
delete run.batch;
|
||||
run.batch = nullptr;
|
||||
batches.erase(batch);
|
||||
delete batch;
|
||||
run.to_run = static_cast<detail::Batch*>(nullptr);
|
||||
}
|
||||
} else {
|
||||
assert(run.cat->active_threads > 0);
|
||||
|
@ -186,11 +255,11 @@ void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
|
|||
if (max_workers == 0) { // Shutting down
|
||||
LMQ_TRACE("Telling worker ", route, " to quit");
|
||||
route_control(workers_socket, route, "QUIT");
|
||||
} else {
|
||||
} else if (!tagged_worker) {
|
||||
idle_workers.push_back(worker_id);
|
||||
}
|
||||
} else if (cmd == "QUITTING") {
|
||||
workers[worker_id].worker_thread.join();
|
||||
} else if (cmd == "QUITTING"sv) {
|
||||
run.worker_thread.join();
|
||||
LMQ_LOG(debug, "Worker ", route, " exited normally");
|
||||
} else {
|
||||
LMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
|
||||
|
@ -199,7 +268,7 @@ void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
|
|||
|
||||
void LokiMQ::proxy_run_worker(run_info& run) {
|
||||
if (!run.worker_thread.joinable())
|
||||
run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id};
|
||||
run.worker_thread = std::thread{[this, id=run.worker_id] { worker_thread(id); }};
|
||||
else
|
||||
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
|
||||
}
|
||||
|
@ -306,5 +375,38 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
|
|||
category.active_threads++;
|
||||
}
|
||||
|
||||
void LokiMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function<void()> callback) {
|
||||
if (!callback) return;
|
||||
auto it = categories.find(category);
|
||||
if (it == categories.end())
|
||||
throw std::out_of_range{"Invalid category `" + category + "': category does not exist"};
|
||||
detail::send_control(get_control_socket(), "INJECT", bt_serialize(detail::serialize_object(
|
||||
injected_task{it->second, std::move(command), std::move(remote), std::move(callback)})));
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_inject_task(injected_task task) {
|
||||
auto& category = task.cat;
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free worker slot, queue for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
LMQ_LOG(warn, "No space to queue injected task ", task.command, "; already have ", category.queued,
|
||||
"commands queued in that category (max ", category.max_queue, "); dropping task");
|
||||
return;
|
||||
}
|
||||
LMQ_LOG(debug, "No available free workers for injected task ", task.command, "; queuing for later");
|
||||
pending_commands.emplace_back(category, std::move(task.command), std::move(task.callback), std::move(task.remote));
|
||||
category.queued++;
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = get_idle_worker();
|
||||
LMQ_TRACE("Forwarding incoming injected task ", task.command, " from ", task.remote, " to worker ", run.worker_routing_id);
|
||||
run.load(&category, std::move(task.command), std::move(task.remote), std::move(task.callback));
|
||||
|
||||
proxy_run_worker(run);
|
||||
category.active_threads++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Subproject commit c94634bbd294204c9ba3f5b267a39582a52e8e5a
|
|
@ -3,24 +3,26 @@ add_subdirectory(Catch2)
|
|||
|
||||
set(LMQ_TEST_SRC
|
||||
main.cpp
|
||||
test_address.cpp
|
||||
test_batch.cpp
|
||||
test_bt.cpp
|
||||
test_connect.cpp
|
||||
test_commands.cpp
|
||||
test_encoding.cpp
|
||||
test_failures.cpp
|
||||
test_inject.cpp
|
||||
test_requests.cpp
|
||||
test_string_view.cpp
|
||||
test_tagged_threads.cpp
|
||||
)
|
||||
|
||||
add_executable(tests ${LMQ_TEST_SRC})
|
||||
|
||||
find_package(Threads)
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(SODIUM REQUIRED libsodium)
|
||||
|
||||
target_link_libraries(tests Catch2::Catch2 lokimq ${SODIUM_LIBRARIES} Threads::Threads)
|
||||
target_link_libraries(tests Catch2::Catch2 lokimq Threads::Threads)
|
||||
|
||||
set_target_properties(tests PROPERTIES
|
||||
CXX_STANDARD 14
|
||||
CXX_STANDARD 17
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CXX_EXTENSIONS OFF
|
||||
)
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
#include "lokimq/address.h"
|
||||
#include "common.h"
|
||||
|
||||
const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88";
|
||||
const std::string pk_hex = "f16ba5591039f089b42a834175093094074d0d937a79e53e5ce730f946e14b88";
|
||||
const std::string pk_HEX = "F16BA5591039F089B42A834175093094074D0D937A79E53E5CE730F946E14B88";
|
||||
const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqry";
|
||||
const std::string pk_B32Z = "6FI4KSEO88AEUPBKOPYZKNJO1ODW4DCUXJH6KX1HHHAX1TZBJQRY";
|
||||
const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g"; // NB: padding '=' omitted
|
||||
|
||||
TEST_CASE("tcp addresses", "[address][tcp]") {
|
||||
address a{"tcp://1.2.3.4:5678"};
|
||||
REQUIRE( a.host == "1.2.3.4" );
|
||||
REQUIRE( a.port == 5678 );
|
||||
REQUIRE_FALSE( a.curve() );
|
||||
REQUIRE( a.tcp() );
|
||||
REQUIRE( a.zmq_address() == "tcp://1.2.3.4:5678" );
|
||||
REQUIRE( a.full_address() == "tcp://1.2.3.4:5678" );
|
||||
REQUIRE( a.qr_address() == "TCP://1.2.3.4:5678" );
|
||||
|
||||
REQUIRE_THROWS_AS( address{"tcp://1:1:1"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://localhost:123"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcp://abc"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://localhost:0"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://[::1:1080"}, std::invalid_argument );
|
||||
|
||||
address b = address::tcp("example.com", 80);
|
||||
REQUIRE( b.host == "example.com" );
|
||||
REQUIRE( b.port == 80 );
|
||||
REQUIRE_FALSE( b.curve() );
|
||||
REQUIRE( b.tcp() );
|
||||
REQUIRE( b.zmq_address() == "tcp://example.com:80" );
|
||||
REQUIRE( b.full_address() == "tcp://example.com:80" );
|
||||
REQUIRE( b.qr_address() == "TCP://EXAMPLE.COM:80" );
|
||||
|
||||
address c{"tcp://[::1]:1111"};
|
||||
REQUIRE( c.host == "[::1]" );
|
||||
REQUIRE( c.port == 1111 );
|
||||
}
|
||||
|
||||
TEST_CASE("unix sockets", "[address][ipc]") {
|
||||
address a{"ipc:///path/to/foo"};
|
||||
REQUIRE( a.socket == "/path/to/foo" );
|
||||
REQUIRE_FALSE( a.curve() );
|
||||
REQUIRE_FALSE( a.tcp() );
|
||||
REQUIRE( a.zmq_address() == "ipc:///path/to/foo" );
|
||||
REQUIRE( a.full_address() == "ipc:///path/to/foo" );
|
||||
|
||||
address b = address::ipc("../foo");
|
||||
REQUIRE( b.socket == "../foo" );
|
||||
REQUIRE_FALSE( b.curve() );
|
||||
REQUIRE_FALSE( b.tcp() );
|
||||
REQUIRE( b.zmq_address() == "ipc://../foo" );
|
||||
REQUIRE( b.full_address() == "ipc://../foo" );
|
||||
}
|
||||
|
||||
TEST_CASE("pubkey formats", "[address][curve][pubkey]") {
|
||||
address a{"tcp+curve://a:1/" + pk_hex};
|
||||
address b{"curve://a:1/" + pk_b32z};
|
||||
address c{"curve://a:1/" + pk_b64};
|
||||
address d{"CURVE://A:1/" + pk_B32Z};
|
||||
REQUIRE( a.curve() );
|
||||
REQUIRE( a.host == "a" );
|
||||
REQUIRE( a.port == 1 );
|
||||
REQUIRE((b.curve() && c.curve() && d.curve()));
|
||||
REQUIRE( a.pubkey == pk );
|
||||
REQUIRE( b.pubkey == pk );
|
||||
REQUIRE( c.pubkey == pk );
|
||||
REQUIRE( d.pubkey == pk );
|
||||
|
||||
address e{"ipc+curve://my.sock/" + pk_hex};
|
||||
address f{"ipc+curve://../my.sock/" + pk_b32z};
|
||||
address g{"ipc+curve:///my.sock/" + pk_B32Z};
|
||||
address h{"ipc+curve://./my.sock/" + pk_b64};
|
||||
REQUIRE( e.curve() );
|
||||
REQUIRE( e.ipc() );
|
||||
REQUIRE_FALSE( e.tcp() );
|
||||
REQUIRE((f.curve() && g.curve() && h.curve()));
|
||||
REQUIRE( e.socket == "my.sock" );
|
||||
REQUIRE( f.socket == "../my.sock" );
|
||||
REQUIRE( g.socket == "/my.sock" );
|
||||
REQUIRE( h.socket == "./my.sock" );
|
||||
REQUIRE( e.pubkey == pk );
|
||||
REQUIRE( f.pubkey == pk );
|
||||
REQUIRE( g.pubkey == pk );
|
||||
REQUIRE( h.pubkey == pk );
|
||||
|
||||
REQUIRE( d.full_address(address::encoding::hex) == "curve://a:1/" + pk_hex );
|
||||
REQUIRE( c.full_address(address::encoding::base32z) == "curve://a:1/" + pk_b32z );
|
||||
REQUIRE( b.full_address(address::encoding::BASE32Z) == "curve://a:1/" + pk_B32Z );
|
||||
REQUIRE( a.full_address(address::encoding::base64) == "curve://a:1/" + pk_b64 );
|
||||
|
||||
REQUIRE( h.full_address(address::encoding::hex) == "ipc+curve://./my.sock/" + pk_hex );
|
||||
REQUIRE( g.full_address(address::encoding::base32z) == "ipc+curve:///my.sock/" + pk_b32z );
|
||||
REQUIRE( f.full_address(address::encoding::BASE32Z) == "ipc+curve://../my.sock/" + pk_B32Z );
|
||||
REQUIRE( e.full_address(address::encoding::base64) == "ipc+curve://my.sock/" + pk_b64 );
|
||||
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_hex.substr(0, 63)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b32z.substr(0, 51)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_B32Z.substr(0, 51)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b64.substr(0, 42)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock"}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/"}, std::invalid_argument);
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("tcp QR-code friendly addresses", "[address][tcp][qr]") {
|
||||
address a{"tcp://public.loki.foundation:12345"};
|
||||
address a_qr{"TCP://PUBLIC.LOKI.FOUNDATION:12345"};
|
||||
address b{"tcp://PUBLIC.LOKI.FOUNDATION:12345"};
|
||||
REQUIRE( a == a_qr );
|
||||
REQUIRE( a != b );
|
||||
REQUIRE( a.host == "public.loki.foundation" );
|
||||
REQUIRE( a.qr_address() == "TCP://PUBLIC.LOKI.FOUNDATION:12345" );
|
||||
|
||||
address c = address::tcp_curve("public.loki.foundation", 12345, pk);
|
||||
REQUIRE( c.qr_address() == "CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z );
|
||||
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z} == c );
|
||||
// We don't produce with upper-case hex, but we accept it:
|
||||
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_HEX} == c );
|
||||
|
||||
// lower case not permitted: ▾
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATiON:12345/" + pk_B32Z}, std::invalid_argument);
|
||||
// also only accept upper-base base32z and hex:
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_b32z}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_hex}, std::invalid_argument);
|
||||
// don't accept base64 even if it's upper-case (because case-converting it changes the value)
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}, std::invalid_argument);
|
||||
}
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
#include "lokimq/bt_serialize.h"
|
||||
#include "common.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <limits>
|
||||
|
||||
TEST_CASE("bt basic value serialization", "[bt][serialization]") {
|
||||
int x = 42;
|
||||
std::string x_ = bt_serialize(x);
|
||||
REQUIRE( bt_serialize(x) == "i42e" );
|
||||
|
||||
int64_t ibig = -8'000'000'000'000'000'000LL;
|
||||
uint64_t ubig = 10'000'000'000'000'000'000ULL;
|
||||
REQUIRE( bt_serialize(ibig) == "i-8000000000000000000e" );
|
||||
REQUIRE( bt_serialize(std::numeric_limits<int64_t>::min()) == "i-9223372036854775808e" );
|
||||
REQUIRE( bt_serialize(ubig) == "i10000000000000000000e" );
|
||||
REQUIRE( bt_serialize(std::numeric_limits<uint64_t>::max()) == "i18446744073709551615e" );
|
||||
|
||||
std::unordered_map<std::string, int> m;
|
||||
m["hi"] = 123;
|
||||
m["omg"] = -7890;
|
||||
m["bye"] = 456;
|
||||
m["zap"] = 0;
|
||||
// bt values are always sorted:
|
||||
REQUIRE( bt_serialize(m) == "d3:byei456e2:hii123e3:omgi-7890e3:zapi0ee" );
|
||||
|
||||
// Dict-like list serializes as a dict (and get sorted, as above)
|
||||
std::list<std::pair<std::string, std::string>> d{{
|
||||
{"c", "x"},
|
||||
{"a", "z"},
|
||||
{"b", "y"},
|
||||
}};
|
||||
REQUIRE( bt_serialize(d) == "d1:a1:z1:b1:y1:c1:xe" );
|
||||
|
||||
std::vector<std::string> v{{"a", "", "\x00"s, "\x00\x00\x00goo"s}};
|
||||
REQUIRE( bt_serialize(v) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
|
||||
|
||||
std::array v2 = {"a"sv, ""sv, "\x00"sv, "\x00\x00\x00goo"sv};
|
||||
REQUIRE( bt_serialize(v2) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
|
||||
}
|
||||
|
||||
TEST_CASE("bt nested value serialization", "[bt][serialization]") {
|
||||
std::unordered_map<std::string, std::list<std::map<std::string, std::set<int>>>> x{{
|
||||
{"foo", {{{"a", {1,2,3}}, {"b", {}}}, {{"c", {4,-5}}}}},
|
||||
{"bar", {}}
|
||||
}};
|
||||
REQUIRE( bt_serialize(x) == "d3:barle3:foold1:ali1ei2ei3ee1:bleed1:cli-5ei4eeeee" );
|
||||
}
|
||||
|
||||
TEST_CASE("bt basic value deserialization", "[bt][deserialization]") {
|
||||
REQUIRE( bt_deserialize<int>("i42e") == 42 );
|
||||
|
||||
int64_t ibig = -8'000'000'000'000'000'000LL;
|
||||
uint64_t ubig = 10'000'000'000'000'000'000ULL;
|
||||
REQUIRE( bt_deserialize<int64_t>("i-8000000000000000000e") == ibig );
|
||||
REQUIRE( bt_deserialize<uint64_t>("i10000000000000000000e") == ubig );
|
||||
REQUIRE( bt_deserialize<int64_t>("i-9223372036854775808e") == std::numeric_limits<int64_t>::min() );
|
||||
REQUIRE( bt_deserialize<uint64_t>("i18446744073709551615e") == std::numeric_limits<uint64_t>::max() );
|
||||
REQUIRE( bt_deserialize<uint32_t>("i4294967295e") == std::numeric_limits<uint32_t>::max() );
|
||||
|
||||
REQUIRE_THROWS( bt_deserialize<int64_t>("i-9223372036854775809e") );
|
||||
REQUIRE_THROWS( bt_deserialize<uint64_t>("i-1e") );
|
||||
REQUIRE_THROWS( bt_deserialize<uint32_t>("i4294967296e") );
|
||||
|
||||
std::unordered_map<std::string, int> m;
|
||||
m["hi"] = 123;
|
||||
m["omg"] = -7890;
|
||||
m["bye"] = 456;
|
||||
m["zap"] = 0;
|
||||
// bt values are always sorted:
|
||||
REQUIRE( bt_deserialize<std::unordered_map<std::string, int>>("d3:byei456e2:hii123e3:omgi-7890e3:zapi0ee") == m );
|
||||
|
||||
// Dict-like list can be used for deserialization
|
||||
std::list<std::pair<std::string, std::string>> d{{
|
||||
{"a", "z"},
|
||||
{"b", "y"},
|
||||
{"c", "x"},
|
||||
}};
|
||||
REQUIRE( bt_deserialize<std::list<std::pair<std::string, std::string>>>("d1:a1:z1:b1:y1:c1:xe") == d );
|
||||
|
||||
std::vector<std::string> v{{"a", "", "\x00"s, "\x00\x00\x00goo"s}};
|
||||
REQUIRE( bt_deserialize<std::vector<std::string>>("l1:a0:1:\0006:\x00\x00\x00gooe"sv) == v );
|
||||
|
||||
std::vector v2 = {"a"sv, ""sv, "\x00"sv, "\x00\x00\x00goo"sv};
|
||||
REQUIRE( bt_deserialize<decltype(v2)>("l1:a0:1:\0006:\x00\x00\x00gooe"sv) == v2 );
|
||||
}
|
||||
|
||||
TEST_CASE("bt_value serialization", "[bt][serialization][bt_value]") {
|
||||
bt_value dna{42};
|
||||
std::string x_ = bt_serialize(dna);
|
||||
REQUIRE( bt_serialize(dna) == "i42e" );
|
||||
|
||||
bt_value foo{"foo"};
|
||||
REQUIRE( bt_serialize(foo) == "3:foo" );
|
||||
|
||||
bt_value ibig{-8'000'000'000'000'000'000LL};
|
||||
bt_value ubig{10'000'000'000'000'000'000ULL};
|
||||
int16_t ismall = -123;
|
||||
uint16_t usmall = 123;
|
||||
bt_dict nums{
|
||||
{"a", 0},
|
||||
{"b", -8'000'000'000'000'000'000LL},
|
||||
{"c", 10'000'000'000'000'000'000ULL},
|
||||
{"d", ismall},
|
||||
{"e", usmall},
|
||||
};
|
||||
|
||||
REQUIRE( bt_serialize(ibig) == "i-8000000000000000000e" );
|
||||
REQUIRE( bt_serialize(ubig) == "i10000000000000000000e" );
|
||||
REQUIRE( bt_serialize(nums) == "d1:ai0e1:bi-8000000000000000000e1:ci10000000000000000000e1:di-123e1:ei123ee" );
|
||||
|
||||
// Same as nested test, above, but with bt_* types
|
||||
bt_dict x{{
|
||||
{"foo", bt_list{{bt_dict{{ {"a", bt_list{{1,2,3}}}, {"b", bt_list{}}}}, bt_dict{{{"c", bt_list{{-5, 4}}}}}}}},
|
||||
{"bar", bt_list{}}
|
||||
}};
|
||||
REQUIRE( bt_serialize(x) == "d3:barle3:foold1:ali1ei2ei3ee1:bleed1:cli-5ei4eeeee" );
|
||||
std::vector<std::string> v{{"a", "", "\x00"s, "\x00\x00\x00goo"s}};
|
||||
REQUIRE( bt_serialize(v) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
|
||||
|
||||
std::array v2 = {"a"sv, ""sv, "\x00"sv, "\x00\x00\x00goo"sv};
|
||||
REQUIRE( bt_serialize(v2) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
|
||||
}
|
||||
|
||||
TEST_CASE("bt_value deserialization", "[bt][deserialization][bt_value]") {
|
||||
auto dna1 = bt_deserialize<bt_value>("i42e");
|
||||
auto dna2 = bt_deserialize<bt_value>("i-42e");
|
||||
REQUIRE( std::get<uint64_t>(dna1) == 42 );
|
||||
REQUIRE( std::get<int64_t>(dna2) == -42 );
|
||||
REQUIRE_THROWS( std::get<int64_t>(dna1) );
|
||||
REQUIRE_THROWS( std::get<uint64_t>(dna2) );
|
||||
REQUIRE( lokimq::get_int<int>(dna1) == 42 );
|
||||
REQUIRE( lokimq::get_int<int>(dna2) == -42 );
|
||||
REQUIRE( lokimq::get_int<unsigned>(dna1) == 42 );
|
||||
REQUIRE_THROWS( lokimq::get_int<unsigned>(dna2) );
|
||||
|
||||
bt_value x = bt_deserialize<bt_value>("d3:barle3:foold1:ali1ei2ei3ee1:bleed1:cli-5ei4eeeee");
|
||||
REQUIRE( std::holds_alternative<bt_dict>(x) );
|
||||
bt_dict& a = std::get<bt_dict>(x);
|
||||
REQUIRE( a.count("bar") );
|
||||
REQUIRE( a.count("foo") );
|
||||
REQUIRE( a.size() == 2 );
|
||||
bt_list& foo = std::get<bt_list>(a["foo"]);
|
||||
REQUIRE( foo.size() == 2 );
|
||||
bt_dict& foo1 = std::get<bt_dict>(foo.front());
|
||||
bt_dict& foo2 = std::get<bt_dict>(foo.back());
|
||||
REQUIRE( foo1.size() == 2 );
|
||||
REQUIRE( foo2.size() == 1 );
|
||||
bt_list& foo1a = std::get<bt_list>(foo1.at("a"));
|
||||
bt_list& foo1b = std::get<bt_list>(foo1.at("b"));
|
||||
bt_list& foo2c = std::get<bt_list>(foo2.at("c"));
|
||||
std::list<int> foo1a_vals, foo1b_vals, foo2c_vals;
|
||||
for (auto& v : foo1a) foo1a_vals.push_back(lokimq::get_int<int>(v));
|
||||
for (auto& v : foo1b) foo1b_vals.push_back(lokimq::get_int<int>(v));
|
||||
for (auto& v : foo2c) foo2c_vals.push_back(lokimq::get_int<int>(v));
|
||||
REQUIRE( foo1a_vals == std::list{{1,2,3}} );
|
||||
REQUIRE( foo1b_vals == std::list<int>{} );
|
||||
REQUIRE( foo2c_vals == std::list{{-5, 4}} );
|
||||
|
||||
REQUIRE( std::get<bt_list>(a.at("bar")).empty() );
|
||||
}
|
||||
|
||||
TEST_CASE("bt tuple serialization", "[bt][tuple][serialization]") {
|
||||
// Deserializing directly into a tuple:
|
||||
std::tuple<int, std::string, std::vector<int>> x{42, "hi", {{1,2,3,4,5}}};
|
||||
REQUIRE( bt_serialize(x) == "li42e2:hili1ei2ei3ei4ei5eee" );
|
||||
|
||||
using Y = std::tuple<std::string, std::string, std::unordered_map<std::string, int>>;
|
||||
REQUIRE( bt_deserialize<Y>("l5:hello3:omgd1:ai1e1:bi2eee")
|
||||
== Y{"hello", "omg", {{"a",1}, {"b",2}}} );
|
||||
|
||||
using Z = std::tuple<std::tuple<int, std::string, std::string>, std::pair<int, int>>;
|
||||
Z z{{3, "abc", "def"}, {4, 5}};
|
||||
REQUIRE( bt_serialize(z) == "lli3e3:abc3:defeli4ei5eee" );
|
||||
REQUIRE( bt_deserialize<Z>("lli6e3:ghi3:jkleli7ei8eee") == Z{{6, "ghi", "jkl"}, {7, 8}} );
|
||||
|
||||
using W = std::pair<std::string, std::pair<int, unsigned>>;
|
||||
REQUIRE( bt_serialize(W{"zzzzzzzzzz", {42, 42}}) == "l10:zzzzzzzzzzli42ei42eee" );
|
||||
|
||||
REQUIRE_THROWS( bt_deserialize<std::tuple<int>>("li1e") ); // missing closing e
|
||||
REQUIRE_THROWS( bt_deserialize<std::pair<int, int>>("li1ei-4e") ); // missing closing e
|
||||
REQUIRE_THROWS( bt_deserialize<std::tuple<int>>("li1ei2ee") ); // too many elements
|
||||
REQUIRE_THROWS( bt_deserialize<std::pair<int, int>>("li1ei-2e0:e") ); // too many elements
|
||||
REQUIRE_THROWS( bt_deserialize<std::tuple<int, int>>("li1ee") ); // too few elements
|
||||
REQUIRE_THROWS( bt_deserialize<std::pair<int, int>>("li1ee") ); // too few elements
|
||||
REQUIRE_THROWS( bt_deserialize<std::tuple<std::string>>("li1ee") ); // wrong element type
|
||||
REQUIRE_THROWS( bt_deserialize<std::pair<int, std::string>>("li1ei8ee") ); // wrong element type
|
||||
REQUIRE_THROWS( bt_deserialize<std::pair<int, std::string>>("l1:x1:xe") ); // wrong element type
|
||||
|
||||
// Converting from a generic bt_value/bt_list:
|
||||
bt_value a = bt_get("l5:hello3:omgi12345ee");
|
||||
using V1 = std::tuple<std::string, std::string_view, uint16_t>;
|
||||
REQUIRE( get_tuple<V1>(a) == V1{"hello", "omg"sv, 12345} );
|
||||
|
||||
bt_value b = bt_get("l5:hellod1:ai1e1:bi2eee");
|
||||
using V2 = std::pair<std::string_view, bt_dict>;
|
||||
REQUIRE( get_tuple<V2>(b) == V2{"hello", {{"a",1U}, {"b",2U}}} );
|
||||
|
||||
bt_value c = bt_get("l5:helloi-4ed1:ai-1e1:bi-2eee");
|
||||
using V3 = std::tuple<std::string, int, bt_dict>;
|
||||
REQUIRE( get_tuple<V3>(c) == V3{"hello", -4, {{"a",-1}, {"b",-2}}} );
|
||||
|
||||
REQUIRE_THROWS( get_tuple<V1>(bt_get("l5:hello3:omge")) ); // too few
|
||||
REQUIRE_THROWS( get_tuple<V1>(bt_get("l5:hello3:omgi1ei1ee")) ); // too many
|
||||
REQUIRE_THROWS( get_tuple<V1>(bt_get("l5:helloi1ei1ee")) ); // wrong type
|
||||
|
||||
// Construct a bt_value from tuples:
|
||||
bt_value l{std::make_tuple(3, 4, "hi"sv)};
|
||||
REQUIRE( bt_serialize(l) == "li3ei4e2:hie" );
|
||||
bt_list m{{1, 2, std::make_tuple(3, 4, "hi"sv), std::make_pair("foo"s, "bar"sv), -4}};
|
||||
REQUIRE( bt_serialize(m) == "li1ei2eli3ei4e2:hiel3:foo3:barei-4ee" );
|
||||
|
||||
// Consumer deserialization:
|
||||
bt_list_consumer lc{"li1ei2eli3ei4e2:hiel3:foo3:barei-4ee"};
|
||||
REQUIRE( lc.consume_integer<int>() == 1 );
|
||||
REQUIRE( lc.consume_integer<int>() == 2 );
|
||||
REQUIRE( lc.consume_list<std::tuple<int, int, std::string>>() == std::make_tuple(3, 4, "hi"s) );
|
||||
REQUIRE( lc.consume_list<std::pair<std::string_view, std::string_view>>() == std::make_pair("foo"sv, "bar"sv) );
|
||||
REQUIRE( lc.consume_integer<int>() == -4 );
|
||||
|
||||
bt_dict_consumer dc{"d1:Ai0e1:ali1e3:omge1:bli1ei2ei3eee"};
|
||||
REQUIRE( dc.key() == "A" );
|
||||
REQUIRE( dc.skip_until("a") );
|
||||
REQUIRE( dc.next_list<std::pair<int8_t, std::string_view>>() ==
|
||||
std::make_pair("a"sv, std::make_pair(int8_t{1}, "omg"sv)) );
|
||||
REQUIRE( dc.next_list<std::tuple<int, int, int>>() ==
|
||||
std::make_pair("b"sv, std::make_tuple(1, 2, 3)) );
|
||||
}
|
||||
|
||||
#if 0
|
||||
{
|
||||
std::cout << "zomg consumption\n";
|
||||
bt_dict_consumer dc{zomg_};
|
||||
for (int i = 0; i < 5; i++)
|
||||
if (!dc.skip_until("b"))
|
||||
throw std::runtime_error("Couldn't find b, but I know it's there!");
|
||||
|
||||
auto dc1 = dc;
|
||||
if (dc.skip_until("z")) {
|
||||
auto v = dc.consume_integer<int>();
|
||||
std::cout << " - " << v.first << ": " << v.second << "\n";
|
||||
} else {
|
||||
std::cout << " - no z (bad!)\n";
|
||||
}
|
||||
|
||||
std::cout << "zomg (second pass)\n";
|
||||
for (auto &p : dc1.consume_dict().second) {
|
||||
std::cout << " - " << p.first << " = (whatever)\n";
|
||||
}
|
||||
while (dc1) {
|
||||
auto v = dc1.consume_integer<int>();
|
||||
std::cout << " - " << v.first << ": " << v.second << "\n";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -41,10 +41,9 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
bool success = false, failed = false;
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
|
||||
[&](auto conn, string_view) { failed = true; got = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto conn, std::string_view) { failed = true; got = true; });
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
|
@ -107,9 +106,10 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
|
||||
client.PUBKEY_BASED_ROUTING_ID = false; // establishing multiple connections below, so we need unique routing ids
|
||||
|
||||
auto public_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey());
|
||||
auto basic_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::basic);
|
||||
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
|
||||
address server_addr{listen, server.get_pubkey()};
|
||||
auto public_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {});
|
||||
auto basic_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::basic);
|
||||
auto admin_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::admin);
|
||||
|
||||
client.send(public_c, "public.reflect", "public.hi");
|
||||
wait_for([&] { return public_hi == 1; });
|
||||
|
@ -193,8 +193,8 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
|
||||
server.start();
|
||||
|
||||
auto connect_success = [&](...) { auto l = catch_lock(); REQUIRE(true); };
|
||||
auto connect_failure = [&](...) { auto l = catch_lock(); REQUIRE(false); };
|
||||
auto connect_success = [&](auto&&...) { auto l = catch_lock(); REQUIRE(true); };
|
||||
auto connect_failure = [&](auto&&...) { auto l = catch_lock(); REQUIRE(false); };
|
||||
|
||||
|
||||
std::set<std::string> backdoor_details;
|
||||
|
@ -202,10 +202,11 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
LokiMQ nsa{get_logger("NSA» ")};
|
||||
nsa.add_category("backdoor", Access{AuthLevel::admin});
|
||||
nsa.add_command("backdoor", "data", [&](Message& m) {
|
||||
backdoor_details.emplace(m.data[0]);
|
||||
auto l = catch_lock();
|
||||
backdoor_details.emplace(m.data[0]);
|
||||
});
|
||||
nsa.start();
|
||||
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
|
||||
auto nsa_c = nsa.connect_remote(address{listen, server.get_pubkey()}, connect_success, connect_failure, AuthLevel::admin);
|
||||
nsa.send(nsa_c, "hey google.install backdoor");
|
||||
|
||||
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
|
||||
|
@ -226,6 +227,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
std::set<std::string> all_the_things;
|
||||
for (auto& pd : personal_details) all_the_things.insert(pd.second.begin(), pd.second.end());
|
||||
|
||||
address server_addr{listen, server.get_pubkey()};
|
||||
std::map<int, std::set<std::string>> google_knows;
|
||||
int things_remembered{0};
|
||||
for (int i = 0; i < 5; i++) {
|
||||
|
@ -240,7 +242,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
});
|
||||
c->start();
|
||||
conns.push_back(
|
||||
c->connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::basic));
|
||||
c->connect_remote(server_addr, connect_success, connect_failure, AuthLevel::basic));
|
||||
for (auto& personal_detail : personal_details[i])
|
||||
c->request(conns.back(), "hey google.remember",
|
||||
[&](bool success, std::vector<std::string> data) {
|
||||
|
@ -252,7 +254,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
},
|
||||
personal_detail);
|
||||
}
|
||||
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size(); });
|
||||
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size() && things_remembered == backdoor_details.size(); });
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( things_remembered == all_the_things.size() );
|
||||
|
@ -303,10 +305,11 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
|||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
string_view hello_sv{hello.data<char>(), hello.size()};
|
||||
auto recvd = client.recv(hello);
|
||||
std::string_view hello_sv{hello.data<char>(), hello.size()};
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello_sv == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -359,3 +362,78 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
|||
REQUIRE( send_failures.load() > 0 );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("data parts", "[send][data_parts]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::mutex mut;
|
||||
std::vector<std::string> r;
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_command("public", "hello", [&](Message& m) {
|
||||
std::lock_guard l{mut};
|
||||
for (const auto& s : m.data)
|
||||
r.emplace_back(s);
|
||||
});
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false, failed = false;
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
|
||||
[&](auto conn, std::string_view) { failed = true; got = true; });
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::vector some_data{{"abc"s, "def"s, "omg123\0zzz"s}};
|
||||
client.send(c, "public.hello", lokimq::send_option::data_parts(some_data.begin(), some_data.end()));
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
std::lock_guard l{mut};
|
||||
REQUIRE( r == some_data );
|
||||
r.clear();
|
||||
}
|
||||
|
||||
std::vector some_data2{{"a"sv, "b"sv, "\0"sv}};
|
||||
client.send(c, "public.hello",
|
||||
"hi",
|
||||
lokimq::send_option::data_parts(some_data2.begin(), some_data2.end()),
|
||||
"another",
|
||||
"string"sv,
|
||||
lokimq::send_option::data_parts(some_data.begin(), some_data.end()));
|
||||
|
||||
std::vector<std::string> expected;
|
||||
expected.push_back("hi");
|
||||
expected.insert(expected.end(), some_data2.begin(), some_data2.end());
|
||||
expected.push_back("another");
|
||||
expected.push_back("string");
|
||||
expected.insert(expected.end(), some_data.begin(), some_data.end());
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
std::lock_guard l{mut};
|
||||
REQUIRE( r == expected );
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,10 +27,9 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") {
|
|||
auto pubkey = server.get_pubkey();
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
auto server_conn = client.connect_remote(listen,
|
||||
auto server_conn = client.connect_remote(address{listen, pubkey},
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; },
|
||||
pubkey);
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; });
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
|
@ -53,6 +52,7 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
|||
std::string pubkey, privkey;
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
REQUIRE(sodium_init() != -1);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
LokiMQ sn{
|
||||
pubkey, privkey,
|
||||
|
@ -108,7 +108,7 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") {
|
|||
bool success = false;
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
|
||||
);
|
||||
|
||||
wait_for_conn(got);
|
||||
|
@ -150,11 +150,11 @@ TEST_CASE("unique connection IDs", "[connect][id]") {
|
|||
std::atomic<bool> good1{false}, good2{false};
|
||||
auto r1 = client1.connect_remote(listen,
|
||||
[&](auto conn) { good1 = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
);
|
||||
auto r2 = client2.connect_remote(listen,
|
||||
[&](auto conn) { good2 = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
);
|
||||
|
||||
wait_for_conn(good1);
|
||||
|
@ -188,6 +188,7 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
|||
std::vector<std::unique_ptr<LokiMQ>> lmq;
|
||||
std::vector<std::string> pubkey, privkey;
|
||||
std::unordered_map<std::string, std::string> conn;
|
||||
REQUIRE(sodium_init() != -1);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
pubkey.emplace_back();
|
||||
privkey.emplace_back();
|
||||
|
@ -234,6 +235,7 @@ TEST_CASE("SN auth checks", "[sandwich][auth]") {
|
|||
std::string pubkey, privkey;
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
REQUIRE(sodium_init() != -1);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
LokiMQ server{
|
||||
pubkey, privkey,
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
#include "lokimq/hex.h"
|
||||
#include "lokimq/base32z.h"
|
||||
#include "lokimq/base64.h"
|
||||
#include "common.h"
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88";
|
||||
const std::string pk_hex = "f16ba5591039f089b42a834175093094074d0d937a79e53e5ce730f946e14b88";
|
||||
const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqry";
|
||||
const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g=";
|
||||
|
||||
TEST_CASE("hex encoding/decoding", "[encoding][decoding][hex]") {
|
||||
REQUIRE( lokimq::to_hex("\xff\x42\x12\x34") == "ff421234"s );
|
||||
std::vector<uint8_t> chars{{1, 10, 100, 254}};
|
||||
std::array<uint8_t, 8> out;
|
||||
std::array<uint8_t, 8> expected{{'0', '1', '0', 'a', '6', '4', 'f', 'e'}};
|
||||
lokimq::to_hex(chars.begin(), chars.end(), out.begin());
|
||||
REQUIRE( out == expected );
|
||||
|
||||
REQUIRE( lokimq::to_hex(chars.begin(), chars.end()) == "010a64fe" );
|
||||
|
||||
REQUIRE( lokimq::from_hex("12345678ffEDbca9") == "\x12\x34\x56\x78\xff\xed\xbc\xa9"s );
|
||||
|
||||
REQUIRE( lokimq::is_hex("1234567890abcdefABCDEF1234567890abcdefABCDEF") );
|
||||
REQUIRE_FALSE( lokimq::is_hex("1234567890abcdefABCDEF1234567890aGcdefABCDEF") );
|
||||
REQUIRE_FALSE( lokimq::is_hex("1234567890abcdefABCDEF1234567890agcdefABCDEF") );
|
||||
REQUIRE_FALSE( lokimq::is_hex("\x11\xff") );
|
||||
|
||||
REQUIRE( lokimq::from_hex(pk_hex) == pk );
|
||||
REQUIRE( lokimq::to_hex(pk) == pk_hex );
|
||||
|
||||
REQUIRE( lokimq::from_hex(pk_hex.begin(), pk_hex.end()) == pk );
|
||||
|
||||
std::vector<std::byte> bytes{{std::byte{0xff}, std::byte{0x42}, std::byte{0x12}, std::byte{0x34}}};
|
||||
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
|
||||
REQUIRE( lokimq::to_hex(b) == "ff421234"s );
|
||||
|
||||
bytes.resize(8);
|
||||
bytes[0] = std::byte{'f'}; bytes[1] = std::byte{'f'}; bytes[2] = std::byte{'4'}; bytes[3] = std::byte{'2'};
|
||||
bytes[4] = std::byte{'1'}; bytes[5] = std::byte{'2'}; bytes[6] = std::byte{'3'}; bytes[7] = std::byte{'4'};
|
||||
std::basic_string_view<std::byte> hex_bytes{bytes.data(), bytes.size()};
|
||||
REQUIRE( lokimq::is_hex(hex_bytes) );
|
||||
REQUIRE( lokimq::from_hex(hex_bytes) == "\xff\x42\x12\x34" );
|
||||
}
|
||||
|
||||
TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") {
|
||||
REQUIRE( lokimq::to_base32z("\0\0\0\0\0"s) == "yyyyyyyy" );
|
||||
REQUIRE( lokimq::to_base32z("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv)
|
||||
== "yrtwk3hjixg66yjdeiuauk6p7hy1gtm8tgih55abrpnsxnpm3zzo");
|
||||
|
||||
REQUIRE( lokimq::from_base32z("yrtwk3hjixg66yjdeiuauk6p7hy1gtm8tgih55abrpnsxnpm3zzo")
|
||||
== "\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv);
|
||||
|
||||
REQUIRE( lokimq::from_base32z("YRTWK3HJIXG66YJDEIUAUK6P7HY1GTM8TGIH55ABRPNSXNPM3ZZO")
|
||||
== "\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv);
|
||||
|
||||
auto five_nulls = lokimq::from_base32z("yyyyyyyy");
|
||||
REQUIRE( five_nulls.size() == 5 );
|
||||
REQUIRE( five_nulls == "\0\0\0\0\0"s );
|
||||
|
||||
// 00000 00001 00010 00011 00100 00101 00110 00111
|
||||
// ==
|
||||
// 00000000 01000100 00110010 00010100 11000111
|
||||
REQUIRE( lokimq::from_base32z("ybndrfg8") == "\x00\x44\x32\x14\xc7"s );
|
||||
|
||||
// Special case 1: 7 base32z digits with 3 trailing 0 bits -> 4 bytes (the trailing 0s are dropped)
|
||||
// 00000 00001 00010 00011 00100 00101 11000
|
||||
// ==
|
||||
// 00000000 01000100 00110010 00010111
|
||||
REQUIRE( lokimq::from_base32z("ybndrfa") == "\x00\x44\x32\x17"s );
|
||||
|
||||
// Round-trip it:
|
||||
REQUIRE( lokimq::from_base32z(lokimq::to_base32z("\x00\x44\x32\x17"sv)) == "\x00\x44\x32\x17"sv );
|
||||
REQUIRE( lokimq::to_base32z(lokimq::from_base32z("ybndrfa")) == "ybndrfa" );
|
||||
|
||||
// Special case 2: 7 base32z digits with 3 trailing bits 010; we just ignore the trailing stuff,
|
||||
// as if it was specified as 0. (The last digit here is 11010 instead of 11000).
|
||||
REQUIRE( lokimq::from_base32z("ybndrf4") == "\x00\x44\x32\x17"s );
|
||||
// This one won't round-trip to the same value since it has ignored garbage bytes at the end
|
||||
REQUIRE( lokimq::to_base32z(lokimq::from_base32z("ybndrf4"s)) == "ybndrfa" );
|
||||
|
||||
REQUIRE( lokimq::to_base32z(pk) == pk_b32z );
|
||||
REQUIRE( lokimq::to_base32z(pk.begin(), pk.end()) == pk_b32z );
|
||||
REQUIRE( lokimq::from_base32z(pk_b32z) == pk );
|
||||
REQUIRE( lokimq::from_base32z(pk_b32z.begin(), pk_b32z.end()) == pk );
|
||||
|
||||
std::string pk_b32z_again, pk_again;
|
||||
lokimq::to_base32z(pk.begin(), pk.end(), std::back_inserter(pk_b32z_again));
|
||||
lokimq::from_base32z(pk_b32z.begin(), pk_b32z.end(), std::back_inserter(pk_again));
|
||||
REQUIRE( pk_b32z_again == pk_b32z );
|
||||
REQUIRE( pk_again == pk );
|
||||
|
||||
std::vector<std::byte> bytes{{std::byte{0}, std::byte{255}}};
|
||||
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
|
||||
REQUIRE( lokimq::to_base32z(b) == "yd9o" );
|
||||
|
||||
bytes.resize(4);
|
||||
bytes[0] = std::byte{'y'}; bytes[1] = std::byte{'d'}; bytes[2] = std::byte{'9'}; bytes[3] = std::byte{'o'};
|
||||
std::basic_string_view<std::byte> b32_bytes{bytes.data(), bytes.size()};
|
||||
REQUIRE( lokimq::is_base32z(b32_bytes) );
|
||||
REQUIRE( lokimq::from_base32z(b32_bytes) == "\x00\xff"sv );
|
||||
}
|
||||
|
||||
TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") {
|
||||
// 00000000 00000000 00000000 -> 000000 000000 000000 000000
|
||||
REQUIRE( lokimq::to_base64("\0\0\0"s) == "AAAA" );
|
||||
// 00000001 00000002 00000003 -> 000000 010000 000200 000003
|
||||
REQUIRE( lokimq::to_base64("\x01\x02\x03"s) == "AQID" );
|
||||
REQUIRE( lokimq::to_base64("\0\0\0\0"s) == "AAAAAA==" );
|
||||
// 00000000 00000000 00000000 11111111 ->
|
||||
// 000000 000000 000000 000000 111111 110000 (pad) (pad)
|
||||
REQUIRE( lokimq::to_base64("a") == "YQ==" );
|
||||
REQUIRE( lokimq::to_base64("ab") == "YWI=" );
|
||||
REQUIRE( lokimq::to_base64("abc") == "YWJj" );
|
||||
REQUIRE( lokimq::to_base64("abcd") == "YWJjZA==" );
|
||||
REQUIRE( lokimq::to_base64("abcde") == "YWJjZGU=" );
|
||||
REQUIRE( lokimq::to_base64("abcdef") == "YWJjZGVm" );
|
||||
|
||||
REQUIRE( lokimq::to_base64("\0\0\0\xff"s) == "AAAA/w==" );
|
||||
REQUIRE( lokimq::to_base64("\0\0\0\xff\xff"s) == "AAAA//8=" );
|
||||
REQUIRE( lokimq::to_base64("\0\0\0\xff\xff\xff"s) == "AAAA////" );
|
||||
REQUIRE( lokimq::to_base64(
|
||||
"Man is distinguished, not only by his reason, but by this singular passion from other "
|
||||
"animals, which is a lust of the mind, that by a perseverance of delight in the "
|
||||
"continued and indefatigable generation of knowledge, exceeds the short vehemence of "
|
||||
"any carnal pleasure.")
|
||||
==
|
||||
"TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlz"
|
||||
"IHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2Yg"
|
||||
"dGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0aGUgY29udGlu"
|
||||
"dWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRo"
|
||||
"ZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=" );
|
||||
|
||||
REQUIRE( lokimq::from_base64("A+/A") == "\x03\xef\xc0" );
|
||||
REQUIRE( lokimq::from_base64("YWJj") == "abc" );
|
||||
REQUIRE( lokimq::from_base64("YWJjZA==") == "abcd" );
|
||||
REQUIRE( lokimq::from_base64("YWJjZA") == "abcd" );
|
||||
REQUIRE( lokimq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits
|
||||
REQUIRE( lokimq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits
|
||||
REQUIRE( lokimq::from_base64("YWJj+") == "abc" ); // ignore superfluous bits
|
||||
REQUIRE( lokimq::from_base64("YWJjZGU=") == "abcde" );
|
||||
REQUIRE( lokimq::from_base64("YWJjZGU") == "abcde" );
|
||||
REQUIRE( lokimq::from_base64("YWJjZGVm") == "abcdef" );
|
||||
|
||||
REQUIRE( lokimq::is_base64("YWJjZGVm") );
|
||||
REQUIRE( lokimq::is_base64("YWJjZGU") );
|
||||
REQUIRE( lokimq::is_base64("YWJjZGU=") );
|
||||
REQUIRE( lokimq::is_base64("YWJjZA==") );
|
||||
REQUIRE( lokimq::is_base64("YWJjZA") );
|
||||
REQUIRE( lokimq::is_base64("YWJjZB") ); // not really valid, but we explicitly accept it
|
||||
|
||||
REQUIRE_FALSE( lokimq::is_base64("YWJjZ=") ); // invalid padding (padding can only be 4th or 3rd+4th of a 4-char block)
|
||||
REQUIRE_FALSE( lokimq::is_base64("YWJj=") );
|
||||
REQUIRE_FALSE( lokimq::is_base64("YWJj=A") );
|
||||
REQUIRE_FALSE( lokimq::is_base64("YWJjA===") );
|
||||
REQUIRE_FALSE( lokimq::is_base64("YWJ[") );
|
||||
REQUIRE_FALSE( lokimq::is_base64("YWJ.") );
|
||||
REQUIRE_FALSE( lokimq::is_base64("_YWJ") );
|
||||
|
||||
REQUIRE( lokimq::from_base64(
|
||||
"TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlz"
|
||||
"IHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2Yg"
|
||||
"dGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0aGUgY29udGlu"
|
||||
"dWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRo"
|
||||
"ZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=" )
|
||||
==
|
||||
"Man is distinguished, not only by his reason, but by this singular passion from other "
|
||||
"animals, which is a lust of the mind, that by a perseverance of delight in the "
|
||||
"continued and indefatigable generation of knowledge, exceeds the short vehemence of "
|
||||
"any carnal pleasure.");
|
||||
|
||||
REQUIRE( lokimq::to_base64(pk) == pk_b64 );
|
||||
REQUIRE( lokimq::to_base64(pk.begin(), pk.end()) == pk_b64 );
|
||||
REQUIRE( lokimq::from_base64(pk_b64) == pk );
|
||||
REQUIRE( lokimq::from_base64(pk_b64.begin(), pk_b64.end()) == pk );
|
||||
|
||||
std::string pk_b64_again, pk_again;
|
||||
lokimq::to_base64(pk.begin(), pk.end(), std::back_inserter(pk_b64_again));
|
||||
lokimq::from_base64(pk_b64.begin(), pk_b64.end(), std::back_inserter(pk_again));
|
||||
REQUIRE( pk_b64_again == pk_b64 );
|
||||
REQUIRE( pk_again == pk );
|
||||
|
||||
std::vector<std::byte> bytes{{std::byte{0}, std::byte{255}}};
|
||||
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
|
||||
REQUIRE( lokimq::to_base64(b) == "AP8=" );
|
||||
|
||||
bytes.resize(4);
|
||||
bytes[0] = std::byte{'/'}; bytes[1] = std::byte{'w'}; bytes[2] = std::byte{'A'}; bytes[3] = std::byte{'='};
|
||||
std::basic_string_view<std::byte> b64_bytes{bytes.data(), bytes.size()};
|
||||
REQUIRE( lokimq::is_base64(b64_bytes) );
|
||||
REQUIRE( lokimq::from_base64(b64_bytes) == "\xff\x00"sv );
|
||||
}
|
|
@ -25,21 +25,23 @@ TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "a.a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -66,37 +68,40 @@ TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.r" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
|
||||
client.recv(resp);
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "foo" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -132,9 +137,10 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -144,18 +150,20 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
|||
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
clients[0].recv(resp);
|
||||
auto recvd = clients[0].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
clients[0].recv(resp);
|
||||
REQUIRE( clients[0].recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
for (int i : {1, 2}) {
|
||||
clients[i].recv(resp);
|
||||
recvd = clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -164,17 +172,19 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
|||
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
|
||||
|
||||
for (int i : {0, 1}) {
|
||||
clients[i].recv(resp);
|
||||
recvd = clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
clients[i].recv(resp);
|
||||
REQUIRE( clients[i].recv(resp) );
|
||||
REQUIRE( resp.to_string() == "y.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
clients[2].recv(resp);
|
||||
recvd = clients[2].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "b" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -207,9 +217,10 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -217,12 +228,13 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
|
|||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -230,15 +242,16 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
|
|||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
client.recv(resp);
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -271,9 +284,10 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -281,12 +295,13 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
|||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -294,15 +309,16 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
|||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
client.recv(resp);
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
#include "common.h"
|
||||
|
||||
using namespace lokimq;
|
||||
|
||||
TEST_CASE("injected external commands", "[injected]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.set_general_threads(1);
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos = 0;
|
||||
std::atomic<bool> done = false;
|
||||
server.add_category("public", AuthLevel::none, 3);
|
||||
server.add_command("public", "hello", [&](Message& m) {
|
||||
hellos++;
|
||||
while (!done) std::this_thread::sleep_for(10ms);
|
||||
});
|
||||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, std::string_view) { got = true; },
|
||||
server.get_pubkey());
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
// First make sure that basic message respects the 3 thread limit
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 3; });
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 3 );
|
||||
}
|
||||
done = true;
|
||||
wait_for([&] { return hellos >= 4; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 4 );
|
||||
}
|
||||
|
||||
// Now try injecting external commands
|
||||
done = false;
|
||||
hellos = 0;
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 1; });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
wait_for([&] { return hellos >= 11; });
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 12; });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
wait_for([&] { return hellos >= 12; });
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 12 );
|
||||
}
|
||||
done = true;
|
||||
wait_for([&] { return hellos >= 42; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 42 );
|
||||
}
|
||||
}
|
|
@ -30,10 +30,9 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
|
@ -88,10 +87,9 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
|
@ -151,10 +149,9 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
|
||||
|
@ -170,7 +167,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
success = ok;
|
||||
data = std::move(data_);
|
||||
},
|
||||
lokimq::send_option::request_timeout{20ms}
|
||||
lokimq::send_option::request_timeout{10ms}
|
||||
);
|
||||
|
||||
std::atomic<bool> got_triggered2{false};
|
||||
|
@ -179,10 +176,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
success = ok;
|
||||
data = std::move(data_);
|
||||
},
|
||||
lokimq::send_option::request_timeout{100ms}
|
||||
lokimq::send_option::request_timeout{200ms}
|
||||
);
|
||||
|
||||
std::this_thread::sleep_for(40ms);
|
||||
std::this_thread::sleep_for(100ms);
|
||||
REQUIRE( got_triggered );
|
||||
REQUIRE_FALSE( got_triggered2 );
|
||||
REQUIRE_FALSE( success );
|
||||
|
|
|
@ -1,187 +0,0 @@
|
|||
#include <catch2/catch.hpp>
|
||||
#include "lokimq/string_view.h"
|
||||
#include <future>
|
||||
|
||||
using namespace lokimq;
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
TEST_CASE("string view", "[string_view]") {
|
||||
std::string foo = "abc 123 xyz";
|
||||
string_view f1{foo};
|
||||
string_view f2{"def 789 uvw"};
|
||||
string_view f3{"nu\0ll", 5};
|
||||
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
REQUIRE( f2 == "def 789 uvw" );
|
||||
REQUIRE( f3.size() == 5 );
|
||||
REQUIRE( f3 == std::string{"nu\0ll", 5} );
|
||||
REQUIRE( f3 != "nu" );
|
||||
REQUIRE( f3.data() == "nu"s );
|
||||
REQUIRE( string_view(f3) == f3 );
|
||||
|
||||
auto f4 = f3;
|
||||
REQUIRE( f4 == f3 );
|
||||
f4 = f2;
|
||||
REQUIRE( f4 == "def 789 uvw" );
|
||||
|
||||
REQUIRE( f1.size() == 11 );
|
||||
REQUIRE( f3.length() == 5 );
|
||||
|
||||
string_view f5{""};
|
||||
REQUIRE( !f3.empty() );
|
||||
REQUIRE( f5.empty() );
|
||||
|
||||
REQUIRE( f1[5] == '2' );
|
||||
size_t i = 0;
|
||||
for (auto c : f3)
|
||||
REQUIRE(c == f3[i++]);
|
||||
|
||||
std::string backwards;
|
||||
for (auto it = std::rbegin(f2); it != f2.crend(); ++it)
|
||||
backwards += *it;
|
||||
|
||||
REQUIRE( backwards == "wvu 987 fed" );
|
||||
|
||||
REQUIRE( f1.at(10) == 'z' );
|
||||
REQUIRE_THROWS_AS( f1.at(15), std::out_of_range );
|
||||
REQUIRE_THROWS_AS( f1.at(11), std::out_of_range );
|
||||
|
||||
f4 = f1;
|
||||
f4.remove_prefix(2);
|
||||
REQUIRE( f4 == "c 123 xyz" );
|
||||
f4.remove_prefix(2);
|
||||
f4.remove_suffix(4);
|
||||
REQUIRE( f4 == "123" );
|
||||
f4.remove_prefix(1);
|
||||
REQUIRE( f4 == "23" );
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
f4.swap(f1);
|
||||
REQUIRE( f1 == "23" );
|
||||
REQUIRE( f4 == "abc 123 xyz" );
|
||||
f1.remove_suffix(2);
|
||||
REQUIRE( f1.empty() );
|
||||
REQUIRE( f4 == "abc 123 xyz" );
|
||||
f1.swap(f4);
|
||||
REQUIRE( f4.empty() );
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
|
||||
REQUIRE( f1.front() == 'a' );
|
||||
REQUIRE( f1.back() == 'z' );
|
||||
REQUIRE( f1.compare("abc") > 0 );
|
||||
REQUIRE( f1.compare("abd") < 0 );
|
||||
REQUIRE( f1.compare("abc 123 xyz") == 0 );
|
||||
REQUIRE( f1.compare("abc 123 xyza") < 0 );
|
||||
REQUIRE( f1.compare("abc 123 xy") > 0 );
|
||||
|
||||
std::string buf;
|
||||
buf.resize(5);
|
||||
f1.copy(&buf[0], 5, 2);
|
||||
REQUIRE( buf == "c 123" );
|
||||
buf.resize(100, 'X');
|
||||
REQUIRE( f1.copy(&buf[0], 100) == 11 );
|
||||
REQUIRE( buf.substr(0, 11) == f1 );
|
||||
REQUIRE( buf.substr(11) == std::string(89, 'X') );
|
||||
REQUIRE( f1.substr(4) == "123 xyz" );
|
||||
REQUIRE( f1.substr(4, 3) == "123" );
|
||||
REQUIRE_THROWS_AS( f1.substr(500, 3), std::out_of_range );
|
||||
REQUIRE( f1.substr(11, 2) == "" );
|
||||
REQUIRE( f1.substr(8, 500) == "xyz" );
|
||||
REQUIRE( f1.find("123") == 4 );
|
||||
REQUIRE( f1.find("abc") == 0 );
|
||||
REQUIRE( f1.find("xyz") == 8 );
|
||||
REQUIRE( f1.find("abc 123 xyz 7") == string_view::npos );
|
||||
REQUIRE( f1.find("23") == 5 );
|
||||
REQUIRE( f1.find("234") == string_view::npos );
|
||||
|
||||
string_view f6{"zz abc abcd abcde abcdef"};
|
||||
REQUIRE( f6.find("abc") == 3 );
|
||||
REQUIRE( f6.find("abc", 3) == 3 );
|
||||
REQUIRE( f6.find("abc", 4) == 7 );
|
||||
REQUIRE( f6.find("abc", 7) == 7 );
|
||||
REQUIRE( f6.find("abc", 8) == 12 );
|
||||
REQUIRE( f6.find("abc", 18) == 18 );
|
||||
REQUIRE( f6.find("abc", 19) == string_view::npos );
|
||||
REQUIRE( f6.find("abcd") == 7 );
|
||||
REQUIRE( f6.rfind("abc") == 18 );
|
||||
REQUIRE( f6.rfind("abcd") == 18 );
|
||||
REQUIRE( f6.rfind("bcd") == 19 );
|
||||
REQUIRE( f6.rfind("abc", 19) == 18 );
|
||||
REQUIRE( f6.rfind("abc", 18) == 18 );
|
||||
REQUIRE( f6.rfind("abc", 17) == 12 );
|
||||
REQUIRE( f6.rfind("abc", 17) == 12 );
|
||||
REQUIRE( f6.rfind("abc", 8) == 7 );
|
||||
REQUIRE( f6.rfind("abc", 7) == 7 );
|
||||
REQUIRE( f6.rfind("abc", 6) == 3 );
|
||||
REQUIRE( f6.rfind("abc", 3) == 3 );
|
||||
REQUIRE( f6.rfind("abc", 2) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.find('a') == 3 );
|
||||
REQUIRE( f6.find('a', 17) == 18 );
|
||||
REQUIRE( f6.find('a', 20) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.rfind('a') == 18 );
|
||||
REQUIRE( f6.rfind('a', 17) == 12 );
|
||||
REQUIRE( f6.rfind('a', 2) == string_view::npos );
|
||||
|
||||
string_view f7{"abc\0def", 7};
|
||||
REQUIRE( f7.find("c\0d", 0, 3) == 2 );
|
||||
REQUIRE( f7.find("c\0e", 0, 3) == string_view::npos );
|
||||
REQUIRE( f7.rfind("c\0d", string_view::npos, 3) == 2 );
|
||||
REQUIRE( f7.rfind("c\0e", 0, 3) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.find_first_of("c789b") == 4 );
|
||||
REQUIRE( f6.find_first_of("c789") == 5 );
|
||||
REQUIRE( f2.find_first_of("c789b") == 4 );
|
||||
REQUIRE( f6.find_first_of("c789b", 6) == 8 );
|
||||
|
||||
REQUIRE( f6.find_last_of("c789b") == 20 );
|
||||
REQUIRE( f6.find_last_of("789b") == 19 );
|
||||
REQUIRE( f2.find_last_of("c789b") == 6 );
|
||||
REQUIRE( f6.find_last_of("c789b", 6) == 5 );
|
||||
REQUIRE( f6.find_last_of("c789b", 5) == 5 );
|
||||
REQUIRE( f6.find_last_of("c789b", 4) == 4 );
|
||||
REQUIRE( f6.find_last_of("c789b", 3) == string_view::npos );
|
||||
|
||||
REQUIRE( f2.find_first_of(f7) == 0 );
|
||||
REQUIRE( f3.find_first_of(f7) == 2 );
|
||||
REQUIRE( f3.find_first_of('\0') == 2 );
|
||||
REQUIRE( f3.find_first_of("jk\0", 0, 3) == 2 );
|
||||
|
||||
REQUIRE( f1.find_first_not_of("abc") == 3 );
|
||||
REQUIRE( f1.find_first_not_of("abc ", 3) == 4 );
|
||||
REQUIRE( f1.find_first_not_of(" 123", 3) == 8 );
|
||||
REQUIRE( f1.find_last_not_of("abc") == 10 );
|
||||
REQUIRE( f1.find_last_not_of("xyz") == 7 );
|
||||
REQUIRE( f1.find_last_not_of("xyz 321") == 2 );
|
||||
REQUIRE( f1.find_last_not_of("xay z1b2c3") == string_view::npos );
|
||||
REQUIRE( f6.find_last_not_of("def") == 20 );
|
||||
REQUIRE( f6.find_last_not_of("abcdef") == 17 );
|
||||
REQUIRE( f6.find_last_not_of("abcdef ") == 1 );
|
||||
REQUIRE( f6.find_first_not_of('z') == 2 );
|
||||
REQUIRE( f6.find_first_not_of("z ") == 3 );
|
||||
REQUIRE( f6.find_first_not_of("a ", 2) == 4 );
|
||||
REQUIRE( f6.find_last_not_of("abc ", 9) == 1 );
|
||||
|
||||
REQUIRE( string_view{"abc"} == string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} == string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} == string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} != string_view{"abd"} );
|
||||
REQUIRE( string_view{"abc"} != string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} < string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} < string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abd"} < string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abcd"} < string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} < string_view{"abc"} );
|
||||
REQUIRE( string_view{"abd"} > string_view{"abc"} );
|
||||
REQUIRE( string_view{"abcd"} > string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abcd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abd"} );
|
||||
REQUIRE( string_view{"abd"} >= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} >= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abcd"} >= string_view{"abc"} );
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
#include "lokimq/batch.h"
|
||||
#include "common.h"
|
||||
#include <future>
|
||||
|
||||
TEST_CASE("tagged thread start functions", "[tagged][start]") {
|
||||
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
|
||||
|
||||
lmq.set_general_threads(2);
|
||||
lmq.set_batch_threads(2);
|
||||
auto t_abc = lmq.add_tagged_thread("abc");
|
||||
std::atomic<bool> start_called = false;
|
||||
auto t_def = lmq.add_tagged_thread("def", [&] { start_called = true; });
|
||||
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE_FALSE( start_called );
|
||||
}
|
||||
|
||||
lmq.start();
|
||||
wait_for([&] { return start_called.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( start_called );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("tagged threads quit-before-start", "[tagged][quit]") {
|
||||
auto lmq = std::make_unique<lokimq::LokiMQ>(get_logger(""), LogLevel::trace);
|
||||
auto t_abc = lmq->add_tagged_thread("abc");
|
||||
REQUIRE_NOTHROW(lmq.reset());
|
||||
}
|
||||
|
||||
TEST_CASE("batch jobs to tagged threads", "[tagged][batch]") {
|
||||
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
|
||||
|
||||
lmq.set_general_threads(2);
|
||||
lmq.set_batch_threads(2);
|
||||
std::thread::id id_abc, id_def;
|
||||
auto t_abc = lmq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
auto t_def = lmq.add_tagged_thread("def", [&] { id_def = std::this_thread::get_id(); });
|
||||
lmq.start();
|
||||
|
||||
std::atomic<bool> done = false;
|
||||
std::thread::id id;
|
||||
lmq.job([&] { id = std::this_thread::get_id(); done = true; });
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id != id_abc );
|
||||
REQUIRE( id != id_def );
|
||||
}
|
||||
|
||||
done = false;
|
||||
lmq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id == id_abc );
|
||||
}
|
||||
|
||||
done = false;
|
||||
lmq.job([&] { id = std::this_thread::get_id(); done = true; }, t_def);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id == id_def );
|
||||
}
|
||||
|
||||
std::atomic<bool> sleep = true;
|
||||
auto sleeper = [&] { for (int i = 0; sleep && i < 10; i++) { std::this_thread::sleep_for(25ms); } };
|
||||
lmq.job(sleeper);
|
||||
lmq.job(sleeper);
|
||||
// This one should stall:
|
||||
std::atomic<bool> bad = false;
|
||||
lmq.job([&] { bad = true; });
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
done = false;
|
||||
lmq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE_FALSE( bad.load() );
|
||||
}
|
||||
|
||||
done = false;
|
||||
// We can queue up a bunch of jobs which should all happen in order, and all on the abc thread.
|
||||
std::vector<int> v;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
lmq.job([&] { if (std::this_thread::get_id() == id_abc) v.push_back(v.size()); }, t_abc);
|
||||
}
|
||||
lmq.job([&] { done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE_FALSE( bad.load() );
|
||||
REQUIRE( v.size() == 100 );
|
||||
for (int i = 0; i < 100; i++)
|
||||
REQUIRE( v[i] == i );
|
||||
}
|
||||
sleep = false;
|
||||
wait_for([&] { return bad.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( bad.load() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]") {
|
||||
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
|
||||
|
||||
lmq.set_general_threads(4);
|
||||
lmq.set_batch_threads(4);
|
||||
std::thread::id id_abc;
|
||||
auto t_abc = lmq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
lmq.start();
|
||||
|
||||
lokimq::Batch<int> batch;
|
||||
for (int i = 1; i < 10; i++)
|
||||
batch.add_job([i, &id_abc]() { if (std::this_thread::get_id() == id_abc) return 0; return i; });
|
||||
|
||||
std::atomic<int> result_sum = -1;
|
||||
batch.completion([&](auto result) {
|
||||
int sum = 0;
|
||||
for (auto& r : result)
|
||||
sum += r.get();
|
||||
result_sum = std::this_thread::get_id() == id_abc ? sum : -sum;
|
||||
}, t_abc);
|
||||
lmq.batch(std::move(batch));
|
||||
wait_for([&] { return result_sum.load() != -1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( result_sum == 45 );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("timer job completion on tagged threads", "[tagged][timer]") {
|
||||
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
|
||||
|
||||
lmq.set_general_threads(4);
|
||||
lmq.set_batch_threads(4);
|
||||
|
||||
std::thread::id id_abc;
|
||||
auto t_abc = lmq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
lmq.start();
|
||||
|
||||
std::atomic<int> ticks = 0;
|
||||
std::atomic<int> abc_ticks = 0;
|
||||
lmq.add_timer([&] { ticks++; }, 10ms);
|
||||
lmq.add_timer([&] { if (std::this_thread::get_id() == id_abc) abc_ticks++; }, 10ms, true, t_abc);
|
||||
|
||||
wait_for([&] { return ticks.load() > 2 && abc_ticks > 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks.load() > 2 );
|
||||
REQUIRE( abc_ticks.load() > 2 );
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue