Merge remote-tracking branch 'origin/dev' into master

This commit is contained in:
Jason Rhinelander 2020-09-30 17:00:18 -03:00
commit 53481cdfa9
38 changed files with 3227 additions and 1112 deletions

63
.drone.jsonnet Normal file
View File

@ -0,0 +1,63 @@
local debian_pipeline(name, image, arch='amd64', deps='g++ libsodium-dev libzmq3-dev', cmake_extra='', build_type='Release', extra_cmds=[], allow_fail=false) = {
kind: 'pipeline',
type: 'docker',
name: name,
platform: { arch: arch },
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
steps: [
{
name: 'build',
image: image,
[if allow_fail then "failure"]: "ignore",
commands: [
'apt-get update',
'apt-get install -y eatmydata',
'eatmydata apt-get dist-upgrade -y',
'eatmydata apt-get install -y cmake git ninja-build pkg-config ccache ' + deps,
'git submodule update --init --recursive',
'mkdir build',
'cd build',
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fdiagnostics-color=always -DCMAKE_BUILD_TYPE='+build_type+' -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ' + cmake_extra,
'ninja -v',
'./tests/tests --use-colour yes'
] + extra_cmds,
}
]
};
[
debian_pipeline("Ubuntu focal (amd64)", "ubuntu:focal"),
debian_pipeline("Ubuntu bionic (amd64)", "ubuntu:bionic", deps='libsodium-dev g++-8',
cmake_extra='-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8'),
debian_pipeline("Debian sid (amd64)", "debian:sid"),
debian_pipeline("Debian sid/Debug (amd64)", "debian:sid", build_type='Debug'),
debian_pipeline("Debian sid/clang-10 (amd64)", "debian:sid", deps='clang-10 lld-10 libsodium-dev libzmq3-dev',
cmake_extra='-DCMAKE_C_COMPILER=clang-10 -DCMAKE_CXX_COMPILER=clang++-10 ' + std.join(' ', [
'-DCMAKE_'+type+'_LINKER_FLAGS=-fuse-ld=lld-10' for type in ['EXE','MODULE','SHARED','STATIC']])),
debian_pipeline("Debian buster (amd64)", "debian:buster"),
debian_pipeline("Debian buster (i386)", "i386/debian:buster"),
debian_pipeline("Ubuntu bionic (ARM64)", "ubuntu:bionic", arch="arm64", deps='libsodium-dev g++-8',
cmake_extra='-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8'),
debian_pipeline("Debian sid (ARM64)", "debian:sid", arch="arm64"),
debian_pipeline("Debian buster (armhf)", "arm32v7/debian:buster", arch="arm64"),
{
kind: 'pipeline',
type: 'exec',
name: 'macOS (Catalina w/macports)',
platform: { os: 'darwin', arch: 'amd64' },
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
steps: [
{
name: 'build',
commands: [
'git submodule update --init --recursive',
'mkdir build',
'cd build',
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fcolor-diagnostics -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
'ninja -v',
'./tests/tests --use-colour yes'
],
}
]
},
]

3
.gitmodules vendored
View File

@ -1,6 +1,3 @@
[submodule "mapbox-variant"]
path = mapbox-variant
url = https://github.com/mapbox/variant.git
[submodule "cppzmq"]
path = cppzmq
url = https://github.com/zeromq/cppzmq.git

View File

@ -1,12 +1,12 @@
cmake_minimum_required(VERSION 3.7)
project(liblokimq CXX)
project(liblokimq CXX C)
include(GNUInstallDirs)
set(LOKIMQ_VERSION_MAJOR 1)
set(LOKIMQ_VERSION_MINOR 1)
set(LOKIMQ_VERSION_PATCH 4)
set(LOKIMQ_VERSION_MINOR 2)
set(LOKIMQ_VERSION_PATCH 0)
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
message(STATUS "lokimq v${LOKIMQ_VERSION}")
@ -22,6 +22,7 @@ configure_file(lokimq/version.h.in lokimq/version.h @ONLY)
configure_file(liblokimq.pc.in liblokimq.pc @ONLY)
add_library(lokimq
lokimq/address.cpp
lokimq/auth.cpp
lokimq/bt_serialize.cpp
lokimq/connections.cpp
@ -46,17 +47,18 @@ if(TARGET libzmq)
elseif(BUILD_SHARED_LIBS)
include(FindPkgConfig)
pkg_check_modules(libzmq libzmq>=4.3 IMPORTED_TARGET)
# Debian sid includes a -isystem in the mit-krb package that, starting with pkg-config 0.29.2,
# breaks cmake's pkgconfig module because it stupidly thinks "-isystem" is a path, so if we find
# -isystem in the include dirs then hack it out.
get_property(zmq_inc TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
list(FIND zmq_inc "-isystem" broken_isystem)
if(NOT broken_isystem EQUAL -1)
list(REMOVE_AT zmq_inc ${broken_isystem})
set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc})
endif()
if(libzmq_FOUND)
# Debian sid includes a -isystem in the mit-krb package that, starting with pkg-config 0.29.2,
# breaks cmake's pkgconfig module because it stupidly thinks "-isystem" is a path, so if we find
# -isystem in the include dirs then hack it out.
get_property(zmq_inc TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
list(FIND zmq_inc "-isystem" broken_isystem)
if(NOT broken_isystem EQUAL -1)
list(REMOVE_AT zmq_inc ${broken_isystem})
set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc})
endif()
target_link_libraries(lokimq PUBLIC PkgConfig::libzmq)
else()
set(lokimq_build_static_libzmq ON)
@ -66,7 +68,7 @@ else()
endif()
if(lokimq_build_static_libzmq)
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled 4.3.2")
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled version")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/local-libzmq")
include(LocalLibzmq)
target_link_libraries(lokimq PUBLIC libzmq_vendor)
@ -77,12 +79,11 @@ target_include_directories(lokimq
$<INSTALL_INTERFACE:>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/cppzmq>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/mapbox-variant/include>
)
target_compile_options(lokimq PRIVATE -Wall -Wextra -Werror)
set_target_properties(lokimq PROPERTIES
CXX_STANDARD 14
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
POSITION_INDEPENDENT_CODE ON
@ -101,7 +102,7 @@ endfunction()
# If the caller has already set up a sodium target then we will just link to it, otherwise we go
# looking for it.
if(TARGET sodium)
target_link_libraries(lokimq PRIVATE sodium)
target_link_libraries(lokimq PUBLIC sodium)
if(lokimq_build_static_libzmq)
target_link_libraries(libzmq_vendor INTERFACE sodium)
endif()
@ -109,13 +110,13 @@ else()
pkg_check_modules(sodium REQUIRED libsodium IMPORTED_TARGET)
if(BUILD_SHARED_LIBS)
target_link_libraries(lokimq PRIVATE PkgConfig::sodium)
target_link_libraries(lokimq PUBLIC PkgConfig::sodium)
if(lokimq_build_static_libzmq)
target_link_libraries(libzmq_vendor INTERFACE PkgConfig::sodium)
endif()
else()
link_dep_libs(lokimq PRIVATE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
target_include_directories(lokimq PRIVATE ${sodium_STATIC_INCLUDE_DIRS})
link_dep_libs(lokimq PUBLIC "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
target_include_directories(lokimq PUBLIC ${sodium_STATIC_INCLUDE_DIRS})
if(lokimq_build_static_libzmq)
link_dep_libs(libzmq_vendor INTERFACE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
target_link_libraries(libzmq_vendor INTERFACE ${sodium_STATIC_INCLUDE_DIRS})
@ -137,9 +138,13 @@ install(
)
install(
FILES lokimq/auth.h
FILES lokimq/address.h
lokimq/auth.h
lokimq/base32z.h
lokimq/base64.h
lokimq/batch.h
lokimq/bt_serialize.h
lokimq/bt_value.h
lokimq/connections.h
lokimq/hex.h
lokimq/lokimq.h
@ -148,17 +153,6 @@ install(
${CMAKE_CURRENT_BINARY_DIR}/lokimq/version.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
)
option(LOKIMQ_INSTALL_MAPBOX_VARIANT "Install mapbox-variant headers with lokimq/ headers" ON)
if(LOKIMQ_INSTALL_MAPBOX_VARIANT)
install(
FILES mapbox-variant/include/mapbox/variant.hpp
mapbox-variant/include/mapbox/variant_cast.hpp
mapbox-variant/include/mapbox/variant_io.hpp
mapbox-variant/include/mapbox/variant_visitor.hpp
mapbox-variant/include/mapbox/recursive_wrapper.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq/mapbox
)
endif()
option(LOKIMQ_INSTALL_CPPZMQ "Install cppzmq header with lokimq/ headers" ON)
if(LOKIMQ_INSTALL_CPPZMQ)

View File

@ -1,6 +1,6 @@
# LokiMQ - zeromq-based message passing for Loki projects
This C++14 library contains an abstraction layer around ZeroMQ to support integration with Loki
This C++17 library contains an abstraction layer around ZeroMQ to support integration with Loki
authentication, RPC, and message passing. It is designed to be usable as the underlying
communication mechanism of SN-to-SN communication ("quorumnet"), the RPC interface used by wallets
and local daemon commands, communication channels between lokid and auxiliary services (storage
@ -123,6 +123,7 @@ The connection ID generally has two possible values:
places to get one, such as from the `Message` object passed to a command: see the following
section).
```C++
// Send to a service node, establishing a connection if necessary:
std::string my_sn = ...; // 32-byte pubkey of a known SN
lmq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
@ -137,6 +138,7 @@ The connection ID generally has two possible values:
else
std::cout << "Timeout fetching height!";
});
```
## Command invocation

View File

@ -1,7 +1,7 @@
set(LIBZMQ_PREFIX ${CMAKE_BINARY_DIR}/libzmq)
set(ZeroMQ_VERSION 4.3.2)
set(ZeroMQ_VERSION 4.3.3)
set(LIBZMQ_URL https://github.com/zeromq/libzmq/releases/download/v${ZeroMQ_VERSION}/zeromq-${ZeroMQ_VERSION}.tar.gz)
set(LIBZMQ_HASH SHA512=b6251641e884181db9e6b0b705cced7ea4038d404bdae812ff47bdd0eed12510b6af6846b85cb96898e253ccbac71eca7fe588673300ddb9c3109c973250c8e4)
set(LIBZMQ_HASH SHA512=4c18d784085179c5b1fcb753a93813095a12c8d34970f2e1bfca6499be6c9d67769c71c68b7ca54ff181b20390043170e89733c22f76ff1ea46494814f7095b1)
message(${LIBZMQ_URL})

351
lokimq/address.cpp Normal file
View File

@ -0,0 +1,351 @@
#include "address.h"
#include <tuple>
#include <limits>
#include <cstddef>
#include <utility>
#include <stdexcept>
#include <ostream>
#include "hex.h"
#include "base32z.h"
#include "base64.h"
namespace lokimq {
constexpr size_t enc_length(address::encoding enc) {
return enc == address::encoding::hex ? 64 :
enc == address::encoding::base64 ? 43 : // this can be 44 with a padding byte, but we don't need it
52 /*base32z*/;
};
// Parses an encoding pubkey from the given string_view. Advanced the string_view beyond the
// consumed pubkey data, and returns the pubkey (as a 32-byte string). Throws if no valid pubkey
// was found at the beginning of addr. We look for hex, base32z, or base64 pubkeys *unless* qr is
// given: for QR-friendly we only accept hex or base32z (since QR cannot handle base64's alphabet).
std::string decode_pubkey(std::string_view& in, bool qr) {
std::string pubkey;
if (in.size() >= 64 && lokimq::is_hex(in.substr(0, 64))) {
pubkey = from_hex(in.substr(0, 64));
in.remove_prefix(64);
} else if (in.size() >= 52 && lokimq::is_base32z(in.substr(0, 52))) {
pubkey = from_base32z(in.substr(0, 52));
in.remove_prefix(52);
} else if (!qr && in.size() >= 43 && lokimq::is_base64(in.substr(0, 43))) {
pubkey = from_base64(in.substr(0, 43));
in.remove_prefix(43);
if (!in.empty() && in.front() == '=')
in.remove_prefix(1); // allow (and eat) a padding byte at the end
} else {
throw std::invalid_argument{"No pubkey found"};
}
return pubkey;
}
// Parse the host, port, and optionally pubkey from a string view, mutating it to remove the parsed
// sections. qr should be true if we should accept $IPv6$ as a QR-encoding-friendly alternative to
// [IPv6] (the returned host will have the $ replaced, i.e. [IPv6]).
std::tuple<std::string, uint16_t, std::string> parse_tcp(std::string_view& addr, bool qr, bool expect_pubkey) {
std::tuple<std::string, uint16_t, std::string> result;
auto& host = std::get<0>(result);
if (addr.front() == '[' || (qr && addr.front() == '$')) { // IPv6 addr (though this is far from complete validation)
auto pos = addr.find_first_not_of(":.1234567890abcdefABCDEF", 1);
if (pos == std::string_view::npos)
throw std::invalid_argument("Could not find terminating ] while parsing an IPv6 address");
if (!(addr[pos] == ']' || (qr && addr[pos] == '$')))
throw std::invalid_argument{"Expected " + (qr ? "$"s : "]"s) + " to close IPv6 address but found " + std::string(1, addr[pos])};
host = std::string{addr.substr(0, pos+1)};
if (qr) {
if (host.front() == '$')
host.front() = '[';
if (host.back() == '$')
host.back() = ']';
}
addr.remove_prefix(pos+1);
} else {
auto port_pos = addr.find(':');
if (port_pos == std::string_view::npos)
throw std::invalid_argument{"Could not determine host (no following ':port' found)"};
if (port_pos == 0)
throw std::invalid_argument{"Host cannot be empty"};
host = std::string{addr.substr(0, port_pos)};
addr.remove_prefix(port_pos);
}
if (qr)
// Lower-case the host because upper case hostnames are ugly
for (char& c : host)
if (c >= 'A' && c <= 'Z')
c = c - 'A' + 'a';
if (addr.size() < 2 || addr[0] != ':')
throw std::invalid_argument{"Could not find :port in address string"};
addr.remove_prefix(1);
auto pos = addr.find_first_not_of("1234567890");
if (pos == 0)
throw std::invalid_argument{"Could not find numeric port in address string"};
if (pos == std::string_view::npos)
pos = addr.size();
size_t processed;
int port_int = std::stoi(std::string{addr.substr(0, pos)}, &processed);
if (port_int == 0 || processed != pos)
throw std::invalid_argument{"Could not parse numeric port in address string"};
if (port_int < 0 || port_int > std::numeric_limits<uint16_t>::max())
throw std::invalid_argument{"Invalid port: port must be in range 1-65535"};
std::get<1>(result) = static_cast<uint16_t>(port_int);
addr.remove_prefix(pos);
if (expect_pubkey) {
if (addr.size() < 1 + enc_length(qr ? address::encoding::base32z : address::encoding::base64)
|| addr.front() != '/')
throw std::invalid_argument{"Invalid address: expected /PUBKEY after port"};
addr.remove_prefix(1);
std::get<2>(result) = decode_pubkey(addr, qr);
if (!addr.empty())
throw std::invalid_argument{"Invalid address: found unexpected trailing data after pubkey"};
} else if (!addr.empty()) {
throw std::invalid_argument{"Invalid address: found unexpected trailing data after port"};
}
return result;
}
// Parse the socket path and (possibly) pubkey, mutating it to remove the parsed sections.
// Currently the /pubkey *must* be at the end of the string, but this might not always be the case
// (e.g. we could in the future support query string-like arguments).
std::pair<std::string, std::string> parse_unix(std::string_view& addr, bool expect_pubkey) {
std::pair<std::string, std::string> result;
if (expect_pubkey) {
size_t b64_len = addr.size() > 0 && addr.back() == '=' ? 44 : 43;
if (addr.size() > 64 && addr[addr.size() - 65] == '/' && is_hex(addr.substr(addr.size() - 64))) {
result.first = std::string{addr.substr(0, addr.size() - 65)};
result.second = from_hex(addr.substr(addr.size() - 64));
} else if (addr.size() > 52 && addr[addr.size() - 53] == '/' && is_base32z(addr.substr(addr.size() - 52))) {
result.first = std::string{addr.substr(0, addr.size() - 53)};
result.second = from_base32z(addr.substr(addr.size() - 52));
} else if (addr.size() > b64_len && addr[addr.size() - b64_len - 1] == '/' && is_base64(addr.substr(addr.size() - b64_len))) {
result.first = std::string{addr.substr(0, addr.size() - b64_len - 1)};
result.second = from_base64(addr.substr(addr.size() - b64_len));
} else {
throw std::invalid_argument{"icp+curve:// requires a trailing /PUBKEY value, got: " + std::string{addr}};
}
} else {
// Anything goes
result.first = std::string{addr};
}
// Any path above consumes everything:
addr.remove_prefix(addr.size());
return result;
}
address::address(std::string_view addr) {
auto protoend = addr.find("://"sv);
if (protoend == std::string_view::npos || protoend == 0)
throw std::invalid_argument("Invalid address: no protocol found");
auto pro = addr.substr(0, protoend);
addr.remove_prefix(protoend + 3);
if (addr.empty())
throw std::invalid_argument("Invalid address: no value specified after protocol");
bool qr = false;
if (pro == "tcp") protocol = proto::tcp;
else if (pro == "tcp+curve" || pro == "curve") protocol = proto::tcp_curve;
else if (pro == "ipc") protocol = proto::ipc;
else if (pro == "ipc+curve") protocol = proto::ipc_curve;
else if (pro == "TCP") {
protocol = proto::tcp;
qr = true;
} else if (pro == "CURVE") {
protocol = proto::tcp_curve;
qr = true;
} else {
throw std::invalid_argument("Invalid protocol '" + std::string{pro} + "'");
}
if (qr) {
// The QR variations only allow QR-alphanumeric characters (upper-case letters, numbers, and
// a few symbols):
for (char c : addr) {
// QR alphanumeric also allows space, %, *, +, but we don't need or allow any of those here.
if (!((c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '$' || c == ':' || c == '/' || c == '.' || c == '-'))
throw std::invalid_argument("Found non-QR-alphanumeric value in QR TCP:// or CURVE:// address");
}
}
if (tcp())
std::tie(host, port, pubkey) = parse_tcp(addr, qr, curve());
else
std::tie(socket, pubkey) = parse_unix(addr, curve());
if (!addr.empty())
throw std::invalid_argument{"Invalid trailing garbage '" + std::string{addr} + "' in address"};
}
address& address::set_pubkey(std::string_view pk) {
if (pk.size() == 0) {
if (protocol == proto::tcp_curve) protocol = proto::tcp;
else if (protocol == proto::ipc_curve) protocol = proto::ipc;
} else if (pk.size() == 32) {
if (protocol == proto::tcp) protocol = proto::tcp_curve;
else if (protocol == proto::ipc) protocol = proto::ipc_curve;
} else {
throw std::invalid_argument{"Invalid pubkey passed to set_pubkey(): require 0- or 32-byte pubkey"};
}
pubkey = pk;
return *this;
}
std::string address::encode_pubkey(encoding enc) const {
std::string pk;
if (enc == encoding::hex)
pk = to_hex(pubkey);
else if (enc == encoding::base32z)
pk = to_base32z(pubkey);
else if (enc == encoding::BASE32Z) {
pk = to_base32z(pubkey);
for (char& c : pk)
if (c >= 'a' && c <= 'z')
c = c - 'a' + 'A';
} else if (enc == encoding::base64) {
pk = to_base64(pubkey);
if (pk.size() == 44 && pk.back() == '=')
pk.resize(43);
} else {
throw std::logic_error{"Invalid encoding"};
}
return pk;
}
std::string address::full_address(encoding enc) const {
std::string result;
std::string pk;
if (curve())
pk = encode_pubkey(enc);
if (protocol == proto::tcp) {
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
result += "tcp://";
result += host;
result += ':';
result += std::to_string(port);
} else if (protocol == proto::tcp_curve) {
result.reserve(8 /*curve:// */ + host.size() + 6 /*:port*/ + 1 /* / */ + pk.size());
result += "curve://";
result += host;
result += ':';
result += std::to_string(port);
result += '/';
result += pk;
} else if (protocol == proto::ipc) {
result.reserve(6 /*ipc:// */ + socket.size());
result += "ipc://";
result += socket;
} else if (protocol == proto::ipc_curve) {
result.reserve(12 /*ipc+curve:// */ + socket.size() + 1 /* / */ + pk.size());
result += "ipc+curve://";
result += socket;
result += '/';
result += pk;
} else {
throw std::logic_error{"Invalid protocol"};
}
return result;
}
std::string address::zmq_address() const {
std::string result;
if (tcp()) {
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
result += "tcp://";
result += host;
result += ':';
result += std::to_string(port);
} else {
result.reserve(6 /*ipc:// */ + socket.size());
result += "ipc://";
result += socket;
}
return result;
}
std::string address::qr_address() const {
if (protocol != proto::tcp && protocol != proto::tcp_curve)
throw std::logic_error("Cannot construct a QR-friendly address for a non-TCP address");
if (host.empty())
throw std::logic_error("Cannot construct a QR-friendly address with an empty TCP host");
std::string result;
result.reserve((curve() ? 8 /*CURVE:// */ : 6 /*TCP:// */) + host.size() + 6 /*:port*/ +
(curve() ? 1 + enc_length(encoding::BASE32Z) : 0));
result += curve() ? "CURVE://" : "TCP://";
std::string uc_host = host;
for (auto& c : uc_host)
if (c >= 'a' && c <= 'z')
c = c - 'a' + 'A';
if (uc_host.front() == '[' && uc_host.back() == ']') {
uc_host.front() = '$';
uc_host.back() = '$';
}
result += uc_host;
result += ':';
result += std::to_string(port);
if (curve()) {
result += '/';
result += encode_pubkey(encoding::BASE32Z);
}
return result;
}
bool address::operator==(const address& other) const {
if (protocol != other.protocol)
return false;
if (tcp())
if (host != other.host || port != other.port)
return false;
if (ipc())
if (socket != other.socket)
return false;
if (curve())
if (pubkey != other.pubkey)
return false;
return true;
}
address address::tcp(std::string host, uint16_t port) {
address a;
a.protocol = proto::tcp;
a.host = std::move(host);
a.port = port;
return a;
}
address address::tcp_curve(std::string host, uint16_t port, std::string pubkey) {
address a;
a.protocol = proto::tcp_curve;
a.host = std::move(host);
a.port = port;
a.pubkey = std::move(pubkey);
return a;
}
address address::ipc(std::string path) {
address a;
a.protocol = proto::ipc;
a.socket = std::move(path);
return a;
}
address address::ipc_curve(std::string path, std::string pubkey) {
address a;
a.protocol = proto::ipc_curve;
a.socket = std::move(path);
a.pubkey = std::move(pubkey);
return a;
}
std::ostream& operator<<(std::ostream& o, const address& a) { return o << a.full_address(); }
}

210
lokimq/address.h Normal file
View File

@ -0,0 +1,210 @@
// Copyright (c) 2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include <string_view>
#include <cstdint>
#include <iosfwd>
namespace lokimq {
using namespace std::literals;
/** LokiMQ address abstraction class. This class uses and extends standard ZMQ addresses allowing
* extra parameters to be passed in in a relative standard way.
*
* External ZMQ addresses generally have two forms that we are concerned with: one for TCP and one
* for Unix sockets:
*
* tcp://HOST:PORT -- HOST can be a hostname, IPv4 address, or IPv6 address in [...]
* ipc://PATH -- PATH can be absolute (ipc:///path/to/some.sock) or relative (ipc://some.sock)
*
* but this doesn't carry enough info: in particular, we can connect with two very different
* protocols: curve25519-encrypted, or plaintext, but for curve25519-encrypted we require the
* remote's public key as well to verify the connection.
*
* This class, then, handles this by allowing addresses of:
*
* Standard ZMQ address: these carry no pubkey and so the connection will be unencrypted:
*
* tcp://HOSTNAME:PORT
* ipc://PATH
*
* Non-ZMQ address formats that specify that the connection shall be x25519 encrypted:
*
* curve://HOSTNAME:PORT/PUBKEY -- PUBKEY must be specified in hex (64 characters), base32z (52)
* or base64 (43 or 44 with one '=' trailing padding)
* ipc+curve:///path/to/my.sock/PUBKEY -- same requirements on PUBKEY as above.
* tcp+curve://(whatever) -- alias for curve://(whatever)
*
* We also accept special upper-case TCP-only variants which *only* accept uppercase characters and
* a few required symbols (:, /, $, ., and -) in the string:
*
* TCP://HOSTNAME:PORT
* CURVE://HOSTNAME:PORT/B32ZPUBKEY
*
* These versions are explicitly meant to be used with QR codes; the upper-case-only requirement
* allows a smaller QR code by allowing QR's alphanumeric mode (which allows only [A-Z0-9 $%*+./:-])
* to be used. Such a QR-friendly address can be created from the qr_address() method. To support
* literal IPv6 addresses we surround the address with $...$ instead of the usual [...].
*
* Note that this class does very little validate the host argument at all, and no socket path
* validation whatsoever. The only constraint on host is when parsing an encoded address: we check
* that it contains no : at all, or must be a [bracketed] expression that contains only hex
* characters, :'s, or .'s. Otherwise, if you pass broken crap into the hostname, expect broken
* crap out.
*/
struct address {
/// Supported address protocols: TCP connections (tcp), or unix sockets (ipc).
enum class proto {
tcp,
tcp_curve,
ipc,
ipc_curve
};
/// Supported public key encodings (used when regenerating an augmented address).
enum class encoding {
hex, ///< hexadecimal encoded
base32z, ///< base32z encoded
base64, ///< base64 encoded (*without* trailing = padding)
BASE32Z ///< upper-case base32z encoding, meant for QR encoding
};
/// The protocol: one of the `protocol` enum values for tcp or ipc (unix sockets), with or
/// without _curve encryption.
proto protocol = proto::tcp;
/// The host for tcp connections; can be a hostname or IP address. If this is an IPv6 it must be surrounded with [ ].
std::string host;
/// The port (for tcp connections)
uint16_t port = 0;
/// The socket path (for unix socket connections)
std::string socket;
/// If a curve connection, this is the required remote public key (in bytes)
std::string pubkey;
/// Default constructor; this gives you an unusable address.
address() = default;
/**
* Constructs an address by parsing a string_view containing one of the formats listed in the
* class description. This is intentionally implicitly constructible so that you can pass a
* string_view into anything expecting an `address`.
*
* Throw std::invalid_argument if the given address is not parseable.
*/
address(std::string_view addr);
/** Constructs an address from a remote string and a separate pubkey. Typically `remote` is a
* basic ZMQ connect string, though this is not enforced. Any pubkey information embedded in
* the remote string will be discarded and replaced with the given pubkey string. The result
* will be curve encrypted if `pubkey` is non-empty, plaintext if `pubkey` is empty.
*
* Throws an exception if either addr or pubkey is invalid.
*
* Exactly equivalent to `address a{remote}; a.set_pubkey(pubkey);`
*/
address(std::string_view addr, std::string_view pubkey) : address(addr) { set_pubkey(pubkey); }
/// Replaces the address's pubkey (if any) with the given pubkey (or no pubkey if empty). If
/// changing from pubkey to no-pubkey or no-pubkey to pubkey then the protocol is update to
/// switch to or from curve encryption.
///
/// pubkey should be the 32-byte binary pubkey, or an empty string to remove an existing pubkey.
///
/// Returns the object itself, so that you can chain it.
address& set_pubkey(std::string_view pubkey);
/// Constructs and builds the ZMQ connection address from the stored connection details. This
/// does not contain any of the curve-related details; those must be specified separately when
/// interfacing with ZMQ.
std::string zmq_address() const;
/// Returns true if the connection was specified as a curve-encryption-enabled connection, false
/// otherwise.
bool curve() const { return protocol == proto::tcp_curve || protocol == proto::ipc_curve; }
/// True if the protocol is TCP (either with or without curve)
bool tcp() const { return protocol == proto::tcp || protocol == proto::tcp_curve; }
/// True if the protocol is unix socket (either with or without curve)
bool ipc() const { return !tcp(); }
/// Returns the full "augmented" address string (i.e. that could be passed in to the
/// constructor). This will be equivalent (but not necessarily identical) to an augmented
/// string passed into the constructor. Takes an optional encoding format for the pubkey (if
/// any), which defaults to base32z.
std::string full_address(encoding enc = encoding::base32z) const;
/// Returns a QR-code friendly address string. This returns an all-uppercase version of the
/// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z
/// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we replace the
/// surround the
/// address with $ instead of $
///
/// \throws std::logic_error if called on a unix socket address.
std::string qr_address() const;
/// Returns `.pubkey` but encoded in the given format
std::string encode_pubkey(encoding enc) const;
/// Returns true if two addresses are identical (i.e. same protocol and relevant protocol
/// arguments).
///
/// Note that it is possible for addresses to connect to the same socket without being
/// identical: for example, using "foo.sock" and "./foo.sock", or writing IPv6 addresses (or
/// even IPv4 addresses) in slightly different ways). Such equivalent but non-equal values will
/// result in a false return here.
///
/// Note also that we ignore irrelevant arguments: for example, we don't care whether pubkeys
/// match when comparing two non-curve TCP addresses.
bool operator==(const address& other) const;
/// Negation of ==
bool operator!=(const address& other) const { return !operator==(other); }
/// Factory function that constructs a TCP address from a host and port. The connection will be
/// plaintext. If the host is an IPv6 address it *must* be surrounded with [ and ].
static address tcp(std::string host, uint16_t port);
/// Factory function that constructs a curve-encrypted TCP address from a host, port, and remote
/// pubkey. The pubkey must be 32 bytes. As above, IPv6 addresses must be specified as [addr].
static address tcp_curve(std::string host, uint16_t, std::string pubkey);
/// Factory function that constructs a unix socket address from a path. The connection will be
/// plaintext (which is usually fine for a socket since unix sockets are local machine).
static address ipc(std::string path);
/// Factory function that constructs a unix socket address from a path and remote pubkey. The
/// connection will be curve25519 encrypted; the remote pubkey must be 32 bytes.
static address ipc_curve(std::string path, std::string pubkey);
};
// Outputs address.full_address() when sent to an ostream.
std::ostream& operator<<(std::ostream& o, const address& a);
}

View File

@ -1,6 +1,8 @@
#include "lokimq.h"
#include "hex.h"
#include "lokimq-internal.h"
#include <ostream>
#include <sstream>
namespace lokimq {
@ -12,7 +14,7 @@ namespace {
// Builds a ZMTP metadata key-value pair. These will be available on every message from that peer.
// Keys must start with X- and be <= 255 characters.
std::string zmtp_metadata(string_view key, string_view value) {
std::string zmtp_metadata(std::string_view key, std::string_view value) {
assert(key.size() > 2 && key.size() <= 255 && key[0] == 'X' && key[1] == '-');
std::string result;
@ -63,7 +65,7 @@ bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info&
msgs.push_back(create_message(peer.route));
msgs.push_back(create_message(reply));
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
msgs.push_back(create_message("REPLY"_sv));
msgs.push_back(create_message("REPLY"sv));
msgs.push_back(create_message(view(data.front()))); // reply tag
} else {
msgs.push_back(create_message(view(cmd)));
@ -87,7 +89,7 @@ void LokiMQ::set_active_sns(pubkey_set pubkeys) {
proxy_set_active_sns(std::move(pubkeys));
}
}
void LokiMQ::proxy_set_active_sns(string_view data) {
void LokiMQ::proxy_set_active_sns(std::string_view data) {
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(bt_deserialize<uintptr_t>(data)));
}
void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
@ -204,7 +206,7 @@ void LokiMQ::process_zap_requests() {
else
o << v;
}
log_(LogLevel::trace, __FILE__, __LINE__, o.str());
log(LogLevel::trace, __FILE__, __LINE__, o.str());
} else
#endif
LMQ_LOG(debug, "Processing ZAP authentication request");
@ -267,7 +269,7 @@ void LokiMQ::process_zap_requests() {
status_text = "Invalid public key size for CURVE authentication";
} else {
auto ip = view(frames[3]);
string_view pubkey;
std::string_view pubkey;
bool sn = false;
if (bind[bind_id].second.curve) {
pubkey = view(frames[6]);

View File

@ -1,5 +1,5 @@
#pragma once
#include <iostream>
#include <iosfwd>
#include <string>
#include <cstring>
#include <unordered_set>

203
lokimq/base32z.h Normal file
View File

@ -0,0 +1,203 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include <string_view>
#include <array>
#include <iterator>
#include <cassert>
namespace lokimq {
namespace detail {
/// Compile-time generated lookup tables for base32z conversion. This is case insensitive (though
/// for byte -> b32z conversion we always produce lower case).
struct b32z_table {
// Store the 0-31 decoded value of every possible char; all the chars that aren't valid are set
// to 0. (If you don't trust your data, check it with is_base32z first, which uses these 0's
// to detect invalid characters -- which is why we want a full 256 element array).
char from_b32z_lut[256];
// Store the encoded character of every 0-31 (5 bit) value.
char to_b32z_lut[32];
// constexpr constructor that fills out the above (and should do it at compile time for any half
// decent compiler).
constexpr b32z_table() noexcept : from_b32z_lut{},
to_b32z_lut{
'y', 'b', 'n', 'd', 'r', 'f', 'g', '8', 'e', 'j', 'k', 'm', 'c', 'p', 'q', 'x',
'o', 't', '1', 'u', 'w', 'i', 's', 'z', 'a', '3', '4', '5', 'h', '7', '6', '9'
}
{
for (unsigned char c = 0; c < 32; c++) {
unsigned char x = to_b32z_lut[c];
from_b32z_lut[x] = c;
if (x >= 'a' && x <= 'z')
from_b32z_lut[x - 'a' + 'A'] = c;
}
}
// Convert a b32z encoded character into a 0-31 value
constexpr char from_b32z(unsigned char c) const noexcept { return from_b32z_lut[c]; }
// Convert a 0-31 value into a b32z encoded character
constexpr char to_b32z(unsigned char b) const noexcept { return to_b32z_lut[b]; }
} constexpr b32z_lut;
// This main point of this static assert is to force the compiler to compile-time build the constexpr tables.
static_assert(b32z_lut.from_b32z('w') == 20 && b32z_lut.from_b32z('T') == 17 && b32z_lut.to_b32z(5) == 'f', "");
} // namespace detail
/// Converts bytes into a base32z encoded character sequence.
template <typename InputIt, typename OutputIt>
void to_base32z(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(decltype(*begin)) == 1, "to_base32z requires chars/bytes");
int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in [0, 4]
std::uint_fast16_t r = 0;
while (begin != end) {
r = r << 8 | static_cast<unsigned char>(*begin++);
// we just added 8 bits, so we can *always* consume 5 to produce one character, so (net) we
// are adding 3 bits.
bits += 3;
*out++ = detail::b32z_lut.to_b32z(r >> bits); // Right-shift off the bits we aren't consuming right now
// Drop the bits we don't want to keep (because we just consumed them)
r &= (1 << bits) - 1;
if (bits >= 5) { // We have enough bits to produce a second character; essentially the same as above
bits -= 5; // Except now we are just consuming 5 without having added any more
*out++ = detail::b32z_lut.to_b32z(r >> bits);
r &= (1 << bits) - 1;
}
}
if (bits > 0) // We hit the end, but still have some unconsumed bits so need one final character to append
*out++ = detail::b32z_lut.to_b32z(r << (5 - bits));
}
/// Creates a base32z string from an iterator pair of a byte sequence.
template <typename It>
std::string to_base32z(It begin, It end) {
std::string base32z;
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
base32z.reserve((std::distance(begin, end)*8 + 4) / 5); // == bytes*8/5, rounded up.
to_base32z(begin, end, std::back_inserter(base32z));
return base32z;
}
/// Creates a base32z string from an iterable, std::string-like object
template <typename CharT>
std::string to_base32z(std::basic_string_view<CharT> s) { return to_base32z(s.begin(), s.end()); }
inline std::string to_base32z(std::string_view s) { return to_base32z<>(s); }
/// Returns true if all elements in the range are base32z characters
template <typename It>
constexpr bool is_base32z(It begin, It end) {
static_assert(sizeof(decltype(*begin)) == 1, "is_base32z requires chars/bytes");
for (; begin != end; ++begin) {
auto c = static_cast<unsigned char>(*begin);
if (detail::b32z_lut.from_b32z(c) == 0 && !(c == 'y' || c == 'Y'))
return false;
}
return true;
}
/// Returns true if all elements in the string-like value are base32z characters
template <typename CharT>
constexpr bool is_base32z(std::basic_string_view<CharT> s) { return is_base32z(s.begin(), s.end()); }
constexpr bool is_base32z(std::string_view s) { return is_base32z<>(s); }
/// Converts a sequence of base32z digits to bytes. Undefined behaviour if any characters are not
/// valid base32z alphabet characters. It is permitted for the input and output ranges to overlap
/// as long as `out` is no earlier than `begin`. Note that if you pass in a sequence that could not
/// have been created by a base32z encoding of a byte sequence, we treat the excess bits as if they
/// were not provided.
///
/// For example, "yyy" represents a 15-bit value, but a byte sequence is either 8-bit (requiring 2
/// characters) or 16-bit (requiring 4). Similarly, "yb" is an impossible encoding because it has
/// its 10th bit set (b = 00001), but a base32z encoded value should have all 0's beyond the 8th (or
/// 16th or 24th or ... bit). We treat any such bits as if they were not specified (even if they
/// are): which means "yy", "yb", "yyy", "yy9", "yd", etc. all decode to the same 1-byte value "\0".
template <typename InputIt, typename OutputIt>
void from_base32z(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(decltype(*begin)) == 1, "from_base32z requires chars/bytes");
uint_fast16_t curr = 0;
int bits = 0; // number of bits we've loaded into val; we always keep this < 8.
while (begin != end) {
curr = curr << 5 | detail::b32z_lut.from_b32z(static_cast<unsigned char>(*begin++));
if (bits >= 3) {
bits -= 3; // Added 5, removing 8
*out++ = static_cast<uint8_t>(curr >> bits);
curr &= (1 << bits) - 1;
} else {
bits += 5;
}
}
// Ignore any trailing bits. base32z encoding always has at least as many bits as the source
// bytes, which means we should not be able to get here from a properly encoded b32z value with
// anything other than 0s: if we have no extra bits (e.g. 5 bytes == 8 b32z chars) then we have
// a 0-bit value; if we have some extra bits (e.g. 6 bytes requires 10 b32z chars, but that
// contains 50 bits > 48 bits) then those extra bits will be 0s (and this covers the bits -= 3
// case above: it'll leave us with 0-4 extra bits, but those extra bits would be 0 if produced
// from an actual byte sequence).
//
// The "bits += 5" case, then, means that we could end with 5-7 bits. This, however, cannot be
// produced by a valid encoding:
// - 0 bytes gives us 0 chars with 0 leftover bits
// - 1 byte gives us 2 chars with 2 leftover bits
// - 2 bytes gives us 4 chars with 4 leftover bits
// - 3 bytes gives us 5 chars with 1 leftover bit
// - 4 bytes gives us 7 chars with 3 leftover bits
// - 5 bytes gives us 8 chars with 0 leftover bits (this is where the cycle repeats)
//
// So really the only way we can get 5-7 leftover bits is if you took a 0, 2 or 5 char output (or
// any 8n + {0,2,5} char output) and added a base32z character to the end. If you do that,
// well, too bad: you're giving invalid output and so we're just going to pretend that extra
// character you added isn't there by not doing anything here.
}
/// Convert a base32z sequence into a std::string of bytes. Undefined behaviour if any characters
/// are not valid (case-insensitive) base32z characters.
template <typename It>
std::string from_base32z(It begin, It end) {
std::string bytes;
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
bytes.reserve((std::distance(begin, end)*5 + 7) / 8); // == chars*5/8, rounded up.
from_base32z(begin, end, std::back_inserter(bytes));
return bytes;
}
/// Converts base32z digits from a std::string-like object into a std::string of bytes. Undefined
/// behaviour if any characters are not valid (case-insensitive) base32z characters.
template <typename CharT>
std::string from_base32z(std::basic_string_view<CharT> s) { return from_base32z(s.begin(), s.end()); }
inline std::string from_base32z(std::string_view s) { return from_base32z<>(s); }
}

219
lokimq/base64.h Normal file
View File

@ -0,0 +1,219 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include <string_view>
#include <array>
#include <iterator>
#include <cassert>
namespace lokimq {
namespace detail {
/// Compile-time generated lookup tables for base64 conversion.
struct b64_table {
// Store the 0-63 decoded value of every possible char; all the chars that aren't valid are set
// to 0. (If you don't trust your data, check it with is_base64 first, which uses these 0's
// to detect invalid characters -- which is why we want a full 256 element array).
char from_b64_lut[256];
// Store the encoded character of every 0-63 (6 bit) value.
char to_b64_lut[64];
// constexpr constructor that fills out the above (and should do it at compile time for any half
// decent compiler).
constexpr b64_table() noexcept : from_b64_lut{}, to_b64_lut{} {
for (unsigned char c = 0; c < 26; c++) {
from_b64_lut[(unsigned char)('A' + c)] = 0 + c;
to_b64_lut[ (unsigned char)( 0 + c)] = 'A' + c;
}
for (unsigned char c = 0; c < 26; c++) {
from_b64_lut[(unsigned char)('a' + c)] = 26 + c;
to_b64_lut[ (unsigned char)(26 + c)] = 'a' + c;
}
for (unsigned char c = 0; c < 10; c++) {
from_b64_lut[(unsigned char)('0' + c)] = 52 + c;
to_b64_lut[ (unsigned char)(52 + c)] = '0' + c;
}
to_b64_lut[62] = '+'; from_b64_lut[(unsigned char) '+'] = 62;
to_b64_lut[63] = '/'; from_b64_lut[(unsigned char) '/'] = 63;
}
// Convert a b64 encoded character into a 0-63 value
constexpr char from_b64(unsigned char c) const noexcept { return from_b64_lut[c]; }
// Convert a 0-31 value into a b64 encoded character
constexpr char to_b64(unsigned char b) const noexcept { return to_b64_lut[b]; }
} constexpr b64_lut;
// This main point of this static assert is to force the compiler to compile-time build the constexpr tables.
static_assert(b64_lut.from_b64('/') == 63 && b64_lut.from_b64('7') == 59 && b64_lut.to_b64(38) == 'm', "");
} // namespace detail
/// Converts bytes into a base64 encoded character sequence.
template <typename InputIt, typename OutputIt>
void to_base64(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(decltype(*begin)) == 1, "to_base64 requires chars/bytes");
int bits = 0; // Tracks the number of unconsumed bits held in r, will always be in {0, 2, 4}
std::uint_fast16_t r = 0;
while (begin != end) {
r = r << 8 | static_cast<unsigned char>(*begin++);
// we just added 8 bits, so we can *always* consume 6 to produce one character, so (net) we
// are adding 2 bits.
bits += 2;
*out++ = detail::b64_lut.to_b64(r >> bits); // Right-shift off the bits we aren't consuming right now
// Drop the bits we don't want to keep (because we just consumed them)
r &= (1 << bits) - 1;
if (bits == 6) { // We have enough bits to produce a second character (which means we had 4 before and added 8)
bits = 0;
*out++ = detail::b64_lut.to_b64(r);
r = 0;
}
}
// If bits == 0 then we ended our 6-bit outputs coinciding with 8-bit values, i.e. at a multiple
// of 24 bits: this means we don't have anything else to output and don't need any padding.
if (bits == 2) {
// We finished with 2 unconsumed bits, which means we ended 1 byte past a 24-bit group (e.g.
// 1 byte, 4 bytes, 301 bytes, etc.); since we need to always be a multiple of 4 output
// characters that means we've produced 1: so we right-fill 0s to get the next char, then
// add two padding ='s.
*out++ = detail::b64_lut.to_b64(r << 4);
*out++ = '=';
*out++ = '=';
} else if (bits == 4) {
// 4 bits left means we produced 2 6-bit values from the first 2 bytes of a 3-byte group.
// Fill 0s to get the last one, plus one padding output.
*out++ = detail::b64_lut.to_b64(r << 2);
*out++ = '=';
}
}
/// Creates and returns a base64 string from an iterator pair of a character sequence
template <typename It>
std::string to_base64(It begin, It end) {
std::string base64;
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
base64.reserve((std::distance(begin, end) + 2) / 3 * 4); // bytes*4/3, rounded up to the next multiple of 4
to_base64(begin, end, std::back_inserter(base64));
return base64;
}
/// Creates a base64 string from an iterable, std::string-like object
template <typename CharT>
std::string to_base64(std::basic_string_view<CharT> s) { return to_base64(s.begin(), s.end()); }
inline std::string to_base64(std::string_view s) { return to_base64<>(s); }
/// Returns true if the range is a base64 encoded value; we allow (but do not require) '=' padding,
/// but only at the end, only 1 or 2, and only if it pads out the total to a multiple of 4.
template <typename It>
constexpr bool is_base64(It begin, It end) {
static_assert(sizeof(decltype(*begin)) == 1, "is_base64 requires chars/bytes");
using std::distance;
using std::prev;
// Allow 1 or 2 padding chars *if* they pad it to a multiple of 4.
if (begin != end && distance(begin, end) % 4 == 0) {
auto last = prev(end);
if (static_cast<unsigned char>(*last) == '=')
end = last--;
if (static_cast<unsigned char>(*last) == '=')
end = last;
}
for (; begin != end; ++begin) {
auto c = static_cast<unsigned char>(*begin);
if (detail::b64_lut.from_b64(c) == 0 && c != 'A')
return false;
}
return true;
}
/// Returns true if the string-like value is a base64 encoded value
template <typename CharT>
constexpr bool is_base64(std::basic_string_view<CharT> s) { return is_base64(s.begin(), s.end()); }
constexpr bool is_base64(std::string_view s) { return is_base64(s.begin(), s.end()); }
/// Converts a sequence of base64 digits to bytes. Undefined behaviour if any characters are not
/// valid base64 alphabet characters. It is permitted for the input and output ranges to overlap as
/// long as `out` is no earlier than `begin`. Trailing padding characters are permitted but not
/// required.
///
/// It is possible to provide "impossible" base64 encoded values; for example "YWJja" which has 30
/// bits of data even though a base64 encoded byte string should have 24 (4 chars) or 36 (6 chars)
/// bits for a 3- and 4-byte input, respectively. We ignore any such "impossible" bits, and
/// similarly ignore impossible bits in the bit "overhang"; that means "YWJjZA==" (the proper
/// encoding of "abcd") and "YWJjZB", "YWJjZC", ..., "YWJjZP" all decode to the same "abcd" value:
/// the last 4 bits of the last character are essentially considered padding.
template <typename InputIt, typename OutputIt>
void from_base64(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(decltype(*begin)) == 1, "from_base64 requires chars/bytes");
uint_fast16_t curr = 0;
int bits = 0; // number of bits we've loaded into val; we always keep this < 8.
while (begin != end) {
auto c = static_cast<unsigned char>(*begin++);
// padding; don't bother checking if we're at the end because is_base64 is a precondition
// and we're allowed UB if it isn't satisfied.
if (c == '=') continue;
curr = curr << 6 | detail::b64_lut.from_b64(c);
if (bits == 0)
bits = 6;
else {
bits -= 2; // Added 6, removing 8
*out++ = static_cast<uint8_t>(curr >> bits);
curr &= (1 << bits) - 1;
}
}
// Don't worry about leftover bits because either they have to be 0, or they can't happen at
// all. See base32z.h for why: the reasoning is exactly the same (except using 6 bits per
// character here instead of 5).
}
/// Converts base64 digits from a iterator pair of characters into a std::string of bytes.
/// Undefined behaviour if any characters are not valid base64 characters.
template <typename It>
std::string from_base64(It begin, It end) {
std::string bytes;
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
bytes.reserve(std::distance(begin, end)*6 / 8); // each digit carries 6 bits; this may overallocate by 1-2 bytes due to padding
from_base64(begin, end, std::back_inserter(bytes));
return bytes;
}
/// Converts base64 digits from a std::string-like object into a std::string of bytes. Undefined
/// behaviour if any characters are not valid base64 characters.
template <typename CharT>
std::string from_base64(std::basic_string_view<CharT> s) { return from_base64(s.begin(), s.end()); }
inline std::string from_base64(std::string_view s) { return from_base64<>(s); }
}

View File

@ -36,18 +36,25 @@ namespace lokimq {
namespace detail {
enum class BatchStatus {
enum class BatchState {
running, // there are still jobs to run (or running)
complete, // the batch is complete but still has a completion job to call
complete_proxy, // same as `complete`, but the completion job should be invoked immediately in the proxy thread (be very careful)
done // the batch is complete and has no completion function
};
struct BatchStatus {
BatchState state;
int thread;
};
// Virtual base class for Batch<R>
class Batch {
public:
// Returns the number of jobs in this batch
virtual size_t size() const = 0;
// Returns the number of jobs in this batch and whether any of them are thread-specific
virtual std::pair<size_t, bool> size() const = 0;
// Returns a vector of exactly the same length of size().first containing the tagged thread ids
// of the batch jobs or 0 for general jobs.
virtual std::vector<int> threads() const = 0;
// Called in a worker thread to run the job
virtual void run_job(int i) = 0;
// Called in the main proxy thread when the worker returns from finishing a job. The return
@ -151,12 +158,13 @@ public:
Batch &operator=(const Batch&) = delete;
private:
std::vector<std::function<R()>> jobs;
std::vector<std::pair<std::function<R()>, int>> jobs;
std::vector<job_result<R>> results;
CompletionFunc complete;
std::size_t jobs_outstanding = 0;
bool complete_in_proxy = false;
int complete_in_thread = 0;
bool started = false;
bool tagged_thread_jobs = false;
void check_not_started() {
if (started)
@ -175,39 +183,61 @@ public:
/// available. The called function may throw exceptions (which will be propagated to the
/// completion function through the job_result values). There is no guarantee on the order of
/// invocation of the jobs.
void add_job(std::function<R()> job) {
///
/// \param job the callback
/// \param thread an optional TaggedThreadID indicating a thread in which this job must run
void add_job(std::function<R()> job, std::optional<TaggedThreadID> thread = std::nullopt) {
check_not_started();
jobs.emplace_back(std::move(job));
results.emplace_back();
jobs_outstanding++;
if (thread && thread->_id == -1)
// There are some special case internal jobs where we allow this, but they use the
// private method below that doesn't have this check.
throw std::logic_error{"Cannot add a proxy thread batch job -- this makes no sense"};
add_job(std::move(job), thread ? thread->_id : 0);
}
/// Sets the completion function to invoke after all jobs have finished. If this is not set
/// then jobs simply run and results are discarded.
void completion(CompletionFunc comp) {
///
/// \param comp - function to call when all jobs have finished
/// \param thread - optional tagged thread in which to schedule the completion job. If not
/// provided then the completion job is scheduled in the pool of batch job threads.
///
/// `thread` can be provided the value &LokiMQ::run_in_proxy to invoke the completion function
/// *IN THE PROXY THREAD* itself after all jobs have finished. Be very, very careful: this
/// should be a nearly trivial job that does not require any substantial CPU time and does not
/// block for any reason. This is only intended for the case where the completion job is so
/// trivial that it will take less time than simply queuing the job to be executed by another
/// thread.
void completion(CompletionFunc comp, std::optional<TaggedThreadID> thread = std::nullopt) {
check_not_started();
if (complete)
throw std::logic_error("Completion function can only be set once");
complete = std::move(comp);
}
/// Sets a completion function to invoke *IN THE PROXY THREAD* after all jobs have finished. Be
/// very, very careful: this should not be a job that takes any significant amount of CPU time
/// or can block for any reason (NO MUTEXES).
void completion_proxy(CompletionFunc comp) {
check_not_started();
if (complete)
throw std::logic_error("Completion function can only be set once");
complete = std::move(comp);
complete_in_proxy = true;
complete_in_thread = thread ? thread->_id : 0;
}
private:
std::size_t size() const override {
return jobs.size();
void add_job(std::function<R()> job, int thread_id) {
jobs.emplace_back(std::move(job), thread_id);
results.emplace_back();
jobs_outstanding++;
if (thread_id != 0)
tagged_thread_jobs = true;
}
std::pair<std::size_t, bool> size() const override {
return {jobs.size(), tagged_thread_jobs};
}
std::vector<int> threads() const override {
std::vector<int> t;
t.reserve(jobs.size());
for (auto& j : jobs)
t.push_back(j.second);
return t;
};
template <typename S = R>
void set_value(job_result<S>& r, std::function<S()>& f) { r.set_value(f()); }
void set_value(job_result<void>&, std::function<void()>& f) { f(); }
@ -216,7 +246,7 @@ private:
// called by worker thread
auto& r = results[i];
try {
set_value(r, jobs[i]);
set_value(r, jobs[i].first);
} catch (...) {
r.set_exception(std::current_exception());
}
@ -225,12 +255,10 @@ private:
detail::BatchStatus job_finished() override {
--jobs_outstanding;
if (jobs_outstanding)
return detail::BatchStatus::running;
return {detail::BatchState::running, 0};
if (complete)
return complete_in_proxy
? detail::BatchStatus::complete_proxy
: detail::BatchStatus::complete;
return detail::BatchStatus::done;
return {detail::BatchState::complete, complete_in_thread};
return {detail::BatchState::done, 0};
}
void job_completion() override {
@ -241,7 +269,7 @@ private:
template <typename R>
void LokiMQ::batch(Batch<R>&& batch) {
if (batch.size() == 0)
if (batch.size().first == 0)
throw std::logic_error("Cannot batch a a job batch with 0 jobs");
// Need to send this over to the proxy thread via the base class pointer. It assumes ownership.
auto* baseptr = static_cast<detail::Batch*>(new Batch<R>(std::move(batch)));

View File

@ -27,12 +27,13 @@
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "bt_serialize.h"
#include <iterator>
namespace lokimq {
namespace detail {
/// Reads digits into an unsigned 64-bit int.
uint64_t extract_unsigned(string_view& s) {
uint64_t extract_unsigned(std::string_view& s) {
if (s.empty())
throw bt_deserialize_invalid{"Expected 0-9 but found end of string"};
if (s[0] < '0' || s[0] > '9')
@ -48,7 +49,7 @@ uint64_t extract_unsigned(string_view& s) {
return uval;
}
void bt_deserialize<string_view>::operator()(string_view& s, string_view& val) {
void bt_deserialize<std::string_view>::operator()(std::string_view& s, std::string_view& val) {
if (s.size() < 2) throw bt_deserialize_invalid{"Deserialize failed: given data is not an bt-encoded string"};
if (s[0] < '0' || s[0] > '9')
throw bt_deserialize_invalid_type{"Expected 0-9 but found '"s + s[0] + "'"};
@ -72,37 +73,33 @@ static_assert(std::numeric_limits<int64_t>::min() + std::numeric_limits<int64_t>
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()) + uint64_t{1} == (uint64_t{1} << 63),
"Non 2s-complement architecture not supported!");
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s) {
std::pair<uint64_t, bool> bt_deserialize_integer(std::string_view& s) {
// Smallest possible encoded integer is 3 chars: "i0e"
if (s.size() < 3) throw bt_deserialize_invalid("Deserialization failed: end of string found where integer expected");
if (s[0] != 'i') throw bt_deserialize_invalid_type("Deserialization failed: expected 'i', found '"s + s[0] + '\'');
s.remove_prefix(1);
std::pair<maybe_signed_int64_t, bool> result;
std::pair<uint64_t, bool> result;
if (s[0] == '-') {
result.second = true;
s.remove_prefix(1);
}
uint64_t uval = extract_unsigned(s);
result.first = extract_unsigned(s);
if (s.empty())
throw bt_deserialize_invalid("Integer deserialization failed: encountered end of string before integer was finished");
if (s[0] != 'e')
throw bt_deserialize_invalid("Integer deserialization failed: expected digit or 'e', found '"s + s[0] + '\'');
s.remove_prefix(1);
if (result.second) { // negative
if (uval > (uint64_t{1} << 63))
throw bt_deserialize_invalid("Deserialization of integer failed: negative integer value is too large for a 64-bit signed int");
result.first.i64 = -uval;
} else {
result.first.u64 = uval;
}
if (result.second /*negative*/ && result.first > (uint64_t{1} << 63))
throw bt_deserialize_invalid("Deserialization of integer failed: negative integer value is too large for a 64-bit signed int");
return result;
}
template struct bt_deserialize<int64_t>;
template struct bt_deserialize<uint64_t>;
void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
void bt_deserialize<bt_value, void>::operator()(std::string_view& s, bt_value& val) {
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where bt-encoded value expected");
switch (s[0]) {
@ -119,8 +116,9 @@ void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
break;
}
case 'i': {
auto read = bt_deserialize_integer(s);
val = read.first.i64; // We only store an i64, but can get a u64 out of it via get<uint64_t>(val)
auto [magnitude, negative] = bt_deserialize_integer(s);
if (negative) val = -static_cast<int64_t>(magnitude);
else val = magnitude;
break;
}
case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': {
@ -137,7 +135,7 @@ void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
} // namespace detail
bt_list_consumer::bt_list_consumer(string_view data_) : data{std::move(data_)} {
bt_list_consumer::bt_list_consumer(std::string_view data_) : data{std::move(data_)} {
if (data.empty()) throw std::runtime_error{"Cannot create a bt_list_consumer with an empty string_view"};
if (data[0] != 'l') throw std::runtime_error{"Cannot create a bt_list_consumer with non-list data"};
data.remove_prefix(1);
@ -145,13 +143,13 @@ bt_list_consumer::bt_list_consumer(string_view data_) : data{std::move(data_)} {
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
/// value is not a string.
string_view bt_list_consumer::consume_string_view() {
std::string_view bt_list_consumer::consume_string_view() {
if (data.empty())
throw bt_deserialize_invalid{"expected a string, but reached end of data"};
else if (!is_string())
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
string_view next{data}, result;
detail::bt_deserialize<string_view>{}(next, result);
std::string_view next{data}, result;
detail::bt_deserialize<std::string_view>{}(next, result);
data = next;
return result;
}
@ -174,7 +172,7 @@ void bt_list_consumer::skip_value() {
throw bt_deserialize_invalid_type{"next bt value has unknown type"};
}
string_view bt_list_consumer::consume_list_data() {
std::string_view bt_list_consumer::consume_list_data() {
auto start = data.begin();
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
data.remove_prefix(1); // Descend into the sublist, consume the "l"
@ -187,7 +185,7 @@ string_view bt_list_consumer::consume_list_data() {
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
}
string_view bt_list_consumer::consume_dict_data() {
std::string_view bt_list_consumer::consume_dict_data() {
auto start = data.begin();
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
data.remove_prefix(1); // Descent into the dict, consumer the "d"
@ -202,7 +200,7 @@ string_view bt_list_consumer::consume_dict_data() {
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
}
bt_dict_consumer::bt_dict_consumer(string_view data_) {
bt_dict_consumer::bt_dict_consumer(std::string_view data_) {
data = std::move(data_);
if (data.empty()) throw std::runtime_error{"Cannot create a bt_dict_consumer with an empty string_view"};
if (data.size() < 2 || data[0] != 'd') throw std::runtime_error{"Cannot create a bt_dict_consumer with non-dict data"};
@ -220,10 +218,10 @@ bool bt_dict_consumer::consume_key() {
return true;
}
std::pair<string_view, string_view> bt_dict_consumer::next_string() {
std::pair<std::string_view, std::string_view> bt_dict_consumer::next_string() {
if (!is_string())
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
std::pair<string_view, string_view> ret;
std::pair<std::string_view, std::string_view> ret;
ret.second = bt_list_consumer::consume_string_view();
ret.first = flush_key();
return ret;

View File

@ -26,21 +26,25 @@
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <iostream>
#pragma once
#include <vector>
#include <list>
#include <unordered_map>
#include <algorithm>
#include <functional>
#include <cstring>
#include <ostream>
#include <sstream>
#include "string_view.h"
#include "mapbox/variant.hpp"
#include <ostream>
#include <string>
#include <string_view>
#include <variant>
#include <cstdint>
#include <limits>
#include <stdexcept>
#include <type_traits>
#include <utility>
#include <tuple>
#include <algorithm>
#include "bt_value.h"
namespace lokimq {
@ -54,16 +58,17 @@ using namespace std::literals;
*
* On the C++ side, on input we allow strings, integral types, STL-like containers of these types,
* and STL-like containers of pairs with a string first value and any of these types as second
* value. We also accept std::variants (if compiled with std::variant support, i.e. in C++17 mode)
* that contain any of these, and mapbox::util::variants (the internal type used for its recursive
* support).
* value. We also accept std::variants of these.
*
* One minor deviation from BEP-0003 is that we don't support serializing values that don't fit in a
* 64-bit integer (BEP-0003 specifies arbitrary precision integers).
*
* On deserialization we can either deserialize into a mapbox::util::variant that supports everything, or
* we can fill a container of your given type (though this fails if the container isn't compatible
* with the deserialized data).
* On deserialization we can either deserialize into a special bt_value type supports everything
* (with arbitrary nesting), or we can fill a container of your given type (though this fails if the
* container isn't compatible with the deserialized data).
*
* There is also a stream deserialization that allows you to deserialize without needing heap
* allocations (as long as you know the precise data structure layout).
*/
/// Exception throw if deserialization fails
@ -80,107 +85,69 @@ class bt_deserialize_invalid_type : public bt_deserialize_invalid {
using bt_deserialize_invalid::bt_deserialize_invalid;
};
class bt_list;
class bt_dict;
/// Special type wrapper for storing a uint64_t value that may need to be larger than an int64_t.
/// You *can* shove a uint64_t directly into a bt_value, but it will end up on the wire as its
/// 2s-complement int64_t value; using this wrapper instead allows you to force a 64-bit positive
/// integer onto the wire.
struct bt_u64 { uint64_t val; explicit bt_u64(uint64_t val) : val{val} {} };
/// Recursive generic type that can fully represent everything valid for a BT serialization.
using bt_value = mapbox::util::variant<
std::string,
string_view,
int64_t,
bt_u64,
mapbox::util::recursive_wrapper<bt_list>,
mapbox::util::recursive_wrapper<bt_dict>
>;
/// Very thin wrapper around a std::list<bt_value> that holds a list of generic values (though *any*
/// compatible data type can be used).
class bt_list : public std::list<bt_value> {
using std::list<bt_value>::list;
};
/// Very thin wrapper around a std::unordered_map<bt_value> that holds a list of string -> generic
/// value pairs (though *any* compatible data type can be used).
class bt_dict : public std::unordered_map<std::string, bt_value> {
using std::unordered_map<std::string, bt_value>::unordered_map;
};
#ifdef __cpp_lib_void_t
using std::void_t;
#else
/// C++17 void_t backport
template <typename... Ts> struct void_t_impl { using type = void; };
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
#endif
namespace detail {
/// Reads digits into an unsigned 64-bit int.
uint64_t extract_unsigned(string_view& s);
inline uint64_t extract_unsigned(string_view&& s) { return extract_unsigned(s); }
uint64_t extract_unsigned(std::string_view& s);
// (Provide non-constant lvalue and rvalue ref functions so that we only accept explicit
// string_views but not implicitly converted ones)
inline uint64_t extract_unsigned(std::string_view&& s) { return extract_unsigned(s); }
// Fallback base case; we only get here if none of the partial specializations below work
template <typename T, typename SFINAE = void>
struct bt_serialize { static_assert(!std::is_same<T, T>::value, "Cannot serialize T: unsupported type for bt serialization"); };
struct bt_serialize { static_assert(!std::is_same_v<T, T>, "Cannot serialize T: unsupported type for bt serialization"); };
template <typename T, typename SFINAE = void>
struct bt_deserialize { static_assert(!std::is_same<T, T>::value, "Cannot deserialize T: unsupported type for bt deserialization"); };
struct bt_deserialize { static_assert(!std::is_same_v<T, T>, "Cannot deserialize T: unsupported type for bt deserialization"); };
/// Checks that we aren't at the end of a string view and throws if we are.
inline void bt_need_more(const string_view &s) {
inline void bt_need_more(const std::string_view &s) {
if (s.empty())
throw bt_deserialize_invalid{"Unexpected end of string while deserializing"};
}
union maybe_signed_int64_t { int64_t i64; uint64_t u64; };
/// Deserializes a signed or unsigned 64-bit integer from a string. Sets the second bool to true
/// iff the value is int64_t because a negative value was read. Throws an exception if the read
/// value doesn't fit in a int64_t (if negative) or a uint64_t (if positive). Removes consumed
/// characters from the string_view.
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s);
/// iff the value read was negative, false if positive; in either case the unsigned value is return
/// in .first. Throws an exception if the read value doesn't fit in a int64_t (if negative) or a
/// uint64_t (if positive). Removes consumed characters from the string_view.
std::pair<uint64_t, bool> bt_deserialize_integer(std::string_view& s);
/// Integer specializations
template <typename T>
struct bt_serialize<T, std::enable_if_t<std::is_integral<T>::value>> {
struct bt_serialize<T, std::enable_if_t<std::is_integral_v<T>>> {
static_assert(sizeof(T) <= sizeof(uint64_t), "Serialization of integers larger than uint64_t is not supported");
void operator()(std::ostream &os, const T &val) {
// Cast 1-byte types to a larger type to avoid iostream interpreting them as single characters
using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t<std::is_signed<T>::value, int, unsigned>>;
using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t<std::is_signed_v<T>, int, unsigned>>;
os << 'i' << static_cast<output_type>(val) << 'e';
}
};
template <typename T>
struct bt_deserialize<T, std::enable_if_t<std::is_integral<T>::value>> {
void operator()(string_view& s, T &val) {
struct bt_deserialize<T, std::enable_if_t<std::is_integral_v<T>>> {
void operator()(std::string_view& s, T &val) {
constexpr uint64_t umax = static_cast<uint64_t>(std::numeric_limits<T>::max());
constexpr int64_t smin = static_cast<int64_t>(std::numeric_limits<T>::min()),
smax = static_cast<int64_t>(std::numeric_limits<T>::max());
constexpr int64_t smin = static_cast<int64_t>(std::numeric_limits<T>::min());
auto read = bt_deserialize_integer(s);
if (std::is_signed<T>::value) {
if (!read.second) { // read a positive value
if (read.first.u64 > umax)
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
val = static_cast<T>(read.first.u64);
auto [magnitude, negative] = bt_deserialize_integer(s);
if (std::is_signed_v<T>) {
if (!negative) {
if (magnitude > umax)
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax));
val = static_cast<T>(magnitude);
} else {
bool oob = read.first.i64 < smin || read.first.i64 > smax;
if (sizeof(T) < sizeof(int64_t) && oob)
throw bt_deserialize_invalid("Integer deserialization failed: found out-of-range value " + std::to_string(read.first.i64) + " not in [" + std::to_string(smin) + "," + std::to_string(smax) + "]");
val = static_cast<T>(read.first.i64);
auto sval = -static_cast<int64_t>(magnitude);
if (!std::is_same_v<T, int64_t> && sval < smin)
throw bt_deserialize_invalid("Integer deserialization failed: found too-low value " + std::to_string(sval) + " < " + std::to_string(smin));
val = static_cast<T>(sval);
}
} else {
if (read.second)
throw bt_deserialize_invalid("Integer deserialization failed: found negative value " + std::to_string(read.first.i64) + " but type is unsigned");
if (sizeof(T) < sizeof(uint64_t) && read.first.u64 > umax)
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
val = static_cast<T>(read.first.u64);
if (negative)
throw bt_deserialize_invalid("Integer deserialization failed: found negative value -" + std::to_string(magnitude) + " but type is unsigned");
if (!std::is_same_v<T, uint64_t> && magnitude > umax)
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(magnitude) + " > " + std::to_string(umax));
val = static_cast<T>(magnitude);
}
}
};
@ -188,73 +155,78 @@ struct bt_deserialize<T, std::enable_if_t<std::is_integral<T>::value>> {
extern template struct bt_deserialize<int64_t>;
extern template struct bt_deserialize<uint64_t>;
template<>
struct bt_serialize<bt_u64> { void operator()(std::ostream& os, bt_u64 val) { bt_serialize<uint64_t>{}(os, val.val); } };
template<>
struct bt_deserialize<bt_u64> { void operator()(string_view& s, bt_u64& val) { bt_deserialize<uint64_t>{}(s, val.val); } };
template <>
struct bt_serialize<string_view> {
void operator()(std::ostream &os, const string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); }
struct bt_serialize<std::string_view> {
void operator()(std::ostream &os, const std::string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); }
};
template <>
struct bt_deserialize<string_view> {
void operator()(string_view& s, string_view& val);
struct bt_deserialize<std::string_view> {
void operator()(std::string_view& s, std::string_view& val);
};
/// String specialization
template <>
struct bt_serialize<std::string> {
void operator()(std::ostream &os, const std::string &val) { bt_serialize<string_view>{}(os, val); }
void operator()(std::ostream &os, const std::string &val) { bt_serialize<std::string_view>{}(os, val); }
};
template <>
struct bt_deserialize<std::string> {
void operator()(string_view& s, std::string& val) { string_view view; bt_deserialize<string_view>{}(s, view); val = {view.data(), view.size()}; }
void operator()(std::string_view& s, std::string& val) { std::string_view view; bt_deserialize<std::string_view>{}(s, view); val = {view.data(), view.size()}; }
};
/// char * and string literals -- we allow serialization for convenience, but not deserialization
template <>
struct bt_serialize<char *> {
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, std::strlen(str)}); }
void operator()(std::ostream &os, const char *str) { bt_serialize<std::string_view>{}(os, {str, std::strlen(str)}); }
};
template <size_t N>
struct bt_serialize<char[N]> {
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, N-1}); }
void operator()(std::ostream &os, const char *str) { bt_serialize<std::string_view>{}(os, {str, N-1}); }
};
/// Partial dict validity; we don't check the second type for serializability, that will be handled
/// via the base case static_assert if invalid.
template <typename T, typename = void> struct is_bt_input_dict_container : std::false_type {};
template <typename T, typename = void> struct is_bt_input_dict_container_impl : std::false_type {};
template <typename T>
struct is_bt_input_dict_container<T, std::enable_if_t<
std::is_same<std::string, std::remove_cv_t<typename T::value_type::first_type>>::value,
void_t<typename T::const_iterator /* is const iterable */,
struct is_bt_input_dict_container_impl<T, std::enable_if_t<
std::is_same_v<std::string, std::remove_cv_t<typename T::value_type::first_type>> ||
std::is_same_v<std::string_view, std::remove_cv_t<typename T::value_type::first_type>>,
std::void_t<typename T::const_iterator /* is const iterable */,
typename T::value_type::second_type /* has a second type */>>>
: std::true_type {};
/// Determines whether the type looks like something we can insert into (using `v.insert(v.end(), x)`)
template <typename T, typename = void> struct is_bt_insertable : std::false_type {};
template <typename T, typename = void> struct is_bt_insertable_impl : std::false_type {};
template <typename T>
struct is_bt_insertable<T,
void_t<decltype(std::declval<T>().insert(std::declval<T>().end(), std::declval<typename T::value_type>()))>>
struct is_bt_insertable_impl<T,
std::void_t<decltype(std::declval<T>().insert(std::declval<T>().end(), std::declval<typename T::value_type>()))>>
: std::true_type {};
template <typename T>
constexpr bool is_bt_insertable = is_bt_insertable_impl<T>::value;
/// Determines whether the given type looks like a compatible map (i.e. has std::string keys) that
/// we can insert into.
template <typename T, typename = void> struct is_bt_output_dict_container : std::false_type {};
template <typename T, typename = void> struct is_bt_output_dict_container_impl : std::false_type {};
template <typename T>
struct is_bt_output_dict_container<T, std::enable_if_t<
std::is_same<std::string, std::remove_cv_t<typename T::key_type>>::value &&
is_bt_insertable<T>::value,
void_t<typename T::value_type::second_type /* has a second type */>>>
struct is_bt_output_dict_container_impl<T, std::enable_if_t<
std::is_same_v<std::string, std::remove_cv_t<typename T::value_type::first_type>> && is_bt_insertable<T>,
std::void_t<typename T::value_type::second_type /* has a second type */>>>
: std::true_type {};
template <typename T>
constexpr bool is_bt_output_dict_container = is_bt_output_dict_container_impl<T>::value;
template <typename T>
constexpr bool is_bt_input_dict_container = is_bt_output_dict_container_impl<T>::value;
// Sanity checks:
static_assert(is_bt_input_dict_container<bt_dict>);
static_assert(is_bt_output_dict_container<bt_dict>);
/// Specialization for a dict-like container (such as an unordered_map). We accept anything for a
/// dict that is const iterable over something that looks like a pair with std::string for first
/// value type. The value (i.e. second element of the pair) also must be serializable.
template <typename T>
struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>::value>> {
struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>>> {
using second_type = typename T::value_type::second_type;
using ref_pair = std::reference_wrapper<const typename T::value_type>;
void operator()(std::ostream &os, const T &dict) {
@ -273,9 +245,9 @@ struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>::value>> {
};
template <typename T>
struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>::value>> {
struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>>> {
using second_type = typename T::value_type::second_type;
void operator()(string_view& s, T& dict) {
void operator()(std::string_view& s, T& dict) {
// Smallest dict is 2 bytes "de", for an empty dict.
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where dict expected");
if (s[0] != 'd') throw bt_deserialize_invalid_type("Deserialization failed: expected 'd', found '"s + s[0] + "'"s);
@ -300,26 +272,31 @@ struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>::value>
/// Accept anything that looks iterable; value serialization validity isn't checked here (it fails
/// via the base case static assert).
template <typename T, typename = void> struct is_bt_input_list_container : std::false_type {};
template <typename T, typename = void> struct is_bt_input_list_container_impl : std::false_type {};
template <typename T>
struct is_bt_input_list_container<T, std::enable_if_t<
!std::is_same<T, std::string>::value &&
!is_bt_input_dict_container<T>::value,
void_t<typename T::const_iterator, typename T::value_type>>>
struct is_bt_input_list_container_impl<T, std::enable_if_t<
!std::is_same_v<T, std::string> && !std::is_same_v<T, std::string_view> && !is_bt_input_dict_container<T>,
std::void_t<typename T::const_iterator, typename T::value_type>>>
: std::true_type {};
template <typename T, typename = void> struct is_bt_output_list_container : std::false_type {};
template <typename T, typename = void> struct is_bt_output_list_container_impl : std::false_type {};
template <typename T>
struct is_bt_output_list_container<T, std::enable_if_t<
!std::is_same<T, std::string>::value &&
!is_bt_output_dict_container<T>::value &&
is_bt_insertable<T>::value>>
struct is_bt_output_list_container_impl<T, std::enable_if_t<
!std::is_same_v<T, std::string> && !is_bt_output_dict_container<T> && is_bt_insertable<T>>>
: std::true_type {};
template <typename T>
constexpr bool is_bt_output_list_container = is_bt_output_list_container_impl<T>::value;
template <typename T>
constexpr bool is_bt_input_list_container = is_bt_input_list_container_impl<T>::value;
// Sanity checks:
static_assert(is_bt_input_list_container<bt_list>);
static_assert(is_bt_output_list_container<bt_list>);
/// List specialization
template <typename T>
struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>::value>> {
struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>>> {
void operator()(std::ostream& os, const T& list) {
os << 'l';
for (const auto &v : list)
@ -328,9 +305,9 @@ struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>::value>> {
}
};
template <typename T>
struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>::value>> {
struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>>> {
using value_type = typename T::value_type;
void operator()(string_view& s, T& list) {
void operator()(std::string_view& s, T& list) {
// Smallest list is 2 bytes "le", for an empty list.
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where list expected");
if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization failed: expected 'l', found '"s + s[0] + "'"s);
@ -348,44 +325,88 @@ struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>::value>
}
};
/// variant visitor; serializes whatever is contained
class bt_serialize_visitor {
std::ostream &os;
/// Serializes a tuple or pair of serializable values (as a list on the wire)
/// Common implementation for both tuple and pair:
template <template<typename...> typename Tuple, typename... T>
struct bt_serialize_tuple {
private:
template <size_t... Is>
void operator()(std::ostream& os, const Tuple<T...>& elems, std::index_sequence<Is...>) {
os << 'l';
(bt_serialize<T>{}(os, std::get<Is>(elems)), ...);
os << 'e';
}
public:
using result_type = void;
bt_serialize_visitor(std::ostream &os) : os{os} {}
template <typename T> void operator()(const T &val) const {
bt_serialize<T>{}(os, val);
void operator()(std::ostream& os, const Tuple<T...>& elems) {
operator()(os, elems, std::index_sequence_for<T...>{});
}
};
template <template<typename...> typename Tuple, typename... T>
struct bt_deserialize_tuple {
private:
template <size_t... Is>
void operator()(std::string_view& s, Tuple<T...>& elems, std::index_sequence<Is...>) {
// Smallest list is 2 bytes "le", for an empty list.
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where tuple expected");
if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization of tuple failed: expected 'l', found '"s + s[0] + "'"s);
s.remove_prefix(1);
(bt_deserialize<T>{}(s, std::get<Is>(elems)), ...);
if (s.empty())
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before tuple was finished");
if (s[0] != 'e')
throw bt_deserialize_invalid("Deserialization failed: expected end of tuple but found something else");
s.remove_prefix(1); // Consume the 'e'
}
public:
void operator()(std::string_view& s, Tuple<T...>& elems) {
operator()(s, elems, std::index_sequence_for<T...>{});
}
};
template <typename... T>
struct bt_serialize<std::tuple<T...>> : bt_serialize_tuple<std::tuple, T...> {};
template <typename... T>
struct bt_deserialize<std::tuple<T...>> : bt_deserialize_tuple<std::tuple, T...> {};
template <typename S, typename T>
struct bt_serialize<std::pair<S, T>> : bt_serialize_tuple<std::pair, S, T> {};
template <typename S, typename T>
struct bt_deserialize<std::pair<S, T>> : bt_deserialize_tuple<std::pair, S, T> {};
template <typename T>
using is_bt_deserializable = std::integral_constant<bool,
std::is_same<T, std::string>::value || std::is_integral<T>::value ||
is_bt_output_dict_container<T>::value || is_bt_output_list_container<T>::value>;
constexpr bool is_bt_tuple = false;
template <typename... T>
constexpr bool is_bt_tuple<std::tuple<T...>> = true;
template <typename S, typename T>
constexpr bool is_bt_tuple<std::pair<S, T>> = true;
template <typename T>
constexpr bool is_bt_deserializable = std::is_same_v<T, std::string> || std::is_integral_v<T> ||
is_bt_output_dict_container<T> || is_bt_output_list_container<T> || is_bt_tuple<T>;
// General template and base case; this base will only actually be invoked when Ts... is empty,
// which means we reached the end without finding any variant type capable of holding the value.
template <typename SFINAE, typename Variant, typename... Ts>
struct bt_deserialize_try_variant_impl {
void operator()(string_view&, Variant&) {
void operator()(std::string_view&, Variant&) {
throw bt_deserialize_invalid("Deserialization failed: could not deserialize value into any variant type");
}
};
template <typename... Ts, typename Variant>
void bt_deserialize_try_variant(string_view& s, Variant& variant) {
void bt_deserialize_try_variant(std::string_view& s, Variant& variant) {
bt_deserialize_try_variant_impl<void, Variant, Ts...>{}(s, variant);
}
template <typename Variant, typename T, typename... Ts>
struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>::value>, Variant, T, Ts...> {
void operator()(string_view& s, Variant& variant) {
if ( is_bt_output_list_container<T>::value ? s[0] == 'l' :
is_bt_output_dict_container<T>::value ? s[0] == 'd' :
std::is_integral<T>::value ? s[0] == 'i' :
std::is_same<T, std::string>::value ? s[0] >= '0' && s[0] <= '9' :
struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>>, Variant, T, Ts...> {
void operator()(std::string_view& s, Variant& variant) {
if ( is_bt_output_list_container<T> ? s[0] == 'l' :
is_bt_tuple<T> ? s[0] == 'l' :
is_bt_output_dict_container<T> ? s[0] == 'd' :
std::is_integral_v<T> ? s[0] == 'i' :
std::is_same_v<T, std::string> ? s[0] >= '0' && s[0] <= '9' :
false) {
T val;
bt_deserialize<T>{}(s, val);
@ -397,49 +418,41 @@ struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>:
};
template <typename Variant, typename T, typename... Ts>
struct bt_deserialize_try_variant_impl<std::enable_if_t<!is_bt_deserializable<T>::value>, Variant, T, Ts...> {
void operator()(string_view& s, Variant& variant) {
struct bt_deserialize_try_variant_impl<std::enable_if_t<!is_bt_deserializable<T>>, Variant, T, Ts...> {
void operator()(std::string_view& s, Variant& variant) {
// Unsupported deserialization type, skip it
bt_deserialize_try_variant<Ts...>(s, variant);
}
};
template <>
struct bt_deserialize<bt_value, void> {
void operator()(string_view& s, bt_value& val);
};
// Serialization of a variant; all variant types must be bt-serializable.
template <typename... Ts>
struct bt_serialize<mapbox::util::variant<Ts...>> {
void operator()(std::ostream& os, const mapbox::util::variant<Ts...>& val) {
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
}
};
template <typename... Ts>
struct bt_deserialize<mapbox::util::variant<Ts...>> {
void operator()(string_view& s, mapbox::util::variant<Ts...>& val) {
bt_deserialize_try_variant<Ts...>(s, val);
}
};
#ifdef __cpp_lib_variant
/// C++17 std::variant support
template <typename... Ts>
struct bt_serialize<std::variant<Ts...>> {
struct bt_serialize<std::variant<Ts...>, std::void_t<bt_serialize<Ts>...>> {
void operator()(std::ostream &os, const std::variant<Ts...>& val) {
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
std::visit(
[&os] (const auto& val) {
using T = std::remove_cv_t<std::remove_reference_t<decltype(val)>>;
bt_serialize<T>{}(os, val);
},
val);
}
};
// Deserialization to a variant; at least one variant type must be bt-deserializble.
template <typename... Ts>
struct bt_deserialize<std::variant<Ts...>> {
void operator()(string_view& s, std::variant<Ts...>& val) {
struct bt_deserialize<std::variant<Ts...>, std::enable_if_t<(is_bt_deserializable<Ts> || ...)>> {
void operator()(std::string_view& s, std::variant<Ts...>& val) {
bt_deserialize_try_variant<Ts...>(s, val);
}
};
#endif
template <>
struct bt_serialize<bt_value> : bt_serialize<bt_variant> {};
template <>
struct bt_deserialize<bt_value> {
void operator()(std::string_view& s, bt_value& val);
};
template <typename T>
struct bt_stream_serializer {
@ -500,8 +513,8 @@ std::string bt_serialize(const T &val) { return bt_serializer(val); }
/// int value;
/// bt_deserialize(encoded, value); // Sets value to 42
///
template <typename T, std::enable_if_t<!std::is_const<T>::value, int> = 0>
void bt_deserialize(string_view s, T& val) {
template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
void bt_deserialize(std::string_view s, T& val) {
return detail::bt_deserialize<T>{}(s, val);
}
@ -512,14 +525,14 @@ void bt_deserialize(string_view s, T& val) {
/// auto mylist = bt_deserialize<std::list<int>>(encoded);
///
template <typename T>
T bt_deserialize(string_view s) {
T bt_deserialize(std::string_view s) {
T val;
bt_deserialize(s, val);
return val;
}
/// Deserializes the given value into a generic `bt_value` type (mapbox::util::variant) which is capable
/// of holding all possible BT-encoded values (including recursion).
/// Deserializes the given value into a generic `bt_value` type (wrapped std::variant) which is
/// capable of holding all possible BT-encoded values (including recursion).
///
/// Example:
///
@ -528,50 +541,107 @@ T bt_deserialize(string_view s) {
/// int v = get_int<int>(val); // fails unless the encoded value was actually an integer that
/// // fits into an `int`
///
inline bt_value bt_get(string_view s) {
inline bt_value bt_get(std::string_view s) {
return bt_deserialize<bt_value>(s);
}
/// Helper functions to extract a value of some integral type from a bt_value which contains an
/// integer. Does range checking, throwing std::overflow_error if the stored value is outside the
/// range of the target type.
/// Helper functions to extract a value of some integral type from a bt_value which contains either
/// a int64_t or uint64_t. Does range checking, throwing std::overflow_error if the stored value is
/// outside the range of the target type.
///
/// Example:
///
/// std::string encoded = "i123456789e";
/// auto val = bt_get(encoded);
/// auto v = get_int<uint32_t>(val); // throws if the decoded value doesn't fit in a uint32_t
template <typename IntType, std::enable_if_t<std::is_integral<IntType>::value, int> = 0>
template <typename IntType, std::enable_if_t<std::is_integral_v<IntType>, int> = 0>
IntType get_int(const bt_value &v) {
// It's highly unlikely that this code ever runs on a non-2s-complement architecture, but check
// at compile time if converting to a uint64_t (because while int64_t -> uint64_t is
// well-defined, uint64_t -> int64_t only does the right thing under 2's complement).
static_assert(!std::is_unsigned<IntType>::value || sizeof(IntType) != sizeof(int64_t) || -1 == ~0,
"Non 2s-complement architecture not supported!");
int64_t value = mapbox::util::get<int64_t>(v);
if (sizeof(IntType) < sizeof(int64_t)) {
if (std::holds_alternative<uint64_t>(v)) {
uint64_t value = std::get<uint64_t>(v);
if constexpr (!std::is_same_v<IntType, uint64_t>)
if (value > static_cast<uint64_t>(std::numeric_limits<IntType>::max()))
throw std::overflow_error("Unable to extract integer value: stored value is too large for the requested type");
return static_cast<IntType>(value);
}
int64_t value = std::get<int64_t>(v);
if constexpr (!std::is_same_v<IntType, int64_t>)
if (value > static_cast<int64_t>(std::numeric_limits<IntType>::max())
|| value < static_cast<int64_t>(std::numeric_limits<IntType>::min()))
throw std::overflow_error("Unable to extract integer value: stored value is outside the range of the requested type");
}
return static_cast<IntType>(value);
}
namespace detail {
template <typename Tuple, size_t... Is>
void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence<Is...>);
}
/// Converts a bt_list into the given template std::tuple or std::pair. Throws a
/// std::invalid_argument if the list has the wrong size or wrong element types. Supports recursion
/// (i.e. if the tuple itself contains tuples or pairs). The tuple (or nested tuples) may only
/// contain integral types, strings, string_views, bt_list, bt_dict, and tuples/pairs of those.
template <typename Tuple>
Tuple get_tuple(const bt_list& x) {
Tuple t;
detail::get_tuple_impl(t, x, std::make_index_sequence<std::tuple_size_v<Tuple>>{});
return t;
}
template <typename Tuple>
Tuple get_tuple(const bt_value& x) {
return get_tuple<Tuple>(std::get<bt_list>(static_cast<const bt_variant&>(x)));
}
namespace detail {
template <typename T, typename It>
void get_tuple_impl_one(T& t, It& it) {
const bt_variant& v = *it++;
if constexpr (std::is_integral_v<T>) {
t = lokimq::get_int<T>(v);
} else if constexpr (is_bt_tuple<T>) {
if (std::holds_alternative<bt_list>(v))
throw std::invalid_argument{"Unable to convert tuple: cannot create sub-tuple from non-bt_list"};
t = get_tuple<T>(std::get<bt_list>(v));
} else if constexpr (std::is_same_v<std::string, T> || std::is_same_v<std::string_view, T>) {
// If we request a string/string_view, we might have the other one and need to copy/view it.
if (std::holds_alternative<std::string_view>(v))
t = std::get<std::string_view>(v);
else
t = std::get<std::string>(v);
} else {
t = std::get<T>(v);
}
}
template <typename Tuple, size_t... Is>
void get_tuple_impl(Tuple& t, const bt_list& l, std::index_sequence<Is...>) {
if (l.size() != sizeof...(Is))
throw std::invalid_argument{"Unable to convert tuple: bt_list has wrong size"};
auto it = l.begin();
(get_tuple_impl_one(std::get<Is>(t), it), ...);
}
} // namespace detail
/// Class that allows you to walk through a bt-encoded list in memory without copying or allocating
/// memory. It accesses existing memory directly and so the caller must ensure that the referenced
/// memory stays valid for the lifetime of the bt_list_consumer object.
class bt_list_consumer {
protected:
string_view data;
std::string_view data;
bt_list_consumer() = default;
public:
bt_list_consumer(string_view data_);
bt_list_consumer(std::string_view data_);
/// Copy constructor. Making a copy copies the current position so can be used for multipass
/// iteration through a list.
bt_list_consumer(const bt_list_consumer&) = default;
bt_list_consumer& operator=(const bt_list_consumer&) = default;
/// Get a copy of the current buffer
std::string_view current_buffer() const { return data; }
/// Returns true if the next value indicates the end of the list
bool is_finished() const { return data.front() == 'e'; }
/// Returns true if the next element looks like an encoded string
@ -588,23 +658,24 @@ public:
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
/// value is not a string.
std::string consume_string();
string_view consume_string_view();
std::string_view consume_string_view();
/// Attempts to parse the next value as an integer (and advance just past it). Throws if the
/// next value is not an integer.
template <typename IntType>
IntType consume_integer() {
if (!is_integer()) throw bt_deserialize_invalid_type{"next value is not an integer"};
string_view next{data};
std::string_view next{data};
IntType ret;
detail::bt_deserialize<IntType>{}(next, ret);
data = next;
return ret;
}
/// Consumes a list, return it as a list-like type. This typically requires dynamic allocation,
/// but only has to parse the data once. Compare with consume_list_data() which allows
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
/// Consumes a list, return it as a list-like type. Can also be used for tuples/pairs. This
/// typically requires dynamic allocation, but only has to parse the data once. Compare with
/// consume_list_data() which allows alloc-free traversal, but requires parsing twice (if the
/// contents are to be used).
template <typename T = bt_list>
T consume_list() {
T list;
@ -616,7 +687,7 @@ public:
template <typename T>
void consume_list(T& list) {
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
string_view n{data};
std::string_view n{data};
detail::bt_deserialize<T>{}(n, list);
data = n;
}
@ -635,7 +706,7 @@ public:
template <typename T>
void consume_dict(T& dict) {
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
string_view n{data};
std::string_view n{data};
detail::bt_deserialize<T>{}(n, dict);
data = n;
}
@ -647,13 +718,13 @@ public:
/// entire thing. This is recursive into both lists and dicts and likely to be quite
/// inefficient for large, nested structures (unless the values only need to be skipped but
/// aren't separately needed). This, however, does not require dynamic memory allocation.
string_view consume_list_data();
std::string_view consume_list_data();
/// Attempts to parse the next value as a dict and returns the string_view that contains the
/// entire thing. This is recursive into both lists and dicts and likely to be quite
/// inefficient for large, nested structures (unless the values only need to be skipped but
/// aren't separately needed). This, however, does not require dynamic memory allocation.
string_view consume_dict_data();
std::string_view consume_dict_data();
};
@ -661,7 +732,7 @@ public:
/// copying or allocating memory. It accesses existing memory directly and so the caller must
/// ensure that the referenced memory stays valid for the lifetime of the bt_dict_consumer object.
class bt_dict_consumer : private bt_list_consumer {
string_view key_;
std::string_view key_;
/// Consume the key if not already consumed and there is a key present (rather than 'e').
/// Throws exception if what should be a key isn't a string, or if the key consumes the entire
@ -671,14 +742,14 @@ class bt_dict_consumer : private bt_list_consumer {
/// Clears the cached key and returns it. Must have already called consume_key directly or
/// indirectly via one of the `is_{...}` methods.
string_view flush_key() {
string_view k;
std::string_view flush_key() {
std::string_view k;
k.swap(key_);
return k;
}
public:
bt_dict_consumer(string_view data_);
bt_dict_consumer(std::string_view data_);
/// Copy constructor. Making a copy copies the current position so can be used for multipass
/// iteration through a list.
@ -703,7 +774,7 @@ public:
/// all of the other consume_* methods. The value is cached whether called here or by some
/// other method; accessing it multiple times simple accesses the cache until the next value is
/// consumed.
string_view key() {
std::string_view key() {
if (!consume_key())
throw bt_deserialize_invalid{"Cannot access next key: at the end of the dict"};
return key_;
@ -711,14 +782,14 @@ public:
/// Attempt to parse the next value as a string->string pair (and advance just past it). Throws
/// if the next value is not a string.
std::pair<string_view, string_view> next_string();
std::pair<std::string_view, std::string_view> next_string();
/// Attempts to parse the next value as an string->integer pair (and advance just past it).
/// Throws if the next value is not an integer.
template <typename IntType>
std::pair<string_view, IntType> next_integer() {
std::pair<std::string_view, IntType> next_integer() {
if (!is_integer()) throw bt_deserialize_invalid_type{"next bt dict value is not an integer"};
std::pair<string_view, IntType> ret;
std::pair<std::string_view, IntType> ret;
ret.second = bt_list_consumer::consume_integer<IntType>();
ret.first = flush_key();
return ret;
@ -729,15 +800,15 @@ public:
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
/// used).
template <typename T = bt_list>
std::pair<string_view, T> next_list() {
std::pair<string_view, T> pair;
pair.first = consume_list(pair.second);
std::pair<std::string_view, T> next_list() {
std::pair<std::string_view, T> pair;
pair.first = next_list(pair.second);
return pair;
}
/// Same as above, but takes a pre-existing list-like data type. Returns the key.
template <typename T>
string_view next_list(T& list) {
std::string_view next_list(T& list) {
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
bt_list_consumer::consume_list(list);
return flush_key();
@ -748,15 +819,15 @@ public:
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
/// used).
template <typename T = bt_dict>
std::pair<string_view, T> next_dict() {
std::pair<string_view, T> pair;
std::pair<std::string_view, T> next_dict() {
std::pair<std::string_view, T> pair;
pair.first = consume_dict(pair.second);
return pair;
}
/// Same as above, but takes a pre-existing dict-like data type. Returns the key.
template <typename T>
string_view next_dict(T& dict) {
std::string_view next_dict(T& dict) {
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
bt_list_consumer::consume_dict(dict);
return flush_key();
@ -766,25 +837,25 @@ public:
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
/// quite inefficient for large, nested structures (unless the values only need to be skipped
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
std::pair<string_view, string_view> next_list_data() {
std::pair<std::string_view, std::string_view> next_list_data() {
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt dict value is not a list"};
return {flush_key(), bt_list_consumer::consume_list_data()};
}
/// Same as next_list_data(), but wraps the value in a bt_list_consumer for convenience
std::pair<string_view, bt_list_consumer> next_list_consumer() { return next_list_data(); }
std::pair<std::string_view, bt_list_consumer> next_list_consumer() { return next_list_data(); }
/// Attempts to parse the next value as a string->dict pair and returns the string_view that
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
/// quite inefficient for large, nested structures (unless the values only need to be skipped
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
std::pair<string_view, string_view> next_dict_data() {
std::pair<std::string_view, std::string_view> next_dict_data() {
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt dict value is not a dict"};
return {flush_key(), bt_list_consumer::consume_dict_data()};
}
/// Same as next_dict_data(), but wraps the value in a bt_dict_consumer for convenience
std::pair<string_view, bt_dict_consumer> next_dict_consumer() { return next_dict_data(); }
std::pair<std::string_view, bt_dict_consumer> next_dict_consumer() { return next_dict_data(); }
/// Skips ahead until we find the first key >= the given key or reach the end of the dict.
/// Returns true if we found an exact match, false if we reached some greater value or the end.
@ -799,7 +870,7 @@ public:
/// - this is irreversible; you cannot returned to skipped values without reparsing. (You *can*
/// however, make a copy of the bt_dict_consumer before calling and use the copy to return to
/// the pre-skipped position).
bool skip_until(string_view find) {
bool skip_until(std::string_view find) {
while (consume_key() && key_ < find) {
flush_key();
skip_value();
@ -834,8 +905,8 @@ public:
template <typename T>
void consume_dict(T& dict) { next_dict(dict); }
string_view consume_list_data() { return next_list_data().second; }
string_view consume_dict_data() { return next_dict_data().second; }
std::string_view consume_list_data() { return next_list_data().second; }
std::string_view consume_dict_data() { return next_dict_data().second; }
bt_list_consumer consume_list_consumer() { return consume_list_data(); }
bt_dict_consumer consume_dict_consumer() { return consume_dict_data(); }

112
lokimq/bt_value.h Normal file
View File

@ -0,0 +1,112 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
// This header is here to provide just the basic bt_value/bt_dict/bt_list definitions without
// needing to include the full bt_serialize.h header.
#include <map>
#include <list>
#include <cstdint>
#include <variant>
#include <string>
#include <string_view>
namespace lokimq {
struct bt_value;
/// The type used to store dictionaries inside bt_value.
using bt_dict = std::map<std::string, bt_value>; // NB: unordered_map doesn't work because it can't be used with a predeclared type
/// The type used to store list items inside bt_value.
using bt_list = std::list<bt_value>;
/// The basic variant that can hold anything (recursively).
using bt_variant = std::variant<
std::string,
std::string_view,
int64_t,
uint64_t,
bt_list,
bt_dict
>;
#ifdef __cpp_lib_remove_cvref // C++20
using std::remove_cvref_t;
#else
template <typename T>
using remove_cvref_t = std::remove_cv_t<std::remove_reference_t<T>>;
#endif
template <typename T, typename Variant>
struct has_alternative;
template <typename T, typename... V>
struct has_alternative<T, std::variant<V...>> : std::bool_constant<(std::is_same_v<T, V> || ...)> {};
template <typename T, typename Variant>
constexpr bool has_alternative_v = has_alternative<T, Variant>::value;
namespace detail {
template <typename Tuple, size_t... Is>
bt_list tuple_to_list(const Tuple& tuple, std::index_sequence<Is...>) {
return {{bt_value{std::get<Is>(tuple)}...}};
}
template <typename T> constexpr bool is_tuple = false;
template <typename... T> constexpr bool is_tuple<std::tuple<T...>> = true;
template <typename S, typename T> constexpr bool is_tuple<std::pair<S, T>> = true;
}
/// Recursive generic type that can fully represent everything valid for a BT serialization.
/// This is basically just an empty wrapper around the std::variant, except we add some extra
/// converting constructors:
/// - integer constructors so that any unsigned value goes to the uint64_t and any signed value goes
/// to the int64_t.
/// - std::tuple and std::pair constructors that build a bt_list out of the tuple/pair elements.
struct bt_value : bt_variant {
using bt_variant::bt_variant;
using bt_variant::operator=;
template <typename T, typename U = std::remove_reference_t<T>, std::enable_if_t<std::is_integral_v<U> && std::is_unsigned_v<U>, int> = 0>
bt_value(T&& uint) : bt_variant{static_cast<uint64_t>(uint)} {}
template <typename T, typename U = std::remove_reference_t<T>, std::enable_if_t<std::is_integral_v<U> && std::is_signed_v<U>, int> = 0>
bt_value(T&& sint) : bt_variant{static_cast<int64_t>(sint)} {}
template <typename... T>
bt_value(const std::tuple<T...>& tuple) : bt_variant{detail::tuple_to_list(tuple, std::index_sequence_for<T...>{})} {}
template <typename S, typename T>
bt_value(const std::pair<S, T>& pair) : bt_variant{detail::tuple_to_list(pair, std::index_sequence_for<S, T>{})} {}
template <typename T, typename U = std::remove_reference_t<T>, std::enable_if_t<!std::is_integral_v<U> && !detail::is_tuple<U>, int> = 0>
bt_value(T&& v) : bt_variant{std::forward<T>(v)} {}
bt_value(const char* s) : bt_value{std::string_view{s}} {}
};
}

View File

@ -47,7 +47,7 @@ void LokiMQ::setup_external_socket(zmq::socket_t& socket) {
}
}
void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey) {
void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey) {
setup_external_socket(socket);
@ -67,7 +67,7 @@ void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pub
// else let ZMQ pick a random one
}
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
ConnectionID LokiMQ::connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive, std::string_view hint) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot call connect_sn() before calling `start()`");
@ -76,30 +76,32 @@ ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds ke
return pubkey;
}
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
ConnectionID LokiMQ::connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure,
AuthLevel auth_level, std::chrono::milliseconds timeout) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot call connect_remote() before calling `start()`");
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
auto id = next_conn_id++;
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id);
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
{"conn_id", id},
{"connect", detail::serialize_object(std::move(on_connect))},
{"failure", detail::serialize_object(std::move(on_failure))},
{"pubkey", pubkey},
{"remote", remote},
{"pubkey", remote.curve() ? remote.pubkey : ""},
{"remote", remote.zmq_address()},
{"timeout", timeout.count()},
}));
return id;
}
ConnectionID LokiMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
std::string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
return connect_remote(address{remote}.set_pubkey(pubkey),
std::move(on_connect), std::move(on_failure), auth_level, timeout);
}
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
{"conn_id", id.id},
@ -109,7 +111,7 @@ void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
}
std::pair<zmq::socket_t *, std::string>
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
LokiMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
ConnectionID remote_cid{remote};
auto its = peers.equal_range(remote_cid);
peer_info* peer = nullptr;
@ -185,7 +187,7 @@ LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool opti
}
std::pair<zmq::socket_t *, std::string> LokiMQ::proxy_connect_sn(bt_dict_consumer data) {
string_view hint, remote_pk;
std::string_view hint, remote_pk;
std::chrono::milliseconds keep_alive;
bool optional = false, incoming_only = false, outgoing_only = false;

View File

@ -1,15 +1,20 @@
#pragma once
#include "auth.h"
#include "string_view.h"
#include "bt_value.h"
#include <string_view>
#include <iosfwd>
#include <stdexcept>
#include <string>
#include <utility>
#include <variant>
namespace lokimq {
class bt_dict;
struct ConnectionID;
namespace detail {
template <typename... T>
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts);
bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts);
}
/// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For
@ -26,7 +31,7 @@ struct ConnectionID {
throw std::runtime_error{"Invalid pubkey: expected 32 bytes"};
}
// Construction from a service node pubkey
ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
ConnectionID(std::string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
ConnectionID(const ConnectionID&) = default;
ConnectionID(ConnectionID&&) = default;
@ -75,7 +80,7 @@ private:
friend class LokiMQ;
friend struct std::hash<ConnectionID>;
template <typename... T>
friend bt_dict detail::build_send(ConnectionID to, string_view cmd, T&&... opts);
friend bt_dict detail::build_send(ConnectionID to, std::string_view cmd, T&&... opts);
friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn);
};

View File

@ -27,7 +27,8 @@
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "string_view.h"
#include <string>
#include <string_view>
#include <array>
#include <iterator>
#include <cassert>
@ -55,46 +56,52 @@ struct hex_table {
constexpr char to_hex(unsigned char b) const noexcept { return to_hex_lut[b]; }
} constexpr hex_lut;
// This main point of this static assert is to force the compiler to compile-time build the constexpr tables.
static_assert(hex_lut.from_hex('a') == 10 && hex_lut.from_hex('F') == 15 && hex_lut.to_hex(13) == 'd', "");
} // namespace detail
/// Creates hex digits from a character sequence.
template <typename InputIt, typename OutputIt>
void to_hex(InputIt begin, InputIt end, OutputIt out) {
static_assert(sizeof(decltype(*begin)) == 1, "to_hex requires chars/bytes");
for (; begin != end; ++begin) {
auto c = *begin;
*out++ = detail::hex_lut.to_hex((c & 0xf0) >> 4);
uint8_t c = static_cast<uint8_t>(*begin);
*out++ = detail::hex_lut.to_hex(c >> 4);
*out++ = detail::hex_lut.to_hex(c & 0x0f);
}
}
/// Creates a hex string from an iterable, std::string-like object
inline std::string to_hex(string_view s) {
/// Creates a string of hex digits from a character sequence iterator pair
template <typename It>
std::string to_hex(It begin, It end) {
std::string hex;
hex.reserve(s.size() * 2);
to_hex(s.begin(), s.end(), std::back_inserter(hex));
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
hex.reserve(2 * std::distance(begin, end));
to_hex(begin, end, std::back_inserter(hex));
return hex;
}
inline std::string to_hex(ustring_view s) {
std::string hex;
hex.reserve(s.size() * 2);
to_hex(s.begin(), s.end(), std::back_inserter(hex));
return hex;
}
/// Creates a hex string from an iterable, std::string-like object
template <typename CharT>
std::string to_hex(std::basic_string_view<CharT> s) { return to_hex(s.begin(), s.end()); }
inline std::string to_hex(std::string_view s) { return to_hex<>(s); }
/// Returns true if all elements in the range are hex characters
template <typename It>
constexpr bool is_hex(It begin, It end) {
static_assert(sizeof(decltype(*begin)) == 1, "is_hex requires chars/bytes");
for (; begin != end; ++begin) {
if (detail::hex_lut.from_hex(*begin) == 0 && *begin != '0')
if (detail::hex_lut.from_hex(static_cast<unsigned char>(*begin)) == 0 && static_cast<unsigned char>(*begin) != '0')
return false;
}
return true;
}
/// Returns true if all elements in the string-like value are hex characters
constexpr bool is_hex(string_view s) { return is_hex(s.begin(), s.end()); }
constexpr bool is_hex(ustring_view s) { return is_hex(s.begin(), s.end()); }
template <typename CharT>
constexpr bool is_hex(std::basic_string_view<CharT> s) { return is_hex(s.begin(), s.end()); }
constexpr bool is_hex(std::string_view s) { return is_hex(s.begin(), s.end()); }
/// Convert a hex digit into its numeric (0-15) value
constexpr char from_hex_digit(unsigned char x) noexcept {
@ -114,24 +121,25 @@ void from_hex(InputIt begin, InputIt end, OutputIt out) {
while (begin != end) {
auto a = *begin++;
auto b = *begin++;
*out++ = from_hex_pair(a, b);
*out++ = from_hex_pair(static_cast<unsigned char>(a), static_cast<unsigned char>(b));
}
}
/// Converts a sequence of hex digits to a string of bytes and returns it. Undefined behaviour if
/// the input sequence is not an even-length sequence of [0-9a-fA-F] characters.
template <typename It>
std::string from_hex(It begin, It end) {
std::string bytes;
if constexpr (std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<It>::iterator_category>)
bytes.reserve(std::distance(begin, end) / 2);
from_hex(begin, end, std::back_inserter(bytes));
return bytes;
}
/// Converts hex digits from a std::string-like object into a std::string of bytes. Undefined
/// behaviour if any characters are not in [0-9a-fA-F] or if the input sequence length is not even.
inline std::string from_hex(string_view s) {
std::string bytes;
bytes.reserve(s.size() / 2);
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
return bytes;
}
inline std::string from_hex(ustring_view s) {
std::string bytes;
bytes.reserve(s.size() / 2);
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
return bytes;
}
template <typename CharT>
std::string from_hex(std::basic_string_view<CharT> s) { return from_hex(s.begin(), s.end()); }
inline std::string from_hex(std::string_view s) { return from_hex<>(s); }
}

View File

@ -6,15 +6,31 @@ namespace lokimq {
void LokiMQ::proxy_batch(detail::Batch* batch) {
batches.insert(batch);
const int jobs = batch->size();
for (int i = 0; i < jobs; i++)
batch_jobs.emplace(batch, i);
const auto [jobs, tagged_threads] = batch->size();
LMQ_TRACE("proxy queuing batch job with ", jobs, " jobs", tagged_threads ? " (job uses tagged thread(s))" : "");
if (!tagged_threads) {
for (size_t i = 0; i < jobs; i++)
batch_jobs.emplace(batch, i);
} else {
// Some (or all) jobs have a specific thread target so queue any such jobs in the tagged
// worker queue.
auto threads = batch->threads();
for (size_t i = 0; i < jobs; i++) {
auto& jobs = threads[i] > 0
? std::get<std::queue<batch_job>>(tagged_workers[threads[i] - 1])
: batch_jobs;
jobs.emplace(batch, i);
}
}
proxy_skip_one_poll = true;
}
void LokiMQ::job(std::function<void()> f) {
void LokiMQ::job(std::function<void()> f, std::optional<TaggedThreadID> thread) {
if (thread && thread->_id == -1)
throw std::logic_error{"job() cannot be used to queue an in-proxy job"};
auto* b = new Batch<void>;
b->add_job(std::move(f));
b->add_job(std::move(f), thread);
auto* baseptr = static_cast<detail::Batch*>(b);
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
}
@ -38,7 +54,7 @@ void LokiMQ::proxy_run_batch_jobs(std::queue<batch_job>& jobs, const int reserve
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
// the timer. If the timer object hasn't been set up yet it gets set up here.
void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, int thread) {
if (!timers)
timers.reset(zmq_timers_new());
@ -48,16 +64,17 @@ void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds in
this);
if (timer_id == -1)
throw zmq::error_t{};
timer_jobs[timer_id] = std::make_tuple(std::move(job), squelch, false);
timer_jobs[timer_id] = { std::move(job), squelch, false, thread };
}
void LokiMQ::proxy_timer(bt_list_consumer timer_data) {
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
auto squelch = timer_data.consume_integer<bool>();
auto thread = timer_data.consume_integer<int>();
if (!timer_data.is_finished())
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
proxy_timer(std::move(*func), interval, squelch);
proxy_timer(std::move(*func), interval, squelch, thread);
}
void LokiMQ::_queue_timer_job(int timer_id) {
@ -66,43 +83,72 @@ void LokiMQ::_queue_timer_job(int timer_id) {
LMQ_LOG(warn, "Could not find timer job ", timer_id);
return;
}
auto& timer = it->second;
auto& squelch = std::get<1>(timer);
auto& running = std::get<2>(timer);
auto& [func, squelch, running, thread] = it->second;
if (squelch && running) {
LMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
return;
}
if (thread == -1) { // Run directly in proxy thread
try { func(); }
catch (const std::exception &e) { LMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
catch (...) { LMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
return;
}
auto* b = new Batch<void>;
b->add_job(std::get<0>(timer));
b->add_job(func, thread);
if (squelch) {
running = true;
b->completion_proxy([this,timer_id](auto results) {
b->completion([this,timer_id](auto results) {
try { results[0].get(); }
catch (const std::exception &e) { LMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
catch (...) { LMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
auto it = timer_jobs.find(timer_id);
if (it != timer_jobs.end())
std::get<2>(it->second)/*running*/ = false;
});
it->second.running = false;
}, LokiMQ::run_in_proxy);
}
batches.insert(b);
batch_jobs.emplace(static_cast<detail::Batch*>(b), 0);
assert(b->size() == 1);
LMQ_TRACE("b: ", b->size().first, ", ", b->size().second, "; thread = ", thread);
assert(b->size() == std::make_pair(size_t{1}, thread > 0));
auto& queue = thread > 0
? std::get<std::queue<batch_job>>(tagged_workers[thread - 1])
: batch_jobs;
queue.emplace(static_cast<detail::Batch*>(b), 0);
}
void LokiMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
void LokiMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
int th_id = thread ? thread->_id : 0;
if (proxy_thread.joinable()) {
detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{
detail::serialize_object(std::move(job)),
interval.count(),
squelch}}));
squelch,
th_id}}));
} else {
proxy_timer(std::move(job), interval, squelch);
proxy_timer(std::move(job), interval, squelch, th_id);
}
}
void LokiMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
TaggedThreadID LokiMQ::add_tagged_thread(std::string name, std::function<void()> start) {
if (proxy_thread.joinable())
throw std::logic_error{"Cannot add tagged threads after calling `start()`"};
if (name == "_proxy"sv || name.empty() || name.find('\0') != std::string::npos)
throw std::logic_error{"Invalid tagged thread name `" + name + "'"};
auto& [run, busy, queue] = tagged_workers.emplace_back();
busy = false;
run.worker_id = tagged_workers.size(); // We want index + 1 (b/c 0 is used for non-tagged jobs)
run.worker_routing_id = "t" + std::to_string(run.worker_id);
LMQ_TRACE("Created new tagged thread ", name, " with routing id ", run.worker_routing_id);
run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id, name, std::move(start)};
return TaggedThreadID{static_cast<int>(run.worker_id)};
}
}

View File

@ -4,12 +4,11 @@
// Inside some method:
// LMQ_LOG(warn, "bad ", 42, " stuff");
//
// (The "this->" is here to work around gcc 5 bugginess when called in a `this`-capturing lambda.)
#define LMQ_LOG(level, ...) this->log_(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
#define LMQ_LOG(level, ...) log(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
#ifndef NDEBUG
// Same as LMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
# define LMQ_TRACE(...) this->log_(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
# define LMQ_TRACE(...) log(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
#else
# define LMQ_TRACE(...)
#endif
@ -33,7 +32,7 @@ inline zmq::message_t create_message(std::string&& data) {
};
/// Create a message copying from a string_view
inline zmq::message_t create_message(string_view data) {
inline zmq::message_t create_message(std::string_view data) {
return zmq::message_t{data.begin(), data.end()};
}
@ -94,7 +93,7 @@ inline const char* peer_address(zmq::message_t& msg) {
// Returns a string view of the given message data. It's the caller's responsibility to keep the
// referenced message alive. If you want a std::string instead just call `m.to_string()`
inline string_view view(const zmq::message_t& m) {
inline std::string_view view(const zmq::message_t& m) {
return {m.data<char>(), m.size()};
}
@ -108,7 +107,7 @@ inline std::string to_string(AuthLevel a) {
}
}
inline AuthLevel auth_from_string(string_view a) {
inline AuthLevel auth_from_string(std::string_view a) {
if (a == "none") return AuthLevel::none;
if (a == "basic") return AuthLevel::basic;
if (a == "admin") return AuthLevel::admin;
@ -116,7 +115,7 @@ inline AuthLevel auth_from_string(string_view a) {
}
// Extracts and builds the "send" part of a message for proxy_send/proxy_reply
inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_view route) {
inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, std::string_view route) {
std::list<zmq::message_t> parts;
if (!route.empty())
parts.push_back(create_message(route));
@ -128,7 +127,7 @@ inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_
/// Sends a control message to a specific destination by prefixing the worker name (or identity)
/// then appending the command and optional data (if non-empty). (This is needed when sending the control message
/// to a router socket, i.e. inside the proxy thread).
inline void route_control(zmq::socket_t& sock, string_view identity, string_view cmd, const std::string& data = {}) {
inline void route_control(zmq::socket_t& sock, std::string_view identity, std::string_view cmd, const std::string& data = {}) {
sock.send(create_message(identity), zmq::send_flags::sndmore);
detail::send_control(sock, cmd, data);
}

View File

@ -2,9 +2,12 @@
#include "lokimq-internal.h"
#include <map>
#include <random>
#include <ostream>
extern "C" {
#include <sodium.h>
#include <sodium/core.h>
#include <sodium/crypto_box.h>
#include <sodium/crypto_scalarmult.h>
}
#include "hex.h"
@ -38,7 +41,7 @@ namespace detail {
// Sends a control messages between proxy and threads or between proxy and workers consisting of a
// single command codes with an optional data part (the data frame is omitted if empty).
void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data) {
auto c = create_message(std::move(cmd));
if (data.empty()) {
sock.send(c, zmq::send_flags::none);
@ -53,7 +56,7 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
auto result = std::make_pair(""s, AuthLevel::none);
try {
string_view pubkey_hex{msg.gets("User-Id")};
std::string_view pubkey_hex{msg.gets("User-Id")};
if (pubkey_hex.size() != 64)
throw std::logic_error("bad user-id");
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
@ -174,7 +177,7 @@ zmq::socket_t& LokiMQ::get_control_socket() {
return *last.second;
}
std::lock_guard<std::mutex> lock{control_sockets_mutex};
std::lock_guard lock{control_sockets_mutex};
if (proxy_shutting_down)
throw std::runtime_error("Unable to obtain LokiMQ control socket: proxy thread is shutting down");
auto control = std::make_shared<zmq::socket_t>(context, zmq::socket_type::dealer);
@ -201,6 +204,9 @@ LokiMQ::LokiMQ(
LMQ_TRACE("Constructing LokiMQ, id=", object_id, ", this=", this);
if (sodium_init() == -1)
throw std::runtime_error{"libsodium initialization failed"};
if (pubkey.empty() != privkey.empty()) {
throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
} else if (pubkey.empty()) {
@ -343,35 +349,75 @@ void LokiMQ::set_general_threads(int threads) {
LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_,
std::vector<zmq::message_t> data_parts_, const std::pair<CommandCallback, bool>* callback_) {
is_batch_job = false;
is_reply_job = false;
reset();
cat = cat_;
command = std::move(command_);
conn = std::move(conn_);
access = std::move(access_);
remote = std::move(remote_);
data_parts = std::move(data_parts_);
callback = callback_;
to_run = callback_;
return *this;
}
LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function<void()> callback) {
reset();
is_injected = true;
cat = cat_;
command = std::move(command_);
conn = {};
access = {};
remote = std::move(remote_);
to_run = std::move(callback);
return *this;
}
LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) {
if (auto *f = std::get_if<std::function<void()>>(&pending.callback))
return load(&pending.cat, std::move(pending.command), std::move(pending.remote), std::move(*f));
assert(pending.callback.index() == 0);
return load(&pending.cat, std::move(pending.command), std::move(pending.conn), std::move(pending.access),
std::move(pending.remote), std::move(pending.data_parts), pending.callback);
std::move(pending.remote), std::move(pending.data_parts), std::get<0>(pending.callback));
}
LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job) {
LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) {
reset();
is_batch_job = true;
is_reply_job = reply_job;
is_tagged_thread_job = tagged_thread > 0;
batch_jobno = bj.second;
batch = bj.first;
to_run = bj.first;
return *this;
}
LokiMQ::~LokiMQ() {
if (!proxy_thread.joinable())
if (!proxy_thread.joinable()) {
if (!tagged_workers.empty()) {
// This is a bit icky: we have tagged workers that are waiting for a signal on
// workers_socket, but the listening end of workers_socket doesn't get set up until the
// proxy thread starts (and we're getting destructed here without a proxy thread). So
// we need to start listening on it here in the destructor so that we establish a
// connection and send the QUITs to the tagged worker threads.
workers_socket.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
workers_socket.bind(SN_ADDR_WORKERS);
for (auto& [run, busy, queue] : tagged_workers) {
while (true) {
try {
route_control(workers_socket, run.worker_routing_id, "QUIT");
break;
} catch (const zmq::error_t&) {
std::this_thread::sleep_for(5ms);
}
}
}
for (auto& [run, busy, queue] : tagged_workers)
run.worker_thread.join();
}
return;
}
LMQ_LOG(info, "LokiMQ shutting down proxy thread");
detail::send_control(get_control_socket(), "QUIT");

View File

@ -29,8 +29,10 @@
#pragma once
#include <string>
#include <string_view>
#include <list>
#include <queue>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <memory>
@ -41,9 +43,10 @@
#include <chrono>
#include <atomic>
#include <cassert>
#include <cstdint>
#include "zmq.hpp"
#include "address.h"
#include "bt_serialize.h"
#include "string_view.h"
#include "connections.h"
#include "message.h"
#include "auth.h"
@ -69,23 +72,33 @@ template <typename R> class Batch;
* use a longer keep-alive to a host call `connect()` first with the desired keep-alive time or pass
* the send_option::keep_alive.
*/
static constexpr auto DEFAULT_SEND_KEEP_ALIVE = 30s;
inline constexpr auto DEFAULT_SEND_KEEP_ALIVE = 30s;
// The default timeout for connect_remote()
static constexpr auto REMOTE_CONNECT_TIMEOUT = 10s;
inline constexpr auto REMOTE_CONNECT_TIMEOUT = 10s;
// The amount of time we wait for a reply to a REQUEST before calling the callback with
// `false` to signal a timeout.
static constexpr auto DEFAULT_REQUEST_TIMEOUT = 15s;
inline constexpr auto DEFAULT_REQUEST_TIMEOUT = 15s;
/// Maximum length of a category
static constexpr size_t MAX_CATEGORY_LENGTH = 50;
inline constexpr size_t MAX_CATEGORY_LENGTH = 50;
/// Maximum length of a command
static constexpr size_t MAX_COMMAND_LENGTH = 200;
inline constexpr size_t MAX_COMMAND_LENGTH = 200;
class CatHelper;
/// Opaque handle for a tagged thread constructed by add_tagged_thread(...). Not directly
/// constructible, but is safe (and cheap) to copy.
struct TaggedThreadID {
private:
int _id;
explicit constexpr TaggedThreadID(int id) : _id{id} {}
friend class LokiMQ;
template <typename R> friend class Batch;
};
/**
* Class that handles LokiMQ listeners, connections, proxying, and workers. An application
* typically has just one instance of this class.
@ -146,12 +159,12 @@ public:
///
/// @returns an `AuthLevel` enum value indicating the default auth level for the incoming
/// connection, or AuthLevel::denied if the connection should be refused.
using AllowFunc = std::function<AuthLevel(string_view address, string_view pubkey, bool service_node)>;
using AllowFunc = std::function<AuthLevel(std::string_view address, std::string_view pubkey, bool service_node)>;
/// Callback that is invoked when we need to send a "strong" message to a SN that we aren't
/// already connected to and need to establish a connection. This callback returns the ZMQ
/// connection string we should use which is typically a string such as `tcp://1.2.3.4:5678`.
using SNRemoteAddress = std::function<std::string(string_view pubkey)>;
using SNRemoteAddress = std::function<std::string(std::string_view pubkey)>;
/// The callback type for registered commands.
using CommandCallback = std::function<void(Message& message)>;
@ -169,7 +182,7 @@ public:
/// Callback for the success case of connect_remote()
using ConnectSuccess = std::function<void(ConnectionID)>;
/// Callback for the failure case of connect_remote()
using ConnectFailure = std::function<void(ConnectionID, string_view)>;
using ConnectFailure = std::function<void(ConnectionID, std::string_view)>;
/// Explicitly non-copyable, non-movable because most things here aren't copyable, and a few
/// things aren't movable, either. If you need to pass the LokiMQ instance around, wrap it
@ -246,6 +259,28 @@ public:
/// Allows you to set options on the internal zmq context object. For advanced use only.
int set_zmq_context_option(int option, int value);
/** The umask to apply when constructing sockets (which affects any new ipc:// listening sockets
* that get created). Does nothing if set to -1 (the default), and does nothing on Windows.
* Note that the umask is applied temporarily during `start()`, so may affect other threads that
* create files/directories at the same time as the start() call.
*/
int STARTUP_UMASK = -1;
/** The gid that owns any sockets when constructed (same as umask)
*/
int SOCKET_GID = -1;
/** The uid that owns any sockets when constructed (same as umask but requires root)
*/
int SOCKET_UID = -1;
/// A special TaggedThreadID value that always refers to the proxy thread; the main use of this is
/// to direct very simple batch completion jobs to be executed directly in the proxy thread.
inline static constexpr TaggedThreadID run_in_proxy{-1};
/// Writes a message to the logging system; intended mostly for internal use.
template <typename... T>
void log(LogLevel lvl, const char* filename, int line, const T&... stuff);
private:
/// The lookup function that tells us where to connect to a peer, or empty if not found.
@ -258,10 +293,6 @@ private:
/// The callback to call with log messages
Logger logger;
/// Logging implementation
template <typename... T>
void log_(LogLevel lvl, const char* filename, int line, const T&... stuff);
///////////////////////////////////////////////////////////////////////////////////
/// NB: The following are all the domain of the proxy thread (once it is started)!
@ -375,7 +406,8 @@ private:
/// Timers. TODO: once cppzmq adds an interface around the zmq C timers API then switch to it.
struct TimersDeleter { void operator()(void* timers); };
std::unordered_map<int, std::tuple<std::function<void()>, bool, bool>> timer_jobs; // id => {func, squelch, running}
struct timer_data { std::function<void()> function; bool squelch; bool running; int thread; };
std::unordered_map<int, timer_data> timer_jobs;
std::unique_ptr<void, TimersDeleter> timers;
public:
// This needs to be public because we have to be able to call it from a plain C function.
@ -401,8 +433,8 @@ private:
/// Number of active workers
int active_workers() const { return workers.size() - idle_workers.size(); }
/// Worker thread loop
void worker_thread(unsigned int index);
/// Worker thread loop. Tagged and start are provided for a tagged worker thread.
void worker_thread(unsigned int index, std::optional<std::string> tagged = std::nullopt, std::function<void()> start = nullptr);
/// If set, skip polling for one proxy loop iteration (set when we know we have something
/// processible without having to shove it onto a socket, such as scheduling an internal job).
@ -458,7 +490,7 @@ private:
// provided then the connection will be curve25519 encrypted and authenticate; otherwise it will
// be unencrypted and unauthenticated. Note that the remote end must be in the same mode (i.e.
// either accepting curve connections, or not accepting curve).
void setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey = {});
void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey = {});
/// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and,
/// if a routing prefix is needed, the required prefix (or an empty string if not needed). For
@ -474,7 +506,7 @@ private:
/// @param keep_alive the keep alive for the connection, if we establish a new outgoing
/// connection. If we already have an outgoing connection then its keep-alive gets increased to
/// this if currently less than this.
std::pair<zmq::socket_t*, std::string> proxy_connect_sn(string_view pubkey, string_view connect_hint,
std::pair<zmq::socket_t*, std::string> proxy_connect_sn(std::string_view pubkey, std::string_view connect_hint,
bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive);
/// CONNECT_SN command telling us to connect to a new pubkey. Returns the socket (which could
@ -499,7 +531,8 @@ private:
/// Currently active batch/reply jobs; this is the container that owns the Batch instances
std::unordered_set<detail::Batch*> batches;
/// Individual batch jobs waiting to run
/// Individual batch jobs waiting to run; .second is the 0-n batch number or -1 for the
/// completion job
using batch_job = std::pair<detail::Batch*, int>;
std::queue<batch_job> batch_jobs, reply_jobs;
int batch_jobs_active = 0;
@ -519,7 +552,7 @@ private:
void proxy_timer(bt_list_consumer timer_data);
/// Same, but deserialized
void proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch);
void proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, int thread);
/// ZAP (https://rfc.zeromq.org/spec:27/ZAP/) authentication handler; this does non-blocking
/// processing of any waiting authentication requests for new incoming connections.
@ -571,31 +604,53 @@ private:
bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
zmq::message_t& command, const cat_call_t& cat_call, std::vector<zmq::message_t>& data);
struct injected_task {
category& cat;
std::string command;
std::string remote;
std::function<void()> callback;
};
/// Injects a external callback to be handled by a worker; this is the proxy side of
/// inject_task().
void proxy_inject_task(injected_task task);
/// Set of active service nodes.
pubkey_set active_service_nodes;
/// Resets or updates the stored set of active SN pubkeys
void proxy_set_active_sns(string_view data);
void proxy_set_active_sns(std::string_view data);
void proxy_set_active_sns(pubkey_set pubkeys);
void proxy_update_active_sns(bt_list_consumer data);
void proxy_update_active_sns(pubkey_set added, pubkey_set removed);
void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed);
/// Details for a pending command; such a command already has authenticated access and is just
/// waiting for a thread to become available to handle it.
/// waiting for a thread to become available to handle it. This also gets used (via the
/// `callback` variant) for injected external jobs to be able to integrate some external
/// interface with the lokimq job queue.
struct pending_command {
category& cat;
std::string command;
std::vector<zmq::message_t> data_parts;
const std::pair<CommandCallback, bool>* callback;
std::variant<
const std::pair<CommandCallback, bool>*, // Normal command callback
std::function<void()> // Injected external callback
> callback;
ConnectionID conn;
Access access;
std::string remote;
// Normal ctor for an actual lmq command being processed
pending_command(category& cat, std::string command, std::vector<zmq::message_t> data_parts,
const std::pair<CommandCallback, bool>* callback, ConnectionID conn, Access access, std::string remote)
: cat{cat}, command{std::move(command)}, data_parts{std::move(data_parts)},
callback{callback}, conn{std::move(conn)}, access{std::move(access)}, remote{std::move(remote)} {}
// Ctor for an injected external command.
pending_command(category& cat, std::string command, std::function<void()> callback, std::string remote)
: cat{cat}, command{std::move(command)}, callback{std::move(callback)}, remote{std::move(remote)} {}
};
std::list<pending_command> pending_commands;
@ -609,9 +664,16 @@ private:
struct run_info {
bool is_batch_job = false;
bool is_reply_job = false;
bool is_tagged_thread_job = false;
bool is_injected = false;
// resets the job type bools, above.
void reset() { is_batch_job = is_reply_job = is_tagged_thread_job = is_injected = false; }
// If is_batch_job is false then these will be set appropriate (if is_batch_job is true then
// these shouldn't be accessed and likely contain stale data).
// these shouldn't be accessed and likely contain stale data). Note that if the command is
// an external, injected command then conn, access, conn_route, and data_parts will be
// empty/default constructed.
category *cat;
std::string command;
ConnectionID conn; // The connection (or SN pubkey) to reply on/to.
@ -623,24 +685,31 @@ private:
// If is_batch_job true then these are set (if is_batch_job false then don't access these!):
int batch_jobno; // >= 0 for a job, -1 for the completion job
union {
const std::pair<CommandCallback, bool>* callback; // set if !is_batch_job
detail::Batch* batch; // set if is_batch_job
};
// The callback or batch job to run. The first of these is for regular tasks, the second
// for batch jobs, the third for injected external tasks.
std::variant<
const std::pair<CommandCallback, bool>*,
detail::Batch*,
std::function<void()>
> to_run;
// These belong to the proxy thread and must not be accessed by a worker:
std::thread worker_thread;
size_t worker_id; // The index in `workers`
std::string worker_routing_id; // "w123" where 123 == worker_id
size_t worker_id; // The index in `workers` (0-n) or index+1 in `tagged_workers` (1-n)
std::string worker_routing_id; // "w123" where 123 == worker_id; "n123" for tagged threads.
/// Loads the run info with an incoming command
run_info& load(category* cat, std::string command, ConnectionID conn, Access access, std::string remote,
std::vector<zmq::message_t> data_parts, const std::pair<CommandCallback, bool>* callback);
/// Loads the run info with an injected external command
run_info& load(category* cat, std::string command, std::string remote, std::function<void()> callback);
/// Loads the run info with a stored pending command
run_info& load(pending_command&& pending);
/// Loads the run info with a batch job
run_info& load(batch_job&& bj, bool reply_job = false);
run_info& load(batch_job&& bj, bool reply_job = false, int tagged_thread = 0);
};
/// Data passed to workers for the RUN command. The proxy thread sets elements in this before
/// sending RUN to a worker then the worker uses it to get call info, and only allocates it
@ -648,6 +717,11 @@ private:
/// change it.
std::vector<run_info> workers;
/// Workers that are reserved for tagged thread tasks (as created with add_tagged_thread). The
/// queue here is similar to worker_jobs, but contains only the tagged thread's jobs. The bool
/// is whether the worker is currently busy (true) or available (false).
std::vector<std::tuple<run_info, bool, std::queue<batch_job>>> tagged_workers;
public:
/**
* LokiMQ constructor. This constructs the object but does not start it; you will typically
@ -786,6 +860,28 @@ public:
*/
void add_command_alias(std::string from, std::string to);
/** Creates a "tagged thread" and starts it immediately. A tagged thread is one that batches,
* jobs, and timer jobs can be sent to specifically, typically to perform coordination of some
* thread-unsafe work.
*
* Tagged threads will *only* process jobs sent specifically to them; they do not participate in
* the thread pool used for regular jobs. Each tagged thread also has its own job queue
* completely separate from any other jobs.
*
* Tagged threads must be created *before* `start()` is called. The name will be used to set the
* thread name in the process table (if supported on the OS).
*
* \param name - the name of the thread; will be used in log messages and (if supported by the
* OS) as the system thread name.
*
* \param start - an optional callback to invoke from the thread as soon as LokiMQ itself starts
* up (i.e. after a call to `start()`).
*
* \returns a TaggedThreadID object that can be passed to job(), batch(), or add_timer() to
* direct the task to the tagged thread.
*/
TaggedThreadID add_tagged_thread(std::string name, std::function<void()> start = nullptr);
/**
* Sets the number of worker threads reserved for batch jobs. If not explicitly called then
* this defaults to half the general worker threads configured (rounded up). This works exactly
@ -895,7 +991,7 @@ public:
* *don't* need to worry about this (and can just discard it): you can always simply pass the
* pubkey as a string wherever a ConnectionID is called.
*/
ConnectionID connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive = 5min, string_view hint = {});
ConnectionID connect_sn(std::string_view pubkey, std::chrono::milliseconds keep_alive = 5min, std::string_view hint = {});
/**
* Establish a connection to the given remote with callbacks invoked on a successful or failed
@ -912,12 +1008,11 @@ public:
* The `on_connect` and `on_failure` callbacks are invoked when a connection has been
* established or failed to establish.
*
* @param remote the remote connection address, such as `tcp://localhost:1234`.
* @param remote the remote connection address either as implicitly from a string or as a full
* lokimq::address object; see address.h for details. This specifies both the connection
* address and whether curve encryption should be used.
* @param on_connect called with the identifier after the connection has been established.
* @param on_failure called with the identifier and failure message if we fail to connect.
* @param pubkey if non-empty then connect securely (using curve encryption) and verify that the
* remote's pubkey equals the given value. Specifying this is similar to using connect_sn()
* except that we do not treat the remote as a SN for command authorization purposes.
* @param auth_level determines the authentication level of the remote for issuing commands to
* us. The default is `AuthLevel::none`.
* @param timeout how long to try before aborting the connection attempt and calling the
@ -927,8 +1022,23 @@ public:
* @param returns ConnectionID that uniquely identifies the connection to this remote node. In
* order to talk to it you will need the returned value (or a copy of it).
*/
ConnectionID connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
string_view pubkey = {},
ConnectionID connect_remote(const address& remote, ConnectSuccess on_connect, ConnectFailure on_failure,
AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT);
/// Same as the above, but takes the address as a string_view and constructs an `address` from
/// it.
ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
AuthLevel auth_level = AuthLevel::none, std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT) {
return connect_remote(address{remote}, std::move(on_connect), std::move(on_failure), auth_level, timeout);
}
/// Deprecated version of the above that takes the remote address and remote pubkey for curve
/// encryption as separate arguments. New code should either use a pubkey-embedded address
/// string, or specify remote address and pubkey with an `address` object such as:
/// connect_remote(address{remote, pubkey}, ...)
[[deprecated("use connect_remote() with a lokimq::address instead")]]
ConnectionID connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
std::string_view pubkey,
AuthLevel auth_level = AuthLevel::none,
std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT);
@ -987,7 +1097,7 @@ public:
* connection hint may be used rather than performing a connection address lookup on the pubkey.
*/
template <typename... T>
void send(ConnectionID to, string_view cmd, const T&... opts);
void send(ConnectionID to, std::string_view cmd, const T&... opts);
/** Send a command configured as a "REQUEST" command to a service node: the data parts will be
* prefixed with a random identifier. The remote is expected to reply with a ["REPLY",
@ -1019,7 +1129,33 @@ public:
* not running as a service node.
*/
template <typename... T>
void request(ConnectionID to, string_view cmd, ReplyCallback callback, const T&... opts);
void request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T&... opts);
/** Injects an external task into the lokimq command queue. This is used to allow connecting
* non-LokiMQ requests into the LokiMQ thread pool as if they were ordinary requests, to be
* scheduled as commands of an individual category. For example, you might support rpc requests
* via LokiMQ as `rpc.some_command` and *also* accept them over HTTP. Using `inject_task()`
* allows you to handle processing the request in the same thread pool with the same priority as
* `rpc.*` commands.
*
* @param category - the category name that should handle the request for the purposes of
* scheduling the job. The category must have been added using add_category(). The category
* can be an actual category with added commands, in which case the injected tasks are queued
* along with LMQ requests for that category, or can have no commands to set up a distinct
* category for the injected jobs.
*
* @param command - a command name; this is mainly used for debugging and does not need to
* actually exist (and, in fact, is often less confusing if it does not). It is recommended for
* clarity purposes to use something that doesn't look like a typical command, for example
* "(http)".
*
* @param remote - some free-form identifier of the remote connection. For example, this could
* be a remote IP address. Can be blank if there is nothing suitable.
*
* @param callback - the function to call from a worker thread when the injected task is
* processed. Takes no arguments.
*/
void inject_task(const std::string& category, std::string command, std::string remote, std::function<void()> callback);
/// The key pair this LokiMQ was created with; if empty keys were given during construction then
/// this returns the generated keys.
@ -1072,8 +1208,12 @@ public:
/**
* Queues a single job to be executed with no return value. This is a shortcut for creating and
* submitting a single-job, no-completion-function batch job.
*
* \param f the callback to invoke
* \param thread an optional tagged thread in which this job should run. You may *not* pass the
* proxy thread here.
*/
void job(std::function<void()> f);
void job(std::function<void()> f, std::optional<TaggedThreadID> = std::nullopt);
/**
* Adds a timer that gets scheduled periodically in the job queue. Normally jobs are not
@ -1081,8 +1221,10 @@ public:
* previously scheduled callback of the job has not yet completed. If you want to override this
* (so that, under heavy load or long jobs, there can be more than one of the same job scheduled
* or running at a time) then specify `squelch` as `false`.
*
* \param thread specifies a thread (added with add_tagged_thread()) on which this timer must run.
*/
void add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch = true);
void add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch = true, std::optional<TaggedThreadID> = std::nullopt);
};
/// Helper class that slightly simplifies adding commands to a category.
@ -1129,9 +1271,10 @@ struct data_parts_impl {
data_parts_impl(InputIt begin, InputIt end) : begin{std::move(begin)}, end{std::move(end)} {}
};
/// Specifies an iterator pair of data options to send, for when the number of arguments to send()
/// cannot be determined at compile time.
template <typename InputIt>
/// Specifies an iterator pair of data parts to send, for when the number of arguments to send()
/// cannot be determined at compile time. The iterator pair must be over strings or string_view (or
/// something convertible to a string_view).
template <typename InputIt, typename = std::enable_if_t<std::is_convertible_v<decltype(*std::declval<InputIt>()), std::string_view>>>
data_parts_impl<InputIt> data_parts(InputIt begin, InputIt end) { return {std::move(begin), std::move(end)}; }
/// Specifies a connection hint when passed in to send(). If there is no current connection to the
@ -1250,10 +1393,10 @@ template <typename T> T deserialize_object(uintptr_t ptrval) {
// Sends a control message to the given socket consisting of the command plus optional dict
// data (only sent if the data is non-empty).
void send_control(zmq::socket_t& sock, string_view cmd, std::string data = {});
void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data = {});
/// Base case: takes a string-like value and appends it to the message parts
inline void apply_send_option(bt_list& parts, bt_dict&, string_view arg) {
inline void apply_send_option(bt_list& parts, bt_dict&, std::string_view arg) {
parts.emplace_back(arg);
}
@ -1261,7 +1404,7 @@ inline void apply_send_option(bt_list& parts, bt_dict&, string_view arg) {
template <typename InputIt>
void apply_send_option(bt_list& parts, bt_dict&, const send_option::data_parts_impl<InputIt> data) {
for (auto it = data.begin; it != data.end; ++it)
parts.push_back(lokimq::bt_deserialize(*it));
parts.emplace_back(*it);
}
/// `hint` specialization: sets the hint in the control data
@ -1308,14 +1451,10 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queu
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg);
template <typename... T>
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) {
bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts) {
bt_dict control_data;
bt_list parts{{cmd}};
#ifdef __cpp_fold_expressions
(detail::apply_send_option(parts, control_data, std::forward<T>(opts)),...);
#else
(void) std::initializer_list<int>{(detail::apply_send_option(parts, control_data, std::forward<T>(opts)), 0)...};
#endif
if (to.sn())
control_data["conn_pubkey"] = std::move(to.pk);
@ -1332,7 +1471,7 @@ bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) {
template <typename... T>
void LokiMQ::send(ConnectionID to, string_view cmd, const T&... opts) {
void LokiMQ::send(ConnectionID to, std::string_view cmd, const T&... opts) {
detail::send_control(get_control_socket(), "SEND",
bt_serialize(detail::build_send(std::move(to), cmd, opts...)));
}
@ -1340,17 +1479,17 @@ void LokiMQ::send(ConnectionID to, string_view cmd, const T&... opts) {
std::string make_random_string(size_t size);
template <typename... T>
void LokiMQ::request(ConnectionID to, string_view cmd, ReplyCallback callback, const T &...opts) {
void LokiMQ::request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T &...opts) {
const auto reply_tag = make_random_string(15); // 15 random bytes is lots and should keep us in most stl implementations' small string optimization
bt_dict control_data = detail::build_send(std::move(to), cmd, reply_tag, opts...);
control_data["request"] = true;
control_data["request_callback"] = detail::serialize_object(std::move(callback));
control_data["request_tag"] = string_view{reply_tag};
control_data["request_tag"] = std::string_view{reply_tag};
detail::send_control(get_control_socket(), "SEND", bt_serialize(std::move(control_data)));
}
template <typename... Args>
void Message::send_back(string_view command, Args&&... args) {
void Message::send_back(std::string_view command, Args&&... args) {
lokimq.send(conn, command, send_option::optional{!conn.sn()}, std::forward<Args>(args)...);
}
@ -1361,14 +1500,14 @@ void Message::send_reply(Args&&... args) {
}
template <typename Callback, typename... Args>
void Message::send_request(string_view cmd, Callback&& callback, Args&&... args) {
void Message::send_request(std::string_view cmd, Callback&& callback, Args&&... args) {
lokimq.request(conn, cmd, std::forward<Callback>(callback),
send_option::optional{!conn.sn()}, std::forward<Args>(args)...);
}
// When log messages are invoked we strip out anything before this in the filename:
constexpr string_view LOG_PREFIX{"lokimq/", 7};
inline string_view trim_log_filename(string_view local_file) {
constexpr std::string_view LOG_PREFIX{"lokimq/", 7};
inline std::string_view trim_log_filename(std::string_view local_file) {
auto chop = local_file.rfind(LOG_PREFIX);
if (chop != local_file.npos)
local_file.remove_prefix(chop);
@ -1376,16 +1515,12 @@ inline string_view trim_log_filename(string_view local_file) {
}
template <typename... T>
void LokiMQ::log_(LogLevel lvl, const char* file, int line, const T&... stuff) {
void LokiMQ::log(LogLevel lvl, const char* file, int line, const T&... stuff) {
if (log_level() < lvl)
return;
std::ostringstream os;
#ifdef __cpp_fold_expressions
(os << ... << stuff);
#else
(void) std::initializer_list<int>{(os << stuff, 0)...};
#endif
logger(lvl, trim_log_filename(file).data(), line, os.str());
}

View File

@ -12,7 +12,7 @@ class LokiMQ;
class Message {
public:
LokiMQ& lokimq; ///< The owning LokiMQ object
std::vector<string_view> data; ///< The provided command data parts, if any.
std::vector<std::string_view> data; ///< The provided command data parts, if any.
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
Access access; ///< The access level of the invoker. This can be higher than the access level of the command, for example for an admin invoking a basic command.
@ -36,7 +36,7 @@ public:
/// If you want to send a non-strong reply even when the remote is a service node then add
/// an explicit `send_option::optional()` argument.
template <typename... Args>
void send_back(string_view, Args&&... args);
void send_back(std::string_view, Args&&... args);
/// Sends a reply to a request. This takes no command: the command is always the built-in
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
@ -51,7 +51,7 @@ public:
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
/// lmq.request() that takes care of setting up the recipient arguments.
template <typename ReplyCallback, typename... Args>
void send_request(string_view cmd, ReplyCallback&& callback, Args&&... args);
void send_request(std::string_view cmd, ReplyCallback&& callback, Args&&... args);
};
}

View File

@ -9,6 +9,13 @@ extern "C" {
}
#endif
#ifndef _WIN32
extern "C" {
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
}
#endif
namespace lokimq {
@ -16,11 +23,12 @@ void LokiMQ::proxy_quit() {
LMQ_LOG(debug, "Received quit command, shutting down proxy thread");
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
assert(std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto& worker) { return std::get<0>(worker).worker_thread.joinable(); }));
command.setsockopt<int>(ZMQ_LINGER, 0);
command.close();
{
std::lock_guard<std::mutex> lock{control_sockets_mutex};
std::lock_guard lock{control_sockets_mutex};
for (auto &control : thread_control_sockets)
control->close();
proxy_shutting_down = true; // To prevent threads from opening new control sockets
@ -37,7 +45,7 @@ void LokiMQ::proxy_quit() {
void LokiMQ::proxy_send(bt_dict_consumer data) {
// NB: bt_dict_consumer goes in alphabetical order
string_view hint;
std::string_view hint;
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
bool optional = false;
@ -106,7 +114,6 @@ void LokiMQ::proxy_send(bt_dict_consumer data) {
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
// connection and try the next one.
bool retry = true, sent = false, nowarn = false;
std::unique_ptr<zmq::error_t> send_error;
while (retry) {
retry = false;
zmq::socket_t *send_to;
@ -265,6 +272,9 @@ void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
LMQ_TRACE("proxy batch jobs");
auto ptrval = bt_deserialize<uintptr_t>(data);
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
} else if (cmd == "INJECT") {
LMQ_TRACE("proxy inject");
return proxy_inject_task(detail::deserialize_object<injected_task>(bt_deserialize<uintptr_t>(data)));
} else if (cmd == "SET_SNS") {
return proxy_set_active_sns(data);
} else if (cmd == "UPDATE_SNS") {
@ -292,6 +302,9 @@ void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
for (const auto &route : idle_workers)
route_control(workers_socket, workers[route].worker_routing_id, "QUIT");
idle_workers.clear();
for (auto& [run, busy, queue] : tagged_workers)
if (!busy)
route_control(workers_socket, run.worker_routing_id, "QUIT");
return;
}
}
@ -332,12 +345,19 @@ void LokiMQ::proxy_loop() {
LMQ_LOG(debug, " - ", cat.first, ": ", cat.second.reserved_threads);
LMQ_LOG(debug, " - (batch jobs): ", batch_jobs_reserved);
LMQ_LOG(debug, " - (reply jobs): ", reply_jobs_reserved);
LMQ_LOG(debug, "Plus ", tagged_workers.size(), " tagged worker threads");
}
workers.reserve(max_workers);
if (!workers.empty())
throw std::logic_error("Internal error: proxy thread started with active worker threads");
#ifndef _WIN32
int saved_umask = -1;
if (STARTUP_UMASK >= 0)
saved_umask = umask(STARTUP_UMASK);
#endif
for (size_t i = 0; i < bind.size(); i++) {
auto& b = bind[i].second;
zmq::socket_t listener{context, zmq::socket_type::router};
@ -362,6 +382,24 @@ void LokiMQ::proxy_loop() {
incoming_conn_index[conn_id] = connections.size() - 1;
b.index = connections.size() - 1;
}
#ifndef _WIN32
if (saved_umask != -1)
umask(saved_umask);
// set socket gid / uid if it is provided
if (SOCKET_GID != -1 or SOCKET_UID != -1) {
for(size_t i = 0; i < bind.size(); i++) {
const address addr(bind[i].first);
if(addr.ipc()) {
if(chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) {
throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno));
}
}
}
}
#endif
pollitems_stale = true;
// Also add an internal connection to self so that calling code can avoid needing to
@ -385,10 +423,38 @@ void LokiMQ::proxy_loop() {
std::vector<zmq::message_t> parts;
// Wait for tagged worker threads to get ready and connect to us (we get a "STARTING" message)
// and send them back a "START" to let them know to go ahead with startup. We need this
// synchronization dance to guarantee that the workers are routable before we can proceed.
if (!tagged_workers.empty()) {
LMQ_LOG(debug, "Waiting for tagged workers");
std::unordered_set<std::string_view> waiting_on;
for (auto& w : tagged_workers)
waiting_on.emplace(std::get<run_info>(w).worker_routing_id);
for (; !waiting_on.empty(); parts.clear()) {
recv_message_parts(workers_socket, parts);
if (parts.size() != 2 || view(parts[1]) != "STARTING"sv) {
LMQ_LOG(error, "Received invalid message on worker socket while waiting for tagged thread startup");
continue;
}
LMQ_LOG(debug, "Received STARTING message from ", view(parts[0]));
if (auto it = waiting_on.find(view(parts[0])); it != waiting_on.end())
waiting_on.erase(it);
else
LMQ_LOG(error, "Received STARTING message from unknown worker ", view(parts[0]));
}
for (auto&w : tagged_workers) {
LMQ_LOG(debug, "Telling tagged thread worker ", std::get<run_info>(w).worker_routing_id, " to finish startup");
route_control(workers_socket, std::get<run_info>(w).worker_routing_id, "START");
}
}
while (true) {
std::chrono::milliseconds poll_timeout;
if (max_workers == 0) { // Will be 0 only if we are quitting
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); })) {
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); }) &&
std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto &w) { return std::get<0>(w).worker_thread.joinable(); })) {
// All the workers have finished, so we can finish shutting down
return proxy_quit();
}
@ -478,7 +544,7 @@ void LokiMQ::proxy_loop() {
}
}
static bool is_error_response(string_view cmd) {
static bool is_error_response(std::string_view cmd) {
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
}
@ -488,7 +554,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
// Doubling as a bool and an offset:
size_t incoming = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_ROUTER;
string_view route, cmd;
std::string_view route, cmd;
if (parts.size() < 1 + incoming) {
LMQ_LOG(warn, "Received empty message; ignoring");
return true;
@ -598,7 +664,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
}
} else {
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"_sv),
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"sv),
" from ", peer_address(parts.back()));
}
return true;
@ -607,7 +673,19 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
}
void LokiMQ::proxy_process_queue() {
// First up: process any batch jobs; since these are internal they are given higher priority.
if (max_workers == 0) // shutting down
return;
// First: send any tagged thread tasks to the tagged threads, if idle
for (auto& [run, busy, queue] : tagged_workers) {
if (!busy && !queue.empty()) {
busy = true;
proxy_run_worker(run.load(std::move(queue.front()), false, run.worker_id));
queue.pop();
}
}
// Second: process any batch jobs; since these are internal they are given higher priority.
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
// Next any reply batch jobs (which are a bit different from the above, since they are
@ -630,5 +708,4 @@ void LokiMQ::proxy_process_queue() {
}
}
}

View File

@ -1,310 +1,15 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#ifdef __cpp_lib_string_view
#include <string_view>
namespace lokimq {
// Deprecated type alias for std::string_view
using string_view = std::string_view;
using ustring_view = std::basic_string_view<unsigned char>;
}
#else
#include <ostream>
#include <limits>
namespace lokimq {
/// Basic implementation of std::string_view (except for std::hash support).
template <typename CharT>
class simple_string_view {
const CharT *data_;
size_t size_;
public:
using traits_type = std::char_traits<CharT>;
using value_type = CharT;
using pointer = CharT*;
using const_pointer = const CharT*;
using reference = CharT&;
using const_reference = const CharT&;
using const_iterator = const_pointer;
using iterator = const_iterator;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using reverse_iterator = const_reverse_iterator;
using size_type = std::size_t;
using different_type = std::ptrdiff_t;
static constexpr auto& npos = std::string::npos;
constexpr simple_string_view() noexcept : data_{nullptr}, size_{0} {}
constexpr simple_string_view(const simple_string_view&) noexcept = default;
simple_string_view(const std::basic_string<CharT>& str) : data_{str.data()}, size_{str.size()} {}
constexpr simple_string_view(const CharT* data, size_t size) noexcept : data_{data}, size_{size} {}
simple_string_view(const CharT* data) : data_{data}, size_{traits_type::length(data)} {}
simple_string_view& operator=(const simple_string_view&) = default;
constexpr const CharT* data() const noexcept { return data_; }
constexpr size_t size() const noexcept { return size_; }
constexpr size_t length() const noexcept { return size_; }
constexpr size_t max_size() const noexcept { return std::numeric_limits<size_t>::max(); }
constexpr bool empty() const noexcept { return size_ == 0; }
explicit operator std::basic_string<CharT>() const { return {data_, size_}; }
constexpr const CharT* begin() const noexcept { return data_; }
constexpr const CharT* cbegin() const noexcept { return data_; }
constexpr const CharT* end() const noexcept { return data_ + size_; }
constexpr const CharT* cend() const noexcept { return data_ + size_; }
reverse_iterator rbegin() const { return reverse_iterator{end()}; }
reverse_iterator crbegin() const { return reverse_iterator{end()}; }
reverse_iterator rend() const { return reverse_iterator{begin()}; }
reverse_iterator crend() const { return reverse_iterator{begin()}; }
constexpr const CharT& operator[](size_t pos) const { return data_[pos]; }
constexpr const CharT& front() const { return *data_; }
constexpr const CharT& back() const { return data_[size_ - 1]; }
int compare(simple_string_view s) const;
constexpr void remove_prefix(size_t n) { data_ += n; size_ -= n; }
constexpr void remove_suffix(size_t n) { size_ -= n; }
void swap(simple_string_view &s) noexcept { std::swap(data_, s.data_); std::swap(size_, s.size_); }
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
constexpr // GCC 5.x is buggy wrt constexpr throwing
#endif
const CharT& at(size_t pos) const {
if (pos >= size())
throw std::out_of_range{"invalid string_view index"};
return data_[pos];
};
size_t copy(CharT* dest, size_t count, size_t pos = 0) const {
if (pos > size()) throw std::out_of_range{"invalid copy pos"};
size_t rcount = std::min(count, size_ - pos);
traits_type::copy(dest, data_ + pos, rcount);
return rcount;
}
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
constexpr // GCC 5.x is buggy wrt constexpr throwing
#endif
simple_string_view substr(size_t pos = 0, size_t count = npos) const {
if (pos > size()) throw std::out_of_range{"invalid substr range"};
simple_string_view result = *this;
if (pos > 0) result.remove_prefix(pos);
if (count < result.size()) result.remove_suffix(result.size() - count);
return result;
}
size_t find(simple_string_view v, size_t pos = 0) const {
if (pos > size_ || v.size_ > size_) return npos;
for (const size_t max_pos = size_ - v.size_; pos <= max_pos; ++pos) {
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
return pos;
}
return npos;
}
size_t find(CharT c, size_t pos = 0) const { return find({&c, 1}, pos); }
size_t find(const CharT* c, size_t pos, size_t count) const { return find({c, count}, pos); }
size_t find(const CharT* c, size_t pos = 0) const { return find(simple_string_view(c), pos); }
size_t rfind(simple_string_view v, size_t pos = npos) const {
if (v.size_ > size_) return npos;
const size_t max_pos = size_ - v.size_;
for (pos = std::min(pos, max_pos); pos <= max_pos; --pos) {
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
return pos;
}
return npos;
}
size_t rfind(CharT c, size_t pos = npos) const { return rfind({&c, 1}, pos); }
size_t rfind(const CharT* c, size_t pos, size_t count) const { return rfind({c, count}, pos); }
size_t rfind(const CharT* c, size_t pos = npos) const { return rfind(simple_string_view(c), pos); }
constexpr size_t find_first_of(simple_string_view v, size_t pos = 0) const noexcept {
for (; pos < size_; ++pos)
for (CharT c : v)
if (data_[pos] == c)
return pos;
return npos;
}
constexpr size_t find_first_of(CharT c, size_t pos = 0) const noexcept { return find_first_of({&c, 1}, pos); }
constexpr size_t find_first_of(const CharT* c, size_t pos, size_t count) const { return find_first_of({c, count}, pos); }
size_t find_first_of(const CharT* c, size_t pos = 0) const { return find_first_of(simple_string_view(c), pos); }
constexpr size_t find_last_of(simple_string_view v, const size_t pos = npos) const noexcept {
if (size_ == 0) return npos;
const size_t last_pos = std::min(pos, size_-1);
for (size_t i = last_pos; i <= last_pos; --i)
for (CharT c : v)
if (data_[i] == c)
return i;
return npos;
}
constexpr size_t find_last_of(CharT c, size_t pos = npos) const noexcept { return find_last_of({&c, 1}, pos); }
constexpr size_t find_last_of(const CharT* c, size_t pos, size_t count) const { return find_last_of({c, count}, pos); }
size_t find_last_of(const CharT* c, size_t pos = npos) const { return find_last_of(simple_string_view(c), pos); }
constexpr size_t find_first_not_of(simple_string_view v, size_t pos = 0) const noexcept {
for (; pos < size_; ++pos) {
bool none = true;
for (CharT c : v) {
if (data_[pos] == c) {
none = false;
break;
}
}
if (none) return pos;
}
return npos;
}
constexpr size_t find_first_not_of(CharT c, size_t pos = 0) const noexcept { return find_first_not_of({&c, 1}, pos); }
constexpr size_t find_first_not_of(const CharT* c, size_t pos, size_t count) const { return find_first_not_of({c, count}, pos); }
size_t find_first_not_of(const CharT* c, size_t pos = 0) const { return find_first_not_of(simple_string_view(c), pos); }
constexpr size_t find_last_not_of(simple_string_view v, const size_t pos = npos) const noexcept {
if (size_ == 0) return npos;
const size_t last_pos = std::min(pos, size_-1);
for (size_t i = last_pos; i <= last_pos; --i) {
bool none = true;
for (CharT c : v) {
if (data_[i] == c) {
none = false;
break;
}
}
if (none) return i;
}
return npos;
}
constexpr size_t find_last_not_of(CharT c, size_t pos = npos) const noexcept { return find_last_not_of({&c, 1}, pos); }
constexpr size_t find_last_not_of(const CharT* c, size_t pos, size_t count) const { return find_last_not_of({c, count}, pos); }
size_t find_last_not_of(const CharT* c, size_t pos = npos) const { return find_last_not_of(simple_string_view(c), pos); }
};
/// We have three of each of these: one with two string views, one with RHS argument deduction, and
/// one with LHS argument deduction, so that you can do (sv == sv), (sv == "foo"), and ("foo" == sv)
template <typename CharT>
inline bool operator==(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
};
template <typename CharT>
inline bool operator==(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
};
template <typename CharT>
inline bool operator==(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
};
template <typename CharT>
inline bool operator!=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
return !(lhs == rhs);
}
template <typename CharT>
inline bool operator!=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
return !(lhs == rhs);
}
template <typename CharT>
inline bool operator!=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
return !(lhs == rhs);
}
template <typename CharT>
inline int simple_string_view<CharT>::compare(simple_string_view s) const {
int cmp = std::char_traits<CharT>::compare(data_, s.data(), std::min(size_, s.size()));
if (cmp) return cmp;
if (size_ < s.size()) return -1;
else if (size_ > s.size()) return 1;
return 0;
}
template <typename CharT>
inline bool operator<(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) < 0;
};
template <typename CharT>
inline bool operator<(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
return lhs.compare(rhs) < 0;
};
template <typename CharT>
inline bool operator<(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) < 0;
};
template <typename CharT>
inline bool operator<=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) <= 0;
};
template <typename CharT>
inline bool operator<=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
return lhs.compare(rhs) <= 0;
};
template <typename CharT>
inline bool operator<=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) <= 0;
};
template <typename CharT>
inline bool operator>(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) > 0;
};
template <typename CharT>
inline bool operator>(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
return lhs.compare(rhs) > 0;
};
template <typename CharT>
inline bool operator>(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) > 0;
};
template <typename CharT>
inline bool operator>=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) >= 0;
};
template <typename CharT>
inline bool operator>=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
return lhs.compare(rhs) >= 0;
};
template <typename CharT>
inline bool operator>=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
return lhs.compare(rhs) >= 0;
};
template <typename CharT>
inline std::basic_ostream<CharT>& operator<<(std::basic_ostream<CharT>& os, const simple_string_view<CharT>& s) {
os.write(s.data(), s.size());
return os;
}
using string_view = simple_string_view<char>;
using ustring_view = simple_string_view<unsigned char>;
}
#endif
// Add a "foo"_sv literal that works exactly like the C++17 "foo"sv literal, but works with our
// implementation in pre-C++17.
namespace lokimq {
// Deprecated "foo"_sv literal; you should use "foo"sv (from <string_view>) instead.
inline namespace literals {
inline constexpr string_view operator""_sv(const char* str, size_t len) { return {str, len}; }
inline constexpr std::string_view operator""_sv(const char* str, size_t len) { return {str, len}; }
}
}

View File

@ -1,6 +1,5 @@
#include "lokimq.h"
#include "batch.h"
#include "hex.h"
#include "lokimq-internal.h"
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
@ -12,36 +11,107 @@ extern "C" {
namespace lokimq {
void LokiMQ::worker_thread(unsigned int index) {
std::string worker_id = "w" + std::to_string(index);
namespace {
// Waits for a specific command or "QUIT" on the given socket. Returns true if the command was
// received. If "QUIT" was received, replies with "QUITTING" on the socket and closes it, then
// returns false.
[[gnu::always_inline]] inline
bool worker_wait_for(LokiMQ& lmq, zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const std::string_view worker_id, const std::string_view expect) {
while (true) {
lmq.log(LogLevel::debug, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect);
parts.clear();
recv_message_parts(sock, parts);
if (parts.size() != 1) {
lmq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part control msg");
continue;
}
auto command = view(parts[0]);
if (command == expect) {
#ifndef NDEBUG
lmq.log(LogLevel::trace, __FILE__, __LINE__, "Worker ", worker_id, " received waited-for ", expect, " command");
#endif
return true;
} else if (command == "QUIT"sv) {
lmq.log(LogLevel::debug, __FILE__, __LINE__, "Worker ", worker_id, " received QUIT command, shutting down");
detail::send_control(sock, "QUITTING");
sock.setsockopt<int>(ZMQ_LINGER, 1000);
sock.close();
return false;
} else {
lmq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
}
}
}
}
void LokiMQ::worker_thread(unsigned int index, std::optional<std::string> tagged, std::function<void()> start) {
std::string routing_id = (tagged ? "t" : "w") + std::to_string(index); // for routing
std::string_view worker_id{tagged ? *tagged : routing_id}; // for debug
[[maybe_unused]] std::string thread_name = tagged.value_or("lmq-" + routing_id);
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
pthread_setname_np(pthread_self(), ("lmq-" + worker_id).c_str());
if (thread_name.size() > 15) thread_name.resize(15);
pthread_setname_np(pthread_self(), thread_name.c_str());
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
pthread_set_name_np(pthread_self(), ("lmq-" + worker_id).c_str());
pthread_set_name_np(pthread_self(), thread_name.c_str());
#elif defined(__MACH__)
pthread_setname_np(("lmq-" + worker_id).c_str());
pthread_setname_np(thread_name.c_str());
#endif
zmq::socket_t sock{context, zmq::socket_type::dealer};
sock.setsockopt(ZMQ_ROUTING_ID, worker_id.data(), worker_id.size());
LMQ_LOG(debug, "New worker thread ", worker_id, " started");
sock.setsockopt(ZMQ_ROUTING_ID, routing_id.data(), routing_id.size());
LMQ_LOG(debug, "New worker thread ", worker_id, " (", routing_id, ") started");
sock.connect(SN_ADDR_WORKERS);
if (tagged)
detail::send_control(sock, "STARTING");
Message message{*this, 0, AuthLevel::none, ""s};
std::vector<zmq::message_t> parts;
run_info& run = workers[index]; // This contains our first job, and will be updated later with subsequent jobs
bool waiting_for_command;
if (tagged) {
// If we're a tagged worker then we got started up before LokiMQ started, so we need to wait
// for an all-clear signal from LokiMQ first, then we fire our `start` callback, then we can
// start waiting for commands in the main loop further down. (We also can't get the
// reference to our `tagged_workers` element until the main proxy threads is running).
waiting_for_command = true;
if (!worker_wait_for(*this, sock, parts, worker_id, "START"sv))
return;
if (start) start();
} else {
// Otherwise for a regular worker we can only be started by an active main proxy thread
// which will have preloaded our first job so we can start off right away.
waiting_for_command = false;
}
// This will always contains the current job, and is guaranteed to never be invalidated.
run_info& run = tagged ? std::get<run_info>(tagged_workers[index - 1]) : workers[index];
while (true) {
if (waiting_for_command) {
if (!worker_wait_for(*this, sock, parts, worker_id, "RUN"sv))
return;
}
try {
if (run.is_batch_job) {
auto* batch = std::get<detail::Batch*>(run.to_run);
if (run.batch_jobno >= 0) {
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, "#", run.batch_jobno);
run.batch->run_job(run.batch_jobno);
LMQ_TRACE("worker thread ", worker_id, " running batch ", batch, "#", run.batch_jobno);
batch->run_job(run.batch_jobno);
} else if (run.batch_jobno == -1) {
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, " completion");
run.batch->job_completion();
LMQ_TRACE("worker thread ", worker_id, " running batch ", batch, " completion");
batch->job_completion();
}
} else if (run.is_injected) {
auto& func = std::get<std::function<void()>>(run.to_run);
LMQ_TRACE("worker thread ", worker_id, " invoking injected command ", run.command);
func();
func = nullptr;
} else {
message.conn = run.conn;
message.access = run.access;
@ -50,7 +120,8 @@ void LokiMQ::worker_thread(unsigned int index) {
LMQ_TRACE("Got incoming command from ", message.remote, "/", message.conn, message.conn.route.empty() ? " (outgoing)" : " (incoming)");
if (run.callback->second /*is_request*/) {
auto& [callback, is_request] = *std::get<const std::pair<CommandCallback, bool>*>(run.to_run);
if (is_request) {
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
message.data.emplace_back(it->data<char>(), it->size());
@ -60,13 +131,13 @@ void LokiMQ::worker_thread(unsigned int index) {
}
LMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
run.callback->first(message);
callback(message);
}
}
catch (const bt_deserialize_invalid& e) {
LMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
}
catch (const mapbox::util::bad_variant_access& e) {
catch (const std::bad_variant_access& e) {
LMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
}
catch (const std::out_of_range& e) {
@ -79,32 +150,9 @@ void LokiMQ::worker_thread(unsigned int index) {
LMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
}
while (true) {
// Signal that we are ready for another job and wait for it. (We do this down here
// because our first job gets set up when the thread is started).
detail::send_control(sock, "RAN");
LMQ_TRACE("worker ", worker_id, " waiting for requests");
parts.clear();
recv_message_parts(sock, parts);
if (parts.size() != 1) {
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part worker instruction");
continue;
}
auto command = view(parts[0]);
if (command == "RUN") {
LMQ_LOG(debug, "worker ", worker_id, " running command ", run.command);
break; // proxy has set up a command for us, go back and run it.
} else if (command == "QUIT") {
LMQ_LOG(debug, "worker ", worker_id, " shutting down");
detail::send_control(sock, "QUITTING");
sock.setsockopt<int>(ZMQ_LINGER, 1000);
sock.close();
return;
} else {
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
}
}
// Tell the proxy thread that we are ready for another job
detail::send_control(sock, "RAN");
waiting_for_command = true;
}
}
@ -132,52 +180,73 @@ void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
}
auto route = view(parts[0]), cmd = view(parts[1]);
LMQ_TRACE("worker message from ", route);
assert(route.size() >= 2 && route[0] == 'w' && route[1] >= '0' && route[1] <= '9');
string_view worker_id_str{&route[1], route.size()-1}; // Chop off the leading "w"
assert(route.size() >= 2 && (route[0] == 'w' || route[0] == 't') && route[1] >= '0' && route[1] <= '9');
bool tagged_worker = route[0] == 't';
std::string_view worker_id_str{&route[1], route.size()-1}; // Chop off the leading "w" (or "t")
unsigned int worker_id = detail::extract_unsigned(worker_id_str);
if (!worker_id_str.empty() /* didn't consume everything */ || worker_id >= workers.size()) {
if (!worker_id_str.empty() /* didn't consume everything */ ||
(tagged_worker
? 0 == worker_id || worker_id > tagged_workers.size() // tagged worker ids are indexed from 1 to N (0 means untagged)
: worker_id >= workers.size() // regular worker ids are indexed from 0 to N-1
)
) {
LMQ_LOG(error, "Worker id '", route, "' is invalid, unable to process worker command");
return;
}
auto& run = workers[worker_id];
auto& run = tagged_worker ? std::get<run_info>(tagged_workers[worker_id - 1]) : workers[worker_id];
LMQ_TRACE("received ", cmd, " command from ", route);
if (cmd == "RAN") {
LMQ_LOG(debug, "Worker ", route, " finished ", run.command);
if (cmd == "RAN"sv) {
LMQ_TRACE("Worker ", route, " finished ", run.is_batch_job ? "batch job" : run.command);
if (run.is_batch_job) {
auto& jobs = run.is_reply_job ? reply_jobs : batch_jobs;
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
assert(active > 0);
active--;
if (tagged_worker) {
std::get<bool>(tagged_workers[worker_id - 1]) = false;
} else {
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
assert(active > 0);
active--;
}
bool clear_job = false;
auto* batch = std::get<detail::Batch*>(run.to_run);
if (run.batch_jobno == -1) {
// Returned from the completion function
clear_job = true;
} else {
auto status = run.batch->job_finished();
if (status == detail::BatchStatus::complete) {
jobs.emplace(run.batch, -1);
} else if (status == detail::BatchStatus::complete_proxy) {
try {
run.batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
} catch (const std::exception &e) {
// Raise these to error levels: the caller really shouldn't be doing
// anything non-trivial in an in-proxy completion function!
LMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
} catch (...) {
LMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
auto [state, thread] = batch->job_finished();
if (state == detail::BatchState::complete) {
if (thread == -1) { // run directly in proxy
LMQ_TRACE("Completion job running directly in proxy");
try {
batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
} catch (const std::exception &e) {
// Raise these to error levels: the caller really shouldn't be doing
// anything non-trivial in an in-proxy completion function!
LMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
} catch (...) {
LMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
}
clear_job = true;
} else {
auto& jobs =
thread > 0
? std::get<std::queue<batch_job>>(tagged_workers[thread - 1]) // run in tagged thread
: run.is_reply_job
? reply_jobs
: batch_jobs;
jobs.emplace(batch, -1);
}
clear_job = true;
} else if (status == detail::BatchStatus::done) {
} else if (state == detail::BatchState::done) {
// No completion job
clear_job = true;
}
// else the job is still running
}
if (clear_job) {
batches.erase(run.batch);
delete run.batch;
run.batch = nullptr;
batches.erase(batch);
delete batch;
run.to_run = static_cast<detail::Batch*>(nullptr);
}
} else {
assert(run.cat->active_threads > 0);
@ -186,11 +255,11 @@ void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
if (max_workers == 0) { // Shutting down
LMQ_TRACE("Telling worker ", route, " to quit");
route_control(workers_socket, route, "QUIT");
} else {
} else if (!tagged_worker) {
idle_workers.push_back(worker_id);
}
} else if (cmd == "QUITTING") {
workers[worker_id].worker_thread.join();
} else if (cmd == "QUITTING"sv) {
run.worker_thread.join();
LMQ_LOG(debug, "Worker ", route, " exited normally");
} else {
LMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
@ -199,7 +268,7 @@ void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
void LokiMQ::proxy_run_worker(run_info& run) {
if (!run.worker_thread.joinable())
run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id};
run.worker_thread = std::thread{[this, id=run.worker_id] { worker_thread(id); }};
else
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
}
@ -306,5 +375,38 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
category.active_threads++;
}
void LokiMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function<void()> callback) {
if (!callback) return;
auto it = categories.find(category);
if (it == categories.end())
throw std::out_of_range{"Invalid category `" + category + "': category does not exist"};
detail::send_control(get_control_socket(), "INJECT", bt_serialize(detail::serialize_object(
injected_task{it->second, std::move(command), std::move(remote), std::move(callback)})));
}
void LokiMQ::proxy_inject_task(injected_task task) {
auto& category = task.cat;
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
// No free worker slot, queue for later
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
LMQ_LOG(warn, "No space to queue injected task ", task.command, "; already have ", category.queued,
"commands queued in that category (max ", category.max_queue, "); dropping task");
return;
}
LMQ_LOG(debug, "No available free workers for injected task ", task.command, "; queuing for later");
pending_commands.emplace_back(category, std::move(task.command), std::move(task.callback), std::move(task.remote));
category.queued++;
return;
}
auto& run = get_idle_worker();
LMQ_TRACE("Forwarding incoming injected task ", task.command, " from ", task.remote, " to worker ", run.worker_routing_id);
run.load(&category, std::move(task.command), std::move(task.remote), std::move(task.callback));
proxy_run_worker(run);
category.active_threads++;
}
}

@ -1 +0,0 @@
Subproject commit c94634bbd294204c9ba3f5b267a39582a52e8e5a

View File

@ -3,24 +3,26 @@ add_subdirectory(Catch2)
set(LMQ_TEST_SRC
main.cpp
test_address.cpp
test_batch.cpp
test_bt.cpp
test_connect.cpp
test_commands.cpp
test_encoding.cpp
test_failures.cpp
test_inject.cpp
test_requests.cpp
test_string_view.cpp
test_tagged_threads.cpp
)
add_executable(tests ${LMQ_TEST_SRC})
find_package(Threads)
find_package(PkgConfig REQUIRED)
pkg_check_modules(SODIUM REQUIRED libsodium)
target_link_libraries(tests Catch2::Catch2 lokimq ${SODIUM_LIBRARIES} Threads::Threads)
target_link_libraries(tests Catch2::Catch2 lokimq Threads::Threads)
set_target_properties(tests PROPERTIES
CXX_STANDARD 14
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
)

131
tests/test_address.cpp Normal file
View File

@ -0,0 +1,131 @@
#include "lokimq/address.h"
#include "common.h"
const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88";
const std::string pk_hex = "f16ba5591039f089b42a834175093094074d0d937a79e53e5ce730f946e14b88";
const std::string pk_HEX = "F16BA5591039F089B42A834175093094074D0D937A79E53E5CE730F946E14B88";
const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqry";
const std::string pk_B32Z = "6FI4KSEO88AEUPBKOPYZKNJO1ODW4DCUXJH6KX1HHHAX1TZBJQRY";
const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g"; // NB: padding '=' omitted
TEST_CASE("tcp addresses", "[address][tcp]") {
address a{"tcp://1.2.3.4:5678"};
REQUIRE( a.host == "1.2.3.4" );
REQUIRE( a.port == 5678 );
REQUIRE_FALSE( a.curve() );
REQUIRE( a.tcp() );
REQUIRE( a.zmq_address() == "tcp://1.2.3.4:5678" );
REQUIRE( a.full_address() == "tcp://1.2.3.4:5678" );
REQUIRE( a.qr_address() == "TCP://1.2.3.4:5678" );
REQUIRE_THROWS_AS( address{"tcp://1:1:1"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcpz://localhost:123"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcp://abc"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcpz://localhost:0"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcpz://[::1:1080"}, std::invalid_argument );
address b = address::tcp("example.com", 80);
REQUIRE( b.host == "example.com" );
REQUIRE( b.port == 80 );
REQUIRE_FALSE( b.curve() );
REQUIRE( b.tcp() );
REQUIRE( b.zmq_address() == "tcp://example.com:80" );
REQUIRE( b.full_address() == "tcp://example.com:80" );
REQUIRE( b.qr_address() == "TCP://EXAMPLE.COM:80" );
address c{"tcp://[::1]:1111"};
REQUIRE( c.host == "[::1]" );
REQUIRE( c.port == 1111 );
}
TEST_CASE("unix sockets", "[address][ipc]") {
address a{"ipc:///path/to/foo"};
REQUIRE( a.socket == "/path/to/foo" );
REQUIRE_FALSE( a.curve() );
REQUIRE_FALSE( a.tcp() );
REQUIRE( a.zmq_address() == "ipc:///path/to/foo" );
REQUIRE( a.full_address() == "ipc:///path/to/foo" );
address b = address::ipc("../foo");
REQUIRE( b.socket == "../foo" );
REQUIRE_FALSE( b.curve() );
REQUIRE_FALSE( b.tcp() );
REQUIRE( b.zmq_address() == "ipc://../foo" );
REQUIRE( b.full_address() == "ipc://../foo" );
}
TEST_CASE("pubkey formats", "[address][curve][pubkey]") {
address a{"tcp+curve://a:1/" + pk_hex};
address b{"curve://a:1/" + pk_b32z};
address c{"curve://a:1/" + pk_b64};
address d{"CURVE://A:1/" + pk_B32Z};
REQUIRE( a.curve() );
REQUIRE( a.host == "a" );
REQUIRE( a.port == 1 );
REQUIRE((b.curve() && c.curve() && d.curve()));
REQUIRE( a.pubkey == pk );
REQUIRE( b.pubkey == pk );
REQUIRE( c.pubkey == pk );
REQUIRE( d.pubkey == pk );
address e{"ipc+curve://my.sock/" + pk_hex};
address f{"ipc+curve://../my.sock/" + pk_b32z};
address g{"ipc+curve:///my.sock/" + pk_B32Z};
address h{"ipc+curve://./my.sock/" + pk_b64};
REQUIRE( e.curve() );
REQUIRE( e.ipc() );
REQUIRE_FALSE( e.tcp() );
REQUIRE((f.curve() && g.curve() && h.curve()));
REQUIRE( e.socket == "my.sock" );
REQUIRE( f.socket == "../my.sock" );
REQUIRE( g.socket == "/my.sock" );
REQUIRE( h.socket == "./my.sock" );
REQUIRE( e.pubkey == pk );
REQUIRE( f.pubkey == pk );
REQUIRE( g.pubkey == pk );
REQUIRE( h.pubkey == pk );
REQUIRE( d.full_address(address::encoding::hex) == "curve://a:1/" + pk_hex );
REQUIRE( c.full_address(address::encoding::base32z) == "curve://a:1/" + pk_b32z );
REQUIRE( b.full_address(address::encoding::BASE32Z) == "curve://a:1/" + pk_B32Z );
REQUIRE( a.full_address(address::encoding::base64) == "curve://a:1/" + pk_b64 );
REQUIRE( h.full_address(address::encoding::hex) == "ipc+curve://./my.sock/" + pk_hex );
REQUIRE( g.full_address(address::encoding::base32z) == "ipc+curve:///my.sock/" + pk_b32z );
REQUIRE( f.full_address(address::encoding::BASE32Z) == "ipc+curve://../my.sock/" + pk_B32Z );
REQUIRE( e.full_address(address::encoding::base64) == "ipc+curve://my.sock/" + pk_b64 );
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_hex.substr(0, 63)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b32z.substr(0, 51)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_B32Z.substr(0, 51)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b64.substr(0, 42)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock"}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/"}, std::invalid_argument);
}
TEST_CASE("tcp QR-code friendly addresses", "[address][tcp][qr]") {
address a{"tcp://public.loki.foundation:12345"};
address a_qr{"TCP://PUBLIC.LOKI.FOUNDATION:12345"};
address b{"tcp://PUBLIC.LOKI.FOUNDATION:12345"};
REQUIRE( a == a_qr );
REQUIRE( a != b );
REQUIRE( a.host == "public.loki.foundation" );
REQUIRE( a.qr_address() == "TCP://PUBLIC.LOKI.FOUNDATION:12345" );
address c = address::tcp_curve("public.loki.foundation", 12345, pk);
REQUIRE( c.qr_address() == "CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z );
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z} == c );
// We don't produce with upper-case hex, but we accept it:
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_HEX} == c );
// lower case not permitted: ▾
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATiON:12345/" + pk_B32Z}, std::invalid_argument);
// also only accept upper-base base32z and hex:
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_b32z}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_hex}, std::invalid_argument);
// don't accept base64 even if it's upper-case (because case-converting it changes the value)
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}, std::invalid_argument);
}

257
tests/test_bt.cpp Normal file
View File

@ -0,0 +1,257 @@
#include "lokimq/bt_serialize.h"
#include "common.h"
#include <map>
#include <set>
#include <limits>
TEST_CASE("bt basic value serialization", "[bt][serialization]") {
int x = 42;
std::string x_ = bt_serialize(x);
REQUIRE( bt_serialize(x) == "i42e" );
int64_t ibig = -8'000'000'000'000'000'000LL;
uint64_t ubig = 10'000'000'000'000'000'000ULL;
REQUIRE( bt_serialize(ibig) == "i-8000000000000000000e" );
REQUIRE( bt_serialize(std::numeric_limits<int64_t>::min()) == "i-9223372036854775808e" );
REQUIRE( bt_serialize(ubig) == "i10000000000000000000e" );
REQUIRE( bt_serialize(std::numeric_limits<uint64_t>::max()) == "i18446744073709551615e" );
std::unordered_map<std::string, int> m;
m["hi"] = 123;
m["omg"] = -7890;
m["bye"] = 456;
m["zap"] = 0;
// bt values are always sorted:
REQUIRE( bt_serialize(m) == "d3:byei456e2:hii123e3:omgi-7890e3:zapi0ee" );
// Dict-like list serializes as a dict (and get sorted, as above)
std::list<std::pair<std::string, std::string>> d{{
{"c", "x"},
{"a", "z"},
{"b", "y"},
}};
REQUIRE( bt_serialize(d) == "d1:a1:z1:b1:y1:c1:xe" );
std::vector<std::string> v{{"a", "", "\x00"s, "\x00\x00\x00goo"s}};
REQUIRE( bt_serialize(v) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
std::array v2 = {"a"sv, ""sv, "\x00"sv, "\x00\x00\x00goo"sv};
REQUIRE( bt_serialize(v2) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
}
TEST_CASE("bt nested value serialization", "[bt][serialization]") {
std::unordered_map<std::string, std::list<std::map<std::string, std::set<int>>>> x{{
{"foo", {{{"a", {1,2,3}}, {"b", {}}}, {{"c", {4,-5}}}}},
{"bar", {}}
}};
REQUIRE( bt_serialize(x) == "d3:barle3:foold1:ali1ei2ei3ee1:bleed1:cli-5ei4eeeee" );
}
TEST_CASE("bt basic value deserialization", "[bt][deserialization]") {
REQUIRE( bt_deserialize<int>("i42e") == 42 );
int64_t ibig = -8'000'000'000'000'000'000LL;
uint64_t ubig = 10'000'000'000'000'000'000ULL;
REQUIRE( bt_deserialize<int64_t>("i-8000000000000000000e") == ibig );
REQUIRE( bt_deserialize<uint64_t>("i10000000000000000000e") == ubig );
REQUIRE( bt_deserialize<int64_t>("i-9223372036854775808e") == std::numeric_limits<int64_t>::min() );
REQUIRE( bt_deserialize<uint64_t>("i18446744073709551615e") == std::numeric_limits<uint64_t>::max() );
REQUIRE( bt_deserialize<uint32_t>("i4294967295e") == std::numeric_limits<uint32_t>::max() );
REQUIRE_THROWS( bt_deserialize<int64_t>("i-9223372036854775809e") );
REQUIRE_THROWS( bt_deserialize<uint64_t>("i-1e") );
REQUIRE_THROWS( bt_deserialize<uint32_t>("i4294967296e") );
std::unordered_map<std::string, int> m;
m["hi"] = 123;
m["omg"] = -7890;
m["bye"] = 456;
m["zap"] = 0;
// bt values are always sorted:
REQUIRE( bt_deserialize<std::unordered_map<std::string, int>>("d3:byei456e2:hii123e3:omgi-7890e3:zapi0ee") == m );
// Dict-like list can be used for deserialization
std::list<std::pair<std::string, std::string>> d{{
{"a", "z"},
{"b", "y"},
{"c", "x"},
}};
REQUIRE( bt_deserialize<std::list<std::pair<std::string, std::string>>>("d1:a1:z1:b1:y1:c1:xe") == d );
std::vector<std::string> v{{"a", "", "\x00"s, "\x00\x00\x00goo"s}};
REQUIRE( bt_deserialize<std::vector<std::string>>("l1:a0:1:\0006:\x00\x00\x00gooe"sv) == v );
std::vector v2 = {"a"sv, ""sv, "\x00"sv, "\x00\x00\x00goo"sv};
REQUIRE( bt_deserialize<decltype(v2)>("l1:a0:1:\0006:\x00\x00\x00gooe"sv) == v2 );
}
TEST_CASE("bt_value serialization", "[bt][serialization][bt_value]") {
bt_value dna{42};
std::string x_ = bt_serialize(dna);
REQUIRE( bt_serialize(dna) == "i42e" );
bt_value foo{"foo"};
REQUIRE( bt_serialize(foo) == "3:foo" );
bt_value ibig{-8'000'000'000'000'000'000LL};
bt_value ubig{10'000'000'000'000'000'000ULL};
int16_t ismall = -123;
uint16_t usmall = 123;
bt_dict nums{
{"a", 0},
{"b", -8'000'000'000'000'000'000LL},
{"c", 10'000'000'000'000'000'000ULL},
{"d", ismall},
{"e", usmall},
};
REQUIRE( bt_serialize(ibig) == "i-8000000000000000000e" );
REQUIRE( bt_serialize(ubig) == "i10000000000000000000e" );
REQUIRE( bt_serialize(nums) == "d1:ai0e1:bi-8000000000000000000e1:ci10000000000000000000e1:di-123e1:ei123ee" );
// Same as nested test, above, but with bt_* types
bt_dict x{{
{"foo", bt_list{{bt_dict{{ {"a", bt_list{{1,2,3}}}, {"b", bt_list{}}}}, bt_dict{{{"c", bt_list{{-5, 4}}}}}}}},
{"bar", bt_list{}}
}};
REQUIRE( bt_serialize(x) == "d3:barle3:foold1:ali1ei2ei3ee1:bleed1:cli-5ei4eeeee" );
std::vector<std::string> v{{"a", "", "\x00"s, "\x00\x00\x00goo"s}};
REQUIRE( bt_serialize(v) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
std::array v2 = {"a"sv, ""sv, "\x00"sv, "\x00\x00\x00goo"sv};
REQUIRE( bt_serialize(v2) == "l1:a0:1:\0006:\x00\x00\x00gooe"sv );
}
TEST_CASE("bt_value deserialization", "[bt][deserialization][bt_value]") {
auto dna1 = bt_deserialize<bt_value>("i42e");
auto dna2 = bt_deserialize<bt_value>("i-42e");
REQUIRE( std::get<uint64_t>(dna1) == 42 );
REQUIRE( std::get<int64_t>(dna2) == -42 );
REQUIRE_THROWS( std::get<int64_t>(dna1) );
REQUIRE_THROWS( std::get<uint64_t>(dna2) );
REQUIRE( lokimq::get_int<int>(dna1) == 42 );
REQUIRE( lokimq::get_int<int>(dna2) == -42 );
REQUIRE( lokimq::get_int<unsigned>(dna1) == 42 );
REQUIRE_THROWS( lokimq::get_int<unsigned>(dna2) );
bt_value x = bt_deserialize<bt_value>("d3:barle3:foold1:ali1ei2ei3ee1:bleed1:cli-5ei4eeeee");
REQUIRE( std::holds_alternative<bt_dict>(x) );
bt_dict& a = std::get<bt_dict>(x);
REQUIRE( a.count("bar") );
REQUIRE( a.count("foo") );
REQUIRE( a.size() == 2 );
bt_list& foo = std::get<bt_list>(a["foo"]);
REQUIRE( foo.size() == 2 );
bt_dict& foo1 = std::get<bt_dict>(foo.front());
bt_dict& foo2 = std::get<bt_dict>(foo.back());
REQUIRE( foo1.size() == 2 );
REQUIRE( foo2.size() == 1 );
bt_list& foo1a = std::get<bt_list>(foo1.at("a"));
bt_list& foo1b = std::get<bt_list>(foo1.at("b"));
bt_list& foo2c = std::get<bt_list>(foo2.at("c"));
std::list<int> foo1a_vals, foo1b_vals, foo2c_vals;
for (auto& v : foo1a) foo1a_vals.push_back(lokimq::get_int<int>(v));
for (auto& v : foo1b) foo1b_vals.push_back(lokimq::get_int<int>(v));
for (auto& v : foo2c) foo2c_vals.push_back(lokimq::get_int<int>(v));
REQUIRE( foo1a_vals == std::list{{1,2,3}} );
REQUIRE( foo1b_vals == std::list<int>{} );
REQUIRE( foo2c_vals == std::list{{-5, 4}} );
REQUIRE( std::get<bt_list>(a.at("bar")).empty() );
}
TEST_CASE("bt tuple serialization", "[bt][tuple][serialization]") {
// Deserializing directly into a tuple:
std::tuple<int, std::string, std::vector<int>> x{42, "hi", {{1,2,3,4,5}}};
REQUIRE( bt_serialize(x) == "li42e2:hili1ei2ei3ei4ei5eee" );
using Y = std::tuple<std::string, std::string, std::unordered_map<std::string, int>>;
REQUIRE( bt_deserialize<Y>("l5:hello3:omgd1:ai1e1:bi2eee")
== Y{"hello", "omg", {{"a",1}, {"b",2}}} );
using Z = std::tuple<std::tuple<int, std::string, std::string>, std::pair<int, int>>;
Z z{{3, "abc", "def"}, {4, 5}};
REQUIRE( bt_serialize(z) == "lli3e3:abc3:defeli4ei5eee" );
REQUIRE( bt_deserialize<Z>("lli6e3:ghi3:jkleli7ei8eee") == Z{{6, "ghi", "jkl"}, {7, 8}} );
using W = std::pair<std::string, std::pair<int, unsigned>>;
REQUIRE( bt_serialize(W{"zzzzzzzzzz", {42, 42}}) == "l10:zzzzzzzzzzli42ei42eee" );
REQUIRE_THROWS( bt_deserialize<std::tuple<int>>("li1e") ); // missing closing e
REQUIRE_THROWS( bt_deserialize<std::pair<int, int>>("li1ei-4e") ); // missing closing e
REQUIRE_THROWS( bt_deserialize<std::tuple<int>>("li1ei2ee") ); // too many elements
REQUIRE_THROWS( bt_deserialize<std::pair<int, int>>("li1ei-2e0:e") ); // too many elements
REQUIRE_THROWS( bt_deserialize<std::tuple<int, int>>("li1ee") ); // too few elements
REQUIRE_THROWS( bt_deserialize<std::pair<int, int>>("li1ee") ); // too few elements
REQUIRE_THROWS( bt_deserialize<std::tuple<std::string>>("li1ee") ); // wrong element type
REQUIRE_THROWS( bt_deserialize<std::pair<int, std::string>>("li1ei8ee") ); // wrong element type
REQUIRE_THROWS( bt_deserialize<std::pair<int, std::string>>("l1:x1:xe") ); // wrong element type
// Converting from a generic bt_value/bt_list:
bt_value a = bt_get("l5:hello3:omgi12345ee");
using V1 = std::tuple<std::string, std::string_view, uint16_t>;
REQUIRE( get_tuple<V1>(a) == V1{"hello", "omg"sv, 12345} );
bt_value b = bt_get("l5:hellod1:ai1e1:bi2eee");
using V2 = std::pair<std::string_view, bt_dict>;
REQUIRE( get_tuple<V2>(b) == V2{"hello", {{"a",1U}, {"b",2U}}} );
bt_value c = bt_get("l5:helloi-4ed1:ai-1e1:bi-2eee");
using V3 = std::tuple<std::string, int, bt_dict>;
REQUIRE( get_tuple<V3>(c) == V3{"hello", -4, {{"a",-1}, {"b",-2}}} );
REQUIRE_THROWS( get_tuple<V1>(bt_get("l5:hello3:omge")) ); // too few
REQUIRE_THROWS( get_tuple<V1>(bt_get("l5:hello3:omgi1ei1ee")) ); // too many
REQUIRE_THROWS( get_tuple<V1>(bt_get("l5:helloi1ei1ee")) ); // wrong type
// Construct a bt_value from tuples:
bt_value l{std::make_tuple(3, 4, "hi"sv)};
REQUIRE( bt_serialize(l) == "li3ei4e2:hie" );
bt_list m{{1, 2, std::make_tuple(3, 4, "hi"sv), std::make_pair("foo"s, "bar"sv), -4}};
REQUIRE( bt_serialize(m) == "li1ei2eli3ei4e2:hiel3:foo3:barei-4ee" );
// Consumer deserialization:
bt_list_consumer lc{"li1ei2eli3ei4e2:hiel3:foo3:barei-4ee"};
REQUIRE( lc.consume_integer<int>() == 1 );
REQUIRE( lc.consume_integer<int>() == 2 );
REQUIRE( lc.consume_list<std::tuple<int, int, std::string>>() == std::make_tuple(3, 4, "hi"s) );
REQUIRE( lc.consume_list<std::pair<std::string_view, std::string_view>>() == std::make_pair("foo"sv, "bar"sv) );
REQUIRE( lc.consume_integer<int>() == -4 );
bt_dict_consumer dc{"d1:Ai0e1:ali1e3:omge1:bli1ei2ei3eee"};
REQUIRE( dc.key() == "A" );
REQUIRE( dc.skip_until("a") );
REQUIRE( dc.next_list<std::pair<int8_t, std::string_view>>() ==
std::make_pair("a"sv, std::make_pair(int8_t{1}, "omg"sv)) );
REQUIRE( dc.next_list<std::tuple<int, int, int>>() ==
std::make_pair("b"sv, std::make_tuple(1, 2, 3)) );
}
#if 0
{
std::cout << "zomg consumption\n";
bt_dict_consumer dc{zomg_};
for (int i = 0; i < 5; i++)
if (!dc.skip_until("b"))
throw std::runtime_error("Couldn't find b, but I know it's there!");
auto dc1 = dc;
if (dc.skip_until("z")) {
auto v = dc.consume_integer<int>();
std::cout << " - " << v.first << ": " << v.second << "\n";
} else {
std::cout << " - no z (bad!)\n";
}
std::cout << "zomg (second pass)\n";
for (auto &p : dc1.consume_dict().second) {
std::cout << " - " << p.first << " = (whatever)\n";
}
while (dc1) {
auto v = dc1.consume_integer<int>();
std::cout << " - " << v.first << ": " << v.second << "\n";
}
}
#endif

View File

@ -41,10 +41,9 @@ TEST_CASE("basic commands", "[commands]") {
bool success = false, failed = false;
std::string pubkey;
auto c = client.connect_remote(listen,
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
[&](auto conn, string_view) { failed = true; got = true; },
server.get_pubkey());
[&](auto conn, std::string_view) { failed = true; got = true; });
wait_for_conn(got);
{
@ -107,9 +106,10 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.PUBKEY_BASED_ROUTING_ID = false; // establishing multiple connections below, so we need unique routing ids
auto public_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey());
auto basic_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::basic);
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
address server_addr{listen, server.get_pubkey()};
auto public_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {});
auto basic_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::basic);
auto admin_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::admin);
client.send(public_c, "public.reflect", "public.hi");
wait_for([&] { return public_hi == 1; });
@ -193,8 +193,8 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
server.start();
auto connect_success = [&](...) { auto l = catch_lock(); REQUIRE(true); };
auto connect_failure = [&](...) { auto l = catch_lock(); REQUIRE(false); };
auto connect_success = [&](auto&&...) { auto l = catch_lock(); REQUIRE(true); };
auto connect_failure = [&](auto&&...) { auto l = catch_lock(); REQUIRE(false); };
std::set<std::string> backdoor_details;
@ -202,10 +202,11 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
LokiMQ nsa{get_logger("NSA» ")};
nsa.add_category("backdoor", Access{AuthLevel::admin});
nsa.add_command("backdoor", "data", [&](Message& m) {
backdoor_details.emplace(m.data[0]);
auto l = catch_lock();
backdoor_details.emplace(m.data[0]);
});
nsa.start();
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
auto nsa_c = nsa.connect_remote(address{listen, server.get_pubkey()}, connect_success, connect_failure, AuthLevel::admin);
nsa.send(nsa_c, "hey google.install backdoor");
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
@ -226,6 +227,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
std::set<std::string> all_the_things;
for (auto& pd : personal_details) all_the_things.insert(pd.second.begin(), pd.second.end());
address server_addr{listen, server.get_pubkey()};
std::map<int, std::set<std::string>> google_knows;
int things_remembered{0};
for (int i = 0; i < 5; i++) {
@ -240,7 +242,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
});
c->start();
conns.push_back(
c->connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::basic));
c->connect_remote(server_addr, connect_success, connect_failure, AuthLevel::basic));
for (auto& personal_detail : personal_details[i])
c->request(conns.back(), "hey google.remember",
[&](bool success, std::vector<std::string> data) {
@ -252,7 +254,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
},
personal_detail);
}
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size(); });
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size() && things_remembered == backdoor_details.size(); });
{
auto l = catch_lock();
REQUIRE( things_remembered == all_the_things.size() );
@ -303,10 +305,11 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
zmq::message_t hello;
client.recv(hello);
string_view hello_sv{hello.data<char>(), hello.size()};
auto recvd = client.recv(hello);
std::string_view hello_sv{hello.data<char>(), hello.size()};
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello_sv == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
@ -359,3 +362,78 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
REQUIRE( send_failures.load() > 0 );
}
}
TEST_CASE("data parts", "[send][data_parts]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_curve(listen);
std::mutex mut;
std::vector<std::string> r;
server.add_category("public", Access{AuthLevel::none});
server.add_command("public", "hello", [&](Message& m) {
std::lock_guard l{mut};
for (const auto& s : m.data)
r.emplace_back(s);
});
server.start();
LokiMQ client{get_logger(""), LogLevel::trace};
client.start();
std::atomic<bool> got{false};
bool success = false, failed = false;
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
[&](auto conn, std::string_view) { failed = true; got = true; });
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE_FALSE( failed );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
std::vector some_data{{"abc"s, "def"s, "omg123\0zzz"s}};
client.send(c, "public.hello", lokimq::send_option::data_parts(some_data.begin(), some_data.end()));
reply_sleep();
{
auto lock = catch_lock();
std::lock_guard l{mut};
REQUIRE( r == some_data );
r.clear();
}
std::vector some_data2{{"a"sv, "b"sv, "\0"sv}};
client.send(c, "public.hello",
"hi",
lokimq::send_option::data_parts(some_data2.begin(), some_data2.end()),
"another",
"string"sv,
lokimq::send_option::data_parts(some_data.begin(), some_data.end()));
std::vector<std::string> expected;
expected.push_back("hi");
expected.insert(expected.end(), some_data2.begin(), some_data2.end());
expected.push_back("another");
expected.push_back("string");
expected.insert(expected.end(), some_data.begin(), some_data.end());
reply_sleep();
{
auto lock = catch_lock();
std::lock_guard l{mut};
REQUIRE( r == expected );
}
}

View File

@ -27,10 +27,9 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") {
auto pubkey = server.get_pubkey();
std::atomic<bool> got{false};
bool success = false;
auto server_conn = client.connect_remote(listen,
auto server_conn = client.connect_remote(address{listen, pubkey},
[&](auto conn) { success = true; got = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; },
pubkey);
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; });
wait_for_conn(got);
{
@ -53,6 +52,7 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
std::string pubkey, privkey;
pubkey.resize(crypto_box_PUBLICKEYBYTES);
privkey.resize(crypto_box_SECRETKEYBYTES);
REQUIRE(sodium_init() != -1);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
LokiMQ sn{
pubkey, privkey,
@ -108,7 +108,7 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") {
bool success = false;
auto c = client.connect_remote(listen,
[&](auto conn) { success = true; got = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
);
wait_for_conn(got);
@ -150,11 +150,11 @@ TEST_CASE("unique connection IDs", "[connect][id]") {
std::atomic<bool> good1{false}, good2{false};
auto r1 = client1.connect_remote(listen,
[&](auto conn) { good1 = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
);
auto r2 = client2.connect_remote(listen,
[&](auto conn) { good2 = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
);
wait_for_conn(good1);
@ -188,6 +188,7 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") {
std::vector<std::unique_ptr<LokiMQ>> lmq;
std::vector<std::string> pubkey, privkey;
std::unordered_map<std::string, std::string> conn;
REQUIRE(sodium_init() != -1);
for (int i = 0; i < 3; i++) {
pubkey.emplace_back();
privkey.emplace_back();
@ -234,6 +235,7 @@ TEST_CASE("SN auth checks", "[sandwich][auth]") {
std::string pubkey, privkey;
pubkey.resize(crypto_box_PUBLICKEYBYTES);
privkey.resize(crypto_box_SECRETKEYBYTES);
REQUIRE(sodium_init() != -1);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
LokiMQ server{
pubkey, privkey,

193
tests/test_encoding.cpp Normal file
View File

@ -0,0 +1,193 @@
#include "lokimq/hex.h"
#include "lokimq/base32z.h"
#include "lokimq/base64.h"
#include "common.h"
using namespace std::literals;
const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88";
const std::string pk_hex = "f16ba5591039f089b42a834175093094074d0d937a79e53e5ce730f946e14b88";
const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqry";
const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g=";
TEST_CASE("hex encoding/decoding", "[encoding][decoding][hex]") {
REQUIRE( lokimq::to_hex("\xff\x42\x12\x34") == "ff421234"s );
std::vector<uint8_t> chars{{1, 10, 100, 254}};
std::array<uint8_t, 8> out;
std::array<uint8_t, 8> expected{{'0', '1', '0', 'a', '6', '4', 'f', 'e'}};
lokimq::to_hex(chars.begin(), chars.end(), out.begin());
REQUIRE( out == expected );
REQUIRE( lokimq::to_hex(chars.begin(), chars.end()) == "010a64fe" );
REQUIRE( lokimq::from_hex("12345678ffEDbca9") == "\x12\x34\x56\x78\xff\xed\xbc\xa9"s );
REQUIRE( lokimq::is_hex("1234567890abcdefABCDEF1234567890abcdefABCDEF") );
REQUIRE_FALSE( lokimq::is_hex("1234567890abcdefABCDEF1234567890aGcdefABCDEF") );
REQUIRE_FALSE( lokimq::is_hex("1234567890abcdefABCDEF1234567890agcdefABCDEF") );
REQUIRE_FALSE( lokimq::is_hex("\x11\xff") );
REQUIRE( lokimq::from_hex(pk_hex) == pk );
REQUIRE( lokimq::to_hex(pk) == pk_hex );
REQUIRE( lokimq::from_hex(pk_hex.begin(), pk_hex.end()) == pk );
std::vector<std::byte> bytes{{std::byte{0xff}, std::byte{0x42}, std::byte{0x12}, std::byte{0x34}}};
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
REQUIRE( lokimq::to_hex(b) == "ff421234"s );
bytes.resize(8);
bytes[0] = std::byte{'f'}; bytes[1] = std::byte{'f'}; bytes[2] = std::byte{'4'}; bytes[3] = std::byte{'2'};
bytes[4] = std::byte{'1'}; bytes[5] = std::byte{'2'}; bytes[6] = std::byte{'3'}; bytes[7] = std::byte{'4'};
std::basic_string_view<std::byte> hex_bytes{bytes.data(), bytes.size()};
REQUIRE( lokimq::is_hex(hex_bytes) );
REQUIRE( lokimq::from_hex(hex_bytes) == "\xff\x42\x12\x34" );
}
TEST_CASE("base32z encoding/decoding", "[encoding][decoding][base32z]") {
REQUIRE( lokimq::to_base32z("\0\0\0\0\0"s) == "yyyyyyyy" );
REQUIRE( lokimq::to_base32z("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv)
== "yrtwk3hjixg66yjdeiuauk6p7hy1gtm8tgih55abrpnsxnpm3zzo");
REQUIRE( lokimq::from_base32z("yrtwk3hjixg66yjdeiuauk6p7hy1gtm8tgih55abrpnsxnpm3zzo")
== "\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv);
REQUIRE( lokimq::from_base32z("YRTWK3HJIXG66YJDEIUAUK6P7HY1GTM8TGIH55ABRPNSXNPM3ZZO")
== "\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef"sv);
auto five_nulls = lokimq::from_base32z("yyyyyyyy");
REQUIRE( five_nulls.size() == 5 );
REQUIRE( five_nulls == "\0\0\0\0\0"s );
// 00000 00001 00010 00011 00100 00101 00110 00111
// ==
// 00000000 01000100 00110010 00010100 11000111
REQUIRE( lokimq::from_base32z("ybndrfg8") == "\x00\x44\x32\x14\xc7"s );
// Special case 1: 7 base32z digits with 3 trailing 0 bits -> 4 bytes (the trailing 0s are dropped)
// 00000 00001 00010 00011 00100 00101 11000
// ==
// 00000000 01000100 00110010 00010111
REQUIRE( lokimq::from_base32z("ybndrfa") == "\x00\x44\x32\x17"s );
// Round-trip it:
REQUIRE( lokimq::from_base32z(lokimq::to_base32z("\x00\x44\x32\x17"sv)) == "\x00\x44\x32\x17"sv );
REQUIRE( lokimq::to_base32z(lokimq::from_base32z("ybndrfa")) == "ybndrfa" );
// Special case 2: 7 base32z digits with 3 trailing bits 010; we just ignore the trailing stuff,
// as if it was specified as 0. (The last digit here is 11010 instead of 11000).
REQUIRE( lokimq::from_base32z("ybndrf4") == "\x00\x44\x32\x17"s );
// This one won't round-trip to the same value since it has ignored garbage bytes at the end
REQUIRE( lokimq::to_base32z(lokimq::from_base32z("ybndrf4"s)) == "ybndrfa" );
REQUIRE( lokimq::to_base32z(pk) == pk_b32z );
REQUIRE( lokimq::to_base32z(pk.begin(), pk.end()) == pk_b32z );
REQUIRE( lokimq::from_base32z(pk_b32z) == pk );
REQUIRE( lokimq::from_base32z(pk_b32z.begin(), pk_b32z.end()) == pk );
std::string pk_b32z_again, pk_again;
lokimq::to_base32z(pk.begin(), pk.end(), std::back_inserter(pk_b32z_again));
lokimq::from_base32z(pk_b32z.begin(), pk_b32z.end(), std::back_inserter(pk_again));
REQUIRE( pk_b32z_again == pk_b32z );
REQUIRE( pk_again == pk );
std::vector<std::byte> bytes{{std::byte{0}, std::byte{255}}};
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
REQUIRE( lokimq::to_base32z(b) == "yd9o" );
bytes.resize(4);
bytes[0] = std::byte{'y'}; bytes[1] = std::byte{'d'}; bytes[2] = std::byte{'9'}; bytes[3] = std::byte{'o'};
std::basic_string_view<std::byte> b32_bytes{bytes.data(), bytes.size()};
REQUIRE( lokimq::is_base32z(b32_bytes) );
REQUIRE( lokimq::from_base32z(b32_bytes) == "\x00\xff"sv );
}
TEST_CASE("base64 encoding/decoding", "[encoding][decoding][base64]") {
// 00000000 00000000 00000000 -> 000000 000000 000000 000000
REQUIRE( lokimq::to_base64("\0\0\0"s) == "AAAA" );
// 00000001 00000002 00000003 -> 000000 010000 000200 000003
REQUIRE( lokimq::to_base64("\x01\x02\x03"s) == "AQID" );
REQUIRE( lokimq::to_base64("\0\0\0\0"s) == "AAAAAA==" );
// 00000000 00000000 00000000 11111111 ->
// 000000 000000 000000 000000 111111 110000 (pad) (pad)
REQUIRE( lokimq::to_base64("a") == "YQ==" );
REQUIRE( lokimq::to_base64("ab") == "YWI=" );
REQUIRE( lokimq::to_base64("abc") == "YWJj" );
REQUIRE( lokimq::to_base64("abcd") == "YWJjZA==" );
REQUIRE( lokimq::to_base64("abcde") == "YWJjZGU=" );
REQUIRE( lokimq::to_base64("abcdef") == "YWJjZGVm" );
REQUIRE( lokimq::to_base64("\0\0\0\xff"s) == "AAAA/w==" );
REQUIRE( lokimq::to_base64("\0\0\0\xff\xff"s) == "AAAA//8=" );
REQUIRE( lokimq::to_base64("\0\0\0\xff\xff\xff"s) == "AAAA////" );
REQUIRE( lokimq::to_base64(
"Man is distinguished, not only by his reason, but by this singular passion from other "
"animals, which is a lust of the mind, that by a perseverance of delight in the "
"continued and indefatigable generation of knowledge, exceeds the short vehemence of "
"any carnal pleasure.")
==
"TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlz"
"IHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2Yg"
"dGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0aGUgY29udGlu"
"dWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRo"
"ZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=" );
REQUIRE( lokimq::from_base64("A+/A") == "\x03\xef\xc0" );
REQUIRE( lokimq::from_base64("YWJj") == "abc" );
REQUIRE( lokimq::from_base64("YWJjZA==") == "abcd" );
REQUIRE( lokimq::from_base64("YWJjZA") == "abcd" );
REQUIRE( lokimq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits
REQUIRE( lokimq::from_base64("YWJjZB") == "abcd" ); // ignore superfluous bits
REQUIRE( lokimq::from_base64("YWJj+") == "abc" ); // ignore superfluous bits
REQUIRE( lokimq::from_base64("YWJjZGU=") == "abcde" );
REQUIRE( lokimq::from_base64("YWJjZGU") == "abcde" );
REQUIRE( lokimq::from_base64("YWJjZGVm") == "abcdef" );
REQUIRE( lokimq::is_base64("YWJjZGVm") );
REQUIRE( lokimq::is_base64("YWJjZGU") );
REQUIRE( lokimq::is_base64("YWJjZGU=") );
REQUIRE( lokimq::is_base64("YWJjZA==") );
REQUIRE( lokimq::is_base64("YWJjZA") );
REQUIRE( lokimq::is_base64("YWJjZB") ); // not really valid, but we explicitly accept it
REQUIRE_FALSE( lokimq::is_base64("YWJjZ=") ); // invalid padding (padding can only be 4th or 3rd+4th of a 4-char block)
REQUIRE_FALSE( lokimq::is_base64("YWJj=") );
REQUIRE_FALSE( lokimq::is_base64("YWJj=A") );
REQUIRE_FALSE( lokimq::is_base64("YWJjA===") );
REQUIRE_FALSE( lokimq::is_base64("YWJ[") );
REQUIRE_FALSE( lokimq::is_base64("YWJ.") );
REQUIRE_FALSE( lokimq::is_base64("_YWJ") );
REQUIRE( lokimq::from_base64(
"TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlz"
"IHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2Yg"
"dGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0aGUgY29udGlu"
"dWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRo"
"ZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=" )
==
"Man is distinguished, not only by his reason, but by this singular passion from other "
"animals, which is a lust of the mind, that by a perseverance of delight in the "
"continued and indefatigable generation of knowledge, exceeds the short vehemence of "
"any carnal pleasure.");
REQUIRE( lokimq::to_base64(pk) == pk_b64 );
REQUIRE( lokimq::to_base64(pk.begin(), pk.end()) == pk_b64 );
REQUIRE( lokimq::from_base64(pk_b64) == pk );
REQUIRE( lokimq::from_base64(pk_b64.begin(), pk_b64.end()) == pk );
std::string pk_b64_again, pk_again;
lokimq::to_base64(pk.begin(), pk.end(), std::back_inserter(pk_b64_again));
lokimq::from_base64(pk_b64.begin(), pk_b64.end(), std::back_inserter(pk_again));
REQUIRE( pk_b64_again == pk_b64 );
REQUIRE( pk_again == pk );
std::vector<std::byte> bytes{{std::byte{0}, std::byte{255}}};
std::basic_string_view<std::byte> b{bytes.data(), bytes.size()};
REQUIRE( lokimq::to_base64(b) == "AP8=" );
bytes.resize(4);
bytes[0] = std::byte{'/'}; bytes[1] = std::byte{'w'}; bytes[2] = std::byte{'A'}; bytes[3] = std::byte{'='};
std::basic_string_view<std::byte> b64_bytes{bytes.data(), bytes.size()};
REQUIRE( lokimq::is_base64(b64_bytes) );
REQUIRE( lokimq::from_base64(b64_bytes) == "\xff\x00"sv );
}

View File

@ -25,21 +25,23 @@ TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
auto recvd = client.recv(resp);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "a.a" );
REQUIRE_FALSE( resp.more() );
}
@ -66,37 +68,40 @@ TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
auto recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "x.r" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
client.recv(resp);
recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "foo" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
@ -132,9 +137,10 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
@ -144,18 +150,20 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
clients[0].recv(resp);
auto recvd = clients[0].recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
clients[0].recv(resp);
REQUIRE( clients[0].recv(resp) );
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
for (int i : {1, 2}) {
clients[i].recv(resp);
recvd = clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
@ -164,17 +172,19 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
for (int i : {0, 1}) {
clients[i].recv(resp);
recvd = clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
clients[i].recv(resp);
REQUIRE( clients[i].recv(resp) );
REQUIRE( resp.to_string() == "y.x" );
REQUIRE_FALSE( resp.more() );
}
clients[2].recv(resp);
recvd = clients[2].recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "b" );
REQUIRE_FALSE( resp.more() );
}
@ -207,9 +217,10 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
@ -217,12 +228,13 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
auto recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
@ -230,15 +242,16 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
client.recv(resp);
recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}
@ -271,9 +284,10 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
@ -281,12 +295,13 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
auto recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
@ -294,15 +309,16 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
client.recv(resp);
recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}

87
tests/test_inject.cpp Normal file
View File

@ -0,0 +1,87 @@
#include "common.h"
using namespace lokimq;
TEST_CASE("injected external commands", "[injected]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.set_general_threads(1);
server.listen_curve(listen);
std::atomic<int> hellos = 0;
std::atomic<bool> done = false;
server.add_category("public", AuthLevel::none, 3);
server.add_command("public", "hello", [&](Message& m) {
hellos++;
while (!done) std::this_thread::sleep_for(10ms);
});
server.start();
LokiMQ client{get_logger(""), LogLevel::trace};
client.start();
std::atomic<bool> got{false};
bool success = false;
auto c = client.connect_remote(listen,
[&](auto conn) { success = true; got = true; },
[&](auto conn, std::string_view) { got = true; },
server.get_pubkey());
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
}
// First make sure that basic message respects the 3 thread limit
client.send(c, "public.hello");
client.send(c, "public.hello");
client.send(c, "public.hello");
client.send(c, "public.hello");
wait_for([&] { return hellos >= 3; });
std::this_thread::sleep_for(20ms);
{
auto lock = catch_lock();
REQUIRE( hellos == 3 );
}
done = true;
wait_for([&] { return hellos >= 4; });
{
auto lock = catch_lock();
REQUIRE( hellos == 4 );
}
// Now try injecting external commands
done = false;
hellos = 0;
client.send(c, "public.hello");
wait_for([&] { return hellos >= 1; });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
wait_for([&] { return hellos >= 11; });
client.send(c, "public.hello");
wait_for([&] { return hellos >= 12; });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
wait_for([&] { return hellos >= 12; });
std::this_thread::sleep_for(20ms);
{
auto lock = catch_lock();
REQUIRE( hellos == 12 );
}
done = true;
wait_for([&] { return hellos >= 42; });
{
auto lock = catch_lock();
REQUIRE( hellos == 42 );
}
}

View File

@ -30,10 +30,9 @@ TEST_CASE("basic requests", "[requests]") {
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(listen,
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; },
server.get_pubkey());
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
{
@ -88,10 +87,9 @@ TEST_CASE("request from server to client", "[requests]") {
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(listen,
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; },
server.get_pubkey());
[&](auto, auto) { failed = true; });
int i;
for (i = 0; i < 5; i++) {
@ -151,10 +149,9 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(listen,
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; },
server.get_pubkey());
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
@ -170,7 +167,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
success = ok;
data = std::move(data_);
},
lokimq::send_option::request_timeout{20ms}
lokimq::send_option::request_timeout{10ms}
);
std::atomic<bool> got_triggered2{false};
@ -179,10 +176,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
success = ok;
data = std::move(data_);
},
lokimq::send_option::request_timeout{100ms}
lokimq::send_option::request_timeout{200ms}
);
std::this_thread::sleep_for(40ms);
std::this_thread::sleep_for(100ms);
REQUIRE( got_triggered );
REQUIRE_FALSE( got_triggered2 );
REQUIRE_FALSE( success );

View File

@ -1,187 +0,0 @@
#include <catch2/catch.hpp>
#include "lokimq/string_view.h"
#include <future>
using namespace lokimq;
using namespace std::literals;
TEST_CASE("string view", "[string_view]") {
std::string foo = "abc 123 xyz";
string_view f1{foo};
string_view f2{"def 789 uvw"};
string_view f3{"nu\0ll", 5};
REQUIRE( f1 == "abc 123 xyz" );
REQUIRE( f2 == "def 789 uvw" );
REQUIRE( f3.size() == 5 );
REQUIRE( f3 == std::string{"nu\0ll", 5} );
REQUIRE( f3 != "nu" );
REQUIRE( f3.data() == "nu"s );
REQUIRE( string_view(f3) == f3 );
auto f4 = f3;
REQUIRE( f4 == f3 );
f4 = f2;
REQUIRE( f4 == "def 789 uvw" );
REQUIRE( f1.size() == 11 );
REQUIRE( f3.length() == 5 );
string_view f5{""};
REQUIRE( !f3.empty() );
REQUIRE( f5.empty() );
REQUIRE( f1[5] == '2' );
size_t i = 0;
for (auto c : f3)
REQUIRE(c == f3[i++]);
std::string backwards;
for (auto it = std::rbegin(f2); it != f2.crend(); ++it)
backwards += *it;
REQUIRE( backwards == "wvu 987 fed" );
REQUIRE( f1.at(10) == 'z' );
REQUIRE_THROWS_AS( f1.at(15), std::out_of_range );
REQUIRE_THROWS_AS( f1.at(11), std::out_of_range );
f4 = f1;
f4.remove_prefix(2);
REQUIRE( f4 == "c 123 xyz" );
f4.remove_prefix(2);
f4.remove_suffix(4);
REQUIRE( f4 == "123" );
f4.remove_prefix(1);
REQUIRE( f4 == "23" );
REQUIRE( f1 == "abc 123 xyz" );
f4.swap(f1);
REQUIRE( f1 == "23" );
REQUIRE( f4 == "abc 123 xyz" );
f1.remove_suffix(2);
REQUIRE( f1.empty() );
REQUIRE( f4 == "abc 123 xyz" );
f1.swap(f4);
REQUIRE( f4.empty() );
REQUIRE( f1 == "abc 123 xyz" );
REQUIRE( f1.front() == 'a' );
REQUIRE( f1.back() == 'z' );
REQUIRE( f1.compare("abc") > 0 );
REQUIRE( f1.compare("abd") < 0 );
REQUIRE( f1.compare("abc 123 xyz") == 0 );
REQUIRE( f1.compare("abc 123 xyza") < 0 );
REQUIRE( f1.compare("abc 123 xy") > 0 );
std::string buf;
buf.resize(5);
f1.copy(&buf[0], 5, 2);
REQUIRE( buf == "c 123" );
buf.resize(100, 'X');
REQUIRE( f1.copy(&buf[0], 100) == 11 );
REQUIRE( buf.substr(0, 11) == f1 );
REQUIRE( buf.substr(11) == std::string(89, 'X') );
REQUIRE( f1.substr(4) == "123 xyz" );
REQUIRE( f1.substr(4, 3) == "123" );
REQUIRE_THROWS_AS( f1.substr(500, 3), std::out_of_range );
REQUIRE( f1.substr(11, 2) == "" );
REQUIRE( f1.substr(8, 500) == "xyz" );
REQUIRE( f1.find("123") == 4 );
REQUIRE( f1.find("abc") == 0 );
REQUIRE( f1.find("xyz") == 8 );
REQUIRE( f1.find("abc 123 xyz 7") == string_view::npos );
REQUIRE( f1.find("23") == 5 );
REQUIRE( f1.find("234") == string_view::npos );
string_view f6{"zz abc abcd abcde abcdef"};
REQUIRE( f6.find("abc") == 3 );
REQUIRE( f6.find("abc", 3) == 3 );
REQUIRE( f6.find("abc", 4) == 7 );
REQUIRE( f6.find("abc", 7) == 7 );
REQUIRE( f6.find("abc", 8) == 12 );
REQUIRE( f6.find("abc", 18) == 18 );
REQUIRE( f6.find("abc", 19) == string_view::npos );
REQUIRE( f6.find("abcd") == 7 );
REQUIRE( f6.rfind("abc") == 18 );
REQUIRE( f6.rfind("abcd") == 18 );
REQUIRE( f6.rfind("bcd") == 19 );
REQUIRE( f6.rfind("abc", 19) == 18 );
REQUIRE( f6.rfind("abc", 18) == 18 );
REQUIRE( f6.rfind("abc", 17) == 12 );
REQUIRE( f6.rfind("abc", 17) == 12 );
REQUIRE( f6.rfind("abc", 8) == 7 );
REQUIRE( f6.rfind("abc", 7) == 7 );
REQUIRE( f6.rfind("abc", 6) == 3 );
REQUIRE( f6.rfind("abc", 3) == 3 );
REQUIRE( f6.rfind("abc", 2) == string_view::npos );
REQUIRE( f6.find('a') == 3 );
REQUIRE( f6.find('a', 17) == 18 );
REQUIRE( f6.find('a', 20) == string_view::npos );
REQUIRE( f6.rfind('a') == 18 );
REQUIRE( f6.rfind('a', 17) == 12 );
REQUIRE( f6.rfind('a', 2) == string_view::npos );
string_view f7{"abc\0def", 7};
REQUIRE( f7.find("c\0d", 0, 3) == 2 );
REQUIRE( f7.find("c\0e", 0, 3) == string_view::npos );
REQUIRE( f7.rfind("c\0d", string_view::npos, 3) == 2 );
REQUIRE( f7.rfind("c\0e", 0, 3) == string_view::npos );
REQUIRE( f6.find_first_of("c789b") == 4 );
REQUIRE( f6.find_first_of("c789") == 5 );
REQUIRE( f2.find_first_of("c789b") == 4 );
REQUIRE( f6.find_first_of("c789b", 6) == 8 );
REQUIRE( f6.find_last_of("c789b") == 20 );
REQUIRE( f6.find_last_of("789b") == 19 );
REQUIRE( f2.find_last_of("c789b") == 6 );
REQUIRE( f6.find_last_of("c789b", 6) == 5 );
REQUIRE( f6.find_last_of("c789b", 5) == 5 );
REQUIRE( f6.find_last_of("c789b", 4) == 4 );
REQUIRE( f6.find_last_of("c789b", 3) == string_view::npos );
REQUIRE( f2.find_first_of(f7) == 0 );
REQUIRE( f3.find_first_of(f7) == 2 );
REQUIRE( f3.find_first_of('\0') == 2 );
REQUIRE( f3.find_first_of("jk\0", 0, 3) == 2 );
REQUIRE( f1.find_first_not_of("abc") == 3 );
REQUIRE( f1.find_first_not_of("abc ", 3) == 4 );
REQUIRE( f1.find_first_not_of(" 123", 3) == 8 );
REQUIRE( f1.find_last_not_of("abc") == 10 );
REQUIRE( f1.find_last_not_of("xyz") == 7 );
REQUIRE( f1.find_last_not_of("xyz 321") == 2 );
REQUIRE( f1.find_last_not_of("xay z1b2c3") == string_view::npos );
REQUIRE( f6.find_last_not_of("def") == 20 );
REQUIRE( f6.find_last_not_of("abcdef") == 17 );
REQUIRE( f6.find_last_not_of("abcdef ") == 1 );
REQUIRE( f6.find_first_not_of('z') == 2 );
REQUIRE( f6.find_first_not_of("z ") == 3 );
REQUIRE( f6.find_first_not_of("a ", 2) == 4 );
REQUIRE( f6.find_last_not_of("abc ", 9) == 1 );
REQUIRE( string_view{"abc"} == string_view{"abc"} );
REQUIRE_FALSE( string_view{"abc"} == string_view{"abd"} );
REQUIRE_FALSE( string_view{"abc"} == string_view{"abcd"} );
REQUIRE( string_view{"abc"} != string_view{"abd"} );
REQUIRE( string_view{"abc"} != string_view{"abcd"} );
REQUIRE( string_view{"abc"} < string_view{"abcd"} );
REQUIRE( string_view{"abc"} < string_view{"abd"} );
REQUIRE_FALSE( string_view{"abd"} < string_view{"abc"} );
REQUIRE_FALSE( string_view{"abcd"} < string_view{"abc"} );
REQUIRE_FALSE( string_view{"abc"} < string_view{"abc"} );
REQUIRE( string_view{"abd"} > string_view{"abc"} );
REQUIRE( string_view{"abcd"} > string_view{"abc"} );
REQUIRE_FALSE( string_view{"abc"} > string_view{"abd"} );
REQUIRE_FALSE( string_view{"abc"} > string_view{"abcd"} );
REQUIRE_FALSE( string_view{"abc"} > string_view{"abc"} );
REQUIRE( string_view{"abc"} <= string_view{"abcd"} );
REQUIRE( string_view{"abc"} <= string_view{"abc"} );
REQUIRE( string_view{"abc"} <= string_view{"abd"} );
REQUIRE( string_view{"abd"} >= string_view{"abc"} );
REQUIRE( string_view{"abc"} >= string_view{"abc"} );
REQUIRE( string_view{"abcd"} >= string_view{"abc"} );
}

View File

@ -0,0 +1,165 @@
#include "lokimq/batch.h"
#include "common.h"
#include <future>
TEST_CASE("tagged thread start functions", "[tagged][start]") {
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
lmq.set_general_threads(2);
lmq.set_batch_threads(2);
auto t_abc = lmq.add_tagged_thread("abc");
std::atomic<bool> start_called = false;
auto t_def = lmq.add_tagged_thread("def", [&] { start_called = true; });
std::this_thread::sleep_for(20ms);
{
auto lock = catch_lock();
REQUIRE_FALSE( start_called );
}
lmq.start();
wait_for([&] { return start_called.load(); });
{
auto lock = catch_lock();
REQUIRE( start_called );
}
}
TEST_CASE("tagged threads quit-before-start", "[tagged][quit]") {
auto lmq = std::make_unique<lokimq::LokiMQ>(get_logger(""), LogLevel::trace);
auto t_abc = lmq->add_tagged_thread("abc");
REQUIRE_NOTHROW(lmq.reset());
}
TEST_CASE("batch jobs to tagged threads", "[tagged][batch]") {
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
lmq.set_general_threads(2);
lmq.set_batch_threads(2);
std::thread::id id_abc, id_def;
auto t_abc = lmq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
auto t_def = lmq.add_tagged_thread("def", [&] { id_def = std::this_thread::get_id(); });
lmq.start();
std::atomic<bool> done = false;
std::thread::id id;
lmq.job([&] { id = std::this_thread::get_id(); done = true; });
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( id != id_abc );
REQUIRE( id != id_def );
}
done = false;
lmq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( id == id_abc );
}
done = false;
lmq.job([&] { id = std::this_thread::get_id(); done = true; }, t_def);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( id == id_def );
}
std::atomic<bool> sleep = true;
auto sleeper = [&] { for (int i = 0; sleep && i < 10; i++) { std::this_thread::sleep_for(25ms); } };
lmq.job(sleeper);
lmq.job(sleeper);
// This one should stall:
std::atomic<bool> bad = false;
lmq.job([&] { bad = true; });
std::this_thread::sleep_for(50ms);
done = false;
lmq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( done.load() );
REQUIRE_FALSE( bad.load() );
}
done = false;
// We can queue up a bunch of jobs which should all happen in order, and all on the abc thread.
std::vector<int> v;
for (int i = 0; i < 100; i++) {
lmq.job([&] { if (std::this_thread::get_id() == id_abc) v.push_back(v.size()); }, t_abc);
}
lmq.job([&] { done = true; }, t_abc);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( done.load() );
REQUIRE_FALSE( bad.load() );
REQUIRE( v.size() == 100 );
for (int i = 0; i < 100; i++)
REQUIRE( v[i] == i );
}
sleep = false;
wait_for([&] { return bad.load(); });
{
auto lock = catch_lock();
REQUIRE( bad.load() );
}
}
TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]") {
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
lmq.set_general_threads(4);
lmq.set_batch_threads(4);
std::thread::id id_abc;
auto t_abc = lmq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
lmq.start();
lokimq::Batch<int> batch;
for (int i = 1; i < 10; i++)
batch.add_job([i, &id_abc]() { if (std::this_thread::get_id() == id_abc) return 0; return i; });
std::atomic<int> result_sum = -1;
batch.completion([&](auto result) {
int sum = 0;
for (auto& r : result)
sum += r.get();
result_sum = std::this_thread::get_id() == id_abc ? sum : -sum;
}, t_abc);
lmq.batch(std::move(batch));
wait_for([&] { return result_sum.load() != -1; });
{
auto lock = catch_lock();
REQUIRE( result_sum == 45 );
}
}
TEST_CASE("timer job completion on tagged threads", "[tagged][timer]") {
lokimq::LokiMQ lmq{get_logger(""), LogLevel::trace};
lmq.set_general_threads(4);
lmq.set_batch_threads(4);
std::thread::id id_abc;
auto t_abc = lmq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
lmq.start();
std::atomic<int> ticks = 0;
std::atomic<int> abc_ticks = 0;
lmq.add_timer([&] { ticks++; }, 10ms);
lmq.add_timer([&] { if (std::this_thread::get_id() == id_abc) abc_ticks++; }, 10ms, true, t_abc);
wait_for([&] { return ticks.load() > 2 && abc_ticks > 2; });
{
auto lock = catch_lock();
REQUIRE( ticks.load() > 2 );
REQUIRE( abc_ticks.load() > 2 );
}
}