Compare commits

..

No commits in common. "dev" and "v1.0.3" have entirely different histories.
dev ... v1.0.3

56 changed files with 5103 additions and 8246 deletions

View file

@ -1,117 +0,0 @@
local docker_base = 'registry.oxen.rocks/lokinet-ci-';
local default_deps_nocxx = ['libsodium-dev', 'libzmq3-dev', 'liboxenc-dev'];
local submodule_commands = ['git fetch --tags', 'git submodule update --init --recursive --depth=1'];
local submodules = {
name: 'submodules',
image: 'drone/git',
commands: submodule_commands,
};
local apt_get_quiet = 'apt-get -o=Dpkg::Use-Pty=0 -q ';
local debian_pipeline(name,
image,
arch='amd64',
deps=['g++'] + default_deps_nocxx,
cmake_extra='',
build_type='Release',
extra_cmds=[],
distro='$$(lsb_release -sc)',
allow_fail=false) = {
kind: 'pipeline',
type: 'docker',
name: name,
platform: { arch: arch },
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
steps: [
submodules,
{
name: 'build',
image: image,
pull: 'always',
[if allow_fail then 'failure']: 'ignore',
commands: [
'echo "Building on ${DRONE_STAGE_MACHINE}"',
'echo "man-db man-db/auto-update boolean false" | debconf-set-selections',
apt_get_quiet + 'update',
apt_get_quiet + 'install -y eatmydata',
'eatmydata ' + apt_get_quiet + ' install --no-install-recommends -y lsb-release',
'cp contrib/deb.oxen.io.gpg /etc/apt/trusted.gpg.d',
'echo deb http://deb.oxen.io ' + distro + ' main >/etc/apt/sources.list.d/oxen.list',
'eatmydata ' + apt_get_quiet + ' update',
'eatmydata ' + apt_get_quiet + 'dist-upgrade -y',
'eatmydata ' + apt_get_quiet + 'install -y cmake git ninja-build pkg-config ccache ' + std.join(' ', deps),
'mkdir build',
'cd build',
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fdiagnostics-color=always -DCMAKE_BUILD_TYPE=' + build_type + ' -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ' + cmake_extra,
'ninja -v',
'./tests/tests --use-colour yes',
] + extra_cmds,
},
],
};
local clang(version) = debian_pipeline(
'Debian sid/clang-' + version + ' (amd64)',
docker_base + 'debian-sid-clang',
distro='sid',
deps=['clang-' + version] + default_deps_nocxx,
cmake_extra='-DCMAKE_C_COMPILER=clang-' + version + ' -DCMAKE_CXX_COMPILER=clang++-' + version + ' '
);
local full_llvm(version) = debian_pipeline(
'Debian sid/llvm-' + version + ' (amd64)',
docker_base + 'debian-sid-clang',
distro='sid',
deps=['clang-' + version, 'lld-' + version, 'libc++-' + version + '-dev', 'libc++abi-' + version + '-dev']
+ default_deps_nocxx,
cmake_extra='-DCMAKE_C_COMPILER=clang-' + version +
' -DCMAKE_CXX_COMPILER=clang++-' + version +
' -DCMAKE_CXX_FLAGS=-stdlib=libc++ ' +
std.join(' ', [
'-DCMAKE_' + type + '_LINKER_FLAGS=-fuse-ld=lld-' + version
for type in ['EXE', 'MODULE', 'SHARED', 'STATIC']
])
);
[
debian_pipeline('Debian sid (amd64)', docker_base + 'debian-sid', distro='sid'),
debian_pipeline('Debian sid/Debug (amd64)', docker_base + 'debian-sid', build_type='Debug', distro='sid'),
clang(16),
full_llvm(16),
debian_pipeline('Debian buster (amd64)', docker_base + 'debian-buster'),
debian_pipeline('Debian stable (i386)', docker_base + 'debian-stable/i386'),
debian_pipeline('Debian sid (ARM64)', docker_base + 'debian-sid', arch='arm64', distro='sid'),
debian_pipeline('Debian stable (armhf)', docker_base + 'debian-stable/arm32v7', arch='arm64'),
debian_pipeline('Debian buster (armhf)', docker_base + 'debian-buster/arm32v7', arch='arm64'),
debian_pipeline('Ubuntu focal (amd64)', docker_base + 'ubuntu-focal'),
debian_pipeline('Ubuntu bionic (amd64)',
docker_base + 'ubuntu-bionic',
deps=default_deps_nocxx,
cmake_extra='-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8'),
{
kind: 'pipeline',
type: 'exec',
name: 'macOS (w/macports)',
platform: { os: 'darwin', arch: 'amd64' },
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
steps: [
{ name: 'submodules', commands: submodule_commands },
{
name: 'build',
commands: [
'mkdir build',
'cd build',
'ulimit -n 1024', // Because macOS has a stupid tiny default ulimit
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fcolor-diagnostics -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
'ninja -v',
'./tests/tests --use-colour yes',
],
},
],
},
]

6
.gitmodules vendored
View file

@ -1,9 +1,9 @@
[submodule "mapbox-variant"]
path = mapbox-variant
url = https://github.com/mapbox/variant.git
[submodule "cppzmq"]
path = cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "Catch2"]
path = tests/Catch2
url = https://github.com/catchorg/Catch2.git
[submodule "oxen-encoding"]
path = oxen-encoding
url = https://github.com/oxen-io/oxen-encoding.git

View file

@ -1,163 +1,88 @@
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
foreach(lang C CXX)
if(NOT DEFINED CMAKE_${lang}_COMPILER_LAUNCHER AND NOT CMAKE_${lang}_COMPILER MATCHES ".*/ccache")
message(STATUS "Enabling ccache for ${lang}")
set(CMAKE_${lang}_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE STRING "")
endif()
endforeach()
endif()
cmake_minimum_required(VERSION 3.7)
# Has to be set before `project()`, and ignored on non-macos:
set(CMAKE_OSX_DEPLOYMENT_TARGET 10.12 CACHE STRING "macOS deployment target (Apple clang only)")
project(liboxenmq
VERSION 1.2.16
LANGUAGES CXX C)
project(liblokimq CXX)
include(GNUInstallDirs)
message(STATUS "oxenmq v${PROJECT_VERSION}")
set(LOKIMQ_VERSION_MAJOR 1)
set(LOKIMQ_VERSION_MINOR 0)
set(LOKIMQ_VERSION_PATCH 3)
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
message(STATUS "lokimq v${LOKIMQ_VERSION}")
set(OXENMQ_LIBVERSION 0)
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
set(oxenmq_IS_TOPLEVEL_PROJECT TRUE)
else()
set(oxenmq_IS_TOPLEVEL_PROJECT FALSE)
endif()
set(LOKIMQ_LIBVERSION 0)
option(BUILD_SHARED_LIBS "Build shared libraries instead of static ones" ON)
set(oxenmq_INSTALL_DEFAULT OFF)
if(BUILD_SHARED_LIBS OR oxenmq_IS_TOPLEVEL_PROJECT)
set(oxenmq_INSTALL_DEFAULT ON)
endif()
set(oxenmq_EPOLL_DEFAULT OFF)
if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND NOT CMAKE_CROSSCOMPILING)
set(oxenmq_EPOLL_DEFAULT ON)
endif()
option(OXENMQ_BUILD_TESTS "Building and perform oxenmq tests" ${oxenmq_IS_TOPLEVEL_PROJECT})
option(OXENMQ_INSTALL "Add oxenmq libraries and headers to cmake install target; defaults to ON if BUILD_SHARED_LIBS is enabled or we are the top-level project; OFF for a static subdirectory build" ${oxenmq_INSTALL_DEFAULT})
option(OXENMQ_INSTALL_CPPZMQ "Install cppzmq header with oxenmq/ headers (requires OXENMQ_INSTALL)" ON)
option(OXENMQ_USE_EPOLL "Use epoll for socket polling (requires Linux)" ${oxenmq_EPOLL_DEFAULT})
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
configure_file(oxenmq/version.h.in oxenmq/version.h @ONLY)
configure_file(liboxenmq.pc.in liboxenmq.pc @ONLY)
configure_file(lokimq/version.h.in lokimq/version.h @ONLY)
configure_file(liblokimq.pc.in liblokimq.pc @ONLY)
add_library(oxenmq
oxenmq/address.cpp
oxenmq/auth.cpp
oxenmq/connections.cpp
oxenmq/jobs.cpp
oxenmq/oxenmq.cpp
oxenmq/proxy.cpp
oxenmq/worker.cpp
add_library(lokimq
lokimq/auth.cpp
lokimq/batch.cpp
lokimq/bt_serialize.cpp
lokimq/connections.cpp
lokimq/jobs.cpp
lokimq/lokimq.cpp
lokimq/proxy.cpp
lokimq/worker.cpp
)
set_target_properties(oxenmq PROPERTIES SOVERSION ${OXENMQ_LIBVERSION})
if(OXENMQ_USE_EPOLL)
target_compile_definitions(oxenmq PRIVATE OXENMQ_USE_EPOLL)
endif()
set_target_properties(lokimq PROPERTIES SOVERSION ${LOKIMQ_LIBVERSION})
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
target_link_libraries(oxenmq PRIVATE Threads::Threads)
if(TARGET oxenc::oxenc)
add_library(_oxenmq_external_oxenc INTERFACE IMPORTED)
target_link_libraries(_oxenmq_external_oxenc INTERFACE oxenc::oxenc)
target_link_libraries(oxenmq PUBLIC _oxenmq_external_oxenc)
message(STATUS "using pre-existing oxenc::oxenc target")
elseif(BUILD_SHARED_LIBS)
include(FindPkgConfig)
pkg_check_modules(oxenc liboxenc IMPORTED_TARGET)
if(oxenc_FOUND)
# Work around cmake bug 22180 (PkgConfig::tgt not set if no flags needed)
if(TARGET PkgConfig::oxenc OR CMAKE_VERSION VERSION_GREATER_EQUAL "3.21")
target_link_libraries(oxenmq PUBLIC PkgConfig::oxenc)
endif()
else()
add_subdirectory(oxen-encoding)
target_link_libraries(oxenmq PUBLIC oxenc::oxenc)
endif()
else()
add_subdirectory(oxen-encoding)
target_link_libraries(oxenmq PUBLIC oxenc::oxenc)
endif()
target_link_libraries(lokimq PRIVATE Threads::Threads)
# libzmq is nearly impossible to link statically from a system-installed static library: it depends
# on a ton of other libraries, some of which are not all statically available. If the caller wants
# to mess with this, so be it: they can set up a libzmq target and we'll use it. Otherwise if they
# asked us to do things statically, don't even try to find a system lib and just build it.
set(oxenmq_build_static_libzmq OFF)
set(lokimq_build_static_libzmq OFF)
if(TARGET libzmq)
target_link_libraries(oxenmq PUBLIC libzmq)
target_link_libraries(lokimq PUBLIC libzmq)
elseif(BUILD_SHARED_LIBS)
include(FindPkgConfig)
pkg_check_modules(libzmq libzmq>=4.3 IMPORTED_TARGET)
if(libzmq_FOUND)
# Debian sid includes a -isystem in the mit-krb package that, starting with pkg-config 0.29.2,
# breaks cmake's pkgconfig module because it stupidly thinks "-isystem" is a path, so if we find
# -isystem in the include dirs then hack it out.
get_property(zmq_inc TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
list(FIND zmq_inc "-isystem" broken_isystem)
if(NOT broken_isystem EQUAL -1)
list(REMOVE_AT zmq_inc ${broken_isystem})
set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc})
endif()
target_link_libraries(oxenmq PUBLIC PkgConfig::libzmq)
target_link_libraries(lokimq PUBLIC PkgConfig::libzmq)
else()
set(oxenmq_build_static_libzmq ON)
set(lokimq_build_static_libzmq ON)
endif()
else()
set(oxenmq_build_static_libzmq ON)
set(lokimq_build_static_libzmq ON)
endif()
if(oxenmq_build_static_libzmq)
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled version")
if(lokimq_build_static_libzmq)
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled 4.3.2")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/local-libzmq")
include(LocalLibzmq)
target_link_libraries(oxenmq PUBLIC libzmq_vendor)
target_link_libraries(lokimq PUBLIC libzmq_vendor)
endif()
target_include_directories(oxenmq
target_include_directories(lokimq
PUBLIC
$<INSTALL_INTERFACE:>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/cppzmq>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/mapbox-variant/include>
)
target_compile_options(oxenmq PRIVATE -Wall -Wextra)
option(WARNINGS_AS_ERRORS "treat all warnings as errors" ON)
if(WARNINGS_AS_ERRORS)
target_compile_options(oxenmq PRIVATE -Werror)
endif()
set_target_properties(oxenmq PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
target_compile_options(lokimq PRIVATE -Wall -Wextra -Werror)
set_target_properties(lokimq PROPERTIES
CXX_STANDARD 14
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
POSITION_INDEPENDENT_CODE ON
)
function(link_dep_libs target linktype libdirs)
foreach(lib ${ARGN})
find_library(link_lib-${lib} NAMES ${lib} PATHS ${libdirs})
message(STATUS "FIND ${lib} FOUND ${link_lib-${lib}}")
if(link_lib-${lib})
target_link_libraries(${target} ${linktype} ${link_lib-${lib}})
endif()
@ -167,72 +92,87 @@ endfunction()
# If the caller has already set up a sodium target then we will just link to it, otherwise we go
# looking for it.
if(TARGET sodium)
target_link_libraries(oxenmq PUBLIC sodium)
if(oxenmq_build_static_libzmq)
target_link_libraries(lokimq PRIVATE sodium)
if(lokimq_build_static_libzmq)
target_link_libraries(libzmq_vendor INTERFACE sodium)
endif()
else()
include(FindPkgConfig)
pkg_check_modules(sodium REQUIRED libsodium IMPORTED_TARGET)
if(BUILD_SHARED_LIBS)
target_link_libraries(oxenmq PUBLIC PkgConfig::sodium)
if(oxenmq_build_static_libzmq)
target_link_libraries(lokimq PRIVATE PkgConfig::sodium)
if(lokimq_build_static_libzmq)
target_link_libraries(libzmq_vendor INTERFACE PkgConfig::sodium)
endif()
else()
link_dep_libs(oxenmq PUBLIC "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
target_include_directories(oxenmq PUBLIC ${sodium_STATIC_INCLUDE_DIRS})
if(oxenmq_build_static_libzmq)
link_dep_libs(lokimq PRIVATE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
target_include_directories(lokimq PRIVATE ${sodium_STATIC_INCLUDE_DIRS})
if(lokimq_build_static_libzmq)
link_dep_libs(libzmq_vendor INTERFACE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
target_link_libraries(libzmq_vendor INTERFACE ${sodium_STATIC_INCLUDE_DIRS})
endif()
endif()
endif()
add_library(oxenmq::oxenmq ALIAS oxenmq)
add_library(lokimq::lokimq ALIAS lokimq)
export(
TARGETS oxenmq
NAMESPACE oxenmq::
FILE oxenmqTargets.cmake
TARGETS lokimq
NAMESPACE lokimq::
FILE lokimqTargets.cmake
)
install(
TARGETS lokimq
EXPORT lokimqConfig
DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
if(OXENMQ_INSTALL)
install(
TARGETS oxenmq
EXPORT oxenmqConfig
DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
install(
FILES oxenmq/address.h
oxenmq/auth.h
oxenmq/batch.h
oxenmq/connections.h
oxenmq/fmt.h
oxenmq/message.h
oxenmq/oxenmq.h
oxenmq/pubsub.h
${CMAKE_CURRENT_BINARY_DIR}/oxenmq/version.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq
)
if(OXENMQ_INSTALL_CPPZMQ)
install(
FILES cppzmq/zmq.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq
)
endif()
install(
FILES ${CMAKE_CURRENT_BINARY_DIR}/liboxenmq.pc
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig
)
install(
FILES lokimq/auth.h
lokimq/batch.h
lokimq/bt_serialize.h
lokimq/connections.h
lokimq/hex.h
lokimq/lokimq.h
lokimq/message.h
lokimq/string_view.h
${CMAKE_CURRENT_BINARY_DIR}/lokimq/version.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
)
option(LOKIMQ_INSTALL_MAPBOX_VARIANT "Install mapbox-variant headers with lokimq/ headers" ON)
if(LOKIMQ_INSTALL_MAPBOX_VARIANT)
install(
FILES mapbox-variant/include/mapbox/variant.hpp
mapbox-variant/include/mapbox/variant_cast.hpp
mapbox-variant/include/mapbox/variant_io.hpp
mapbox-variant/include/mapbox/variant_visitor.hpp
mapbox-variant/include/mapbox/recursive_wrapper.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq/mapbox
)
endif()
if(OXENMQ_BUILD_TESTS)
option(LOKIMQ_INSTALL_CPPZMQ "Install cppzmq header with lokimq/ headers" ON)
if(LOKIMQ_INSTALL_CPPZMQ)
install(
FILES cppzmq/zmq.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
)
endif()
install(
FILES ${CMAKE_CURRENT_BINARY_DIR}/liblokimq.pc
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig
)
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
set(lokimq_IS_TOPLEVEL_PROJECT TRUE)
else()
set(lokimq_IS_TOPLEVEL_PROJECT FALSE)
endif()
option(LOKIMQ_BUILD_TESTS "Building and perform lokimq tests" ${lokimq_IS_TOPLEVEL_PROJECT})
if(LOKIMQ_BUILD_TESTS)
add_subdirectory(tests)
endif()

View file

@ -1,15 +1,14 @@
# OxenMQ - high-level zeromq-based message passing for network-based projects
# LokiMQ - zeromq-based message passing for Loki projects
This C++17 library contains an abstraction layer around ZeroMQ to provide a high-level interface to
authentication, RPC, and message passing. It is used extensively within Oxen projects (hence the
name) as the underlying communication mechanism of SN-to-SN communication ("quorumnet"), the RPC
interface used by wallets and local daemon commands, communication channels between oxend and
auxiliary services (storage server, lokinet), and also provides local multithreaded job scheduling
within a process.
This C++14 library contains an abstraction layer around ZeroMQ to support integration with Loki
authentication, RPC, and message passing. It is designed to be usable as the underlying
communication mechanism of SN-to-SN communication ("quorumnet"), the RPC interface used by wallets
and local daemon commands, communication channels between lokid and auxiliary services (storage
server, lokinet), and also provides a local multithreaded job scheduling within a process.
Messages channels can be encrypted (using x25519) or not -- however opening an encrypted channel
requires knowing the server pubkey. Within Oxen, all SN-to-SN traffic is encrypted, and other
traffic can be encrypted as needed.
requires knowing the server pubkey. All SN-to-SN traffic is encrypted, and other traffic can be
encrypted as needed.
This library makes minimal use of mutexes, and none in the hot paths of the code, instead mostly
relying on ZMQ sockets for synchronization; for more information on this (and why this is generally
@ -17,20 +16,20 @@ much better performing and more scalable) see the ZMQ guide documentation on the
## Basic message structure
OxenMQ messages come in two fundamental forms: "commands", consisting of a command named and
LokiMQ messages come in two fundamental forms: "commands", consisting of a command named and
optional arguments, and "requests", consisting of a request name, a request tag, and optional
arguments.
All channels are capable of bidirectional communication, and multiple messages can be in transit in
either direction at any time. OxenMQ sets up a "listener" and "client" connections, but these only
either direction at any time. LokiMQ sets up a "listener" and "client" connections, but these only
determine how connections are established: once established, commands can be issued by either party.
The command/request string is one of two types:
`category.command` - for commands/requests registered by the OxenMQ caller (e.g. oxend). Here
`category.command` - for commands/requests registered by the LokiMQ caller (e.g. lokid). Here
`category` must be at least one character not containing a `.` and `command` may be anything. These
categories and commands are registered according to general function and authentication level (more
on this below). For example, for oxend categories are:
on this below). For example, for lokid categories are:
- `system` - is for RPC commands related to the system administration such as mining, getting
sensitive statistics, accessing SN private keys, remote shutdown, etc.
@ -43,14 +42,14 @@ on this below). For example, for oxend categories are:
The difference between a request and a command is that a request includes an additional opaque tag
value which is used to identify a reply. For example you could register a `general.backwards`
request that takes a string that receives a reply containing that string reversed. When invoking
the request via OxenMQ you provide a callback to be invoked when the reply arrives. On the wire
the request via LokiMQ you provide a callback to be invoked when the reply arrives. On the wire
this looks like:
<<< [general.backwards] [v71.&a] [hello world]
>>> [REPLY] [v71.&a] [dlrow olleh]
where each [] denotes a message part and `v71.&a` is a unique randomly generated identifier handled
by OxenMQ (both the invoker and the recipient code only see the `hello world`/`dlrow olleh` message
by LokiMQ (both the invoker and the recipient code only see the `hello world`/`dlrow olleh` message
parts).
In contrast, regular registered commands have no identifier or expected reply callback. For example
@ -93,7 +92,7 @@ handled for you transparently.
## Command arguments
Optional command/request arguments are always strings on the wire. The OxenMQ-using developer is
Optional command/request arguments are always strings on the wire. The LokiMQ-using developer is
free to create whatever encoding she wants, and these can vary across commands. For example
`wallet.tx` might be a request that returns a transaction in binary, while `wallet.tx_info` might
return tx metadata in JSON, and `p2p.send_tx` might encode tx data and metadata in a bt-encoded
@ -102,49 +101,47 @@ data string.
No structure at all is imposed on message data to allow maximum flexibility; it is entirely up to
the calling code to handle all encoding/decoding duties.
Internal commands passed between OxenMQ-managed threads use either plain strings or bt-encoded
dictionaries. See `oxenmq/bt_serialize.h` if you want a bt serializer/deserializer.
Internal commands passed between LokiMQ-managed threads use either plain strings or bt-encoded
dictionaries. See `lokimq/bt_serialize.h` if you want a bt serializer/deserializer.
## Sending commands
Sending a command to a peer is done by using a connection ID, and generally falls into either a
`send()` method or a `request()` method.
omq.send(conn, "category.command", "some data");
omq.request(conn, "category.command", [](bool success, std::vector<std::string> data) {
lmq.send(conn, "category.command", "some data");
lmq.request(conn, "category.command", [](bool success, std::vector<std::string> data) {
if (success) { std::cout << "Remote replied: " << data.at(0) << "\n"; } });
The connection ID generally has two possible values:
- a string containing a service node pubkey. In this mode OxenMQ will look for the given SN in
- a string containing a service node pubkey. In this mode LokiMQ will look for the given SN in
already-established connections, reusing a connection if one exists. If no connection already
exists, a new connection to the given SN is attempted (this requires constructing the OxenMQ
exists, a new connection to the given SN is attempted (this requires constructing the LokiMQ
object with a callback to determine SN remote addresses).
- a ConnectionID object, typically returned by the `connect_remote` method (although there are other
places to get one, such as from the `Message` object passed to a command: see the following
section).
```C++
// Send to a service node, establishing a connection if necessary:
std::string my_sn = ...; // 32-byte pubkey of a known SN
omq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
lmq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
// Connect to a remote by address then send it something
auto conn = omq.connect_remote("tcp://127.0.0.1:4567",
auto conn = lmq.connect_remote("tcp://127.0.0.1:4567",
[](ConnectionID c) { std::cout << "Connected!\n"; },
[](ConnectionID c, string_view f) { std::cout << "Connect failed: " << f << "\n" });
omq.request(conn, "rpc.get_height", [](bool s, std::vector<std::string> d) {
lmq.request(conn, "rpc.get_height", [](bool s, std::vector<std::string> d) {
if (s && d.size() == 1)
std::cout << "Current height: " << d[0] << "\n";
else
std::cout << "Timeout fetching height!";
});
```
## Command invocation
The application registers categories and registers commands within these categories with callbacks.
The callbacks are passed a OxenMQ::Message object from which the message (plus various connection
The callbacks are passed a LokiMQ::Message object from which the message (plus various connection
information) can be obtained. There is no structure imposed at all on the data passed in subsequent
message parts: it is up to the command itself to deserialize however it wishes (e.g. JSON,
bt-encoded, or any other encoding).
@ -152,13 +149,13 @@ bt-encoded, or any other encoding).
The Message object also provides methods for replying to the caller. Simple replies queue a reply
if the client is still connected. Replies to service nodes can also be "strong" replies: when
replying to a SN that has closed connection with a strong reply we will attempt to reestablish a
connection to deliver the message. In order for this to work the OxenMQ caller must provide a
connection to deliver the message. In order for this to work the LokiMQ caller must provide a
lookup function to retrieve the remote address given a SN x25519 pubkey.
### Callbacks
Invoked command functions are always invoked with exactly one arguments: a non-const OxenMQ::Message
reference from which the connection info, OxenMQ object, and message data can be obtained.
Invoked command functions are always invoked with exactly one arguments: a non-const LokiMQ::Message
reference from which the connection info, LokiMQ object, and message data can be obtained.
The Message object also contains a `ConnectionID` object as the public `conn` member; it is safe to
take a copy of this and then use it later to send commands to this peer. (For example, a wallet
@ -188,7 +185,7 @@ logins.
Configuration defaults allows controlling the default access for an incoming connection based on its
remote address. Typically this is used to allow connections from localhost (or a unix domain
socket) to automatically be an Admin connection without requiring explicit authentication. This
also allows configuration of how public connections should be treated: for example, an oxend running
also allows configuration of how public connections should be treated: for example, a lokid running
as a public RPC server would do so by granting Basic access to all incoming connections.
Explicit logins allow the daemon to specify username/passwords with mapping to Basic or Admin
@ -197,7 +194,7 @@ authentication levels.
Thus, for example, a daemon could be configured to be allow Basic remote access with authentication
(i.e. requiring a username/password login given out to people who should be able to access).
For example, in oxend the categories described above have authentication levels of:
For example, in lokid the categories described above have authentication levels of:
- `system` - Admin
- `sn` - ServiceNode
@ -206,7 +203,7 @@ For example, in oxend the categories described above have authentication levels
### Service Node authentication
In order to handle ServiceNode authentication, OxenMQ uses an Allow callback invoked during
In order to handle ServiceNode authentication, LokiMQ uses an Allow callback invoked during
connection to determine both whether to allow the connection, and to determine whether the incoming
connection is an active service node.
@ -227,7 +224,7 @@ such aliases be used only temporarily for version transitions.
## Threads
OxenMQ operates a pool of worker threads to handle jobs. The simplest use just allocates new jobs
LokiMQ operates a pool of worker threads to handle jobs. The simplest use just allocates new jobs
to a free worker thread, and we have a "general threads" value to configure how many such threads
are available.
@ -242,7 +239,7 @@ Note that these actual reserved threads are not exclusive: reserving M of N tota
category simply ensures that no more than (N-M) threads are being used for other categories at any
given time, but the actual jobs may run on any worker thread.
As mentioned above, OxenMQ tries to avoid exceeding the configured general threads value (G)
As mentioned above, LokiMQ tries to avoid exceeding the configured general threads value (G)
whenever possible: the only time we will dispatch a job to a worker thread when we have >= G threads
already running is when a new command arrives, the category reserves M threads, and the thread pool
is currently processing fewer than M jobs for that category.
@ -278,7 +275,7 @@ when a command with reserve threads arrived.
A common pattern is one where a single thread suddenly has some work that can be be parallelized.
You could employ some blocking, locking, mutex + condition variable monstrosity, but you shouldn't.
Instead OxenMQ provides a mechanism for this by allowing you to submit a batch of jobs with a
Instead LokiMQ provides a mechanism for this by allowing you to submit a batch of jobs with a
completion callback. All jobs will be queued and, when the last one finishes, the finalization
callback will be queued to continue with the task.
@ -303,7 +300,7 @@ double do_my_task(int input) {
return 3.0 * input;
}
void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
void continue_big_task(std::vector<lokimq::job_result<double>> results) {
double sum = 0;
for (auto& r : results) {
try {
@ -324,7 +321,7 @@ void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
void start_big_task() {
size_t num_jobs = 32;
oxenmq::Batch<double /*return type*/> batch;
lokimq::Batch<double /*return type*/> batch;
batch.reserve(num_jobs);
for (size_t i = 0; i < num_jobs; i++)
@ -332,7 +329,7 @@ void start_big_task() {
batch.completion(&continue_big_task);
omq.batch(std::move(batch));
lmq.batch(std::move(batch));
// ... to be continued in `continue_big_task` after all the jobs finish
// Can do other things here, but note that continue_big_task could run
@ -342,12 +339,12 @@ void start_big_task() {
This code deliberately does not support blocking to wait for the tasks to finish: if you want such a
poor design (which is a recipe for deadlocks: imagine jobs that queuing other jobs that can end up
exhausting the worker threads with waiting jobs) then you can implement it yourself; OxenMQ isn't
exhausting the worker threads with waiting jobs) then you can implement it yourself; LokiMQ isn't
going to help you hurt yourself like that.
### Single-job queuing
As a shortcut there is a `omq.job(...)` method that schedules a single task (with no return value)
As a shortcut there is a `lmq.job(...)` method that schedules a single task (with no return value)
in the batch job queue. This is useful when some event requires triggering some other event, but
you don't need to wait for or collect its result. (Internally this is just a convenience method
around creating a single-job, no-completion Batch job).
@ -359,7 +356,7 @@ either using your own thread or a periodic timer (see below) to shepherd those o
## Timers
OxenMQ supports scheduling periodic tasks via the `add_timer()` function. These timers have an
LokiMQ supports scheduling periodic tasks via the `add_timer()` function. These timers have an
interval and are scheduled as (single-job) batches when the timer fires. They also support
"squelching" (enabled by default) that supresses the job being scheduled if a previously scheduled
job is already scheduled or running.

View file

@ -1,7 +1,7 @@
set(LIBZMQ_PREFIX ${CMAKE_BINARY_DIR}/libzmq)
set(ZeroMQ_VERSION 4.3.5)
set(ZeroMQ_VERSION 4.3.2)
set(LIBZMQ_URL https://github.com/zeromq/libzmq/releases/download/v${ZeroMQ_VERSION}/zeromq-${ZeroMQ_VERSION}.tar.gz)
set(LIBZMQ_HASH SHA512=a71d48aa977ad8941c1609947d8db2679fc7a951e4cd0c3a1127ae026d883c11bd4203cf315de87f95f5031aec459a731aec34e5ce5b667b8d0559b157952541)
set(LIBZMQ_HASH SHA512=b6251641e884181db9e6b0b705cced7ea4038d404bdae812ff47bdd0eed12510b6af6846b85cb96898e253ccbac71eca7fe588673300ddb9c3109c973250c8e4)
message(${LIBZMQ_URL})
@ -13,30 +13,19 @@ endif()
file(MAKE_DIRECTORY ${LIBZMQ_PREFIX}/include)
set(libzmq_compiler_args)
foreach(lang C CXX)
foreach(thing COMPILER FLAGS COMPILER_LAUNCHER)
if(DEFINED CMAKE_${lang}_${thing})
list(APPEND libzmq_compiler_args "-DCMAKE_${lang}_${thing}=${CMAKE_${lang}_${thing}}")
endif()
endforeach()
endforeach()
include(ExternalProject)
include(ProcessorCount)
ExternalProject_Add(libzmq_external
PREFIX ${LIBZMQ_PREFIX}
URL ${LIBZMQ_URL}
URL_HASH ${LIBZMQ_HASH}
CMAKE_ARGS ${libzmq_compiler_args}
-DCMAKE_BUILD_TYPE=Release
-DWITH_LIBSODIUM=ON -DZMQ_BUILD_TESTS=OFF -DWITH_PERF_TOOL=OFF -DENABLE_DRAFTS=OFF
CMAKE_ARGS -DWITH_LIBSODIUM=ON -DZMQ_BUILD_TESTS=OFF -DWITH_PERF_TOOL=OFF -DENABLE_DRAFTS=OFF
-DBUILD_SHARED=OFF -DBUILD_STATIC=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX=${LIBZMQ_PREFIX}
BUILD_BYPRODUCTS ${LIBZMQ_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libzmq.a
BUILD_BYPRODUCTS ${LIBZMQ_PREFIX}/lib/libzmq.a
)
add_library(libzmq_vendor STATIC IMPORTED GLOBAL)
add_dependencies(libzmq_vendor libzmq_external)
set_target_properties(libzmq_vendor PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${LIBZMQ_PREFIX}/include
IMPORTED_LOCATION ${LIBZMQ_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libzmq.a)
IMPORTED_LOCATION ${LIBZMQ_PREFIX}/lib/libzmq.a)

Binary file not shown.

2
cppzmq

@ -1 +1 @@
Subproject commit 76bf169fd67b8e99c1b0e6490029d9cd5ef97666
Subproject commit 8d5c9a88988dcbebb72939ca0939d432230ffde1

View file

@ -3,12 +3,11 @@ exec_prefix=${prefix}
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
Name: liboxenmq
Description: ZeroMQ-based communication library
Version: @PROJECT_VERSION@
Name: liblokimq
Description: ZeroMQ-based communication library for Loki
Version: @LOKIMQ_VERSION@
Libs: -L${libdir} -loxenmq
Libs: -L${libdir} -llokimq
Libs.private: @PRIVATE_LIBS@
Requires: liboxenc
Requires.private: libzmq libsodium
Cflags: -I${includedir}

187
lokimq/auth.cpp Normal file
View file

@ -0,0 +1,187 @@
#include "lokimq.h"
#include "hex.h"
#include "lokimq-internal.h"
namespace lokimq {
std::ostream& operator<<(std::ostream& o, AuthLevel a) {
return o << to_string(a);
}
namespace {
// Builds a ZMTP metadata key-value pair. These will be available on every message from that peer.
// Keys must start with X- and be <= 255 characters.
std::string zmtp_metadata(string_view key, string_view value) {
assert(key.size() > 2 && key.size() <= 255 && key[0] == 'X' && key[1] == '-');
std::string result;
result.reserve(1 + key.size() + 4 + value.size());
result += static_cast<char>(key.size()); // Size octet of key
result.append(&key[0], key.size()); // key data
for (int i = 24; i >= 0; i -= 8) // 4-byte size of value in network order
result += static_cast<char>((value.size() >> i) & 0xff);
result.append(&value[0], value.size()); // value data
return result;
}
}
bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
const std::string& command, const category& cat, zmq::message_t& msg) {
std::string reply;
if (peer.auth_level < cat.access.auth) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
": peer auth level ", peer.auth_level, " < ", cat.access.auth);
reply = "FORBIDDEN";
}
else if (cat.access.local_sn && !local_service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
": that command is only available when this LokiMQ is running in service node mode");
reply = "NOT_A_SERVICE_NODE";
}
else if (cat.access.remote_sn && !peer.service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
": remote is not recognized as a service node");
// Disconnect: we don't think the remote is a SN, but it issued a command only SNs should be
// issuing. Drop the connection; if the remote has something important to relay it will
// reconnect, at which point we will reassess the SN status on the new incoming connection.
if (outgoing)
proxy_disconnect(peer.service_node ? ConnectionID{peer.pubkey} : conn_index_to_id[conn_index], 1s);
else
send_routed_message(connections[conn_index], peer.route, "BYE");
return false;
}
if (reply.empty())
return true;
if (outgoing)
send_direct_message(connections[conn_index], std::move(reply), command);
else
send_routed_message(connections[conn_index], peer.route, std::move(reply), command);
return false;
}
void LokiMQ::process_zap_requests() {
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
#ifndef NDEBUG
if (log_level() >= LogLevel::trace) {
std::ostringstream o;
o << "Processing ZAP authentication request:";
for (size_t i = 0; i < frames.size(); i++) {
o << "\n[" << i << "]: ";
auto v = view(frames[i]);
if (i == 1 || i == 6)
o << to_hex(v);
else
o << v;
}
log_(LogLevel::trace, __FILE__, __LINE__, o.str());
} else
#endif
LMQ_LOG(debug, "Processing ZAP authentication request");
// https://rfc.zeromq.org/spec:27/ZAP/
//
// The request message SHALL consist of the following message frames:
//
// The version frame, which SHALL contain the three octets "1.0".
// The request id, which MAY contain an opaque binary blob.
// The domain, which SHALL contain a (non-empty) string.
// The address, the origin network IP address.
// The identity, the connection Identity, if any.
// The mechanism, which SHALL contain a string.
// The credentials, which SHALL be zero or more opaque frames.
//
// The reply message SHALL consist of the following message frames:
//
// The version frame, which SHALL contain the three octets "1.0".
// The request id, which MAY contain an opaque binary blob.
// The status code, which SHALL contain a string.
// The status text, which MAY contain a string.
// The user id, which SHALL contain a string.
// The metadata, which MAY contain a blob.
//
// (NB: there are also null address delimiters at the beginning of each mentioned in the
// RFC, but those have already been removed through the use of a REP socket)
std::vector<std::string> response_vals(6);
response_vals[0] = "1.0"; // version
if (frames.size() >= 2)
response_vals[1] = std::string{view(frames[1])}; // unique identifier
std::string &status_code = response_vals[2], &status_text = response_vals[3];
if (frames.size() < 6 || view(frames[0]) != "1.0") {
LMQ_LOG(error, "Bad ZAP authentication request: version != 1.0 or invalid ZAP message parts");
status_code = "500";
status_text = "Internal error: invalid auth request";
} else {
auto auth_domain = view(frames[2]);
size_t bind_id = (size_t) -1;
try {
bind_id = bt_deserialize<size_t>(view(frames[2]));
} catch (...) {}
if (bind_id >= bind.size()) {
LMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
status_code = "400";
status_text = "Unknown authentication domain: " + std::string{auth_domain};
} else if (bind[bind_id].second.curve
? !(frames.size() == 7 && view(frames[5]) == "CURVE")
: !(frames.size() == 6 && view(frames[5]) == "NULL")) {
LMQ_LOG(error, "Bad ZAP authentication request: invalid ",
bind[bind_id].second.curve ? "CURVE" : "NULL", " authentication request");
status_code = "500";
status_text = "Invalid authentication request mechanism";
} else if (bind[bind_id].second.curve && frames[6].size() != 32) {
LMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
status_code = "500";
status_text = "Invalid public key size for CURVE authentication";
} else {
auto ip = view(frames[3]);
string_view pubkey;
if (bind[bind_id].second.curve)
pubkey = view(frames[6]);
auto result = bind[bind_id].second.allow(ip, pubkey);
bool sn = result.remote_sn;
auto& user_id = response_vals[4];
if (bind[bind_id].second.curve) {
user_id.reserve(64);
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
}
if (result.auth <= AuthLevel::denied || result.auth > AuthLevel::admin) {
LMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
" with initial auth level ", result.auth);
status_code = "400";
status_text = "Access denied";
user_id.clear();
} else {
LMQ_LOG(info, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection with authentication level ", result.auth,
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
auto& metadata = response_vals[5];
metadata += zmtp_metadata("X-SN", result.remote_sn ? "1" : "0");
metadata += zmtp_metadata("X-AuthLevel", to_string(result.auth));
status_code = "200";
status_text = "";
}
}
}
LMQ_TRACE("ZAP request result: ", status_code, " ", status_text);
std::vector<zmq::message_t> response;
response.reserve(response_vals.size());
for (auto &r : response_vals) response.push_back(create_message(std::move(r)));
send_message_parts(zap_auth, response.begin(), response.end());
}
}
}

33
lokimq/auth.h Normal file
View file

@ -0,0 +1,33 @@
#pragma once
#include <iostream>
namespace lokimq {
/// Authentication levels for command categories and connections
enum class AuthLevel {
denied, ///< Not actually an auth level, but can be returned by the AllowFunc to deny an incoming connection.
none, ///< No authentication at all; any random incoming ZMQ connection can invoke this command.
basic, ///< Basic authentication commands require a login, or a node that is specifically configured to be a public node (e.g. for public RPC).
admin, ///< Advanced authentication commands require an admin user, either via explicit login or by implicit login from localhost. This typically protects administrative commands like shutting down, starting mining, or access sensitive data.
};
std::ostream& operator<<(std::ostream& os, AuthLevel a);
/// The access level for a command category
struct Access {
/// Minimum access level required
AuthLevel auth = AuthLevel::none;
/// If true only remote SNs may call the category commands
bool remote_sn = false;
/// If true the category requires that the local node is a SN
bool local_sn = false;
};
/// Return type of the AllowFunc: this determines whether we allow the connection at all, and if so,
/// sets the initial authentication level and tells LokiMQ whether the other end is an active SN.
struct Allow {
AuthLevel auth = AuthLevel::none;
bool remote_sn = false;
};
}

7
lokimq/batch.cpp Normal file
View file

@ -0,0 +1,7 @@
#include "batch.h"
#include "lokimq-internal.h"
namespace lokimq {
}

View file

@ -1,4 +1,4 @@
// Copyright (c) 2020-2021, The Oxen Project
// Copyright (c) 2020, The Loki Project
//
// All rights reserved.
//
@ -30,31 +30,24 @@
#include <exception>
#include <functional>
#include <vector>
#include "oxenmq.h"
#include "lokimq.h"
namespace oxenmq {
namespace lokimq {
namespace detail {
enum class BatchState {
enum class BatchStatus {
running, // there are still jobs to run (or running)
complete, // the batch is complete but still has a completion job to call
complete_proxy, // same as `complete`, but the completion job should be invoked immediately in the proxy thread (be very careful)
done // the batch is complete and has no completion function
};
struct BatchStatus {
BatchState state;
int thread;
};
// Virtual base class for Batch<R>
class Batch {
public:
// Returns the number of jobs in this batch and whether any of them are thread-specific
virtual std::pair<size_t, bool> size() const = 0;
// Returns a vector of exactly the same length of size().first containing the tagged thread ids
// of the batch jobs or 0 for general jobs.
virtual std::vector<int> threads() const = 0;
// Returns the number of jobs in this batch
virtual size_t size() const = 0;
// Called in a worker thread to run the job
virtual void run_job(int i) = 0;
// Called in the main proxy thread when the worker returns from finishing a job. The return
@ -78,7 +71,7 @@ public:
* This is designed to be like a very stripped down version of a std::promise/std::future pair. We
* reimplemented it, however, because by ditching all the thread synchronization that promise/future
* guarantees we can substantially reduce call overhead (by a factor of ~8 according to benchmarking
* code). Since OxenMQ's proxy<->worker communication channel already gives us thread that overhead
* code). Since LokiMQ's proxy<->worker communication channel already gives us thread that overhead
* would just be wasted.
*
* @tparam R the value type held by the result; must be default constructible. Note, however, that
@ -135,13 +128,13 @@ public:
void get() { if (exc) std::rethrow_exception(exc); }
};
/// Helper class used to set up batches of jobs to be scheduled via the oxenmq job handler.
/// Helper class used to set up batches of jobs to be scheduled via the lokimq job handler.
///
/// @tparam R - the return type of the individual jobs
///
template <typename R>
class Batch final : private detail::Batch {
friend class OxenMQ;
friend class LokiMQ;
public:
/// The completion function type, called after all jobs have finished.
using CompletionFunc = std::function<void(std::vector<job_result<R>> results)>;
@ -158,17 +151,16 @@ public:
Batch &operator=(const Batch&) = delete;
private:
std::vector<std::pair<std::function<R()>, int>> jobs;
std::vector<std::function<R()>> jobs;
std::vector<job_result<R>> results;
CompletionFunc complete;
std::size_t jobs_outstanding = 0;
int complete_in_thread = 0;
bool complete_in_proxy = false;
bool started = false;
bool tagged_thread_jobs = false;
void check_not_started() {
if (started)
throw std::logic_error("Cannot add jobs or completion function after starting a oxenmq::Batch!");
throw std::logic_error("Cannot add jobs or completion function after starting a lokimq::Batch!");
}
public:
@ -183,61 +175,39 @@ public:
/// available. The called function may throw exceptions (which will be propagated to the
/// completion function through the job_result values). There is no guarantee on the order of
/// invocation of the jobs.
///
/// \param job the callback
/// \param thread an optional TaggedThreadID indicating a thread in which this job must run
void add_job(std::function<R()> job, std::optional<TaggedThreadID> thread = std::nullopt) {
void add_job(std::function<R()> job) {
check_not_started();
if (thread && thread->_id == -1)
// There are some special case internal jobs where we allow this, but they use the
// private method below that doesn't have this check.
throw std::logic_error{"Cannot add a proxy thread batch job -- this makes no sense"};
add_job(std::move(job), thread ? thread->_id : 0);
jobs.emplace_back(std::move(job));
results.emplace_back();
jobs_outstanding++;
}
/// Sets the completion function to invoke after all jobs have finished. If this is not set
/// then jobs simply run and results are discarded.
///
/// \param comp - function to call when all jobs have finished
/// \param thread - optional tagged thread in which to schedule the completion job. If not
/// provided then the completion job is scheduled in the pool of batch job threads.
///
/// `thread` can be provided the value &OxenMQ::run_in_proxy to invoke the completion function
/// *IN THE PROXY THREAD* itself after all jobs have finished. Be very, very careful: this
/// should be a nearly trivial job that does not require any substantial CPU time and does not
/// block for any reason. This is only intended for the case where the completion job is so
/// trivial that it will take less time than simply queuing the job to be executed by another
/// thread.
void completion(CompletionFunc comp, std::optional<TaggedThreadID> thread = std::nullopt) {
void completion(CompletionFunc comp) {
check_not_started();
if (complete)
throw std::logic_error("Completion function can only be set once");
complete = std::move(comp);
complete_in_thread = thread ? thread->_id : 0;
}
/// Sets a completion function to invoke *IN THE PROXY THREAD* after all jobs have finished. Be
/// very, very careful: this should not be a job that takes any significant amount of CPU time
/// or can block for any reason (NO MUTEXES).
void completion_proxy(CompletionFunc comp) {
check_not_started();
if (complete)
throw std::logic_error("Completion function can only be set once");
complete = std::move(comp);
complete_in_proxy = true;
}
private:
void add_job(std::function<R()> job, int thread_id) {
jobs.emplace_back(std::move(job), thread_id);
results.emplace_back();
jobs_outstanding++;
if (thread_id != 0)
tagged_thread_jobs = true;
std::size_t size() const override {
return jobs.size();
}
std::pair<std::size_t, bool> size() const override {
return {jobs.size(), tagged_thread_jobs};
}
std::vector<int> threads() const override {
std::vector<int> t;
t.reserve(jobs.size());
for (auto& j : jobs)
t.push_back(j.second);
return t;
};
template <typename S = R>
void set_value(job_result<S>& r, std::function<S()>& f) { r.set_value(f()); }
void set_value(job_result<void>&, std::function<void()>& f) { f(); }
@ -246,7 +216,7 @@ private:
// called by worker thread
auto& r = results[i];
try {
set_value(r, jobs[i].first);
set_value(r, jobs[i]);
} catch (...) {
r.set_exception(std::current_exception());
}
@ -255,10 +225,12 @@ private:
detail::BatchStatus job_finished() override {
--jobs_outstanding;
if (jobs_outstanding)
return {detail::BatchState::running, 0};
return detail::BatchStatus::running;
if (complete)
return {detail::BatchState::complete, complete_in_thread};
return {detail::BatchState::done, 0};
return complete_in_proxy
? detail::BatchStatus::complete_proxy
: detail::BatchStatus::complete;
return detail::BatchStatus::done;
}
void job_completion() override {
@ -266,60 +238,14 @@ private:
}
};
// Similar to Batch<void>, but doesn't support a completion function and only handles a single task.
class Job final : private detail::Batch {
friend class OxenMQ;
public:
/// Constructs the Job to run a single task. Takes any callable invokable with no arguments and
/// having no return value. The task will be scheduled and run when the next worker thread is
/// available. Any exceptions thrown by the job will be caught and squelched (the exception
/// terminates/completes the job).
explicit Job(std::function<void()> f, std::optional<TaggedThreadID> thread = std::nullopt)
: Job{std::move(f), thread ? thread->_id : 0}
{
if (thread && thread->_id == -1)
// There are some special case internal jobs where we allow this, but they use the
// private ctor below that doesn't have this check.
throw std::logic_error{"Cannot add a proxy thread job -- this makes no sense"};
}
// movable
Job(Job&&) = default;
Job &operator=(Job&&) = default;
// non-copyable
Job(const Job&) = delete;
Job &operator=(const Job&) = delete;
private:
explicit Job(std::function<void()> f, int thread_id)
: job{std::move(f), thread_id} {}
std::pair<std::function<void()>, int> job;
bool done = false;
std::pair<size_t, bool> size() const override { return {1, job.second != 0}; }
std::vector<int> threads() const override { return {job.second}; }
void run_job(const int /*i*/) override {
try { job.first(); }
catch (...) {}
}
detail::BatchStatus job_finished() override { return {detail::BatchState::done, 0}; }
void job_completion() override {} // Never called because we return ::done (not ::complete) above.
};
template <typename R>
void OxenMQ::batch(Batch<R>&& batch) {
if (batch.size().first == 0)
void LokiMQ::batch(Batch<R>&& batch) {
if (batch.size() == 0)
throw std::logic_error("Cannot batch a a job batch with 0 jobs");
// Need to send this over to the proxy thread via the base class pointer. It assumes ownership.
auto* baseptr = static_cast<detail::Batch*>(new Batch<R>(std::move(batch)));
detail::send_control(get_control_socket(), "BATCH", oxenc::bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
}
}

233
lokimq/bt_serialize.cpp Normal file
View file

@ -0,0 +1,233 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "bt_serialize.h"
namespace lokimq {
namespace detail {
/// Reads digits into an unsigned 64-bit int.
uint64_t extract_unsigned(string_view& s) {
if (s.empty())
throw bt_deserialize_invalid{"Expected 0-9 but found end of string"};
if (s[0] < '0' || s[0] > '9')
throw bt_deserialize_invalid("Expected 0-9 but found '"s + s[0]);
uint64_t uval = 0;
while (!s.empty() && (s[0] >= '0' && s[0] <= '9')) {
uint64_t bigger = uval * 10 + (s[0] - '0');
s.remove_prefix(1);
if (bigger < uval) // overflow
throw bt_deserialize_invalid("Integer deserialization failed: value is too large for a 64-bit int");
uval = bigger;
}
return uval;
}
void bt_deserialize<string_view>::operator()(string_view& s, string_view& val) {
if (s.size() < 2) throw bt_deserialize_invalid{"Deserialize failed: given data is not an bt-encoded string"};
if (s[0] < '0' || s[0] > '9')
throw bt_deserialize_invalid_type{"Expected 0-9 but found '"s + s[0] + "'"};
auto len = static_cast<size_t>(extract_unsigned(s));
if (s.empty() || s[0] != ':')
throw bt_deserialize_invalid{"Did not find expected ':' during string deserialization"};
s.remove_prefix(1);
if (len > s.size())
throw bt_deserialize_invalid{"String deserialization failed: encoded string length is longer than the serialized data"};
val = {s.data(), len};
s.remove_prefix(len);
}
// Check that we are on a 2's complement architecture. It's highly unlikely that this code ever
// runs on a non-2s-complement architecture (especially since C++20 requires a two's complement
// signed value behaviour), but check at compile time anyway because we rely on these relations
// below.
static_assert(std::numeric_limits<int64_t>::min() + std::numeric_limits<int64_t>::max() == -1 &&
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()) + uint64_t{1} == (uint64_t{1} << 63),
"Non 2s-complement architecture not supported!");
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s) {
// Smallest possible encoded integer is 3 chars: "i0e"
if (s.size() < 3) throw bt_deserialize_invalid("Deserialization failed: end of string found where integer expected");
if (s[0] != 'i') throw bt_deserialize_invalid_type("Deserialization failed: expected 'i', found '"s + s[0] + '\'');
s.remove_prefix(1);
std::pair<maybe_signed_int64_t, bool> result;
if (s[0] == '-') {
result.second = true;
s.remove_prefix(1);
}
uint64_t uval = extract_unsigned(s);
if (s.empty())
throw bt_deserialize_invalid("Integer deserialization failed: encountered end of string before integer was finished");
if (s[0] != 'e')
throw bt_deserialize_invalid("Integer deserialization failed: expected digit or 'e', found '"s + s[0] + '\'');
s.remove_prefix(1);
if (result.second) { // negative
if (uval > (uint64_t{1} << 63))
throw bt_deserialize_invalid("Deserialization of integer failed: negative integer value is too large for a 64-bit signed int");
result.first.i64 = -uval;
} else {
result.first.u64 = uval;
}
return result;
}
template struct bt_deserialize<int64_t>;
template struct bt_deserialize<uint64_t>;
void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where bt-encoded value expected");
switch (s[0]) {
case 'd': {
bt_dict dict;
bt_deserialize<bt_dict>{}(s, dict);
val = std::move(dict);
break;
}
case 'l': {
bt_list list;
bt_deserialize<bt_list>{}(s, list);
val = std::move(list);
break;
}
case 'i': {
auto read = bt_deserialize_integer(s);
val = read.first.i64; // We only store an i64, but can get a u64 out of it via get<uint64_t>(val)
break;
}
case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': {
std::string str;
bt_deserialize<std::string>{}(s, str);
val = std::move(str);
break;
}
default:
throw bt_deserialize_invalid("Deserialize failed: encountered invalid value '"s + s[0] + "'; expected one of [0-9idl]");
}
}
} // namespace detail
bt_list_consumer::bt_list_consumer(string_view data_) : data{std::move(data_)} {
if (data.empty()) throw std::runtime_error{"Cannot create a bt_list_consumer with an empty string_view"};
if (data[0] != 'l') throw std::runtime_error{"Cannot create a bt_list_consumer with non-list data"};
data.remove_prefix(1);
}
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
/// value is not a string.
string_view bt_list_consumer::consume_string_view() {
if (data.empty())
throw bt_deserialize_invalid{"expected a string, but reached end of data"};
else if (!is_string())
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
string_view next{data}, result;
detail::bt_deserialize<string_view>{}(next, result);
data = next;
return result;
}
std::string bt_list_consumer::consume_string() {
return std::string{consume_string_view()};
}
/// Consumes a value without returning it.
void bt_list_consumer::skip_value() {
if (is_string())
consume_string_view();
else if (is_integer())
detail::bt_deserialize_integer(data);
else if (is_list())
consume_list_data();
else if (is_dict())
consume_dict_data();
else
throw bt_deserialize_invalid_type{"next bt value has unknown type"};
}
string_view bt_list_consumer::consume_list_data() {
auto start = data.begin();
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
data.remove_prefix(1); // Descend into the sublist, consume the "l"
while (!is_finished()) {
skip_value();
if (data.empty())
throw bt_deserialize_invalid{"bt list consumption failed: hit the end of string before the list was done"};
}
data.remove_prefix(1); // Back out from the sublist, consume the "e"
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
}
string_view bt_list_consumer::consume_dict_data() {
auto start = data.begin();
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
data.remove_prefix(1); // Descent into the dict, consumer the "d"
while (!is_finished()) {
consume_string_view(); // Key is always a string
if (!data.empty())
skip_value();
if (data.empty())
throw bt_deserialize_invalid{"bt dict consumption failed: hit the end of string before the dict was done"};
}
data.remove_prefix(1); // Back out of the dict, consume the "e"
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
}
bt_dict_consumer::bt_dict_consumer(string_view data_) {
data = std::move(data_);
if (data.empty()) throw std::runtime_error{"Cannot create a bt_dict_consumer with an empty string_view"};
if (data.size() < 2 || data[0] != 'd') throw std::runtime_error{"Cannot create a bt_dict_consumer with non-dict data"};
data.remove_prefix(1);
}
bool bt_dict_consumer::consume_key() {
if (key_.data())
return true;
if (data.empty()) throw bt_deserialize_invalid_type{"expected a key or dict end, found end of string"};
if (data[0] == 'e') return false;
key_ = bt_list_consumer::consume_string_view();
if (data.empty() || data[0] == 'e')
throw bt_deserialize_invalid{"dict key isn't followed by a value"};
return true;
}
std::pair<string_view, string_view> bt_dict_consumer::next_string() {
if (!is_string())
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
std::pair<string_view, string_view> ret;
ret.second = bt_list_consumer::consume_string_view();
ret.first = flush_key();
return ret;
}
} // namespace lokimq

829
lokimq/bt_serialize.h Normal file
View file

@ -0,0 +1,829 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <iostream>
#pragma once
#include <vector>
#include <list>
#include <unordered_map>
#include <algorithm>
#include <functional>
#include <cstring>
#include <ostream>
#include <sstream>
#include "string_view.h"
#include "mapbox/variant.hpp"
namespace lokimq {
using namespace std::literals;
/** \file
* LokiMQ serialization for internal commands is very simple: we support two primitive types,
* strings and integers, and two container types, lists and dicts with string keys. On the wire
* these go in BitTorrent byte encoding as described in BEP-0003
* (https://www.bittorrent.org/beps/bep_0003.html#bencoding).
*
* On the C++ side, on input we allow strings, integral types, STL-like containers of these types,
* and STL-like containers of pairs with a string first value and any of these types as second
* value. We also accept std::variants (if compiled with std::variant support, i.e. in C++17 mode)
* that contain any of these, and mapbox::util::variants (the internal type used for its recursive
* support).
*
* One minor deviation from BEP-0003 is that we don't support serializing values that don't fit in a
* 64-bit integer (BEP-0003 specifies arbitrary precision integers).
*
* On deserialization we can either deserialize into a mapbox::util::variant that supports everything, or
* we can fill a container of your given type (though this fails if the container isn't compatible
* with the deserialized data).
*/
/// Exception throw if deserialization fails
class bt_deserialize_invalid : public std::invalid_argument {
using std::invalid_argument::invalid_argument;
};
/// A more specific subclass that is thown if the serialization type is an initial mismatch: for
/// example, trying deserializing an int but the next thing in input is a list. This is not,
/// however, thrown if the type initially looks fine but, say, a nested serialization fails. This
/// error will only be thrown when the input stream has not been advanced (and so can be tried for a
/// different type).
class bt_deserialize_invalid_type : public bt_deserialize_invalid {
using bt_deserialize_invalid::bt_deserialize_invalid;
};
class bt_list;
class bt_dict;
/// Recursive generic type that can fully represent everything valid for a BT serialization.
using bt_value = mapbox::util::variant<
std::string,
string_view,
int64_t,
mapbox::util::recursive_wrapper<bt_list>,
mapbox::util::recursive_wrapper<bt_dict>
>;
/// Very thin wrapper around a std::list<bt_value> that holds a list of generic values (though *any*
/// compatible data type can be used).
class bt_list : public std::list<bt_value> {
using std::list<bt_value>::list;
};
/// Very thin wrapper around a std::unordered_map<bt_value> that holds a list of string -> generic
/// value pairs (though *any* compatible data type can be used).
class bt_dict : public std::unordered_map<std::string, bt_value> {
using std::unordered_map<std::string, bt_value>::unordered_map;
};
#ifdef __cpp_lib_void_t
using std::void_t;
#else
/// C++17 void_t backport
template <typename... Ts> struct void_t_impl { using type = void; };
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
#endif
namespace detail {
/// Reads digits into an unsigned 64-bit int.
uint64_t extract_unsigned(string_view& s);
inline uint64_t extract_unsigned(string_view&& s) { return extract_unsigned(s); }
// Fallback base case; we only get here if none of the partial specializations below work
template <typename T, typename SFINAE = void>
struct bt_serialize { static_assert(!std::is_same<T, T>::value, "Cannot serialize T: unsupported type for bt serialization"); };
template <typename T, typename SFINAE = void>
struct bt_deserialize { static_assert(!std::is_same<T, T>::value, "Cannot deserialize T: unsupported type for bt deserialization"); };
/// Checks that we aren't at the end of a string view and throws if we are.
inline void bt_need_more(const string_view &s) {
if (s.empty())
throw bt_deserialize_invalid{"Unexpected end of string while deserializing"};
}
union maybe_signed_int64_t { int64_t i64; uint64_t u64; };
/// Deserializes a signed or unsigned 64-bit integer from a string. Sets the second bool to true
/// iff the value is int64_t because a negative value was read. Throws an exception if the read
/// value doesn't fit in a int64_t (if negative) or a uint64_t (if positive). Removes consumed
/// characters from the string_view.
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s);
/// Integer specializations
template <typename T>
struct bt_serialize<T, std::enable_if_t<std::is_integral<T>::value>> {
static_assert(sizeof(T) <= sizeof(uint64_t), "Serialization of integers larger than uint64_t is not supported");
void operator()(std::ostream &os, const T &val) {
// Cast 1-byte types to a larger type to avoid iostream interpreting them as single characters
using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t<std::is_signed<T>::value, int, unsigned>>;
os << 'i' << static_cast<output_type>(val) << 'e';
}
};
template <typename T>
struct bt_deserialize<T, std::enable_if_t<std::is_integral<T>::value>> {
void operator()(string_view& s, T &val) {
constexpr uint64_t umax = static_cast<uint64_t>(std::numeric_limits<T>::max());
constexpr int64_t smin = static_cast<int64_t>(std::numeric_limits<T>::min()),
smax = static_cast<int64_t>(std::numeric_limits<T>::max());
auto read = bt_deserialize_integer(s);
if (std::is_signed<T>::value) {
if (!read.second) { // read a positive value
if (read.first.u64 > umax)
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
val = static_cast<T>(read.first.u64);
} else {
bool oob = read.first.i64 < smin || read.first.i64 > smax;
if (sizeof(T) < sizeof(int64_t) && oob)
throw bt_deserialize_invalid("Integer deserialization failed: found out-of-range value " + std::to_string(read.first.i64) + " not in [" + std::to_string(smin) + "," + std::to_string(smax) + "]");
val = static_cast<T>(read.first.i64);
}
} else {
if (read.second)
throw bt_deserialize_invalid("Integer deserialization failed: found negative value " + std::to_string(read.first.i64) + " but type is unsigned");
if (sizeof(T) < sizeof(uint64_t) && read.first.u64 > umax)
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
val = static_cast<T>(read.first.u64);
}
}
};
extern template struct bt_deserialize<int64_t>;
extern template struct bt_deserialize<uint64_t>;
template <>
struct bt_serialize<string_view> {
void operator()(std::ostream &os, const string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); }
};
template <>
struct bt_deserialize<string_view> {
void operator()(string_view& s, string_view& val);
};
/// String specialization
template <>
struct bt_serialize<std::string> {
void operator()(std::ostream &os, const std::string &val) { bt_serialize<string_view>{}(os, val); }
};
template <>
struct bt_deserialize<std::string> {
void operator()(string_view& s, std::string& val) { string_view view; bt_deserialize<string_view>{}(s, view); val = {view.data(), view.size()}; }
};
/// char * and string literals -- we allow serialization for convenience, but not deserialization
template <>
struct bt_serialize<char *> {
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, std::strlen(str)}); }
};
template <size_t N>
struct bt_serialize<char[N]> {
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, N-1}); }
};
/// Partial dict validity; we don't check the second type for serializability, that will be handled
/// via the base case static_assert if invalid.
template <typename T, typename = void> struct is_bt_input_dict_container : std::false_type {};
template <typename T>
struct is_bt_input_dict_container<T, std::enable_if_t<
std::is_same<std::string, std::remove_cv_t<typename T::value_type::first_type>>::value,
void_t<typename T::const_iterator /* is const iterable */,
typename T::value_type::second_type /* has a second type */>>>
: std::true_type {};
/// Determines whether the type looks like something we can insert into (using `v.insert(v.end(), x)`)
template <typename T, typename = void> struct is_bt_insertable : std::false_type {};
template <typename T>
struct is_bt_insertable<T,
void_t<decltype(std::declval<T>().insert(std::declval<T>().end(), std::declval<typename T::value_type>()))>>
: std::true_type {};
/// Determines whether the given type looks like a compatible map (i.e. has std::string keys) that
/// we can insert into.
template <typename T, typename = void> struct is_bt_output_dict_container : std::false_type {};
template <typename T>
struct is_bt_output_dict_container<T, std::enable_if_t<
std::is_same<std::string, std::remove_cv_t<typename T::key_type>>::value &&
is_bt_insertable<T>::value,
void_t<typename T::value_type::second_type /* has a second type */>>>
: std::true_type {};
/// Specialization for a dict-like container (such as an unordered_map). We accept anything for a
/// dict that is const iterable over something that looks like a pair with std::string for first
/// value type. The value (i.e. second element of the pair) also must be serializable.
template <typename T>
struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>::value>> {
using second_type = typename T::value_type::second_type;
using ref_pair = std::reference_wrapper<const typename T::value_type>;
void operator()(std::ostream &os, const T &dict) {
os << 'd';
std::vector<ref_pair> pairs;
pairs.reserve(dict.size());
for (const auto &pair : dict)
pairs.emplace(pairs.end(), pair);
std::sort(pairs.begin(), pairs.end(), [](ref_pair a, ref_pair b) { return a.get().first < b.get().first; });
for (auto &ref : pairs) {
bt_serialize<std::string>{}(os, ref.get().first);
bt_serialize<second_type>{}(os, ref.get().second);
}
os << 'e';
}
};
template <typename T>
struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>::value>> {
using second_type = typename T::value_type::second_type;
void operator()(string_view& s, T& dict) {
// Smallest dict is 2 bytes "de", for an empty dict.
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where dict expected");
if (s[0] != 'd') throw bt_deserialize_invalid_type("Deserialization failed: expected 'd', found '"s + s[0] + "'"s);
s.remove_prefix(1);
dict.clear();
bt_deserialize<std::string> key_deserializer;
bt_deserialize<second_type> val_deserializer;
while (!s.empty() && s[0] != 'e') {
std::string key;
second_type val;
key_deserializer(s, key);
val_deserializer(s, val);
dict.insert(dict.end(), typename T::value_type{std::move(key), std::move(val)});
}
if (s.empty())
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before dict was finished");
s.remove_prefix(1); // Consume the 'e'
}
};
/// Accept anything that looks iterable; value serialization validity isn't checked here (it fails
/// via the base case static assert).
template <typename T, typename = void> struct is_bt_input_list_container : std::false_type {};
template <typename T>
struct is_bt_input_list_container<T, std::enable_if_t<
!std::is_same<T, std::string>::value &&
!is_bt_input_dict_container<T>::value,
void_t<typename T::const_iterator, typename T::value_type>>>
: std::true_type {};
template <typename T, typename = void> struct is_bt_output_list_container : std::false_type {};
template <typename T>
struct is_bt_output_list_container<T, std::enable_if_t<
!std::is_same<T, std::string>::value &&
!is_bt_output_dict_container<T>::value &&
is_bt_insertable<T>::value>>
: std::true_type {};
/// List specialization
template <typename T>
struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>::value>> {
void operator()(std::ostream& os, const T& list) {
os << 'l';
for (const auto &v : list)
bt_serialize<std::remove_cv_t<typename T::value_type>>{}(os, v);
os << 'e';
}
};
template <typename T>
struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>::value>> {
using value_type = typename T::value_type;
void operator()(string_view& s, T& list) {
// Smallest list is 2 bytes "le", for an empty list.
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where list expected");
if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization failed: expected 'l', found '"s + s[0] + "'"s);
s.remove_prefix(1);
list.clear();
bt_deserialize<value_type> deserializer;
while (!s.empty() && s[0] != 'e') {
value_type v;
deserializer(s, v);
list.insert(list.end(), std::move(v));
}
if (s.empty())
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before list was finished");
s.remove_prefix(1); // Consume the 'e'
}
};
/// variant visitor; serializes whatever is contained
class bt_serialize_visitor {
std::ostream &os;
public:
using result_type = void;
bt_serialize_visitor(std::ostream &os) : os{os} {}
template <typename T> void operator()(const T &val) const {
bt_serialize<T>{}(os, val);
}
};
template <typename T>
using is_bt_deserializable = std::integral_constant<bool,
std::is_same<T, std::string>::value || std::is_integral<T>::value ||
is_bt_output_dict_container<T>::value || is_bt_output_list_container<T>::value>;
// General template and base case; this base will only actually be invoked when Ts... is empty,
// which means we reached the end without finding any variant type capable of holding the value.
template <typename SFINAE, typename Variant, typename... Ts>
struct bt_deserialize_try_variant_impl {
void operator()(string_view&, Variant&) {
throw bt_deserialize_invalid("Deserialization failed: could not deserialize value into any variant type");
}
};
template <typename... Ts, typename Variant>
void bt_deserialize_try_variant(string_view& s, Variant& variant) {
bt_deserialize_try_variant_impl<void, Variant, Ts...>{}(s, variant);
}
template <typename Variant, typename T, typename... Ts>
struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>::value>, Variant, T, Ts...> {
void operator()(string_view& s, Variant& variant) {
if ( is_bt_output_list_container<T>::value ? s[0] == 'l' :
is_bt_output_dict_container<T>::value ? s[0] == 'd' :
std::is_integral<T>::value ? s[0] == 'i' :
std::is_same<T, std::string>::value ? s[0] >= '0' && s[0] <= '9' :
false) {
T val;
bt_deserialize<T>{}(s, val);
variant = std::move(val);
} else {
bt_deserialize_try_variant<Ts...>(s, variant);
}
}
};
template <typename Variant, typename T, typename... Ts>
struct bt_deserialize_try_variant_impl<std::enable_if_t<!is_bt_deserializable<T>::value>, Variant, T, Ts...> {
void operator()(string_view& s, Variant& variant) {
// Unsupported deserialization type, skip it
bt_deserialize_try_variant<Ts...>(s, variant);
}
};
template <>
struct bt_deserialize<bt_value, void> {
void operator()(string_view& s, bt_value& val);
};
template <typename... Ts>
struct bt_serialize<mapbox::util::variant<Ts...>> {
void operator()(std::ostream& os, const mapbox::util::variant<Ts...>& val) {
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
}
};
template <typename... Ts>
struct bt_deserialize<mapbox::util::variant<Ts...>> {
void operator()(string_view& s, mapbox::util::variant<Ts...>& val) {
bt_deserialize_try_variant<Ts...>(s, val);
}
};
#ifdef __cpp_lib_variant
/// C++17 std::variant support
template <typename... Ts>
struct bt_serialize<std::variant<Ts...>> {
void operator()(std::ostream &os, const std::variant<Ts...>& val) {
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
}
};
template <typename... Ts>
struct bt_deserialize<std::variant<Ts...>> {
void operator()(string_view& s, std::variant<Ts...>& val) {
bt_deserialize_try_variant<Ts...>(s, val);
}
};
#endif
template <typename T>
struct bt_stream_serializer {
const T &val;
explicit bt_stream_serializer(const T &val) : val{val} {}
operator std::string() const {
std::ostringstream oss;
oss << *this;
return oss.str();
}
};
template <typename T>
std::ostream &operator<<(std::ostream &os, const bt_stream_serializer<T> &s) {
bt_serialize<T>{}(os, s.val);
return os;
}
} // namespace detail
/// Returns a wrapper around a value reference that can serialize the value directly to an output
/// stream. This class is intended to be used inline (i.e. without being stored) as in:
///
/// std::list<int> my_list{{1,2,3}};
/// std::cout << bt_serializer(my_list);
///
/// While it is possible to store the returned object and use it, such as:
///
/// auto encoded = bt_serializer(42);
/// std::cout << encoded;
///
/// this approach is not generally recommended: the returned object stores a reference to the
/// passed-in type, which may not survive. If doing this note that it is the caller's
/// responsibility to ensure the serializer is not used past the end of the lifetime of the value
/// being serialized.
///
/// Also note that serializing directly to an output stream is more efficient as no intermediate
/// string containing the entire serialization has to be constructed.
///
template <typename T>
detail::bt_stream_serializer<T> bt_serializer(const T &val) { return detail::bt_stream_serializer<T>{val}; }
/// Serializes the given value into a std::string.
///
/// int number = 42;
/// std::string encoded = bt_serialize(number);
/// // Equivalent:
/// //auto encoded = (std::string) bt_serialize(number);
///
/// This takes any serializable type: integral types, strings, lists of serializable types, and
/// string->value maps of serializable types.
template <typename T>
std::string bt_serialize(const T &val) { return bt_serializer(val); }
/// Deserializes the given string view directly into `val`. Usage:
///
/// std::string encoded = "i42e";
/// int value;
/// bt_deserialize(encoded, value); // Sets value to 42
///
template <typename T, std::enable_if_t<!std::is_const<T>::value, int> = 0>
void bt_deserialize(string_view s, T& val) {
return detail::bt_deserialize<T>{}(s, val);
}
/// Deserializes the given string_view into a `T`, which is returned.
///
/// std::string encoded = "li1ei2ei3ee"; // bt-encoded list of ints: [1,2,3]
/// auto mylist = bt_deserialize<std::list<int>>(encoded);
///
template <typename T>
T bt_deserialize(string_view s) {
T val;
bt_deserialize(s, val);
return val;
}
/// Deserializes the given value into a generic `bt_value` type (mapbox::util::variant) which is capable
/// of holding all possible BT-encoded values (including recursion).
///
/// Example:
///
/// std::string encoded = "i42e";
/// auto val = bt_get(encoded);
/// int v = get_int<int>(val); // fails unless the encoded value was actually an integer that
/// // fits into an `int`
///
inline bt_value bt_get(string_view s) {
return bt_deserialize<bt_value>(s);
}
/// Helper functions to extract a value of some integral type from a bt_value which contains an
/// integer. Does range checking, throwing std::overflow_error if the stored value is outside the
/// range of the target type.
///
/// Example:
///
/// std::string encoded = "i123456789e";
/// auto val = bt_get(encoded);
/// auto v = get_int<uint32_t>(val); // throws if the decoded value doesn't fit in a uint32_t
template <typename IntType, std::enable_if_t<std::is_integral<IntType>::value, int> = 0>
IntType get_int(const bt_value &v) {
// It's highly unlikely that this code ever runs on a non-2s-complement architecture, but check
// at compile time if converting to a uint64_t (because while int64_t -> uint64_t is
// well-defined, uint64_t -> int64_t only does the right thing under 2's complement).
static_assert(!std::is_unsigned<IntType>::value || sizeof(IntType) != sizeof(int64_t) || -1 == ~0,
"Non 2s-complement architecture not supported!");
int64_t value = mapbox::util::get<int64_t>(v);
if (sizeof(IntType) < sizeof(int64_t)) {
if (value > static_cast<int64_t>(std::numeric_limits<IntType>::max())
|| value < static_cast<int64_t>(std::numeric_limits<IntType>::min()))
throw std::overflow_error("Unable to extract integer value: stored value is outside the range of the requested type");
}
return static_cast<IntType>(value);
}
/// Class that allows you to walk through a bt-encoded list in memory without copying or allocating
/// memory. It accesses existing memory directly and so the caller must ensure that the referenced
/// memory stays valid for the lifetime of the bt_list_consumer object.
class bt_list_consumer {
protected:
string_view data;
bt_list_consumer() = default;
public:
bt_list_consumer(string_view data_);
/// Copy constructor. Making a copy copies the current position so can be used for multipass
/// iteration through a list.
bt_list_consumer(const bt_list_consumer&) = default;
bt_list_consumer& operator=(const bt_list_consumer&) = default;
/// Returns true if the next value indicates the end of the list
bool is_finished() const { return data.front() == 'e'; }
/// Returns true if the next element looks like an encoded string
bool is_string() const { return data.front() >= '0' && data.front() <= '9'; }
/// Returns true if the next element looks like an encoded integer
bool is_integer() const { return data.front() == 'i'; }
/// Returns true if the next element looks like an encoded list
bool is_list() const { return data.front() == 'l'; }
/// Returns true if the next element looks like an encoded dict
bool is_dict() const { return data.front() == 'd'; }
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
/// value is not a string.
std::string consume_string();
string_view consume_string_view();
/// Attempts to parse the next value as an integer (and advance just past it). Throws if the
/// next value is not an integer.
template <typename IntType>
IntType consume_integer() {
if (!is_integer()) throw bt_deserialize_invalid_type{"next value is not an integer"};
string_view next{data};
IntType ret;
detail::bt_deserialize<IntType>{}(next, ret);
data = next;
return ret;
}
/// Consumes a list, return it as a list-like type. This typically requires dynamic allocation,
/// but only has to parse the data once. Compare with consume_list_data() which allows
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
template <typename T = bt_list>
T consume_list() {
T list;
consume_list(list);
return list;
}
/// Same as above, but takes a pre-existing list-like data type.
template <typename T>
void consume_list(T& list) {
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
string_view n{data};
detail::bt_deserialize<T>{}(n, list);
data = n;
}
/// Consumes a dict, return it as a dict-like type. This typically requires dynamic allocation,
/// but only has to parse the data once. Compare with consume_dict_data() which allows
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
template <typename T = bt_dict>
T consume_dict() {
T dict;
consume_dict(dict);
return dict;
}
/// Same as above, but takes a pre-existing dict-like data type.
template <typename T>
void consume_dict(T& dict) {
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
string_view n{data};
detail::bt_deserialize<T>{}(n, dict);
data = n;
}
/// Consumes a value without returning it.
void skip_value();
/// Attempts to parse the next value as a list and returns the string_view that contains the
/// entire thing. This is recursive into both lists and dicts and likely to be quite
/// inefficient for large, nested structures (unless the values only need to be skipped but
/// aren't separately needed). This, however, does not require dynamic memory allocation.
string_view consume_list_data();
/// Attempts to parse the next value as a dict and returns the string_view that contains the
/// entire thing. This is recursive into both lists and dicts and likely to be quite
/// inefficient for large, nested structures (unless the values only need to be skipped but
/// aren't separately needed). This, however, does not require dynamic memory allocation.
string_view consume_dict_data();
};
/// Class that allows you to walk through key-value pairs of a bt-encoded dict in memory without
/// copying or allocating memory. It accesses existing memory directly and so the caller must
/// ensure that the referenced memory stays valid for the lifetime of the bt_dict_consumer object.
class bt_dict_consumer : private bt_list_consumer {
string_view key_;
/// Consume the key if not already consumed and there is a key present (rather than 'e').
/// Throws exception if what should be a key isn't a string, or if the key consumes the entire
/// data (i.e. requires that it be followed by something). Returns true if the key was consumed
/// (either now or previously and cached).
bool consume_key();
/// Clears the cached key and returns it. Must have already called consume_key directly or
/// indirectly via one of the `is_{...}` methods.
string_view flush_key() {
string_view k;
k.swap(key_);
return k;
}
public:
bt_dict_consumer(string_view data_);
/// Copy constructor. Making a copy copies the current position so can be used for multipass
/// iteration through a list.
bt_dict_consumer(const bt_dict_consumer&) = default;
bt_dict_consumer& operator=(const bt_dict_consumer&) = default;
/// Returns true if the next value indicates the end of the dict
bool is_finished() { return !consume_key() && data.front() == 'e'; }
/// Operator bool is an alias for `!is_finished()`
operator bool() { return !is_finished(); }
/// Returns true if the next value looks like an encoded string
bool is_string() { return consume_key() && data.front() >= '0' && data.front() <= '9'; }
/// Returns true if the next element looks like an encoded integer
bool is_integer() { return consume_key() && data.front() == 'i'; }
/// Returns true if the next element looks like an encoded list
bool is_list() { return consume_key() && data.front() == 'l'; }
/// Returns true if the next element looks like an encoded dict
bool is_dict() { return consume_key() && data.front() == 'd'; }
/// Returns the key of the next pair. This does not have to be called; it is also returned by
/// all of the other consume_* methods. The value is cached whether called here or by some
/// other method; accessing it multiple times simple accesses the cache until the next value is
/// consumed.
string_view key() {
if (!consume_key())
throw bt_deserialize_invalid{"Cannot access next key: at the end of the dict"};
return key_;
}
/// Attempt to parse the next value as a string->string pair (and advance just past it). Throws
/// if the next value is not a string.
std::pair<string_view, string_view> next_string();
/// Attempts to parse the next value as an string->integer pair (and advance just past it).
/// Throws if the next value is not an integer.
template <typename IntType>
std::pair<string_view, IntType> next_integer() {
if (!is_integer()) throw bt_deserialize_invalid_type{"next bt dict value is not an integer"};
std::pair<string_view, IntType> ret;
ret.second = bt_list_consumer::consume_integer<IntType>();
ret.first = flush_key();
return ret;
}
/// Consumes a string->list pair, return it as a list-like type. This typically requires
/// dynamic allocation, but only has to parse the data once. Compare with consume_list_data()
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
/// used).
template <typename T = bt_list>
std::pair<string_view, T> next_list() {
std::pair<string_view, T> pair;
pair.first = consume_list(pair.second);
return pair;
}
/// Same as above, but takes a pre-existing list-like data type. Returns the key.
template <typename T>
string_view next_list(T& list) {
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
bt_list_consumer::consume_list(list);
return flush_key();
}
/// Consumes a string->dict pair, return it as a dict-like type. This typically requires
/// dynamic allocation, but only has to parse the data once. Compare with consume_dict_data()
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
/// used).
template <typename T = bt_dict>
std::pair<string_view, T> next_dict() {
std::pair<string_view, T> pair;
pair.first = consume_dict(pair.second);
return pair;
}
/// Same as above, but takes a pre-existing dict-like data type. Returns the key.
template <typename T>
string_view next_dict(T& dict) {
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
bt_list_consumer::consume_dict(dict);
return flush_key();
}
/// Attempts to parse the next value as a string->list pair and returns the string_view that
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
/// quite inefficient for large, nested structures (unless the values only need to be skipped
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
std::pair<string_view, string_view> next_list_data() {
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt dict value is not a list"};
return {flush_key(), bt_list_consumer::consume_list_data()};
}
/// Same as next_list_data(), but wraps the value in a bt_list_consumer for convenience
std::pair<string_view, bt_list_consumer> next_list_consumer() { return next_list_data(); }
/// Attempts to parse the next value as a string->dict pair and returns the string_view that
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
/// quite inefficient for large, nested structures (unless the values only need to be skipped
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
std::pair<string_view, string_view> next_dict_data() {
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt dict value is not a dict"};
return {flush_key(), bt_list_consumer::consume_dict_data()};
}
/// Same as next_dict_data(), but wraps the value in a bt_dict_consumer for convenience
std::pair<string_view, bt_dict_consumer> next_dict_consumer() { return next_dict_data(); }
/// Skips ahead until we find the first key >= the given key or reach the end of the dict.
/// Returns true if we found an exact match, false if we reached some greater value or the end.
/// If we didn't hit the end, the next `consumer_*()` call will return the key-value pair we
/// found (either the exact match or the first key greater than the requested key).
///
/// Two important notes:
///
/// - properly encoded bt dicts must have lexicographically sorted keys, and this method assumes
/// that the input is correctly sorted (and thus if we find a greater value then your key does
/// not exist).
/// - this is irreversible; you cannot returned to skipped values without reparsing. (You *can*
/// however, make a copy of the bt_dict_consumer before calling and use the copy to return to
/// the pre-skipped position).
bool skip_until(string_view find) {
while (consume_key() && key_ < find) {
flush_key();
skip_value();
}
return key_ == find;
}
/// The `consume_*` functions are wrappers around next_whatever that discard the returned key.
///
/// Intended for use with skip_until such as:
///
/// std::string value;
/// if (d.skip_until("key"))
/// value = d.consume_string();
///
auto consume_string_view() { return next_string().second; }
auto consume_string() { return std::string{consume_string_view()}; }
template <typename IntType>
auto consume_integer() { return next_integer<IntType>().second; }
template <typename T = bt_list>
auto consume_list() { return next_list<T>().second; }
template <typename T>
void consume_list(T& list) { next_list(list); }
template <typename T = bt_dict>
auto consume_dict() { return next_dict<T>().second; }
template <typename T>
void consume_dict(T& dict) { next_dict(dict); }
string_view consume_list_data() { return next_list_data().second; }
string_view consume_dict_data() { return next_dict_data().second; }
bt_list_consumer consume_list_consumer() { return consume_list_data(); }
bt_dict_consumer consume_dict_consumer() { return consume_dict_data(); }
};
} // namespace lokimq

339
lokimq/connections.cpp Normal file
View file

@ -0,0 +1,339 @@
#include "lokimq.h"
#include "lokimq-internal.h"
#include "hex.h"
namespace lokimq {
std::ostream& operator<<(std::ostream& o, const ConnectionID& conn) {
if (!conn.pk.empty())
return o << (conn.sn() ? "SN " : "non-SN authenticated remote ") << to_hex(conn.pk);
else
return o << "unauthenticated remote [" << conn.id << "]";
}
namespace {
void add_pollitem(std::vector<zmq::pollitem_t>& pollitems, zmq::socket_t& sock) {
pollitems.emplace_back();
auto &p = pollitems.back();
p.socket = static_cast<void *>(sock);
p.fd = 0;
p.events = ZMQ_POLLIN;
}
} // anonymous namespace
void LokiMQ::rebuild_pollitems() {
pollitems.clear();
add_pollitem(pollitems, command);
add_pollitem(pollitems, workers_socket);
add_pollitem(pollitems, zap_auth);
for (auto& s : connections)
add_pollitem(pollitems, s);
}
void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey) {
if (!remote_pubkey.empty()) {
socket.setsockopt(ZMQ_CURVE_SERVERKEY, remote_pubkey.data(), remote_pubkey.size());
socket.setsockopt(ZMQ_CURVE_PUBLICKEY, pubkey.data(), pubkey.size());
socket.setsockopt(ZMQ_CURVE_SECRETKEY, privkey.data(), privkey.size());
}
socket.setsockopt(ZMQ_HANDSHAKE_IVL, (int) HANDSHAKE_TIME.count());
socket.setsockopt<int64_t>(ZMQ_MAXMSGSIZE, MAX_MSG_SIZE);
if (PUBKEY_BASED_ROUTING_ID) {
std::string routing_id;
routing_id.reserve(33);
routing_id += 'L'; // Prefix because routing id's starting with \0 are reserved by zmq (and our pubkey might start with \0)
routing_id.append(pubkey.begin(), pubkey.end());
socket.setsockopt(ZMQ_ROUTING_ID, routing_id.data(), routing_id.size());
}
// else let ZMQ pick a random one
}
std::pair<zmq::socket_t *, std::string>
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, std::chrono::milliseconds keep_alive) {
ConnectionID remote_cid{remote};
auto its = peers.equal_range(remote_cid);
peer_info* peer = nullptr;
for (auto it = its.first; it != its.second; ++it) {
if (incoming_only && it->second.route.empty())
continue; // outgoing connection but we were asked to only use incoming connections
peer = &it->second;
break;
}
if (peer) {
LMQ_TRACE("proxy asked to connect to ", to_hex(remote), "; reusing existing connection");
if (peer->route.empty() /* == outgoing*/) {
if (peer->idle_expiry < keep_alive) {
LMQ_LOG(debug, "updating existing outgoing peer connection idle expiry time from ",
peer->idle_expiry.count(), "ms to ", keep_alive.count(), "ms");
peer->idle_expiry = keep_alive;
}
peer->activity();
}
return {&connections[peer->conn_index], peer->route};
} else if (optional || incoming_only) {
LMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt");
return {nullptr, ""s};
}
// No connection so establish a new one
LMQ_LOG(debug, "proxy establishing new outbound connection to ", to_hex(remote));
std::string addr;
bool to_self = false && remote == pubkey; // FIXME; need to use a separate listening socket for this, otherwise we can't easily
// tell it wasn't from a remote.
if (to_self) {
// special inproc connection if self that doesn't need any external connection
addr = SN_ADDR_SELF;
} else {
addr = std::string{connect_hint};
if (addr.empty())
addr = sn_lookup(remote);
else
LMQ_LOG(debug, "using connection hint ", connect_hint);
if (addr.empty()) {
LMQ_LOG(error, "peer lookup failed for ", to_hex(remote));
return {nullptr, ""s};
}
}
LMQ_LOG(debug, to_hex(pubkey), " (me) connecting to ", addr, " to reach ", to_hex(remote));
zmq::socket_t socket{context, zmq::socket_type::dealer};
setup_outgoing_socket(socket, remote);
socket.connect(addr);
peer_info p{};
p.service_node = true;
p.pubkey = std::string{remote};
p.conn_index = connections.size();
p.idle_expiry = keep_alive;
p.activity();
conn_index_to_id.push_back(remote_cid);
peers.emplace(std::move(remote_cid), std::move(p));
connections.push_back(std::move(socket));
return {&connections.back(), ""s};
}
std::pair<zmq::socket_t *, std::string> LokiMQ::proxy_connect_sn(bt_dict_consumer data) {
string_view hint, remote_pk;
std::chrono::milliseconds keep_alive;
bool optional = false, incoming_only = false;
// Alphabetical order
if (data.skip_until("hint"))
hint = data.consume_string();
if (data.skip_until("incoming"))
incoming_only = data.consume_integer<bool>();
if (data.skip_until("keep_alive"))
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
if (data.skip_until("optional"))
optional = data.consume_integer<bool>();
if (!data.skip_until("pubkey"))
throw std::runtime_error("Internal error: Invalid proxy_connect_sn command; pubkey missing");
remote_pk = data.consume_string();
return proxy_connect_sn(remote_pk, hint, optional, incoming_only, keep_alive);
}
template <typename Container, typename AccessIndex>
void update_connection_indices(Container& c, size_t index, AccessIndex get_index) {
for (auto it = c.begin(); it != c.end(); ) {
size_t& i = get_index(*it);
if (index == i) {
it = c.erase(it);
continue;
}
if (i > index)
--i;
++it;
}
}
/// Closes outgoing connections and removes all references. Note that this will invalidate
/// iterators on the various connection containers - if you don't want that, delete it first so that
/// the container won't contain the element being deleted.
void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) {
connections[index].setsockopt<int>(ZMQ_LINGER, linger > 0ms ? linger.count() : 0);
pollitems_stale = true;
connections.erase(connections.begin() + index);
LMQ_LOG(debug, "Closing conn index ", index);
update_connection_indices(peers, index,
[](auto& p) -> size_t& { return p.second.conn_index; });
update_connection_indices(pending_connects, index,
[](auto& pc) -> size_t& { return std::get<size_t>(pc); });
update_connection_indices(bind, index,
[](auto& b) -> size_t& { return b.second.index; });
update_connection_indices(incoming_conn_index, index,
[](auto& oci) -> size_t& { return oci.second; });
assert(index < conn_index_to_id.size());
conn_index_to_id.erase(conn_index_to_id.begin() + index);
}
void LokiMQ::proxy_expire_idle_peers() {
for (auto it = peers.begin(); it != peers.end(); ) {
auto &info = it->second;
if (info.outgoing()) {
auto idle = info.last_activity - std::chrono::steady_clock::now();
if (idle <= info.idle_expiry) {
++it;
continue;
}
LMQ_LOG(info, "Closing outgoing connection to ", it->first, ": idle timeout reached");
proxy_close_connection(info.conn_index, CLOSE_LINGER);
} else {
++it;
}
}
}
void LokiMQ::proxy_conn_cleanup() {
LMQ_TRACE("starting proxy connections cleanup");
// Drop idle connections (if we haven't done it in a while) but *only* if we have some idle
// general workers: if we don't have any idle workers then we may still have incoming messages which
// we haven't processed yet and those messages might end up resetting the last activity time.
if (static_cast<int>(workers.size()) < general_workers) {
LMQ_TRACE("closing idle connections");
proxy_expire_idle_peers();
}
auto now = std::chrono::steady_clock::now();
// FIXME - check other outgoing connections to see if they died and if so purge them
LMQ_TRACE("Timing out pending outgoing connections");
// Check any pending outgoing connections for timeout
for (auto it = pending_connects.begin(); it != pending_connects.end(); ) {
auto& pc = *it;
if (std::get<std::chrono::steady_clock::time_point>(pc) < now) {
job([cid = ConnectionID{std::get<long long>(pc)}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); });
it = pending_connects.erase(it); // Don't let the below erase it (because it invalidates iterators)
proxy_close_connection(std::get<size_t>(pc), CLOSE_LINGER);
} else {
++it;
}
}
LMQ_TRACE("Timing out pending requests");
// Remove any expired pending requests and schedule their callback with a failure
for (auto it = pending_requests.begin(); it != pending_requests.end(); ) {
auto& callback = it->second;
if (callback.first < now) {
LMQ_LOG(debug, "pending request ", to_hex(it->first), " expired, invoking callback with failure status and removing");
job([callback = std::move(callback.second)] { callback(false, {}); });
it = pending_requests.erase(it);
} else {
++it;
}
}
LMQ_TRACE("done proxy connections cleanup");
};
void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
AuthLevel auth_level = AuthLevel::none;
long long conn_id = -1;
ConnectSuccess on_connect;
ConnectFailure on_failure;
std::string remote;
std::string remote_pubkey;
std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT;
if (data.skip_until("auth_level"))
auth_level = static_cast<AuthLevel>(data.consume_integer<std::underlying_type_t<AuthLevel>>());
if (data.skip_until("conn_id"))
conn_id = data.consume_integer<long long>();
if (data.skip_until("connect")) {
auto* ptr = reinterpret_cast<ConnectSuccess*>(data.consume_integer<uintptr_t>());
on_connect = std::move(*ptr);
delete ptr;
}
if (data.skip_until("failure")) {
auto* ptr = reinterpret_cast<ConnectFailure*>(data.consume_integer<uintptr_t>());
on_failure = std::move(*ptr);
delete ptr;
}
if (data.skip_until("pubkey")) {
remote_pubkey = data.consume_string();
assert(remote_pubkey.size() == 32 || remote_pubkey.empty());
}
if (data.skip_until("remote"))
remote = data.consume_string();
if (data.skip_until("timeout"))
timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
if (conn_id == -1 || remote.empty())
throw std::runtime_error("Internal error: CONNECT_REMOTE proxy command missing required 'conn_id' and/or 'remote' value");
LMQ_LOG(info, "Establishing remote connection to ", remote, remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + to_hex(remote_pubkey));
assert(conn_index_to_id.size() == connections.size());
zmq::socket_t sock{context, zmq::socket_type::dealer};
try {
setup_outgoing_socket(sock, remote_pubkey);
sock.connect(remote);
} catch (const zmq::error_t &e) {
proxy_schedule_reply_job([conn_id, on_failure=std::move(on_failure), what="connect() failed: "s+e.what()] {
on_failure(conn_id, std::move(what));
});
return;
}
connections.push_back(std::move(sock));
LMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI");
send_direct_message(connections.back(), "HI");
pending_connects.emplace_back(connections.size()-1, conn_id, std::chrono::steady_clock::now() + timeout,
std::move(on_connect), std::move(on_failure));
peer_info peer;
peer.pubkey = std::move(remote_pubkey);
peer.service_node = false;
peer.auth_level = auth_level;
peer.conn_index = connections.size() - 1;
ConnectionID conn{conn_id, peer.pubkey};
conn_index_to_id.push_back(conn);
assert(connections.size() == conn_index_to_id.size());
peer.idle_expiry = 24h * 10 * 365; // "forever"
peer.activity();
peers.emplace(std::move(conn), std::move(peer));
}
void LokiMQ::proxy_disconnect(bt_dict_consumer data) {
ConnectionID connid{-1};
std::chrono::milliseconds linger = 1s;
if (data.skip_until("conn_id"))
connid.id = data.consume_integer<long long>();
if (data.skip_until("linger_ms"))
linger = std::chrono::milliseconds(data.consume_integer<long long>());
if (data.skip_until("pubkey"))
connid.pk = data.consume_string();
if (connid.sn() && connid.pk.size() != 32)
throw std::runtime_error("Error: invalid disconnect of SN without a valid pubkey");
proxy_disconnect(std::move(connid), linger);
}
void LokiMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) {
LMQ_TRACE("Disconnecting outgoing connection to ", conn);
auto pr = peers.equal_range(conn);
for (auto it = pr.first; it != pr.second; ++it) {
auto& peer = it->second;
if (peer.outgoing()) {
LMQ_LOG(info, "Closing outgoing connection to ", conn);
proxy_close_connection(peer.conn_index, linger);
return;
}
}
LMQ_LOG(warn, "Failed to disconnect ", conn, ": no such outgoing connection");
}
}

View file

@ -1,20 +1,14 @@
#pragma once
#include "auth.h"
#include <oxenc/bt_value.h>
#include <string_view>
#include <iosfwd>
#include <stdexcept>
#include <string>
#include <utility>
#include <variant>
#include "string_view.h"
namespace oxenmq {
namespace lokimq {
class bt_dict;
struct ConnectionID;
namespace detail {
template <typename... T>
oxenc::bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts);
bt_dict build_send(ConnectionID to, string_view cmd, const T&... opts);
}
/// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For
@ -22,17 +16,11 @@ oxenc::bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts);
/// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy
/// of the ConnectionID returned by connect_remote().
struct ConnectionID {
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
// an actual connection.
ConnectionID() : ConnectionID(0) {}
// Construction from a service node pubkey
ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} {
if (pk.size() != 32)
throw std::runtime_error{"Invalid pubkey: expected 32 bytes"};
}
// Construction from a service node pubkey
ConnectionID(std::string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
ConnectionID(const ConnectionID&) = default;
ConnectionID(ConnectionID&&) = default;
ConnectionID& operator=(const ConnectionID&) = default;
@ -44,55 +32,52 @@ struct ConnectionID {
}
// Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both
// not SNs and have matching internal IDs and routes. (Pubkeys do not have to match for
// non-SNs).
// not SNs and have matching internal IDs. (Pubkeys do not have to match for non-SNs, and
// routes are not considered for equality at all).
bool operator==(const ConnectionID &o) const {
if (sn() && o.sn())
if (id == SN_ID && o.id == SN_ID)
return pk == o.pk;
return id == o.id && route == o.route;
return id == o.id;
}
bool operator!=(const ConnectionID &o) const { return !(*this == o); }
bool operator<(const ConnectionID &o) const {
if (sn() && o.sn())
if (id == SN_ID && o.id == SN_ID)
return pk < o.pk;
return id < o.id || (id == o.id && route < o.route);
return id < o.id;
}
// Returns true if this ConnectionID represents a SN connection
bool sn() const { return id == SN_ID; }
// Returns this connection's pubkey, if any. (Note that all curve connections have pubkeys, not
// only SNs).
// Returns this connection's pubkey, if any. (Note that it is possible to have a pubkey and not
// be a SN when connecting to secure remotes: having a non-empty pubkey does not imply that
// `sn()` is true).
const std::string& pubkey() const { return pk; }
// Returns a copy of the ConnectionID with the route set to empty.
ConnectionID unrouted() { return ConnectionID{id, pk, ""}; }
std::string to_string() const;
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
// an actual connection.
ConnectionID() : ConnectionID(0) {}
private:
ConnectionID(int64_t id) : id{id} {}
ConnectionID(int64_t id, std::string pubkey, std::string route = "")
ConnectionID(long long id) : id{id} {}
ConnectionID(long long id, std::string pubkey, std::string route = "")
: id{id}, pk{std::move(pubkey)}, route{std::move(route)} {}
constexpr static int64_t SN_ID = -1;
int64_t id = 0;
constexpr static long long SN_ID = -1;
long long id = 0;
std::string pk;
std::string route;
friend class OxenMQ;
friend class LokiMQ;
friend struct std::hash<ConnectionID>;
template <typename... T>
friend oxenc::bt_dict detail::build_send(ConnectionID to, std::string_view cmd, T&&... opts);
friend bt_dict detail::build_send(ConnectionID to, string_view cmd, const T&... opts);
friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn);
};
} // namespace oxenmq
} // namespace lokimq
namespace std {
template <> struct hash<oxenmq::ConnectionID> {
size_t operator()(const oxenmq::ConnectionID &c) const {
return c.sn() ? oxenmq::already_hashed{}(c.pk) :
std::hash<int64_t>{}(c.id) + std::hash<std::string>{}(c.route);
template <> struct hash<lokimq::ConnectionID> {
size_t operator()(const lokimq::ConnectionID &c) const {
return c.sn() ? std::hash<std::string>{}(c.pk) :
std::hash<long long>{}(c.id);
}
};
} // namespace std

122
lokimq/hex.h Normal file
View file

@ -0,0 +1,122 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "string_view.h"
#include <array>
#include <iterator>
#include <cassert>
namespace lokimq {
namespace detail {
/// Compile-time generated lookup tables hex conversion
struct hex_table {
char from_hex_lut[256];
char to_hex_lut[16];
constexpr hex_table() noexcept : from_hex_lut{}, to_hex_lut{} {
for (unsigned char c = 0; c < 10; c++) {
from_hex_lut[(unsigned char)('0' + c)] = 0 + c;
to_hex_lut[ (unsigned char)( 0 + c)] = '0' + c;
}
for (unsigned char c = 0; c < 6; c++) {
from_hex_lut[(unsigned char)('a' + c)] = 10 + c;
from_hex_lut[(unsigned char)('A' + c)] = 10 + c;
to_hex_lut[ (unsigned char)(10 + c)] = 'a' + c;
}
}
constexpr char from_hex(unsigned char c) const noexcept { return from_hex_lut[c]; }
constexpr char to_hex(unsigned char b) const noexcept { return to_hex_lut[b]; }
} constexpr hex_lut;
} // namespace detail
/// Creates hex digits from a character sequence.
template <typename InputIt, typename OutputIt>
void to_hex(InputIt begin, InputIt end, OutputIt out) {
for (; begin != end; ++begin) {
auto c = *begin;
*out++ = detail::hex_lut.to_hex((c & 0xf0) >> 4);
*out++ = detail::hex_lut.to_hex(c & 0x0f);
}
}
/// Creates a hex string from an iterable, std::string-like object
inline std::string to_hex(string_view s) {
std::string hex;
hex.reserve(s.size() * 2);
to_hex(s.begin(), s.end(), std::back_inserter(hex));
return hex;
}
/// Returns true if all elements in the range are hex characters
template <typename It>
constexpr bool is_hex(It begin, It end) {
for (; begin != end; ++begin) {
if (detail::hex_lut.from_hex(*begin) == 0 && *begin != '0')
return false;
}
return true;
}
/// Returns true if all elements in the string-like value are hex characters
constexpr bool is_hex(string_view s) { return is_hex(s.begin(), s.end()); }
/// Convert a hex digit into its numeric (0-15) value
constexpr char from_hex_digit(unsigned char x) noexcept {
return detail::hex_lut.from_hex(x);
}
/// Constructs a byte value from a pair of hex digits
constexpr char from_hex_pair(unsigned char a, unsigned char b) noexcept { return (from_hex_digit(a) << 4) | from_hex_digit(b); }
/// Converts a sequence of hex digits to bytes. Undefined behaviour if any characters are not in
/// [0-9a-fA-F] or if the input sequence length is not even. It is permitted for the input and
/// output ranges to overlap as long as out is no earlier than begin.
template <typename InputIt, typename OutputIt>
void from_hex(InputIt begin, InputIt end, OutputIt out) {
using std::distance;
assert(distance(begin, end) % 2 == 0);
while (begin != end) {
auto a = *begin++;
auto b = *begin++;
*out++ = from_hex_pair(a, b);
}
}
/// Converts hex digits from a std::string-like object into a std::string of bytes. Undefined
/// behaviour if any characters are not in [0-9a-fA-F] or if the input sequence length is not even.
inline std::string from_hex(string_view s) {
std::string bytes;
bytes.reserve(s.size() / 2);
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
return bytes;
}
}

109
lokimq/jobs.cpp Normal file
View file

@ -0,0 +1,109 @@
#include "lokimq.h"
#include "batch.h"
#include "lokimq-internal.h"
namespace lokimq {
void LokiMQ::proxy_batch(detail::Batch* batch) {
batches.insert(batch);
const int jobs = batch->size();
for (int i = 0; i < jobs; i++)
batch_jobs.emplace(batch, i);
proxy_skip_one_poll = true;
}
void LokiMQ::job(std::function<void()> f) {
auto* b = new Batch<void>;
b->add_job(std::move(f));
auto* baseptr = static_cast<detail::Batch*>(b);
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
}
void LokiMQ::proxy_schedule_reply_job(std::function<void()> f) {
auto* b = new Batch<void>;
b->add_job(std::move(f));
batches.insert(b);
reply_jobs.emplace(static_cast<detail::Batch*>(b), 0);
proxy_skip_one_poll = true;
}
void LokiMQ::proxy_run_batch_jobs(std::queue<batch_job>& jobs, const int reserved, int& active, bool reply) {
while (!jobs.empty() &&
(active < reserved || static_cast<int>(workers.size() - idle_workers.size()) < general_workers)) {
proxy_run_worker(get_idle_worker().load(std::move(jobs.front()), reply));
jobs.pop();
active++;
}
}
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
// the timer. If the timer object hasn't been set up yet it gets set up here.
void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
if (!timers)
timers.reset(zmq_timers_new());
int timer_id = zmq_timers_add(timers.get(),
interval.count(),
[](int timer_id, void* self) { static_cast<LokiMQ*>(self)->_queue_timer_job(timer_id); },
this);
if (timer_id == -1)
throw zmq::error_t{};
timer_jobs[timer_id] = std::make_tuple(std::move(job), squelch, false);
}
void LokiMQ::proxy_timer(bt_list_consumer timer_data) {
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
auto squelch = timer_data.consume_integer<bool>();
if (!timer_data.is_finished())
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
proxy_timer(std::move(*func), interval, squelch);
}
void LokiMQ::_queue_timer_job(int timer_id) {
auto it = timer_jobs.find(timer_id);
if (it == timer_jobs.end()) {
LMQ_LOG(warn, "Could not find timer job ", timer_id);
return;
}
auto& timer = it->second;
auto& squelch = std::get<1>(timer);
auto& running = std::get<2>(timer);
if (squelch && running) {
LMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
return;
}
auto* b = new Batch<void>;
b->add_job(std::get<0>(timer));
if (squelch) {
running = true;
b->completion_proxy([this,timer_id](auto results) {
try { results[0].get(); }
catch (const std::exception &e) { LMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
catch (...) { LMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
auto it = timer_jobs.find(timer_id);
if (it != timer_jobs.end())
std::get<2>(it->second)/*running*/ = false;
});
}
batches.insert(b);
batch_jobs.emplace(static_cast<detail::Batch*>(b), 0);
assert(b->size() == 1);
}
void LokiMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
if (proxy_thread.joinable()) {
auto *jobptr = new std::function<void()>{std::move(job)};
detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{
reinterpret_cast<uintptr_t>(jobptr),
interval.count(),
squelch}}));
} else {
proxy_timer(std::move(job), interval, squelch);
}
}
void LokiMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
}

View file

@ -1,91 +1,82 @@
#pragma once
#include <limits>
#include "oxenmq.h"
#include "lokimq.h"
// Inside some method:
// OMQ_LOG(warn, "bad ", 42, " stuff");
// LMQ_LOG(warn, "bad ", 42, " stuff");
//
#define OMQ_LOG(level, ...) log(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
// (The "this->" is here to work around gcc 5 bugginess when called in a `this`-capturing lambda.)
#define LMQ_LOG(level, ...) this->log_(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
#ifndef NDEBUG
// Same as OMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
# define OMQ_TRACE(...) log(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
// Same as LMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
# define LMQ_TRACE(...) this->log_(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
#else
# define OMQ_TRACE(...)
# define LMQ_TRACE(...)
#endif
namespace oxenmq {
namespace lokimq {
constexpr char SN_ADDR_COMMAND[] = "inproc://sn-command";
constexpr char SN_ADDR_WORKERS[] = "inproc://sn-workers";
constexpr char SN_ADDR_SELF[] = "inproc://sn-self";
constexpr char ZMQ_ADDR_ZAP[] = "inproc://zeromq.zap.01";
#ifdef OXENMQ_USE_EPOLL
constexpr auto EPOLL_COMMAND_ID = std::numeric_limits<uint64_t>::max();
constexpr auto EPOLL_WORKER_ID = std::numeric_limits<uint64_t>::max() - 1;
constexpr auto EPOLL_ZAP_ID = std::numeric_limits<uint64_t>::max() - 2;
#endif
/// Destructor for create_message(std::string&&) that zmq calls when it's done with the message.
extern "C" inline void message_buffer_destroy(void*, void* hint) {
delete reinterpret_cast<std::string*>(hint);
}
/// Creates a message without needing to reallocate the provided string data
inline zmq::message_t create_message(std::string&& data) {
inline zmq::message_t create_message(std::string &&data) {
auto *buffer = new std::string(std::move(data));
return zmq::message_t{&(*buffer)[0], buffer->size(), message_buffer_destroy, buffer};
};
/// Create a message copying from a string_view
inline zmq::message_t create_message(std::string_view data) {
inline zmq::message_t create_message(string_view data) {
return zmq::message_t{data.begin(), data.end()};
}
template <typename It>
bool send_message_parts(zmq::socket_t& sock, It begin, It end) {
void send_message_parts(zmq::socket_t &sock, It begin, It end) {
while (begin != end) {
// FIXME: for outgoing connections on ZMQ_DEALER we want to use ZMQ_DONTWAIT and handle
// EAGAIN error (which either means the peer HWM is hit -- probably indicating a connection
// failure -- or the underlying connect() system call failed). Assuming it's an outgoing
// connection, we should destroy it.
zmq::message_t &msg = *begin++;
if (!sock.send(msg, begin == end ? zmq::send_flags::dontwait : zmq::send_flags::dontwait | zmq::send_flags::sndmore))
return false;
sock.send(msg, begin == end ? zmq::send_flags::none : zmq::send_flags::sndmore);
}
return true;
}
template <typename Container>
bool send_message_parts(zmq::socket_t& sock, Container&& c) {
return send_message_parts(sock, c.begin(), c.end());
void send_message_parts(zmq::socket_t &sock, Container &&c) {
send_message_parts(sock, c.begin(), c.end());
}
/// Sends a message with an initial route. `msg` and `data` can be empty: if `msg` is empty then
/// the msg frame will be an empty message; if `data` is empty then the data frame will be omitted.
/// `flags` is passed through to zmq: typically given `zmq::send_flags::dontwait` to throw rather
/// than block if a message can't be queued.
inline bool send_routed_message(zmq::socket_t& socket, std::string route, std::string msg = {}, std::string data = {}) {
inline void send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}) {
assert(!route.empty());
std::array<zmq::message_t, 3> msgs{{create_message(std::move(route))}};
if (!msg.empty())
msgs[1] = create_message(std::move(msg));
if (!data.empty())
msgs[2] = create_message(std::move(data));
return send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
}
// Sends some stuff to a socket directly. If dontwait is true then we throw instead of blocking if
// the message cannot be accepted by zmq (i.e. because the outgoing buffer is full).
inline bool send_direct_message(zmq::socket_t& socket, std::string msg, std::string data = {}) {
// Sends some stuff to a socket directly.
inline void send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}) {
std::array<zmq::message_t, 2> msgs{{create_message(std::move(msg))}};
if (!data.empty())
msgs[1] = create_message(std::move(data));
return send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
}
// Receive all the parts of a single message from the given socket. Returns true if a message was
// received, false if called with flags=zmq::recv_flags::dontwait and no message was available.
inline bool recv_message_parts(zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
inline bool recv_message_parts(zmq::socket_t &sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
do {
zmq::message_t msg;
if (!sock.recv(msg, flags))
@ -95,21 +86,6 @@ inline bool recv_message_parts(zmq::socket_t& sock, std::vector<zmq::message_t>&
return true;
}
// Same as above, but using a fixed sized array; this is only used for internal jobs (e.g. control
// messages) where we know the message parts should never exceed a given size (this function does
// not bounds check except in debug builds). Returns the number of message parts received, or 0 on
// read error.
template <size_t N>
inline size_t recv_message_parts(zmq::socket_t& sock, std::array<zmq::message_t, N>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
for (size_t count = 0; ; count++) {
assert(count < N);
if (!sock.recv(parts[count], flags))
return 0;
if (!parts[count].more())
return count + 1;
}
}
inline const char* peer_address(zmq::message_t& msg) {
try { return msg.gets("Peer-Address"); } catch (...) {}
return "(unknown)";
@ -117,12 +93,29 @@ inline const char* peer_address(zmq::message_t& msg) {
// Returns a string view of the given message data. It's the caller's responsibility to keep the
// referenced message alive. If you want a std::string instead just call `m.to_string()`
inline std::string_view view(const zmq::message_t& m) {
inline string_view view(const zmq::message_t& m) {
return {m.data<char>(), m.size()};
}
inline std::string to_string(AuthLevel a) {
switch (a) {
case AuthLevel::denied: return "denied";
case AuthLevel::none: return "none";
case AuthLevel::basic: return "basic";
case AuthLevel::admin: return "admin";
default: return "(unknown)";
}
}
inline AuthLevel auth_from_string(string_view a) {
if (a == "none") return AuthLevel::none;
if (a == "basic") return AuthLevel::basic;
if (a == "admin") return AuthLevel::admin;
return AuthLevel::denied;
}
// Extracts and builds the "send" part of a message for proxy_send/proxy_reply
inline std::list<zmq::message_t> build_send_parts(oxenc::bt_list_consumer send, std::string_view route) {
inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_view route) {
std::list<zmq::message_t> parts;
if (!route.empty())
parts.push_back(create_message(route));
@ -134,7 +127,7 @@ inline std::list<zmq::message_t> build_send_parts(oxenc::bt_list_consumer send,
/// Sends a control message to a specific destination by prefixing the worker name (or identity)
/// then appending the command and optional data (if non-empty). (This is needed when sending the control message
/// to a router socket, i.e. inside the proxy thread).
inline void route_control(zmq::socket_t& sock, std::string_view identity, std::string_view cmd, const std::string& data = {}) {
inline void route_control(zmq::socket_t& sock, string_view identity, string_view cmd, const std::string& data = {}) {
sock.send(create_message(identity), zmq::send_flags::sndmore);
detail::send_control(sock, cmd, data);
}

View file

@ -1,29 +1,21 @@
#include "oxenmq.h"
#include "oxenmq-internal.h"
#include "zmq.hpp"
#include "lokimq.h"
#include "lokimq-internal.h"
#include <map>
#include <mutex>
#include <random>
#include <ostream>
#include <thread>
#include <future>
extern "C" {
#include <sodium/core.h>
#include <sodium/crypto_box.h>
#include <sodium/crypto_scalarmult.h>
#include <sodium.h>
}
#include <oxenc/hex.h>
#include <oxenc/variant.h>
#include "hex.h"
namespace oxenmq {
namespace lokimq {
namespace {
/// Creates a message by bt-serializing the given value (string, number, list, or dict)
template <typename T>
zmq::message_t create_bt_message(T&& data) { return create_message(oxenc::bt_serialize(std::forward<T>(data))); }
zmq::message_t create_bt_message(T&& data) { return create_message(bt_serialize(std::forward<T>(data))); }
template <typename MessageContainer>
std::vector<std::string> as_strings(const MessageContainer& msgs) {
@ -34,6 +26,11 @@ std::vector<std::string> as_strings(const MessageContainer& msgs) {
return result;
}
void check_started(const std::thread& proxy_thread, const std::string &verb) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot " + verb + " before calling `start()`");
}
void check_not_started(const std::thread& proxy_thread, const std::string &verb) {
if (proxy_thread.joinable())
throw std::logic_error("Cannot " + verb + " after calling `start()`");
@ -46,7 +43,7 @@ namespace detail {
// Sends a control messages between proxy and threads or between proxy and workers consisting of a
// single command codes with an optional data part (the data frame is omitted if empty).
void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data) {
void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
auto c = create_message(std::move(cmd));
if (data.empty()) {
sock.send(c, zmq::send_flags::none);
@ -57,20 +54,29 @@ void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data) {
}
}
/// Extracts a pubkey and and auth level from a zmq message received on a *listening* socket.
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
auto result = std::make_pair(""s, AuthLevel::none);
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
/// socket.
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg) {
auto result = std::make_tuple(""s, false, AuthLevel::none);
try {
std::string_view pubkey_hex{msg.gets("User-Id")};
string_view pubkey_hex{msg.gets("User-Id")};
if (pubkey_hex.size() != 64)
throw std::logic_error("bad user-id");
assert(oxenc::is_hex(pubkey_hex.begin(), pubkey_hex.end()));
result.first.resize(32, 0);
oxenc::from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin());
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
auto& pubkey = std::get<std::string>(result);
pubkey.resize(32, 0);
from_hex(pubkey_hex.begin(), pubkey_hex.end(), pubkey.begin());
} catch (...) {}
try {
result.second = auth_from_string(msg.gets("X-AuthLevel"));
string_view is_sn{msg.gets("X-SN")};
if (is_sn.size() == 1 && is_sn[0] == '1')
std::get<bool>(result) = true;
} catch (...) {}
try {
string_view auth_level{msg.gets("X-AuthLevel")};
std::get<AuthLevel>(result) = auth_from_string(auth_level);
} catch (...) {}
return result;
@ -79,20 +85,16 @@ std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
} // namespace detail
void OxenMQ::set_zmq_context_option(zmq::ctxopt option, int value) {
context.set(option, value);
}
void OxenMQ::log_level(LogLevel level) {
void LokiMQ::log_level(LogLevel level) {
log_lvl.store(level, std::memory_order_relaxed);
}
LogLevel OxenMQ::log_level() const {
LogLevel LokiMQ::log_level() const {
return log_lvl.load(std::memory_order_relaxed);
}
CatHelper OxenMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) {
CatHelper LokiMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) {
check_not_started(proxy_thread, "add a category");
if (name.size() > MAX_CATEGORY_LENGTH)
@ -110,7 +112,7 @@ CatHelper OxenMQ::add_category(std::string name, Access access_level, unsigned i
return ret;
}
void OxenMQ::add_command(const std::string& category, std::string name, CommandCallback callback) {
void LokiMQ::add_command(const std::string& category, std::string name, CommandCallback callback) {
check_not_started(proxy_thread, "add a command");
if (name.size() > MAX_COMMAND_LENGTH)
@ -129,12 +131,12 @@ void OxenMQ::add_command(const std::string& category, std::string name, CommandC
throw std::runtime_error("Cannot add command `" + fullname + "': that command already exists");
}
void OxenMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) {
void LokiMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) {
add_command(category, name, std::move(callback));
categories.at(category).commands.at(name).second = true;
}
void OxenMQ::add_command_alias(std::string from, std::string to) {
void LokiMQ::add_command_alias(std::string from, std::string to) {
check_not_started(proxy_thread, "add a command alias");
if (from.empty())
@ -163,56 +165,57 @@ std::atomic<int> next_id{1};
/// Accesses a thread-local command socket connected to the proxy's command socket used to issue
/// commands in a thread-safe manner. A mutex is only required here the first time a thread
/// accesses the control socket.
zmq::socket_t& OxenMQ::get_control_socket() {
zmq::socket_t& LokiMQ::get_control_socket() {
assert(proxy_thread.joinable());
// Optimize by caching the last value; OxenMQ is often a singleton and in that case we're
// Maps the LokiMQ unique ID to a local thread command socket.
static thread_local std::map<int, std::shared_ptr<zmq::socket_t>> control_sockets;
static thread_local std::pair<int, std::shared_ptr<zmq::socket_t>> last{-1, nullptr};
// Optimize by caching the last value; LokiMQ is often a singleton and in that case we're
// going to *always* hit this optimization. Even if it isn't, we're probably likely to need the
// same control socket from the same thread multiple times sequentially so this may still help.
static thread_local int last_id = -1;
static thread_local zmq::socket_t* last_socket = nullptr;
if (object_id == last_id)
return *last_socket;
if (object_id == last.first)
return *last.second;
std::lock_guard lock{control_sockets_mutex};
if (proxy_shutting_down)
throw std::runtime_error("Unable to obtain OxenMQ control socket: proxy thread is shutting down");
auto& socket = control_sockets[std::this_thread::get_id()];
if (!socket) {
socket = std::make_unique<zmq::socket_t>(context, zmq::socket_type::dealer);
socket->set(zmq::sockopt::linger, 0);
socket->connect(SN_ADDR_COMMAND);
auto it = control_sockets.find(object_id);
if (it != control_sockets.end()) {
last = *it;
return *last.second;
}
last_id = object_id;
last_socket = socket.get();
return *last_socket;
std::lock_guard<std::mutex> lock{control_sockets_mutex};
if (proxy_shutting_down)
throw std::runtime_error("Unable to obtain LokiMQ control socket: proxy thread is shutting down");
auto control = std::make_shared<zmq::socket_t>(context, zmq::socket_type::dealer);
control->setsockopt<int>(ZMQ_LINGER, 0);
control->connect(SN_ADDR_COMMAND);
thread_control_sockets.push_back(control);
control_sockets.emplace(object_id, control);
last.first = object_id;
last.second = std::move(control);
return *last.second;
}
OxenMQ::OxenMQ(
LokiMQ::LokiMQ(
std::string pubkey_,
std::string privkey_,
bool service_node,
SNRemoteAddress lookup,
Logger logger,
LogLevel level)
Logger logger)
: object_id{next_id++}, pubkey{std::move(pubkey_)}, privkey{std::move(privkey_)}, local_service_node{service_node},
sn_lookup{std::move(lookup)}, log_lvl{level}, logger{std::move(logger)}
sn_lookup{std::move(lookup)}, logger{std::move(logger)}
{
OMQ_TRACE("Constructing OxenMQ, id=", object_id, ", this=", this);
if (sodium_init() == -1)
throw std::runtime_error{"libsodium initialization failed"};
LMQ_TRACE("Constructing listening LokiMQ, id=", object_id, ", this=", this);
if (pubkey.empty() != privkey.empty()) {
throw std::invalid_argument("OxenMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
} else if (pubkey.empty()) {
if (service_node)
throw std::invalid_argument("Cannot construct a service node mode OxenMQ without a keypair");
OMQ_LOG(debug, "generating x25519 keypair for remote-only OxenMQ instance");
throw std::invalid_argument("Cannot construct a service node mode LokiMQ without a keypair");
LMQ_LOG(debug, "generating x25519 keypair for remote-only LokiMQ instance");
pubkey.resize(crypto_box_PUBLICKEYBYTES);
privkey.resize(crypto_box_SECRETKEYBYTES);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
@ -227,101 +230,63 @@ OxenMQ::OxenMQ(
std::string verify_pubkey(crypto_box_PUBLICKEYBYTES, 0);
crypto_scalarmult_base(reinterpret_cast<unsigned char*>(&verify_pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
if (verify_pubkey != pubkey)
throw std::invalid_argument("Invalid pubkey/privkey values given to OxenMQ construction: pubkey verification failed");
throw std::invalid_argument("Invalid pubkey/privkey values given to LokiMQ construction: pubkey verification failed");
}
}
void OxenMQ::start() {
void LokiMQ::start() {
if (proxy_thread.joinable())
throw std::logic_error("Cannot call start() multiple times!");
OMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", oxenc::to_hex(pubkey));
// If we're not binding to anything then we don't listen, i.e. we can only establish outbound
// connections. Don't allow this if we are in service_node mode because, if we aren't
// listening, we are useless as a service node.
if (bind.empty() && local_service_node)
throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"};
assert(general_workers > 0);
if (batch_jobs_reserved < 0)
batch_jobs_reserved = (general_workers + 1) / 2;
if (reply_jobs_reserved < 0)
reply_jobs_reserved = (general_workers + 7) / 8;
max_workers = general_workers + batch_jobs_reserved + reply_jobs_reserved;
for (const auto& cat : categories) {
max_workers += cat.second.reserved_threads;
}
if (log_level() >= LogLevel::debug) {
OMQ_LOG(debug, "Reserving space for ", max_workers, " max workers = ", general_workers, " general plus reservations for:");
for (const auto& cat : categories)
OMQ_LOG(debug, " - ", cat.first, ": ", cat.second.reserved_threads);
OMQ_LOG(debug, " - (batch jobs): ", batch_jobs_reserved);
OMQ_LOG(debug, " - (reply jobs): ", reply_jobs_reserved);
OMQ_LOG(debug, "Plus ", tagged_workers.size(), " tagged worker threads");
}
if (MAX_SOCKETS != 0) {
// The max sockets setting we apply to the context here is used during zmq context
// initialization, which happens when the first socket is constructed using this context:
// hence we set this *before* constructing any socket_t on the context.
int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit);
int want_sockets = MAX_SOCKETS < 0 ? zmq_socket_limit :
std::min<int>(zmq_socket_limit,
MAX_SOCKETS + max_workers + tagged_workers.size()
+ 4 /* zap_auth, workers_socket, command, inproc_listener */);
context.set(zmq::ctxopt::max_sockets, want_sockets);
}
LMQ_LOG(info, "Initializing LokiMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey));
// We bind `command` here so that the `get_control_socket()` below is always connecting to a
// bound socket, but we do nothing else here: the proxy thread is responsible for everything
// except binding it.
command = zmq::socket_t{context, zmq::socket_type::router};
command.bind(SN_ADDR_COMMAND);
std::promise<void> startup_prom;
auto proxy_startup = startup_prom.get_future();
proxy_thread = std::thread{&OxenMQ::proxy_loop, this, std::move(startup_prom)};
proxy_thread = std::thread{&LokiMQ::proxy_loop, this};
OMQ_LOG(debug, "Waiting for proxy thread to initialize...");
proxy_startup.get(); // Rethrows exceptions from the proxy startup (e.g. failure to bind)
OMQ_LOG(debug, "Waiting for proxy thread to get ready...");
LMQ_LOG(debug, "Waiting for proxy thread to get ready...");
auto &control = get_control_socket();
detail::send_control(control, "START");
OMQ_TRACE("Sent START command");
LMQ_TRACE("Sent START command");
zmq::message_t ready_msg;
std::vector<zmq::message_t> parts;
try { recv_message_parts(control, parts); }
catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from OxenMQ::Proxy thread: "s + e.what()); }
catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from LokiMQ::Proxy thread: "s + e.what()); }
if (!(parts.size() == 1 && view(parts.front()) == "READY"))
throw std::runtime_error("Invalid startup message from proxy thread (didn't get expected READY message)");
OMQ_LOG(debug, "Proxy thread is ready");
LMQ_LOG(debug, "Proxy thread is ready");
}
void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
if (std::string_view{bind_addr}.substr(0, 9) == "inproc://")
throw std::logic_error{"inproc:// cannot be used with listen_curve"};
if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; };
bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)};
if (proxy_thread.joinable())
detail::send_control(get_control_socket(), "BIND", oxenc::bt_serialize(detail::serialize_object(std::move(d))));
else
bind.push_back(std::move(d));
void LokiMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) {
// TODO: there's no particular reason we can't start listening after starting up; just needs to
// be implemented. (But if we can start we'll probably also want to be able to stop, so it's
// more than just binding that needs implementing).
check_not_started(proxy_thread, "start listening");
bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)});
}
void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
if (std::string_view{bind_addr}.substr(0, 9) == "inproc://")
throw std::logic_error{"inproc:// cannot be used with listen_plain"};
if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; };
bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)};
if (proxy_thread.joinable())
detail::send_control(get_control_socket(), "BIND", oxenc::bt_serialize(detail::serialize_object(std::move(d))));
else
bind.push_back(std::move(d));
void LokiMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) {
// TODO: As above.
check_not_started(proxy_thread, "start listening");
bind.emplace_back(std::move(bind_addr), bind_data{false, std::move(allow_connection)});
}
std::pair<OxenMQ::category*, const std::pair<OxenMQ::CommandCallback, bool>*> OxenMQ::get_command(std::string& command) {
std::pair<LokiMQ::category*, const std::pair<LokiMQ::CommandCallback, bool>*> LokiMQ::get_command(std::string& command) {
if (command.size() > MAX_CATEGORY_LENGTH + 1 + MAX_COMMAND_LENGTH) {
OMQ_LOG(warn, "Invalid command '", command, "': command too long");
LMQ_LOG(warn, "Invalid command '", command, "': command too long");
return {};
}
@ -333,7 +298,7 @@ std::pair<OxenMQ::category*, const std::pair<OxenMQ::CommandCallback, bool>*> Ox
auto dot = command.find('.');
if (dot == 0 || dot == std::string::npos) {
OMQ_LOG(warn, "Invalid command '", command, "': expected <category>.<command>");
LMQ_LOG(warn, "Invalid command '", command, "': expected <category>.<command>");
return {};
}
std::string catname = command.substr(0, dot);
@ -341,21 +306,21 @@ std::pair<OxenMQ::category*, const std::pair<OxenMQ::CommandCallback, bool>*> Ox
auto catit = categories.find(catname);
if (catit == categories.end()) {
OMQ_LOG(warn, "Invalid command category '", catname, "'");
LMQ_LOG(warn, "Invalid command category '", catname, "'");
return {};
}
const auto& category = catit->second;
auto callback_it = category.commands.find(cmd);
if (callback_it == category.commands.end()) {
OMQ_LOG(warn, "Invalid command '", command, "'");
LMQ_LOG(warn, "Invalid command '", command, "'");
return {};
}
return {&catit->second, &callback_it->second};
}
void OxenMQ::set_batch_threads(int threads) {
void LokiMQ::set_batch_threads(int threads) {
if (proxy_thread.joinable())
throw std::logic_error("Cannot change reserved batch threads after calling `start()`");
if (threads < -1) // -1 is the default which is based on general threads
@ -363,7 +328,7 @@ void OxenMQ::set_batch_threads(int threads) {
batch_jobs_reserved = threads;
}
void OxenMQ::set_reply_threads(int threads) {
void LokiMQ::set_reply_threads(int threads) {
if (proxy_thread.joinable())
throw std::logic_error("Cannot change reserved reply threads after calling `start()`");
if (threads < -1) // -1 is the default which is based on general threads
@ -371,7 +336,7 @@ void OxenMQ::set_reply_threads(int threads) {
reply_jobs_reserved = threads;
}
void OxenMQ::set_general_threads(int threads) {
void LokiMQ::set_general_threads(int threads) {
if (proxy_thread.joinable())
throw std::logic_error("Cannot change general thread count after calling `start()`");
if (threads < 1)
@ -379,72 +344,80 @@ void OxenMQ::set_general_threads(int threads) {
general_workers = threads;
}
OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_,
LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_,
std::vector<zmq::message_t> data_parts_, const std::pair<CommandCallback, bool>* callback_) {
reset();
is_batch_job = false;
is_reply_job = false;
cat = cat_;
command = std::move(command_);
conn = std::move(conn_);
access = std::move(access_);
remote = std::move(remote_);
data_parts = std::move(data_parts_);
to_run = callback_;
callback = callback_;
return *this;
}
OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function<void()> callback) {
reset();
is_injected = true;
cat = cat_;
command = std::move(command_);
conn = {};
access = {};
remote = std::move(remote_);
to_run = std::move(callback);
return *this;
LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) {
return load(&pending.cat, std::move(pending.command), std::move(pending.conn),
std::move(pending.data_parts), pending.callback);
}
OxenMQ::run_info& OxenMQ::run_info::load(pending_command&& pending) {
if (auto *f = std::get_if<std::function<void()>>(&pending.callback))
return load(&pending.cat, std::move(pending.command), std::move(pending.remote), std::move(*f));
assert(pending.callback.index() == 0);
return load(&pending.cat, std::move(pending.command), std::move(pending.conn), std::move(pending.access),
std::move(pending.remote), std::move(pending.data_parts), var::get<0>(pending.callback));
}
OxenMQ::run_info& OxenMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) {
reset();
LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job) {
is_batch_job = true;
is_reply_job = reply_job;
is_tagged_thread_job = tagged_thread > 0;
batch_jobno = bj.second;
to_run = bj.first;
batch = bj.first;
return *this;
}
OxenMQ::~OxenMQ() {
if (!proxy_thread.joinable()) {
if (!tagged_workers.empty()) {
// We have tagged workers that are waiting on a signal for startup, but we didn't start
// up, so signal them so that they can end themselves.
{
std::lock_guard lock{tagged_startup_mutex};
tagged_go = tagged_go_mode::SHUTDOWN;
}
tagged_cv.notify_all();
for (auto& [run, busy, queue] : tagged_workers)
run.worker_thread.join();
}
LokiMQ::~LokiMQ() {
if (!proxy_thread.joinable())
return;
}
OMQ_LOG(info, "OxenMQ shutting down proxy thread");
LMQ_LOG(info, "LokiMQ shutting down proxy thread");
detail::send_control(get_control_socket(), "QUIT");
proxy_thread.join();
OMQ_LOG(info, "OxenMQ proxy thread has stopped");
LMQ_LOG(info, "LokiMQ proxy thread has stopped");
}
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
check_started(proxy_thread, "connect");
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
return pubkey;
}
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
if (!proxy_thread.joinable())
LMQ_LOG(warn, "connect_remote() called before start(); this won't take effect until start() is called");
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
auto id = next_conn_id++;
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
{"conn_id", id},
{"connect", reinterpret_cast<uintptr_t>(new ConnectSuccess{std::move(on_connect)})},
{"failure", reinterpret_cast<uintptr_t>(new ConnectFailure{std::move(on_failure)})},
{"pubkey", pubkey},
{"remote", remote},
{"timeout", timeout.count()},
}));
return id;
}
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
{"conn_id", id.id},
{"linger_ms", linger.count()},
{"pubkey", id.pk},
}));
}
std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
@ -460,14 +433,13 @@ std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
std::string make_random_string(size_t size) {
static thread_local std::mt19937_64 rng{std::random_device{}()};
static thread_local std::uniform_int_distribution<char> dist{std::numeric_limits<char>::min(), std::numeric_limits<char>::max()};
std::string rando;
rando.reserve(size);
while (rando.size() < size) {
uint64_t x = rng();
rando.append(reinterpret_cast<const char*>(&x), std::min<size_t>(size - rando.size(), 8));
}
for (size_t i = 0; i < size; i++)
rando += dist(rng);
return rando;
}
} // namespace oxenmq
} // namespace lokimq
// vim:sw=4:et

1200
lokimq/lokimq.h Normal file

File diff suppressed because it is too large Load diff

54
lokimq/message.h Normal file
View file

@ -0,0 +1,54 @@
#pragma once
#include <vector>
#include "connections.h"
namespace lokimq {
class LokiMQ;
/// Encapsulates an incoming message from a remote connection with message details plus extra
/// info need to send a reply back through the proxy thread via the `reply()` method. Note that
/// this object gets reused: callbacks should use but not store any reference beyond the callback.
class Message {
public:
LokiMQ& lokimq; ///< The owning LokiMQ object
std::vector<string_view> data; ///< The provided command data parts, if any.
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
/// Constructor
Message(LokiMQ& lmq, ConnectionID cid) : lokimq{lmq}, conn{std::move(cid)} {}
// Non-copyable
Message(const Message&) = delete;
Message& operator=(const Message&) = delete;
/// Sends a command back to whomever sent this message. Arguments are forwarded to send() but
/// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e.
/// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to
/// establish a new connection to the SN if no longer connected. For non-SN messages the reply
/// will be attempted using the available routing information, but if the connection has already
/// been closed the reply will be dropped.
///
/// If you want to send a non-strong reply even when the remote is a service node then add
/// an explicit `send_option::optional()` argument.
template <typename... Args>
void send_back(string_view, Args&&... args);
/// Sends a reply to a request. This takes no command: the command is always the built-in
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
/// arguments are as in `send_back()`. You should only send one reply for a command expecting
/// replies, though this is not enforced: attempting to send multiple replies will simply be
/// dropped when received by the remote. (Note, however, that it is possible to send multiple
/// messages -- e.g. you could send a reply and then also call send_back() and/or send_request()
/// to send more requests back to the sender).
template <typename... Args>
void send_reply(Args&&... args);
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
/// lmq.request() that takes care of setting up the recipient arguments.
template <typename ReplyCallback, typename... Args>
void send_request(string_view cmd, ReplyCallback&& callback, Args&&... args);
};
}

564
lokimq/proxy.cpp Normal file
View file

@ -0,0 +1,564 @@
#include "lokimq.h"
#include "lokimq-internal.h"
#include "hex.h"
namespace lokimq {
void LokiMQ::proxy_quit() {
LMQ_LOG(debug, "Received quit command, shutting down proxy thread");
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
command.setsockopt<int>(ZMQ_LINGER, 0);
command.close();
{
std::lock_guard<std::mutex> lock{control_sockets_mutex};
for (auto &control : thread_control_sockets)
control->close();
proxy_shutting_down = true; // To prevent threads from opening new control sockets
}
workers_socket.close();
int linger = std::chrono::milliseconds{CLOSE_LINGER}.count();
for (auto& s : connections)
s.setsockopt(ZMQ_LINGER, linger);
connections.clear();
peers.clear();
LMQ_LOG(debug, "Proxy thread teardown complete");
}
void LokiMQ::proxy_send(bt_dict_consumer data) {
// NB: bt_dict_consumer goes in alphabetical order
string_view hint;
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
bool optional = false;
bool incoming = false;
bool request = false;
bool have_conn_id = false;
ConnectionID conn_id;
std::string request_tag;
ReplyCallback request_callback;
if (data.skip_until("conn_id")) {
conn_id.id = data.consume_integer<long long>();
if (conn_id.id == -1)
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
have_conn_id = true;
}
if (data.skip_until("conn_pubkey")) {
if (have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy send command; conn_id and conn_pubkey are exclusive");
conn_id.pk = data.consume_string();
conn_id.id = ConnectionID::SN_ID;
} else if (!have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy send command; conn_pubkey or conn_id missing");
if (data.skip_until("conn_route"))
conn_id.route = data.consume_string();
if (data.skip_until("hint"))
hint = data.consume_string_view();
if (data.skip_until("incoming"))
incoming = data.consume_integer<bool>();
if (data.skip_until("keep_alive"))
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
if (data.skip_until("optional"))
optional = data.consume_integer<bool>();
if (data.skip_until("request"))
request = data.consume_integer<bool>();
if (request) {
if (!data.skip_until("request_callback"))
throw std::runtime_error("Internal error: received request without request_callback");
// The initiator gives up ownership of the callback to us (serializing it through a
// uintptr_t), so we take the pointer, move the value out of it, then destroy the pointer we
// were given. Further down, if we are able to send the request successfully, we set up the
// pending request.
auto* cbptr = reinterpret_cast<ReplyCallback*>(data.consume_integer<uintptr_t>());
request_callback = std::move(*cbptr);
delete cbptr;
if (!data.skip_until("request_tag"))
throw std::runtime_error("Internal error: received request without request_name");
request_tag = data.consume_string();
if (data.skip_until("request_timeout"))
request_timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
}
if (!data.skip_until("send"))
throw std::runtime_error("Internal error: Invalid proxy send command; send parts missing");
bt_list_consumer send = data.consume_list_consumer();
// Now figure out which socket to send to and do the actual sending. We can repeat this loop
// multiple times, if we're sending to a SN, because it's possible that we have multiple
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
// connection and try the next one.
bool retry = true, sent = false;
while (retry) {
retry = false;
zmq::socket_t *send_to;
if (conn_id.sn()) {
auto sock_route = proxy_connect_sn(conn_id.pk, hint, optional, incoming, keep_alive);
if (!sock_route.first) {
if (optional)
LMQ_LOG(debug, "Not sending: send is optional and no connection to ",
to_hex(conn_id.pk), " is currently established");
else
LMQ_LOG(error, "Unable to send to ", to_hex(conn_id.pk), ": no connection address found");
break;
}
send_to = sock_route.first;
conn_id.route = std::move(sock_route.second);
} else if (!conn_id.route.empty()) { // incoming non-SN connection
auto it = incoming_conn_index.find(conn_id);
if (it == incoming_conn_index.end()) {
LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
break;
}
send_to = &connections[it->second];
} else {
auto pr = peers.equal_range(conn_id);
if (pr.first == peers.end()) {
LMQ_LOG(warn, "Unable to send: connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
break;
}
auto& peer = pr.first->second;
send_to = &connections[peer.conn_index];
}
try {
send_message_parts(*send_to, build_send_parts(send, conn_id.route));
sent = true;
} catch (const zmq::error_t &e) {
if (e.num() == EHOSTUNREACH && !conn_id.route.empty() /*= incoming conn*/) {
LMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
auto pr = peers.equal_range(conn_id);
if (pr.first != peers.end()) {
if (!conn_id.sn()) {
peers.erase(pr.first);
} else {
bool removed;
for (auto it = pr.first; it != pr.second; ) {
auto& peer = it->second;
if (peer.route == conn_id.route) {
peers.erase(it);
removed = true;
break;
}
}
// The incoming connection to the SN is no longer good, but we can retry because
// we may have another active connection with the SN (or may want to open one).
if (removed) {
LMQ_LOG(debug, "Retrying sending to SN ", to_hex(conn_id.pk), " using other sockets");
retry = true;
}
}
}
}
if (!retry) {
LMQ_LOG(warn, "Unable to send message to ", conn_id, ": ", e.what());
}
}
}
if (request) {
if (sent) {
LMQ_LOG(debug, "Added new pending request ", to_hex(request_tag));
pending_requests.insert({ request_tag, {
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
} else {
LMQ_LOG(debug, "Could not send request, scheduling request callback failure");
job([callback = std::move(request_callback)] { callback(false, {}); });
}
}
}
void LokiMQ::proxy_reply(bt_dict_consumer data) {
bool have_conn_id = false;
ConnectionID conn_id{0};
if (data.skip_until("conn_id")) {
conn_id.id = data.consume_integer<long long>();
if (conn_id.id == -1)
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
have_conn_id = true;
}
if (data.skip_until("conn_pubkey")) {
if (have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_id and conn_pubkey are exclusive");
conn_id.pk = data.consume_string();
conn_id.id = ConnectionID::SN_ID;
} else if (!have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_pubkey or conn_id missing");
if (!data.skip_until("send"))
throw std::runtime_error("Internal error: Invalid proxy reply command; send parts missing");
bt_list_consumer send = data.consume_list_consumer();
auto pr = peers.equal_range(conn_id);
if (pr.first == pr.second) {
LMQ_LOG(warn, "Unable to send tagged reply: the connection is no longer valid");
return;
}
// We try any connections until one works (for ordinary remotes there will be just one, but for
// SNs there might be one incoming and one outgoing).
for (auto it = pr.first; it != pr.second; ) {
try {
send_message_parts(connections[it->second.conn_index], build_send_parts(send, it->second.route));
break;
} catch (const zmq::error_t &err) {
if (err.num() == EHOSTUNREACH) {
LMQ_LOG(info, "Unable to send reply to incoming non-SN request: remote is no longer connected");
LMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
it = peers.erase(it);
} else {
LMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what());
++it;
}
}
}
}
void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
if (parts.size() < 2)
throw std::logic_error("Expected 2-3 message parts for a proxy control message");
auto route = view(parts[0]), cmd = view(parts[1]);
LMQ_TRACE("control message: ", cmd);
if (parts.size() == 3) {
LMQ_TRACE("...: ", parts[2]);
if (cmd == "SEND") {
LMQ_TRACE("proxying message");
return proxy_send(view(parts[2]));
} else if (cmd == "REPLY") {
LMQ_TRACE("proxying reply to non-SN incoming message");
return proxy_reply(view(parts[2]));
} else if (cmd == "BATCH") {
LMQ_TRACE("proxy batch jobs");
auto ptrval = bt_deserialize<uintptr_t>(view(parts[2]));
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
} else if (cmd == "CONNECT_SN") {
proxy_connect_sn(view(parts[2]));
return;
} else if (cmd == "CONNECT_REMOTE") {
return proxy_connect_remote(view(parts[2]));
} else if (cmd == "DISCONNECT") {
return proxy_disconnect(view(parts[2]));
} else if (cmd == "TIMER") {
return proxy_timer(view(parts[2]));
}
} else if (parts.size() == 2) {
if (cmd == "START") {
// Command send by the owning thread during startup; we send back a simple READY reply to
// let it know we are running.
return route_control(command, route, "READY");
} else if (cmd == "QUIT") {
// Asked to quit: set max_workers to zero and tell any idle ones to quit. We will
// close workers as they come back to READY status, and then close external
// connections once all workers are done.
max_workers = 0;
for (const auto &route : idle_workers)
route_control(workers_socket, workers[route].worker_routing_id, "QUIT");
idle_workers.clear();
return;
}
}
throw std::runtime_error("Proxy received invalid control command: " + std::string{cmd} +
" (" + std::to_string(parts.size()) + ")");
}
void LokiMQ::proxy_loop() {
zap_auth.setsockopt<int>(ZMQ_LINGER, 0);
zap_auth.bind(ZMQ_ADDR_ZAP);
workers_socket.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
workers_socket.bind(SN_ADDR_WORKERS);
assert(general_workers > 0);
if (batch_jobs_reserved < 0)
batch_jobs_reserved = (general_workers + 1) / 2;
if (reply_jobs_reserved < 0)
reply_jobs_reserved = (general_workers + 7) / 8;
max_workers = general_workers + batch_jobs_reserved + reply_jobs_reserved;
for (const auto& cat : categories) {
max_workers += cat.second.reserved_threads;
}
#ifndef NDEBUG
if (log_level() >= LogLevel::trace) {
LMQ_TRACE("Reserving space for ", max_workers, " max workers = ", general_workers, " general plus reservations for:");
for (const auto& cat : categories)
LMQ_TRACE(" - ", cat.first, ": ", cat.second.reserved_threads);
LMQ_TRACE(" - (batch jobs): ", batch_jobs_reserved);
LMQ_TRACE(" - (reply jobs): ", reply_jobs_reserved);
}
#endif
workers.reserve(max_workers);
if (!workers.empty())
throw std::logic_error("Internal error: proxy thread started with active worker threads");
for (size_t i = 0; i < bind.size(); i++) {
auto& b = bind[i].second;
zmq::socket_t listener{context, zmq::socket_type::router};
std::string auth_domain = bt_serialize(i);
listener.setsockopt(ZMQ_ZAP_DOMAIN, auth_domain.c_str(), auth_domain.size());
if (b.curve) {
listener.setsockopt<int>(ZMQ_CURVE_SERVER, 1);
listener.setsockopt(ZMQ_CURVE_PUBLICKEY, pubkey.data(), pubkey.size());
listener.setsockopt(ZMQ_CURVE_SECRETKEY, privkey.data(), privkey.size());
}
listener.setsockopt(ZMQ_HANDSHAKE_IVL, (int) HANDSHAKE_TIME.count());
listener.setsockopt<int64_t>(ZMQ_MAXMSGSIZE, MAX_MSG_SIZE);
listener.setsockopt<int>(ZMQ_ROUTER_HANDOVER, 1);
listener.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
listener.bind(bind[i].first);
LMQ_LOG(info, "LokiMQ listening on ", bind[i].first);
connections.push_back(std::move(listener));
auto conn_id = next_conn_id++;
conn_index_to_id.push_back(conn_id);
incoming_conn_index[conn_id] = connections.size() - 1;
b.index = connections.size() - 1;
}
pollitems_stale = true;
// Also add an internal connection to self so that calling code can avoid needing to
// special-case rare situations where we are supposed to talk to a quorum member that happens to
// be ourselves (which can happen, for example, with cross-quoum Blink communication)
// FIXME: not working
//listener.bind(SN_ADDR_SELF);
if (!timers)
timers.reset(zmq_timers_new());
auto do_conn_cleanup = [this] { proxy_conn_cleanup(); };
using CleanupLambda = decltype(do_conn_cleanup);
if (-1 == zmq_timers_add(timers.get(),
std::chrono::milliseconds{CONN_CHECK_INTERVAL}.count(),
// Wrap our lambda into a C function pointer where we pass in the lambda pointer as extra arg
[](int /*timer_id*/, void* cleanup) { (*static_cast<CleanupLambda*>(cleanup))(); },
&do_conn_cleanup)) {
throw zmq::error_t{};
}
std::vector<zmq::message_t> parts;
while (true) {
std::chrono::milliseconds poll_timeout;
if (max_workers == 0) { // Will be 0 only if we are quitting
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); })) {
// All the workers have finished, so we can finish shutting down
return proxy_quit();
}
poll_timeout = 1s; // We don't keep running timers when we're quitting, so don't have a timer to check
} else {
poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())};
}
if (proxy_skip_one_poll)
proxy_skip_one_poll = false;
else {
LMQ_TRACE("polling for new messages");
if (pollitems_stale)
rebuild_pollitems();
// We poll the control socket and worker socket for any incoming messages. If we have
// available worker room then also poll incoming connections and outgoing connections
// for messages to forward to a worker. Otherwise, we just look for a control message
// or a worker coming back with a ready message.
zmq::poll(pollitems.data(), pollitems.size(), poll_timeout);
}
LMQ_TRACE("processing control messages");
// Retrieve any waiting incoming control messages
for (parts.clear(); recv_message_parts(command, parts, zmq::recv_flags::dontwait); parts.clear()) {
proxy_control_message(parts);
}
LMQ_TRACE("processing worker messages");
for (parts.clear(); recv_message_parts(workers_socket, parts, zmq::recv_flags::dontwait); parts.clear()) {
proxy_worker_message(parts);
}
LMQ_TRACE("processing timers");
zmq_timers_execute(timers.get());
// Handle any zap authentication
LMQ_TRACE("processing zap requests");
process_zap_requests();
// See if we can drain anything from the current queue before we potentially add to it
// below.
LMQ_TRACE("processing queued jobs and messages");
proxy_process_queue();
LMQ_TRACE("processing new incoming messages");
// We round-robin connections when pulling off pending messages one-by-one rather than
// pulling off all messages from one connection before moving to the next; thus in cases of
// contention we end up fairly distributing.
const int num_sockets = connections.size();
std::queue<int> queue_index;
for (int i = 0; i < num_sockets; i++)
queue_index.push(i);
for (parts.clear(); !queue_index.empty() && static_cast<int>(workers.size()) < max_workers; parts.clear()) {
size_t i = queue_index.front();
queue_index.pop();
auto& sock = connections[i];
if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait))
continue;
// We only pull this one message now but then requeue the socket so that after we check
// all other sockets we come back to this one to check again.
queue_index.push(i);
if (parts.empty()) {
LMQ_LOG(warn, "Ignoring empty (0-part) incoming message");
continue;
}
if (!proxy_handle_builtin(i, parts))
proxy_to_worker(i, parts);
if (pollitems_stale) {
// If our items became stale then we may have just closed a connection and so our
// queue index maybe also be stale, so restart the proxy loop (so that we rebuild
// pollitems).
LMQ_TRACE("pollitems became stale; short-circuiting incoming message loop");
break;
}
}
LMQ_TRACE("done proxy loop");
}
}
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
// reason)
bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts) {
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
string_view route, cmd;
if (parts.size() < (outgoing ? 1 : 2)) {
LMQ_LOG(warn, "Received empty message; ignoring");
return true;
}
if (outgoing) {
cmd = view(parts[0]);
} else {
route = view(parts[0]);
cmd = view(parts[1]);
}
LMQ_TRACE("Checking for builtins: ", cmd, " from ", peer_address(parts.back()));
if (cmd == "REPLY") {
size_t tag_pos = (outgoing ? 1 : 2);
if (parts.size() <= tag_pos) {
LMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
return true;
}
std::string reply_tag{view(parts[tag_pos])};
auto it = pending_requests.find(reply_tag);
if (it != pending_requests.end()) {
LMQ_LOG(debug, "Received REPLY for pending command", to_hex(reply_tag), "; scheduling callback");
std::vector<std::string> data;
data.reserve(parts.size() - (tag_pos + 1));
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
data.emplace_back(view(*it));
proxy_schedule_reply_job([callback=std::move(it->second.second), data=std::move(data)] {
callback(true, std::move(data));
});
pending_requests.erase(it);
} else {
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
}
return true;
} else if (cmd == "HI") {
if (outgoing) {
LMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
return true;
}
LMQ_LOG(info, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO");
try {
send_routed_message(connections[conn_index], std::string{route}, "HELLO");
} catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
return true;
} else if (cmd == "HELLO") {
if (!outgoing) {
LMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
return true;
}
auto it = std::find_if(pending_connects.begin(), pending_connects.end(),
[&](auto& pc) { return std::get<size_t>(pc) == conn_index; });
if (it == pending_connects.end()) {
LMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring");
return true;
}
auto& pc = *it;
auto pit = peers.find(std::get<long long>(pc));
if (pit == peers.end()) {
LMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring");
return true;
}
LMQ_LOG(info, "Got initial HELLO server response from ", peer_address(parts.back()));
proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)),
conn=conn_index_to_id[conn_index]] {
on_success(conn);
});
pending_connects.erase(it);
return true;
} else if (cmd == "BYE") {
if (outgoing) {
std::string pk;
bool sn;
AuthLevel a;
std::tie(pk, sn, a) = detail::extract_metadata(parts.back());
ConnectionID conn = sn ? ConnectionID{std::move(pk)} : conn_index_to_id[conn_index];
LMQ_LOG(info, "BYE command received; disconnecting from ", conn);
proxy_disconnect(conn, 1s);
} else {
LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
}
return true;
}
else if (cmd == "FORBIDDEN" || cmd == "NOT_A_SERVICE_NODE") {
return true; // FIXME - ignore these? Log?
}
return false;
}
void LokiMQ::proxy_process_queue() {
// First up: process any batch jobs; since these are internal they are given higher priority.
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
// Next any reply batch jobs (which are a bit different from the above, since they are
// externally triggered but for things we initiated locally).
proxy_run_batch_jobs(reply_jobs, reply_jobs_reserved, reply_jobs_active, true);
// Finally general incoming commands
for (auto it = pending_commands.begin(); it != pending_commands.end() && active_workers() < max_workers; ) {
auto& pending = *it;
if (pending.cat.active_threads < pending.cat.reserved_threads
|| active_workers() < general_workers) {
proxy_run_worker(get_idle_worker().load(std::move(pending)));
pending.cat.queued--;
pending.cat.active_threads++;
assert(pending.cat.queued >= 0);
it = pending_commands.erase(it);
} else {
++it; // no available general or reserved worker spots for this job right now
}
}
}
}

247
lokimq/string_view.h Normal file
View file

@ -0,0 +1,247 @@
// Copyright (c) 2019-2020, The Loki Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#ifdef __cpp_lib_string_view
#include <string_view>
namespace lokimq { using string_view = std::string_view; }
#else
#include <ostream>
#include <limits>
namespace lokimq {
/// Basic implementation of std::string_view (except for std::hash support).
class simple_string_view {
const char *data_;
size_t size_;
public:
using traits_type = std::char_traits<char>;
using value_type = char;
using pointer = char*;
using const_pointer = const char*;
using reference = char&;
using const_reference = const char&;
using const_iterator = const_pointer;
using iterator = const_iterator;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using reverse_iterator = const_reverse_iterator;
using size_type = std::size_t;
using different_type = std::ptrdiff_t;
static constexpr auto& npos = std::string::npos;
constexpr simple_string_view() noexcept : data_{nullptr}, size_{0} {}
constexpr simple_string_view(const simple_string_view&) noexcept = default;
simple_string_view(const std::string& str) : data_{str.data()}, size_{str.size()} {}
constexpr simple_string_view(const char* data, size_t size) noexcept : data_{data}, size_{size} {}
simple_string_view(const char* data) : data_{data}, size_{traits_type::length(data)} {}
simple_string_view& operator=(const simple_string_view&) = default;
constexpr const char* data() const noexcept { return data_; }
constexpr size_t size() const noexcept { return size_; }
constexpr size_t length() const noexcept { return size_; }
constexpr size_t max_size() const noexcept { return std::numeric_limits<size_t>::max(); }
constexpr bool empty() const noexcept { return size_ == 0; }
explicit operator std::string() const { return {data_, size_}; }
constexpr const char* begin() const noexcept { return data_; }
constexpr const char* cbegin() const noexcept { return data_; }
constexpr const char* end() const noexcept { return data_ + size_; }
constexpr const char* cend() const noexcept { return data_ + size_; }
reverse_iterator rbegin() const { return reverse_iterator{end()}; }
reverse_iterator crbegin() const { return reverse_iterator{end()}; }
reverse_iterator rend() const { return reverse_iterator{begin()}; }
reverse_iterator crend() const { return reverse_iterator{begin()}; }
constexpr const char& operator[](size_t pos) const { return data_[pos]; }
constexpr const char& front() const { return *data_; }
constexpr const char& back() const { return data_[size_ - 1]; }
int compare(simple_string_view s) const;
constexpr void remove_prefix(size_t n) { data_ += n; size_ -= n; }
constexpr void remove_suffix(size_t n) { size_ -= n; }
void swap(simple_string_view &s) noexcept { std::swap(data_, s.data_); std::swap(size_, s.size_); }
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
constexpr // GCC 5.x is buggy wrt constexpr throwing
#endif
const char& at(size_t pos) const {
if (pos >= size())
throw std::out_of_range{"invalid string_view index"};
return data_[pos];
};
size_t copy(char* dest, size_t count, size_t pos = 0) const {
if (pos > size()) throw std::out_of_range{"invalid copy pos"};
size_t rcount = std::min(count, size_ - pos);
traits_type::copy(dest, data_ + pos, rcount);
return rcount;
}
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
constexpr // GCC 5.x is buggy wrt constexpr throwing
#endif
simple_string_view substr(size_t pos = 0, size_t count = npos) const {
if (pos > size()) throw std::out_of_range{"invalid substr range"};
simple_string_view result = *this;
if (pos > 0) result.remove_prefix(pos);
if (count < result.size()) result.remove_suffix(result.size() - count);
return result;
}
size_t find(simple_string_view v, size_t pos = 0) const {
if (pos > size_ || v.size_ > size_) return npos;
for (const size_t max_pos = size_ - v.size_; pos <= max_pos; ++pos) {
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
return pos;
}
return npos;
}
size_t find(char c, size_t pos = 0) const { return find({&c, 1}, pos); }
size_t find(const char* c, size_t pos, size_t count) const { return find({c, count}, pos); }
size_t find(const char* c, size_t pos = 0) const { return find(simple_string_view(c), pos); }
size_t rfind(simple_string_view v, size_t pos = npos) const {
if (v.size_ > size_) return npos;
const size_t max_pos = size_ - v.size_;
for (pos = std::min(pos, max_pos); pos <= max_pos; --pos) {
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
return pos;
}
return npos;
}
size_t rfind(char c, size_t pos = npos) const { return rfind({&c, 1}, pos); }
size_t rfind(const char* c, size_t pos, size_t count) const { return rfind({c, count}, pos); }
size_t rfind(const char* c, size_t pos = npos) const { return rfind(simple_string_view(c), pos); }
constexpr size_t find_first_of(simple_string_view v, size_t pos = 0) const noexcept {
for (; pos < size_; ++pos)
for (char c : v)
if (data_[pos] == c)
return pos;
return npos;
}
constexpr size_t find_first_of(char c, size_t pos = 0) const noexcept { return find_first_of({&c, 1}, pos); }
constexpr size_t find_first_of(const char* c, size_t pos, size_t count) const { return find_first_of({c, count}, pos); }
size_t find_first_of(const char* c, size_t pos = 0) const { return find_first_of(simple_string_view(c), pos); }
constexpr size_t find_last_of(simple_string_view v, const size_t pos = npos) const noexcept {
if (size_ == 0) return npos;
const size_t last_pos = std::min(pos, size_-1);
for (size_t i = last_pos; i <= last_pos; --i)
for (char c : v)
if (data_[i] == c)
return i;
return npos;
}
constexpr size_t find_last_of(char c, size_t pos = npos) const noexcept { return find_last_of({&c, 1}, pos); }
constexpr size_t find_last_of(const char* c, size_t pos, size_t count) const { return find_last_of({c, count}, pos); }
size_t find_last_of(const char* c, size_t pos = npos) const { return find_last_of(simple_string_view(c), pos); }
constexpr size_t find_first_not_of(simple_string_view v, size_t pos = 0) const noexcept {
for (; pos < size_; ++pos) {
bool none = true;
for (char c : v) {
if (data_[pos] == c) {
none = false;
break;
}
}
if (none) return pos;
}
return npos;
}
constexpr size_t find_first_not_of(char c, size_t pos = 0) const noexcept { return find_first_not_of({&c, 1}, pos); }
constexpr size_t find_first_not_of(const char* c, size_t pos, size_t count) const { return find_first_not_of({c, count}, pos); }
size_t find_first_not_of(const char* c, size_t pos = 0) const { return find_first_not_of(simple_string_view(c), pos); }
constexpr size_t find_last_not_of(simple_string_view v, const size_t pos = npos) const noexcept {
if (size_ == 0) return npos;
const size_t last_pos = std::min(pos, size_-1);
for (size_t i = last_pos; i <= last_pos; --i) {
bool none = true;
for (char c : v) {
if (data_[i] == c) {
none = false;
break;
}
}
if (none) return i;
}
return npos;
}
constexpr size_t find_last_not_of(char c, size_t pos = npos) const noexcept { return find_last_not_of({&c, 1}, pos); }
constexpr size_t find_last_not_of(const char* c, size_t pos, size_t count) const { return find_last_not_of({c, count}, pos); }
size_t find_last_not_of(const char* c, size_t pos = npos) const { return find_last_not_of(simple_string_view(c), pos); }
};
inline bool operator==(simple_string_view lhs, simple_string_view rhs) {
return lhs.size() == rhs.size() && 0 == std::char_traits<char>::compare(lhs.data(), rhs.data(), lhs.size());
};
inline bool operator!=(simple_string_view lhs, simple_string_view rhs) {
return !(lhs == rhs);
}
inline int simple_string_view::compare(simple_string_view s) const {
int cmp = std::char_traits<char>::compare(data_, s.data(), std::min(size_, s.size()));
if (cmp) return cmp;
if (size_ < s.size()) return -1;
else if (size_ > s.size()) return 1;
return 0;
}
inline bool operator<(simple_string_view lhs, simple_string_view rhs) {
return lhs.compare(rhs) < 0;
};
inline bool operator<=(simple_string_view lhs, simple_string_view rhs) {
return lhs.compare(rhs) <= 0;
};
inline bool operator>(simple_string_view lhs, simple_string_view rhs) {
return lhs.compare(rhs) > 0;
};
inline bool operator>=(simple_string_view lhs, simple_string_view rhs) {
return lhs.compare(rhs) >= 0;
};
inline std::ostream& operator<<(std::ostream& os, const simple_string_view& s) {
os.write(s.data(), s.size());
return os;
}
using string_view = simple_string_view;
}
#endif
// Add a "foo"_sv literal that works exactly like the C++17 "foo"sv literal, but works with out
// implementation in pre-C++17.
namespace lokimq {
inline namespace literals {
inline string_view operator""_sv(const char* str, size_t len) { return {str, len}; }
}
}

5
lokimq/version.h.in Normal file
View file

@ -0,0 +1,5 @@
namespace lokimq {
constexpr int VERSION_MAJOR = @LOKIMQ_VERSION_MAJOR@;
constexpr int VERSION_MINOR = @LOKIMQ_VERSION_MINOR@;
constexpr int VERSION_PATCH = @LOKIMQ_VERSION_PATCH@;
}

292
lokimq/worker.cpp Normal file
View file

@ -0,0 +1,292 @@
#include "lokimq.h"
#include "batch.h"
#include "lokimq-internal.h"
namespace lokimq {
void LokiMQ::worker_thread(unsigned int index) {
std::string worker_id = "w" + std::to_string(index);
zmq::socket_t sock{context, zmq::socket_type::dealer};
sock.setsockopt(ZMQ_ROUTING_ID, worker_id.data(), worker_id.size());
LMQ_LOG(debug, "New worker thread ", worker_id, " started");
sock.connect(SN_ADDR_WORKERS);
Message message{*this, 0};
std::vector<zmq::message_t> parts;
run_info& run = workers[index]; // This contains our first job, and will be updated later with subsequent jobs
while (true) {
try {
if (run.is_batch_job) {
if (run.batch_jobno >= 0) {
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, "#", run.batch_jobno);
run.batch->run_job(run.batch_jobno);
} else if (run.batch_jobno == -1) {
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, " completion");
run.batch->job_completion();
}
} else {
message.conn = run.conn;
message.data.clear();
LMQ_TRACE("Got incoming command from ", message.conn, message.conn.route.empty() ? "(outgoing)" : "(incoming)");
if (run.callback->second /*is_request*/) {
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
message.data.emplace_back(it->data<char>(), it->size());
} else {
for (auto& m : run.data_parts)
message.data.emplace_back(m.data<char>(), m.size());
}
LMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
run.callback->first(message);
}
}
catch (const bt_deserialize_invalid& e) {
LMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
}
catch (const mapbox::util::bad_variant_access& e) {
LMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
}
catch (const std::out_of_range& e) {
LMQ_LOG(warn, worker_id, " deserialization failed: invalid data - required field missing (", e.what(), "); ignoring request");
}
catch (const std::exception& e) {
LMQ_LOG(warn, worker_id, " caught exception when processing command: ", e.what());
}
catch (...) {
LMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
}
while (true) {
// Signal that we are ready for another job and wait for it. (We do this down here
// because our first job gets set up when the thread is started).
detail::send_control(sock, "RAN");
LMQ_TRACE("worker ", worker_id, " waiting for requests");
parts.clear();
recv_message_parts(sock, parts);
if (parts.size() != 1) {
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part worker instruction");
continue;
}
auto command = view(parts[0]);
if (command == "RUN") {
LMQ_LOG(debug, "worker ", worker_id, " running command ", run.command);
break; // proxy has set up a command for us, go back and run it.
} else if (command == "QUIT") {
LMQ_LOG(debug, "worker ", worker_id, " shutting down");
detail::send_control(sock, "QUITTING");
sock.setsockopt<int>(ZMQ_LINGER, 1000);
sock.close();
return;
} else {
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
}
}
}
}
LokiMQ::run_info& LokiMQ::get_idle_worker() {
if (idle_workers.empty()) {
size_t id = workers.size();
assert(workers.capacity() > id);
workers.emplace_back();
auto& r = workers.back();
r.worker_id = id;
r.worker_routing_id = "w" + std::to_string(id);
return r;
}
size_t id = idle_workers.back();
idle_workers.pop_back();
return workers[id];
}
void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
// Process messages sent by workers
if (parts.size() != 2) {
LMQ_LOG(error, "Received send invalid ", parts.size(), "-part message");
return;
}
auto route = view(parts[0]), cmd = view(parts[1]);
LMQ_TRACE("worker message from ", route);
assert(route.size() >= 2 && route[0] == 'w' && route[1] >= '0' && route[1] <= '9');
string_view worker_id_str{&route[1], route.size()-1}; // Chop off the leading "w"
unsigned int worker_id = detail::extract_unsigned(worker_id_str);
if (!worker_id_str.empty() /* didn't consume everything */ || worker_id >= workers.size()) {
LMQ_LOG(error, "Worker id '", route, "' is invalid, unable to process worker command");
return;
}
auto& run = workers[worker_id];
LMQ_TRACE("received ", cmd, " command from ", route);
if (cmd == "RAN") {
LMQ_LOG(debug, "Worker ", route, " finished ", run.command);
if (run.is_batch_job) {
auto& jobs = run.is_reply_job ? reply_jobs : batch_jobs;
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
assert(active > 0);
active--;
bool clear_job = false;
if (run.batch_jobno == -1) {
// Returned from the completion function
clear_job = true;
} else {
auto status = run.batch->job_finished();
if (status == detail::BatchStatus::complete) {
jobs.emplace(run.batch, -1);
} else if (status == detail::BatchStatus::complete_proxy) {
try {
run.batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
} catch (const std::exception &e) {
// Raise these to error levels: the caller really shouldn't be doing
// anything non-trivial in an in-proxy completion function!
LMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
} catch (...) {
LMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
}
clear_job = true;
} else if (status == detail::BatchStatus::done) {
clear_job = true;
}
}
if (clear_job) {
batches.erase(run.batch);
delete run.batch;
run.batch = nullptr;
}
} else {
assert(run.cat->active_threads > 0);
run.cat->active_threads--;
}
if (max_workers == 0) { // Shutting down
LMQ_TRACE("Telling worker ", route, " to quit");
route_control(workers_socket, route, "QUIT");
} else {
idle_workers.push_back(worker_id);
}
} else if (cmd == "QUITTING") {
workers[worker_id].worker_thread.join();
LMQ_LOG(debug, "Worker ", route, " exited normally");
} else {
LMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
}
}
void LokiMQ::proxy_run_worker(run_info& run) {
if (!run.worker_thread.joinable())
run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id};
else
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
}
void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& parts) {
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
peer_info tmp_peer;
tmp_peer.conn_index = conn_index;
if (!outgoing) tmp_peer.route = parts[0].to_string();
peer_info* peer = nullptr;
if (outgoing) {
auto it = peers.find(conn_index_to_id[conn_index]);
if (it == peers.end()) {
LMQ_LOG(warn, "Internal error: connection index ", conn_index, " not found");
return;
}
peer = &it->second;
} else {
std::tie(tmp_peer.pubkey, tmp_peer.service_node, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
if (tmp_peer.service_node) {
// It's a service node so we should have a peer_info entry; see if we can find one with
// the same route, and if not, add one.
auto pr = peers.equal_range(tmp_peer.pubkey);
for (auto it = pr.first; it != pr.second; ++it) {
if (it->second.route == tmp_peer.route) {
peer = &it->second;
// Upgrade permissions in case we have something higher on the socket
peer->service_node |= tmp_peer.service_node;
if (tmp_peer.auth_level > peer->auth_level)
peer->auth_level = tmp_peer.auth_level;
break;
}
}
if (!peer) {
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
}
} else {
// Incoming, non-SN connection: we don't store a peer_info for this, so just use the
// temporary one
peer = &tmp_peer;
}
}
size_t command_part_index = outgoing ? 0 : 1;
std::string command = parts[command_part_index].to_string();
auto cat_call = get_command(command);
if (!cat_call.first) {
if (outgoing)
send_direct_message(connections[conn_index], "UNKNOWNCOMMAND", command);
else
send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND", command);
return;
}
auto& category = *cat_call.first;
if (!proxy_check_auth(conn_index, outgoing, *peer, command, category, parts.back()))
return;
// Steal any data message parts
size_t data_part_index = command_part_index + 1;
std::vector<zmq::message_t> data_parts;
data_parts.reserve(parts.size() - data_part_index);
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
data_parts.push_back(std::move(*it));
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
// No free reserved or general spots, try to queue it for later
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
LMQ_LOG(warn, "No space to queue incoming command ", command, "; already have ", category.queued,
"commands queued in that category (max ", category.max_queue, "); dropping message");
return;
}
LMQ_LOG(debug, "No available free workers, queuing ", command, " for later");
ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey, std::move(tmp_peer.route)};
pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second, std::move(conn));
category.queued++;
return;
}
if (cat_call.second->second /*is_request*/ && data_parts.empty()) {
LMQ_LOG(warn, "Received an invalid request command with no reply tag; dropping message");
return;
}
auto& run = get_idle_worker();
{
ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey};
c.route = std::move(tmp_peer.route);
if (outgoing || peer->service_node)
tmp_peer.route.clear();
run.load(&category, std::move(command), std::move(c), std::move(data_parts), cat_call.second);
}
if (outgoing)
peer->activity(); // outgoing connection activity, pump the activity timer
LMQ_TRACE("Forwarding incoming ", run.command, " from ", run.conn, " @ ", peer_address(parts[command_part_index]),
" to worker ", run.worker_routing_id);
proxy_run_worker(run);
category.active_threads++;
}
}

1
mapbox-variant Submodule

@ -0,0 +1 @@
Subproject commit c94634bbd294204c9ba3f5b267a39582a52e8e5a

@ -1 +0,0 @@
Subproject commit d6f300d7d250ae0a9708090c0011c0f495377e6a

View file

@ -1,351 +0,0 @@
#include "address.h"
#include <tuple>
#include <limits>
#include <cstddef>
#include <utility>
#include <stdexcept>
#include <ostream>
#include <oxenc/hex.h>
#include <oxenc/base32z.h>
#include <oxenc/base64.h>
namespace oxenmq {
constexpr size_t enc_length(address::encoding enc) {
return enc == address::encoding::hex ? 64 :
enc == address::encoding::base64 ? 43 : // this can be 44 with a padding byte, but we don't need it
52 /*base32z*/;
};
// Parses an encoding pubkey from the given string_view. Advanced the string_view beyond the
// consumed pubkey data, and returns the pubkey (as a 32-byte string). Throws if no valid pubkey
// was found at the beginning of addr. We look for hex, base32z, or base64 pubkeys *unless* qr is
// given: for QR-friendly we only accept hex or base32z (since QR cannot handle base64's alphabet).
std::string decode_pubkey(std::string_view& in, bool qr) {
std::string pubkey;
if (in.size() >= 64 && oxenc::is_hex(in.substr(0, 64))) {
pubkey = oxenc::from_hex(in.substr(0, 64));
in.remove_prefix(64);
} else if (in.size() >= 52 && oxenc::is_base32z(in.substr(0, 52))) {
pubkey = oxenc::from_base32z(in.substr(0, 52));
in.remove_prefix(52);
} else if (!qr && in.size() >= 43 && oxenc::is_base64(in.substr(0, 43))) {
pubkey = oxenc::from_base64(in.substr(0, 43));
in.remove_prefix(43);
if (!in.empty() && in.front() == '=')
in.remove_prefix(1); // allow (and eat) a padding byte at the end
} else {
throw std::invalid_argument{"No pubkey found"};
}
return pubkey;
}
// Parse the host, port, and optionally pubkey from a string view, mutating it to remove the parsed
// sections. qr should be true if we should accept $IPv6$ as a QR-encoding-friendly alternative to
// [IPv6] (the returned host will have the $ replaced, i.e. [IPv6]).
std::tuple<std::string, uint16_t, std::string> parse_tcp(std::string_view& addr, bool qr, bool expect_pubkey) {
std::tuple<std::string, uint16_t, std::string> result;
auto& host = std::get<0>(result);
if (addr.front() == '[' || (qr && addr.front() == '$')) { // IPv6 addr (though this is far from complete validation)
auto pos = addr.find_first_not_of(":.1234567890abcdefABCDEF", 1);
if (pos == std::string_view::npos)
throw std::invalid_argument("Could not find terminating ] while parsing an IPv6 address");
if (!(addr[pos] == ']' || (qr && addr[pos] == '$')))
throw std::invalid_argument{"Expected " + (qr ? "$"s : "]"s) + " to close IPv6 address but found " + std::string(1, addr[pos])};
host = std::string{addr.substr(0, pos+1)};
if (qr) {
if (host.front() == '$')
host.front() = '[';
if (host.back() == '$')
host.back() = ']';
}
addr.remove_prefix(pos+1);
} else {
auto port_pos = addr.find(':');
if (port_pos == std::string_view::npos)
throw std::invalid_argument{"Could not determine host (no following ':port' found)"};
if (port_pos == 0)
throw std::invalid_argument{"Host cannot be empty"};
host = std::string{addr.substr(0, port_pos)};
addr.remove_prefix(port_pos);
}
if (qr)
// Lower-case the host because upper case hostnames are ugly
for (char& c : host)
if (c >= 'A' && c <= 'Z')
c = c - 'A' + 'a';
if (addr.size() < 2 || addr[0] != ':')
throw std::invalid_argument{"Could not find :port in address string"};
addr.remove_prefix(1);
auto pos = addr.find_first_not_of("1234567890");
if (pos == 0)
throw std::invalid_argument{"Could not find numeric port in address string"};
if (pos == std::string_view::npos)
pos = addr.size();
size_t processed;
int port_int = std::stoi(std::string{addr.substr(0, pos)}, &processed);
if (port_int == 0 || processed != pos)
throw std::invalid_argument{"Could not parse numeric port in address string"};
if (port_int < 0 || port_int > std::numeric_limits<uint16_t>::max())
throw std::invalid_argument{"Invalid port: port must be in range 1-65535"};
std::get<1>(result) = static_cast<uint16_t>(port_int);
addr.remove_prefix(pos);
if (expect_pubkey) {
if (addr.size() < 1 + enc_length(qr ? address::encoding::base32z : address::encoding::base64)
|| addr.front() != '/')
throw std::invalid_argument{"Invalid address: expected /PUBKEY after port"};
addr.remove_prefix(1);
std::get<2>(result) = decode_pubkey(addr, qr);
if (!addr.empty())
throw std::invalid_argument{"Invalid address: found unexpected trailing data after pubkey"};
} else if (!addr.empty()) {
throw std::invalid_argument{"Invalid address: found unexpected trailing data after port"};
}
return result;
}
// Parse the socket path and (possibly) pubkey, mutating it to remove the parsed sections.
// Currently the /pubkey *must* be at the end of the string, but this might not always be the case
// (e.g. we could in the future support query string-like arguments).
std::pair<std::string, std::string> parse_unix(std::string_view& addr, bool expect_pubkey) {
std::pair<std::string, std::string> result;
if (expect_pubkey) {
size_t b64_len = addr.size() > 0 && addr.back() == '=' ? 44 : 43;
if (addr.size() > 64 && addr[addr.size() - 65] == '/' && oxenc::is_hex(addr.substr(addr.size() - 64))) {
result.first = std::string{addr.substr(0, addr.size() - 65)};
result.second = oxenc::from_hex(addr.substr(addr.size() - 64));
} else if (addr.size() > 52 && addr[addr.size() - 53] == '/' && oxenc::is_base32z(addr.substr(addr.size() - 52))) {
result.first = std::string{addr.substr(0, addr.size() - 53)};
result.second = oxenc::from_base32z(addr.substr(addr.size() - 52));
} else if (addr.size() > b64_len && addr[addr.size() - b64_len - 1] == '/' && oxenc::is_base64(addr.substr(addr.size() - b64_len))) {
result.first = std::string{addr.substr(0, addr.size() - b64_len - 1)};
result.second = oxenc::from_base64(addr.substr(addr.size() - b64_len));
} else {
throw std::invalid_argument{"icp+curve:// requires a trailing /PUBKEY value, got: " + std::string{addr}};
}
} else {
// Anything goes
result.first = std::string{addr};
}
// Any path above consumes everything:
addr.remove_prefix(addr.size());
return result;
}
address::address(std::string_view addr) {
auto protoend = addr.find("://"sv);
if (protoend == std::string_view::npos || protoend == 0)
throw std::invalid_argument("Invalid address: no protocol found");
auto pro = addr.substr(0, protoend);
addr.remove_prefix(protoend + 3);
if (addr.empty())
throw std::invalid_argument("Invalid address: no value specified after protocol");
bool qr = false;
if (pro == "tcp") protocol = proto::tcp;
else if (pro == "tcp+curve" || pro == "curve") protocol = proto::tcp_curve;
else if (pro == "ipc") protocol = proto::ipc;
else if (pro == "ipc+curve") protocol = proto::ipc_curve;
else if (pro == "TCP") {
protocol = proto::tcp;
qr = true;
} else if (pro == "CURVE") {
protocol = proto::tcp_curve;
qr = true;
} else {
throw std::invalid_argument("Invalid protocol '" + std::string{pro} + "'");
}
if (qr) {
// The QR variations only allow QR-alphanumeric characters (upper-case letters, numbers, and
// a few symbols):
for (char c : addr) {
// QR alphanumeric also allows space, %, *, +, but we don't need or allow any of those here.
if (!((c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '$' || c == ':' || c == '/' || c == '.' || c == '-'))
throw std::invalid_argument("Found non-QR-alphanumeric value in QR TCP:// or CURVE:// address");
}
}
if (tcp())
std::tie(host, port, pubkey) = parse_tcp(addr, qr, curve());
else
std::tie(socket, pubkey) = parse_unix(addr, curve());
if (!addr.empty())
throw std::invalid_argument{"Invalid trailing garbage '" + std::string{addr} + "' in address"};
}
address& address::set_pubkey(std::string_view pk) {
if (pk.size() == 0) {
if (protocol == proto::tcp_curve) protocol = proto::tcp;
else if (protocol == proto::ipc_curve) protocol = proto::ipc;
} else if (pk.size() == 32) {
if (protocol == proto::tcp) protocol = proto::tcp_curve;
else if (protocol == proto::ipc) protocol = proto::ipc_curve;
} else {
throw std::invalid_argument{"Invalid pubkey passed to set_pubkey(): require 0- or 32-byte pubkey"};
}
pubkey = pk;
return *this;
}
std::string address::encode_pubkey(encoding enc) const {
std::string pk;
if (enc == encoding::hex)
pk = oxenc::to_hex(pubkey);
else if (enc == encoding::base32z)
pk = oxenc::to_base32z(pubkey);
else if (enc == encoding::BASE32Z) {
pk = oxenc::to_base32z(pubkey);
for (char& c : pk)
if (c >= 'a' && c <= 'z')
c = c - 'a' + 'A';
} else if (enc == encoding::base64) {
pk = oxenc::to_base64(pubkey);
if (pk.size() == 44 && pk.back() == '=')
pk.resize(43);
} else {
throw std::logic_error{"Invalid encoding"};
}
return pk;
}
std::string address::full_address(encoding enc) const {
std::string result;
std::string pk;
if (curve())
pk = encode_pubkey(enc);
if (protocol == proto::tcp) {
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
result += "tcp://";
result += host;
result += ':';
result += std::to_string(port);
} else if (protocol == proto::tcp_curve) {
result.reserve(8 /*curve:// */ + host.size() + 6 /*:port*/ + 1 /* / */ + pk.size());
result += "curve://";
result += host;
result += ':';
result += std::to_string(port);
result += '/';
result += pk;
} else if (protocol == proto::ipc) {
result.reserve(6 /*ipc:// */ + socket.size());
result += "ipc://";
result += socket;
} else if (protocol == proto::ipc_curve) {
result.reserve(12 /*ipc+curve:// */ + socket.size() + 1 /* / */ + pk.size());
result += "ipc+curve://";
result += socket;
result += '/';
result += pk;
} else {
throw std::logic_error{"Invalid protocol"};
}
return result;
}
std::string address::zmq_address() const {
std::string result;
if (tcp()) {
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
result += "tcp://";
result += host;
result += ':';
result += std::to_string(port);
} else {
result.reserve(6 /*ipc:// */ + socket.size());
result += "ipc://";
result += socket;
}
return result;
}
std::string address::qr_address() const {
if (protocol != proto::tcp && protocol != proto::tcp_curve)
throw std::logic_error("Cannot construct a QR-friendly address for a non-TCP address");
if (host.empty())
throw std::logic_error("Cannot construct a QR-friendly address with an empty TCP host");
std::string result;
result.reserve((curve() ? 8 /*CURVE:// */ : 6 /*TCP:// */) + host.size() + 6 /*:port*/ +
(curve() ? 1 + enc_length(encoding::BASE32Z) : 0));
result += curve() ? "CURVE://" : "TCP://";
std::string uc_host = host;
for (auto& c : uc_host)
if (c >= 'a' && c <= 'z')
c = c - 'a' + 'A';
if (uc_host.front() == '[' && uc_host.back() == ']') {
uc_host.front() = '$';
uc_host.back() = '$';
}
result += uc_host;
result += ':';
result += std::to_string(port);
if (curve()) {
result += '/';
result += encode_pubkey(encoding::BASE32Z);
}
return result;
}
bool address::operator==(const address& other) const {
if (protocol != other.protocol)
return false;
if (tcp())
if (host != other.host || port != other.port)
return false;
if (ipc())
if (socket != other.socket)
return false;
if (curve())
if (pubkey != other.pubkey)
return false;
return true;
}
address address::tcp(std::string host, uint16_t port) {
address a;
a.protocol = proto::tcp;
a.host = std::move(host);
a.port = port;
return a;
}
address address::tcp_curve(std::string host, uint16_t port, std::string pubkey) {
address a;
a.protocol = proto::tcp_curve;
a.host = std::move(host);
a.port = port;
a.pubkey = std::move(pubkey);
return a;
}
address address::ipc(std::string path) {
address a;
a.protocol = proto::ipc;
a.socket = std::move(path);
return a;
}
address address::ipc_curve(std::string path, std::string pubkey) {
address a;
a.protocol = proto::ipc_curve;
a.socket = std::move(path);
a.pubkey = std::move(pubkey);
return a;
}
std::ostream& operator<<(std::ostream& o, const address& a) { return o << a.full_address(); }
}

View file

@ -1,218 +0,0 @@
// Copyright (c) 2020-2021, The Oxen Project
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification, are
// permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of
// conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
// of conditions and the following disclaimer in the documentation and/or other
// materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be
// used to endorse or promote products derived from this software without specific
// prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include <string_view>
#include <cstdint>
#include <iosfwd>
#include <functional>
namespace oxenmq {
using namespace std::literals;
/** OxenMQ address abstraction class. This class uses and extends standard ZMQ addresses allowing
* extra parameters to be passed in in a relative standard way.
*
* External ZMQ addresses generally have two forms that we are concerned with: one for TCP and one
* for Unix sockets:
*
* tcp://HOST:PORT -- HOST can be a hostname, IPv4 address, or IPv6 address in [...]
* ipc://PATH -- PATH can be absolute (ipc:///path/to/some.sock) or relative (ipc://some.sock)
*
* but this doesn't carry enough info: in particular, we can connect with two very different
* protocols: curve25519-encrypted, or plaintext, but for curve25519-encrypted we require the
* remote's public key as well to verify the connection.
*
* This class, then, handles this by allowing addresses of:
*
* Standard ZMQ address: these carry no pubkey and so the connection will be unencrypted:
*
* tcp://HOSTNAME:PORT
* ipc://PATH
*
* Non-ZMQ address formats that specify that the connection shall be x25519 encrypted:
*
* curve://HOSTNAME:PORT/PUBKEY -- PUBKEY must be specified in hex (64 characters), base32z (52)
* or base64 (43 or 44 with one '=' trailing padding)
* ipc+curve:///path/to/my.sock/PUBKEY -- same requirements on PUBKEY as above.
* tcp+curve://(whatever) -- alias for curve://(whatever)
*
* We also accept special upper-case TCP-only variants which *only* accept uppercase characters and
* a few required symbols (:, /, $, ., and -) in the string:
*
* TCP://HOSTNAME:PORT
* CURVE://HOSTNAME:PORT/B32ZPUBKEY
*
* These versions are explicitly meant to be used with QR codes; the upper-case-only requirement
* allows a smaller QR code by allowing QR's alphanumeric mode (which allows only [A-Z0-9 $%*+./:-])
* to be used. Such a QR-friendly address can be created from the qr_address() method. To support
* literal IPv6 addresses we surround the address with $...$ instead of the usual [...].
*
* Note that this class does very little validate the host argument at all, and no socket path
* validation whatsoever. The only constraint on host is when parsing an encoded address: we check
* that it contains no : at all, or must be a [bracketed] expression that contains only hex
* characters, :'s, or .'s. Otherwise, if you pass broken crap into the hostname, expect broken
* crap out.
*/
struct address {
/// Supported address protocols: TCP connections (tcp), or unix sockets (ipc).
enum class proto {
tcp,
tcp_curve,
ipc,
ipc_curve
};
/// Supported public key encodings (used when regenerating an augmented address).
enum class encoding {
hex, ///< hexadecimal encoded
base32z, ///< base32z encoded
base64, ///< base64 encoded (*without* trailing = padding)
BASE32Z ///< upper-case base32z encoding, meant for QR encoding
};
/// The protocol: one of the `protocol` enum values for tcp or ipc (unix sockets), with or
/// without _curve encryption.
proto protocol = proto::tcp;
/// The host for tcp connections; can be a hostname or IP address. If this is an IPv6 it must be surrounded with [ ].
std::string host;
/// The port (for tcp connections)
uint16_t port = 0;
/// The socket path (for unix socket connections)
std::string socket;
/// If a curve connection, this is the required remote public key (in bytes)
std::string pubkey;
/// Default constructor; this gives you an unusable address.
address() = default;
/**
* Constructs an address by parsing a string_view containing one of the formats listed in the
* class description. This is intentionally implicitly constructible so that you can pass a
* string_view into anything expecting an `address`.
*
* Throw std::invalid_argument if the given address is not parseable.
*/
address(std::string_view addr);
/** Constructs an address from a remote string and a separate pubkey. Typically `remote` is a
* basic ZMQ connect string, though this is not enforced. Any pubkey information embedded in
* the remote string will be discarded and replaced with the given pubkey string. The result
* will be curve encrypted if `pubkey` is non-empty, plaintext if `pubkey` is empty.
*
* Throws an exception if either addr or pubkey is invalid.
*
* Exactly equivalent to `address a{remote}; a.set_pubkey(pubkey);`
*/
address(std::string_view addr, std::string_view pubkey) : address(addr) { set_pubkey(pubkey); }
/// Replaces the address's pubkey (if any) with the given pubkey (or no pubkey if empty). If
/// changing from pubkey to no-pubkey or no-pubkey to pubkey then the protocol is update to
/// switch to or from curve encryption.
///
/// pubkey should be the 32-byte binary pubkey, or an empty string to remove an existing pubkey.
///
/// Returns the object itself, so that you can chain it.
address& set_pubkey(std::string_view pubkey);
/// Constructs and builds the ZMQ connection address from the stored connection details. This
/// does not contain any of the curve-related details; those must be specified separately when
/// interfacing with ZMQ.
std::string zmq_address() const;
/// Returns true if the connection was specified as a curve-encryption-enabled connection, false
/// otherwise.
bool curve() const { return protocol == proto::tcp_curve || protocol == proto::ipc_curve; }
/// True if the protocol is TCP (either with or without curve)
bool tcp() const { return protocol == proto::tcp || protocol == proto::tcp_curve; }
/// True if the protocol is unix socket (either with or without curve)
bool ipc() const { return !tcp(); }
/// Returns the full "augmented" address string (i.e. that could be passed in to the
/// constructor). This will be equivalent (but not necessarily identical) to an augmented
/// string passed into the constructor. Takes an optional encoding format for the pubkey (if
/// any), which defaults to base32z.
std::string full_address(encoding enc = encoding::base32z) const;
/// Returns a QR-code friendly address string. This returns an all-uppercase version of the
/// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z
/// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we surround the
/// address with $ instead of [...]
///
/// \throws std::logic_error if called on a unix socket address.
std::string qr_address() const;
/// Returns `.pubkey` but encoded in the given format
std::string encode_pubkey(encoding enc) const;
/// Returns true if two addresses are identical (i.e. same protocol and relevant protocol
/// arguments).
///
/// Note that it is possible for addresses to connect to the same socket without being
/// identical: for example, using "foo.sock" and "./foo.sock", or writing IPv6 addresses (or
/// even IPv4 addresses) in slightly different ways). Such equivalent but non-equal values will
/// result in a false return here.
///
/// Note also that we ignore irrelevant arguments: for example, we don't care whether pubkeys
/// match when comparing two non-curve TCP addresses.
bool operator==(const address& other) const;
/// Negation of ==
bool operator!=(const address& other) const { return !operator==(other); }
/// Factory function that constructs a TCP address from a host and port. The connection will be
/// plaintext. If the host is an IPv6 address it *must* be surrounded with [ and ].
static address tcp(std::string host, uint16_t port);
/// Factory function that constructs a curve-encrypted TCP address from a host, port, and remote
/// pubkey. The pubkey must be 32 bytes. As above, IPv6 addresses must be specified as [addr].
static address tcp_curve(std::string host, uint16_t, std::string pubkey);
/// Factory function that constructs a unix socket address from a path. The connection will be
/// plaintext (which is usually fine for a socket since unix sockets are local machine).
static address ipc(std::string path);
/// Factory function that constructs a unix socket address from a path and remote pubkey. The
/// connection will be curve25519 encrypted; the remote pubkey must be 32 bytes.
static address ipc_curve(std::string path, std::string pubkey);
};
// Outputs address.full_address() when sent to an ostream.
std::ostream& operator<<(std::ostream& o, const address& a);
} // namespace oxenmq
namespace std {
template<> struct hash<oxenmq::address> {
std::size_t operator()(const oxenmq::address& a) const noexcept {
return std::hash<std::string>{}(a.full_address(oxenmq::address::encoding::hex));
}
};
} // namespace std

View file

@ -1,318 +0,0 @@
#include "oxenmq.h"
#include <oxenc/hex.h>
#include "oxenmq-internal.h"
#include <ostream>
#include <sstream>
namespace oxenmq {
std::ostream& operator<<(std::ostream& o, AuthLevel a) {
return o << to_string(a);
}
namespace {
// Builds a ZMTP metadata key-value pair. These will be available on every message from that peer.
// Keys must start with X- and be <= 255 characters.
std::string zmtp_metadata(std::string_view key, std::string_view value) {
assert(key.size() > 2 && key.size() <= 255 && key[0] == 'X' && key[1] == '-');
std::string result;
result.reserve(1 + key.size() + 4 + value.size());
result += static_cast<char>(key.size()); // Size octet of key
result.append(&key[0], key.size()); // key data
for (int i = 24; i >= 0; i -= 8) // 4-byte size of value in network order
result += static_cast<char>((value.size() >> i) & 0xff);
result.append(&value[0], value.size()); // value data
return result;
}
}
bool OxenMQ::proxy_check_auth(int64_t conn_id, bool outgoing, const peer_info& peer,
zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) {
auto command = view(cmd);
std::string reply;
if (!cat_call.first) {
OMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd));
reply = "UNKNOWNCOMMAND";
} else if (peer.auth_level < cat_call.first->access.auth) {
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
": peer auth level ", peer.auth_level, " < ", cat_call.first->access.auth);
reply = "FORBIDDEN";
} else if (cat_call.first->access.local_sn && !local_service_node) {
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
": that command is only available when this OxenMQ is running in service node mode");
reply = "NOT_A_SERVICE_NODE";
} else if (cat_call.first->access.remote_sn && !peer.service_node) {
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
": remote is not recognized as a service node");
reply = "FORBIDDEN_SN";
} else if (cat_call.second->second /*is_request*/ && data.empty()) {
OMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [",
oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd));
reply = "NO_REPLY_TAG";
} else {
return true;
}
std::vector<zmq::message_t> msgs;
msgs.reserve(4);
if (!outgoing)
msgs.push_back(create_message(peer.route));
msgs.push_back(create_message(reply));
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
msgs.push_back(create_message("REPLY"sv));
msgs.push_back(create_message(view(data.front()))); // reply tag
} else {
msgs.push_back(create_message(view(cmd)));
}
try {
send_message_parts(connections.at(conn_id), msgs);
} catch (const zmq::error_t& err) {
/* can't send: possibly already disconnected. Ignore. */
OMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
}
return false;
}
void OxenMQ::set_active_sns(pubkey_set pubkeys) {
if (proxy_thread.joinable()) {
auto data = oxenc::bt_serialize(detail::serialize_object(std::move(pubkeys)));
detail::send_control(get_control_socket(), "SET_SNS", data);
} else {
proxy_set_active_sns(std::move(pubkeys));
}
}
void OxenMQ::proxy_set_active_sns(std::string_view data) {
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(oxenc::bt_deserialize<uintptr_t>(data)));
}
void OxenMQ::proxy_set_active_sns(pubkey_set pubkeys) {
pubkey_set added, removed;
for (auto it = pubkeys.begin(); it != pubkeys.end(); ) {
auto& pk = *it;
if (pk.size() != 32) {
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to set_active_sns");
it = pubkeys.erase(it);
continue;
}
if (!active_service_nodes.count(pk))
added.insert(std::move(pk));
++it;
}
if (added.empty() && active_service_nodes.size() == pubkeys.size()) {
OMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update");
return;
}
for (const auto& pk : active_service_nodes) {
if (!pubkeys.count(pk))
removed.insert(pk);
if (active_service_nodes.size() + added.size() - removed.size() == pubkeys.size())
break;
}
proxy_update_active_sns_clean(std::move(added), std::move(removed));
}
void OxenMQ::update_active_sns(pubkey_set added, pubkey_set removed) {
if (proxy_thread.joinable()) {
std::array<uintptr_t, 2> data;
data[0] = detail::serialize_object(std::move(added));
data[1] = detail::serialize_object(std::move(removed));
detail::send_control(get_control_socket(), "UPDATE_SNS", oxenc::bt_serialize(data));
} else {
proxy_update_active_sns(std::move(added), std::move(removed));
}
}
void OxenMQ::proxy_update_active_sns(oxenc::bt_list_consumer data) {
auto added = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
auto remed = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
proxy_update_active_sns(std::move(added), std::move(remed));
}
void OxenMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
// We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting
// values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists
// to the _clean version.
for (auto it = removed.begin(); it != removed.end(); ) {
const auto& pk = *it;
if (pk.size() != 32) {
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to update_active_sns (removed)");
it = removed.erase(it);
} else if (!active_service_nodes.count(pk) || added.count(pk) /* added wins if in both */) {
it = removed.erase(it);
} else {
++it;
}
}
for (auto it = added.begin(); it != added.end(); ) {
const auto& pk = *it;
if (pk.size() != 32) {
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to update_active_sns (added)");
it = added.erase(it);
} else if (active_service_nodes.count(pk)) {
it = added.erase(it);
} else {
++it;
}
}
proxy_update_active_sns_clean(std::move(added), std::move(removed));
}
void OxenMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) {
OMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys");
// For anything we remove we want close the connection to the SN (if outgoing), and remove the
// stored peer_info (incoming or outgoing).
for (const auto& pk : removed) {
ConnectionID c{pk};
active_service_nodes.erase(pk);
auto range = peers.equal_range(c);
for (auto it = range.first; it != range.second; ) {
bool outgoing = it->second.outgoing();
auto conn_id = it->second.conn_id;
it = peers.erase(it);
if (outgoing) {
OMQ_LOG(debug, "Closing outgoing connection to ", c);
proxy_close_connection(conn_id, CLOSE_LINGER);
}
}
}
// For pubkeys we add there's nothing special to be done beyond adding them to the pubkey set
for (auto& pk : added)
active_service_nodes.insert(std::move(pk));
}
void OxenMQ::process_zap_requests() {
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
#ifndef NDEBUG
if (log_level() >= LogLevel::trace) {
std::ostringstream o;
o << "Processing ZAP authentication request:";
for (size_t i = 0; i < frames.size(); i++) {
o << "\n[" << i << "]: ";
auto v = view(frames[i]);
if (i == 1 || i == 6)
o << oxenc::to_hex(v);
else
o << v;
}
log(LogLevel::trace, __FILE__, __LINE__, o.str());
} else
#endif
OMQ_LOG(debug, "Processing ZAP authentication request");
// https://rfc.zeromq.org/spec:27/ZAP/
//
// The request message SHALL consist of the following message frames:
//
// The version frame, which SHALL contain the three octets "1.0".
// The request id, which MAY contain an opaque binary blob.
// The domain, which SHALL contain a (non-empty) string.
// The address, the origin network IP address.
// The identity, the connection Identity, if any.
// The mechanism, which SHALL contain a string.
// The credentials, which SHALL be zero or more opaque frames.
//
// The reply message SHALL consist of the following message frames:
//
// The version frame, which SHALL contain the three octets "1.0".
// The request id, which MAY contain an opaque binary blob.
// The status code, which SHALL contain a string.
// The status text, which MAY contain a string.
// The user id, which SHALL contain a string.
// The metadata, which MAY contain a blob.
//
// (NB: there are also null address delimiters at the beginning of each mentioned in the
// RFC, but those have already been removed through the use of a REP socket)
std::vector<std::string> response_vals(6);
response_vals[0] = "1.0"; // version
if (frames.size() >= 2)
response_vals[1] = std::string{view(frames[1])}; // unique identifier
std::string &status_code = response_vals[2], &status_text = response_vals[3];
if (frames.size() < 6 || view(frames[0]) != "1.0") {
OMQ_LOG(error, "Bad ZAP authentication request: version != 1.0 or invalid ZAP message parts");
status_code = "500";
status_text = "Internal error: invalid auth request";
} else {
auto auth_domain = view(frames[2]);
size_t bind_id = (size_t) -1;
try {
bind_id = oxenc::bt_deserialize<size_t>(view(frames[2]));
} catch (...) {}
if (bind_id >= bind.size()) {
OMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
status_code = "400";
status_text = "Unknown authentication domain: " + std::string{auth_domain};
} else if (bind[bind_id].curve
? !(frames.size() == 7 && view(frames[5]) == "CURVE")
: !(frames.size() == 6 && view(frames[5]) == "NULL")) {
OMQ_LOG(error, "Bad ZAP authentication request: invalid ",
bind[bind_id].curve ? "CURVE" : "NULL", " authentication request");
status_code = "500";
status_text = "Invalid authentication request mechanism";
} else if (bind[bind_id].curve && frames[6].size() != 32) {
OMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
status_code = "500";
status_text = "Invalid public key size for CURVE authentication";
} else {
auto ip = view(frames[3]);
// If we're in dual stack mode IPv4 address might be IPv4-mapped IPv6 address (e.g.
// ::ffff:192.168.0.1); if so, remove the prefix to get a proper IPv4 address:
if (ip.size() >= 14 && ip.substr(0, 7) == "::ffff:"sv && ip.find_last_not_of("0123456789."sv) == 6)
ip = ip.substr(7);
std::string_view pubkey;
bool sn = false;
if (bind[bind_id].curve) {
pubkey = view(frames[6]);
sn = active_service_nodes.count(std::string{pubkey});
}
auto auth = bind[bind_id].allow(ip, pubkey, sn);
auto& user_id = response_vals[4];
if (bind[bind_id].curve) {
user_id.reserve(64);
oxenc::to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
}
if (auth <= AuthLevel::denied || auth > AuthLevel::admin) {
OMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
" with initial auth level ", auth);
status_code = "400";
status_text = "Access denied";
user_id.clear();
} else {
OMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection with authentication level ", auth,
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
auto& metadata = response_vals[5];
metadata += zmtp_metadata("X-AuthLevel", to_string(auth));
status_code = "200";
status_text = "";
}
}
}
OMQ_TRACE("ZAP request result: ", status_code, " ", status_text);
std::vector<zmq::message_t> response;
response.reserve(response_vals.size());
for (auto &r : response_vals) response.push_back(create_message(std::move(r)));
send_message_parts(zap_auth, response.begin(), response.end());
}
}
}

View file

@ -1,73 +0,0 @@
#pragma once
#include <iosfwd>
#include <string>
#include <cstring>
#include <unordered_set>
namespace oxenmq {
/// Authentication levels for command categories and connections
enum class AuthLevel {
denied, ///< Not actually an auth level, but can be returned by the AllowFunc to deny an incoming connection.
none, ///< No authentication at all; any random incoming ZMQ connection can invoke this command.
basic, ///< Basic authentication commands require a login, or a node that is specifically configured to be a public node (e.g. for public RPC).
admin, ///< Advanced authentication commands require an admin user, either via explicit login or by implicit login from localhost. This typically protects administrative commands like shutting down, starting mining, or access sensitive data.
};
std::ostream& operator<<(std::ostream& os, AuthLevel a);
/// The access level for a command category
struct Access {
/// Minimum access level required
AuthLevel auth;
/// If true only remote SNs may call the category commands
bool remote_sn;
/// If true the category requires that the local node is a SN
bool local_sn;
/// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an
/// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both
/// remote and local sn set to false).
Access(AuthLevel auth = AuthLevel::none, bool remote_sn = false, bool local_sn = false)
: auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {}
};
/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey).
/// Falls back to std::hash<std::string> if given a string smaller than a size_t.
struct already_hashed {
size_t operator()(const std::string& s) const {
if (s.size() < sizeof(size_t))
return std::hash<std::string>{}(s);
size_t hash;
std::memcpy(&hash, &s[0], sizeof(hash));
return hash;
}
};
/// std::unordered_set specialization for specifying pubkeys (used, in particular, by
/// OxenMQ::set_active_sns and OxenMQ::update_active_sns); this is a std::string unordered_set that
/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the
/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like
/// pubkeys and a terrible hash choice for anything else).
using pubkey_set = std::unordered_set<std::string, already_hashed>;
inline constexpr std::string_view to_string(AuthLevel a) {
switch (a) {
case AuthLevel::denied: return "denied";
case AuthLevel::none: return "none";
case AuthLevel::basic: return "basic";
case AuthLevel::admin: return "admin";
default: return "(unknown)";
}
}
inline AuthLevel auth_from_string(std::string_view a) {
if (a == "none") return AuthLevel::none;
if (a == "basic") return AuthLevel::basic;
if (a == "admin") return AuthLevel::admin;
return AuthLevel::denied;
}
}

View file

@ -1,429 +0,0 @@
#include "oxenmq.h"
#include "oxenmq-internal.h"
#include <oxenc/hex.h>
#include <optional>
#ifdef OXENMQ_USE_EPOLL
extern "C" {
#include <sys/epoll.h>
#include <unistd.h>
}
#endif
namespace oxenmq {
std::ostream& operator<<(std::ostream& o, const ConnectionID& conn) {
return o << conn.to_string();
}
#ifdef OXENMQ_USE_EPOLL
void OxenMQ::rebuild_pollitems() {
if (epoll_fd != -1)
close(epoll_fd);
epoll_fd = epoll_create1(0);
struct epoll_event ev;
ev.events = EPOLLIN | EPOLLET;
ev.data.u64 = EPOLL_COMMAND_ID;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, command.get(zmq::sockopt::fd), &ev);
ev.data.u64 = EPOLL_WORKER_ID;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, workers_socket.get(zmq::sockopt::fd), &ev);
ev.data.u64 = EPOLL_ZAP_ID;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, zap_auth.get(zmq::sockopt::fd), &ev);
for (auto& [id, s] : connections) {
ev.data.u64 = id;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, s.get(zmq::sockopt::fd), &ev);
}
connections_updated = false;
}
#else // !OXENMQ_USE_EPOLL
namespace {
void add_pollitem(std::vector<zmq::pollitem_t>& pollitems, zmq::socket_t& sock) {
pollitems.emplace_back();
auto &p = pollitems.back();
p.socket = static_cast<void *>(sock);
p.fd = 0;
p.events = ZMQ_POLLIN;
}
} // anonymous namespace
void OxenMQ::rebuild_pollitems() {
pollitems.clear();
add_pollitem(pollitems, command);
add_pollitem(pollitems, workers_socket);
add_pollitem(pollitems, zap_auth);
for (auto& [id, s] : connections)
add_pollitem(pollitems, s);
connections_updated = false;
}
#endif // OXENMQ_USE_EPOLL
void OxenMQ::setup_external_socket(zmq::socket_t& socket) {
socket.set(zmq::sockopt::reconnect_ivl, (int) RECONNECT_INTERVAL.count());
socket.set(zmq::sockopt::reconnect_ivl_max, (int) RECONNECT_INTERVAL_MAX.count());
socket.set(zmq::sockopt::handshake_ivl, (int) HANDSHAKE_TIME.count());
socket.set(zmq::sockopt::maxmsgsize, MAX_MSG_SIZE);
if (IPV6)
socket.set(zmq::sockopt::ipv6, 1);
if (CONN_HEARTBEAT > 0s) {
socket.set(zmq::sockopt::heartbeat_ivl, (int) CONN_HEARTBEAT.count());
if (CONN_HEARTBEAT_TIMEOUT > 0s)
socket.set(zmq::sockopt::heartbeat_timeout, (int) CONN_HEARTBEAT_TIMEOUT.count());
}
}
void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey, bool use_ephemeral_routing_id) {
setup_external_socket(socket);
if (!remote_pubkey.empty()) {
socket.set(zmq::sockopt::curve_serverkey, remote_pubkey);
socket.set(zmq::sockopt::curve_publickey, pubkey);
socket.set(zmq::sockopt::curve_secretkey, privkey);
}
if (!use_ephemeral_routing_id) {
std::string routing_id;
routing_id.reserve(33);
routing_id += 'L'; // Prefix because routing id's starting with \0 are reserved by zmq (and our pubkey might start with \0)
routing_id.append(pubkey.begin(), pubkey.end());
socket.set(zmq::sockopt::routing_id, routing_id);
}
// else let ZMQ pick a random one
}
void OxenMQ::setup_incoming_socket(zmq::socket_t& listener, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index) {
setup_external_socket(listener);
listener.set(zmq::sockopt::zap_domain, oxenc::bt_serialize(bind_index));
if (curve) {
listener.set(zmq::sockopt::curve_server, true);
listener.set(zmq::sockopt::curve_publickey, pubkey);
listener.set(zmq::sockopt::curve_secretkey, privkey);
}
listener.set(zmq::sockopt::router_handover, true);
listener.set(zmq::sockopt::router_mandatory, true);
}
// Deprecated versions:
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) {
return connect_remote(address{remote}, std::move(on_connect), std::move(on_failure),
auth_level, connect_option::timeout{timeout});
}
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
ConnectFailure on_failure, std::string_view pubkey, AuthLevel auth_level,
std::chrono::milliseconds timeout) {
return connect_remote(address{remote}.set_pubkey(pubkey), std::move(on_connect),
std::move(on_failure), auth_level, connect_option::timeout{timeout});
}
void OxenMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
detail::send_control(get_control_socket(), "DISCONNECT", oxenc::bt_serialize<oxenc::bt_dict>({
{"conn_id", id.id},
{"linger_ms", linger.count()},
{"pubkey", id.pk},
}));
}
std::pair<zmq::socket_t *, std::string>
OxenMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, bool use_ephemeral_routing_id, std::chrono::milliseconds keep_alive) {
ConnectionID remote_cid{remote};
auto its = peers.equal_range(remote_cid);
peer_info* peer = nullptr;
for (auto it = its.first; it != its.second; ++it) {
if (incoming_only && it->second.route.empty())
continue; // outgoing connection but we were asked to only use incoming connections
if (outgoing_only && !it->second.route.empty())
continue;
peer = &it->second;
break;
}
if (peer) {
OMQ_TRACE("proxy asked to connect to ", oxenc::to_hex(remote), "; reusing existing connection");
if (peer->route.empty() /* == outgoing*/) {
if (peer->idle_expiry < keep_alive) {
OMQ_LOG(debug, "updating existing outgoing peer connection idle expiry time from ",
peer->idle_expiry.count(), "ms to ", keep_alive.count(), "ms");
peer->idle_expiry = keep_alive;
}
peer->activity();
}
return {&connections[peer->conn_id], peer->route};
} else if (optional || incoming_only) {
OMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt");
return {nullptr, ""s};
}
// No connection so establish a new one
OMQ_LOG(debug, "proxy establishing new outbound connection to ", oxenc::to_hex(remote));
std::string addr;
bool to_self = false && remote == pubkey; // FIXME; need to use a separate listening socket for this, otherwise we can't easily
// tell it wasn't from a remote.
if (to_self) {
// special inproc connection if self that doesn't need any external connection
addr = SN_ADDR_SELF;
} else {
addr = std::string{connect_hint};
if (addr.empty())
addr = sn_lookup(remote);
else
OMQ_LOG(debug, "using connection hint ", connect_hint);
if (addr.empty()) {
OMQ_LOG(error, "peer lookup failed for ", oxenc::to_hex(remote));
return {nullptr, ""s};
}
}
OMQ_LOG(debug, oxenc::to_hex(pubkey), " (me) connecting to ", addr, " to reach ", oxenc::to_hex(remote));
std::optional<zmq::socket_t> socket;
try {
socket.emplace(context, zmq::socket_type::dealer);
setup_outgoing_socket(*socket, remote, use_ephemeral_routing_id);
socket->connect(addr);
} catch (const zmq::error_t& e) {
// Note that this failure cases indicates something serious went wrong that means zmq isn't
// even going to try connecting (for example an unparseable remote address).
OMQ_LOG(error, "Outgoing connection to ", addr, " failed: ", e.what());
return {nullptr, ""s};
}
auto& p = peers.emplace(std::move(remote_cid), peer_info{})->second;
p.service_node = true;
p.pubkey = std::string{remote};
p.conn_id = next_conn_id++;
p.idle_expiry = keep_alive;
p.activity();
connections_updated = true;
outgoing_sn_conns.emplace_hint(outgoing_sn_conns.end(), p.conn_id, ConnectionID{remote});
auto it = connections.emplace_hint(connections.end(), p.conn_id, *std::move(socket));
return {&it->second, ""s};
}
std::pair<zmq::socket_t *, std::string> OxenMQ::proxy_connect_sn(oxenc::bt_dict_consumer data) {
std::string_view hint, remote_pk;
std::chrono::milliseconds keep_alive;
bool optional = false, incoming_only = false, outgoing_only = false, ephemeral_rid = EPHEMERAL_ROUTING_ID;
// Alphabetical order
if (data.skip_until("ephemeral_rid"))
ephemeral_rid = data.consume_integer<bool>();
if (data.skip_until("hint"))
hint = data.consume_string_view();
if (data.skip_until("incoming"))
incoming_only = data.consume_integer<bool>();
if (data.skip_until("keep_alive"))
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
if (data.skip_until("optional"))
optional = data.consume_integer<bool>();
if (data.skip_until("outgoing_only"))
outgoing_only = data.consume_integer<bool>();
if (!data.skip_until("pubkey"))
throw std::runtime_error("Internal error: Invalid proxy_connect_sn command; pubkey missing");
remote_pk = data.consume_string_view();
return proxy_connect_sn(remote_pk, hint, optional, incoming_only, outgoing_only, ephemeral_rid, keep_alive);
}
/// Closes outgoing connections and removes all references. Note that this will call `erase()`
/// which can invalidate iterators on the various connection containers - if you don't want that,
/// delete it first so that the container won't contain the element being deleted.
void OxenMQ::proxy_close_connection(int64_t id, std::chrono::milliseconds linger) {
auto it = connections.find(id);
if (it == connections.end()) {
OMQ_LOG(warn, "internal error: connection to close (", id, ") doesn't exist!");
return;
}
OMQ_LOG(debug, "Closing conn ", id);
it->second.set(zmq::sockopt::linger, linger > 0ms ? (int) linger.count() : 0);
connections.erase(it);
connections_updated = true;
outgoing_sn_conns.erase(id);
}
void OxenMQ::proxy_expire_idle_peers() {
for (auto it = peers.begin(); it != peers.end(); ) {
auto &info = it->second;
if (info.outgoing()) {
auto idle = std::chrono::steady_clock::now() - info.last_activity;
if (idle > info.idle_expiry) {
OMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle time (",
std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(), "ms) reached connection timeout (",
info.idle_expiry.count(), "ms)");
proxy_close_connection(info.conn_id, CLOSE_LINGER);
it = peers.erase(it);
} else {
OMQ_LOG(trace, "Not closing ", it->first, ": ", std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(),
"ms <= ", info.idle_expiry.count(), "ms");
++it;
continue;
}
} else {
++it;
}
}
}
void OxenMQ::proxy_conn_cleanup() {
OMQ_TRACE("starting proxy connections cleanup");
// Drop idle connections (if we haven't done it in a while)
OMQ_TRACE("closing idle connections");
proxy_expire_idle_peers();
auto now = std::chrono::steady_clock::now();
// FIXME - check other outgoing connections to see if they died and if so purge them
OMQ_TRACE("Timing out pending outgoing connections");
// Check any pending outgoing connections for timeout
for (auto it = pending_connects.begin(); it != pending_connects.end(); ) {
auto& pc = *it;
if (std::get<std::chrono::steady_clock::time_point>(pc) < now) {
auto id = std::get<int64_t>(pc);
job([cid = ConnectionID{id}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); });
it = pending_connects.erase(it); // Don't let the below erase it (because it invalidates iterators)
proxy_close_connection(id, CLOSE_LINGER);
} else {
++it;
}
}
OMQ_TRACE("Timing out pending requests");
// Remove any expired pending requests and schedule their callback with a failure
for (auto it = pending_requests.begin(); it != pending_requests.end(); ) {
auto& callback = it->second;
if (callback.first < now) {
OMQ_LOG(debug, "pending request ", oxenc::to_hex(it->first), " expired, invoking callback with failure status and removing");
job([callback = std::move(callback.second)] { callback(false, {{"TIMEOUT"s}}); });
it = pending_requests.erase(it);
} else {
++it;
}
}
OMQ_TRACE("done proxy connections cleanup");
};
void OxenMQ::proxy_connect_remote(oxenc::bt_dict_consumer data) {
AuthLevel auth_level = AuthLevel::none;
long long conn_id = -1;
ConnectSuccess on_connect;
ConnectFailure on_failure;
std::string remote;
std::string remote_pubkey;
std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT;
bool ephemeral_rid = EPHEMERAL_ROUTING_ID;
if (data.skip_until("auth_level"))
auth_level = static_cast<AuthLevel>(data.consume_integer<std::underlying_type_t<AuthLevel>>());
if (data.skip_until("conn_id"))
conn_id = data.consume_integer<long long>();
if (data.skip_until("connect"))
on_connect = detail::deserialize_object<ConnectSuccess>(data.consume_integer<uintptr_t>());
if (data.skip_until("ephemeral_rid"))
ephemeral_rid = data.consume_integer<bool>();
if (data.skip_until("failure"))
on_failure = detail::deserialize_object<ConnectFailure>(data.consume_integer<uintptr_t>());
if (data.skip_until("pubkey")) {
remote_pubkey = data.consume_string();
assert(remote_pubkey.size() == 32 || remote_pubkey.empty());
}
if (data.skip_until("remote"))
remote = data.consume_string();
if (data.skip_until("timeout"))
timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
if (conn_id == -1 || remote.empty())
throw std::runtime_error("Internal error: CONNECT_REMOTE proxy command missing required 'conn_id' and/or 'remote' value");
OMQ_LOG(debug, "Establishing remote connection to ", remote,
remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + oxenc::to_hex(remote_pubkey));
std::optional<zmq::socket_t> sock;
try {
sock.emplace(context, zmq::socket_type::dealer);
setup_outgoing_socket(*sock, remote_pubkey, ephemeral_rid);
sock->connect(remote);
} catch (const zmq::error_t &e) {
proxy_schedule_reply_job([conn_id, on_failure=std::move(on_failure), what="connect() failed: "s+e.what()] {
on_failure(conn_id, std::move(what));
});
return;
}
auto &s = connections.emplace_hint(connections.end(), conn_id, std::move(*sock))->second;
connections_updated = true;
OMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI");
send_direct_message(s, "HI");
pending_connects.emplace_back(conn_id, std::chrono::steady_clock::now() + timeout,
std::move(on_connect), std::move(on_failure));
auto& peer = peers.emplace(ConnectionID{conn_id, remote_pubkey}, peer_info{})->second;
peer.pubkey = std::move(remote_pubkey);
peer.service_node = false;
peer.auth_level = auth_level;
peer.conn_id = conn_id;
peer.idle_expiry = 24h * 10 * 365; // "forever"
peer.activity();
}
void OxenMQ::proxy_disconnect(oxenc::bt_dict_consumer data) {
ConnectionID connid{-1};
std::chrono::milliseconds linger = 1s;
if (data.skip_until("conn_id"))
connid.id = data.consume_integer<long long>();
if (data.skip_until("linger_ms"))
linger = std::chrono::milliseconds(data.consume_integer<long long>());
if (data.skip_until("pubkey"))
connid.pk = data.consume_string();
if (connid.sn() && connid.pk.size() != 32)
throw std::runtime_error("Error: invalid disconnect of SN without a valid pubkey");
proxy_disconnect(std::move(connid), linger);
}
void OxenMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) {
OMQ_TRACE("Disconnecting outgoing connection to ", conn);
auto pr = peers.equal_range(conn);
for (auto it = pr.first; it != pr.second; ++it) {
auto& peer = it->second;
if (peer.outgoing()) {
OMQ_LOG(debug, "Closing outgoing connection to ", conn);
proxy_close_connection(peer.conn_id, linger);
peers.erase(it);
return;
}
}
OMQ_LOG(warn, "Failed to disconnect ", conn, ": no such outgoing connection");
}
std::string ConnectionID::to_string() const {
if (!pk.empty())
return (sn() ? std::string("SN ") : std::string("non-SN authenticated remote ")) + oxenc::to_hex(pk);
else
return std::string("unauthenticated remote [") + std::to_string(id) + "]";
}
}

View file

@ -1,28 +0,0 @@
#pragma once
#include <fmt/format.h>
#include "connections.h"
#include "auth.h"
#include "address.h"
template <>
struct fmt::formatter<oxenmq::AuthLevel> : fmt::formatter<std::string> {
auto format(oxenmq::AuthLevel v, format_context& ctx) {
return formatter<std::string>::format(
fmt::format("{}", to_string(v)), ctx);
}
};
template <>
struct fmt::formatter<oxenmq::ConnectionID> : fmt::formatter<std::string> {
auto format(oxenmq::ConnectionID conn, format_context& ctx) {
return formatter<std::string>::format(
fmt::format("{}", conn.to_string()), ctx);
}
};
template <>
struct fmt::formatter<oxenmq::address> : fmt::formatter<std::string> {
auto format(oxenmq::address addr, format_context& ctx) {
return formatter<std::string>::format(
fmt::format("{}", addr.full_address()), ctx);
}
};

View file

@ -1,182 +0,0 @@
#include "oxenmq.h"
#include "batch.h"
#include "oxenmq-internal.h"
namespace oxenmq {
void OxenMQ::proxy_batch(detail::Batch* batch) {
const auto [jobs, tagged_threads] = batch->size();
OMQ_TRACE("proxy queuing batch job with ", jobs, " jobs", tagged_threads ? " (job uses tagged thread(s))" : "");
if (!tagged_threads) {
for (size_t i = 0; i < jobs; i++)
batch_jobs.emplace_back(batch, i);
} else {
// Some (or all) jobs have a specific thread target so queue any such jobs in the tagged
// worker queue.
auto threads = batch->threads();
for (size_t i = 0; i < jobs; i++) {
auto& jobs = threads[i] > 0
? std::get<batch_queue>(tagged_workers[threads[i] - 1])
: batch_jobs;
jobs.emplace_back(batch, i);
}
}
proxy_skip_one_poll = true;
}
void OxenMQ::job(std::function<void()> f, std::optional<TaggedThreadID> thread) {
if (thread && thread->_id == -1)
throw std::logic_error{"job() cannot be used to queue an in-proxy job"};
auto* j = new Job(std::move(f), thread);
auto* baseptr = static_cast<detail::Batch*>(j);
detail::send_control(get_control_socket(), "BATCH", oxenc::bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
}
void OxenMQ::proxy_schedule_reply_job(std::function<void()> f) {
auto* j = new Job(std::move(f));
reply_jobs.emplace_back(static_cast<detail::Batch*>(j), 0);
proxy_skip_one_poll = true;
}
void OxenMQ::proxy_run_batch_jobs(batch_queue& jobs, const int reserved, int& active, bool reply) {
while (!jobs.empty() && active_workers() < max_workers &&
(active < reserved || active_workers() < general_workers)) {
proxy_run_worker(get_idle_worker().load(std::move(jobs.front()), reply));
jobs.pop_front();
active++;
}
}
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
// the timer. If the timer object hasn't been set up yet it gets set up here.
void OxenMQ::proxy_timer(int id, std::function<void()> job, std::chrono::milliseconds interval, bool squelch, int thread) {
if (!timers)
timers.reset(zmq_timers_new());
int zmq_timer_id = zmq_timers_add(timers.get(),
interval.count(),
[](int timer_id, void* self) { static_cast<OxenMQ*>(self)->_queue_timer_job(timer_id); },
this);
if (zmq_timer_id == -1)
throw zmq::error_t{};
timer_jobs[zmq_timer_id] = { std::move(job), squelch, false, thread };
timer_zmq_id[id] = zmq_timer_id;
}
void OxenMQ::proxy_timer(oxenc::bt_list_consumer timer_data) {
auto timer_id = timer_data.consume_integer<int>();
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
auto squelch = timer_data.consume_integer<bool>();
auto thread = timer_data.consume_integer<int>();
if (!timer_data.is_finished())
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
proxy_timer(timer_id, std::move(*func), interval, squelch, thread);
}
void OxenMQ::_queue_timer_job(int timer_id) {
auto it = timer_jobs.find(timer_id);
if (it == timer_jobs.end()) {
OMQ_LOG(warn, "Could not find timer job ", timer_id);
return;
}
auto& [func, squelch, running, thread] = it->second;
if (squelch && running) {
OMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
return;
}
if (thread == -1) { // Run directly in proxy thread
try { func(); }
catch (const std::exception &e) { OMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
catch (...) { OMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
return;
}
detail::Batch* b;
if (squelch) {
auto* bv = new Batch<void>;
bv->add_job(func, thread);
running = true;
bv->completion([this,timer_id](auto results) {
try { results[0].get(); }
catch (const std::exception &e) { OMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
catch (...) { OMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
auto it = timer_jobs.find(timer_id);
if (it != timer_jobs.end())
it->second.running = false;
}, OxenMQ::run_in_proxy);
b = bv;
} else {
b = new Job(func, thread);
}
OMQ_TRACE("b: ", b->size().first, ", ", b->size().second, "; thread = ", thread);
assert(b->size() == std::make_pair(size_t{1}, thread > 0));
auto& queue = thread > 0
? std::get<batch_queue>(tagged_workers[thread - 1])
: batch_jobs;
queue.emplace_back(static_cast<detail::Batch*>(b), 0);
}
void OxenMQ::add_timer(TimerID& timer, std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
int th_id = thread ? thread->_id : 0;
timer._id = next_timer_id++;
if (proxy_thread.joinable()) {
detail::send_control(get_control_socket(), "TIMER", oxenc::bt_serialize(oxenc::bt_list{{
timer._id,
detail::serialize_object(std::move(job)),
interval.count(),
squelch,
th_id}}));
} else {
proxy_timer(timer._id, std::move(job), interval, squelch, th_id);
}
}
TimerID OxenMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
TimerID tid;
add_timer(tid, std::move(job), interval, squelch, std::move(thread));
return tid;
}
void OxenMQ::proxy_timer_del(int id) {
if (!timers)
return;
auto it = timer_zmq_id.find(id);
if (it == timer_zmq_id.end())
return;
zmq_timers_cancel(timers.get(), it->second);
timer_zmq_id.erase(it);
}
void OxenMQ::cancel_timer(TimerID timer_id) {
if (proxy_thread.joinable()) {
detail::send_control(get_control_socket(), "TIMER_DEL", oxenc::bt_serialize(timer_id._id));
} else {
proxy_timer_del(timer_id._id);
}
}
void OxenMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
TaggedThreadID OxenMQ::add_tagged_thread(std::string name, std::function<void()> start) {
if (proxy_thread.joinable())
throw std::logic_error{"Cannot add tagged threads after calling `start()`"};
if (name == "_proxy"sv || name.empty() || name.find('\0') != std::string::npos)
throw std::logic_error{"Invalid tagged thread name `" + name + "'"};
auto& [run, busy, queue] = tagged_workers.emplace_back();
busy = false;
run.worker_id = tagged_workers.size(); // We want index + 1 (b/c 0 is used for non-tagged jobs)
run.worker_routing_name = "t" + std::to_string(run.worker_id);
run.worker_routing_id = "t" + std::string{reinterpret_cast<const char*>(&run.worker_id), sizeof(run.worker_id)};
OMQ_TRACE("Created new tagged thread ", name, " with routing id ", run.worker_routing_name);
run.worker_thread = std::thread{&OxenMQ::worker_thread, this, run.worker_id, name, std::move(start)};
return TaggedThreadID{static_cast<int>(run.worker_id)};
}
}

View file

@ -1,106 +0,0 @@
#pragma once
#include <vector>
#include "connections.h"
namespace oxenmq {
class OxenMQ;
/// Encapsulates an incoming message from a remote connection with message details plus extra
/// info need to send a reply back through the proxy thread via the `reply()` method. Note that
/// this object gets reused: callbacks should use but not store any reference beyond the callback.
class Message {
public:
OxenMQ& oxenmq; ///< The owning OxenMQ object
std::vector<std::string_view> data; ///< The provided command data parts, if any.
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
Access access; ///< The access level of the invoker. This can be higher than the access level of the command, for example for an admin invoking a basic command.
std::string remote; ///< Some sort of remote address from which the request came. Often "IP" for TCP connections and "localhost:UID:GID:PID" for unix socket connections.
/// Constructor
Message(OxenMQ& omq, ConnectionID cid, Access access, std::string remote)
: oxenmq{omq}, conn{std::move(cid)}, access{std::move(access)}, remote{std::move(remote)} {}
// Non-copyable
Message(const Message&) = delete;
Message& operator=(const Message&) = delete;
/// Sends a command back to whomever sent this message. Arguments are forwarded to send() but
/// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e.
/// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to
/// establish a new connection to the SN if no longer connected. For non-SN messages the reply
/// will be attempted using the available routing information, but if the connection has already
/// been closed the reply will be dropped.
///
/// If you want to send a non-strong reply even when the remote is a service node then add
/// an explicit `send_option::optional()` argument.
template <typename... Args>
void send_back(std::string_view command, Args&&... args);
/// Sends a reply to a request. This takes no command: the command is always the built-in
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
/// arguments are as in `send_back()`. You should only send one reply for a command expecting
/// replies, though this is not enforced: attempting to send multiple replies will simply be
/// dropped when received by the remote. (Note, however, that it is possible to send multiple
/// messages -- e.g. you could send a reply and then also call send_back() and/or send_request()
/// to send more requests back to the sender).
template <typename... Args>
void send_reply(Args&&... args);
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
/// omq.request() that takes care of setting up the recipient arguments.
template <typename ReplyCallback, typename... Args>
void send_request(std::string_view command, ReplyCallback&& callback, Args&&... args);
/** Class returned by `send_later()` that can be used to call `back()`, `reply()`, or
* `request()` beyond the lifetime of the Message instance as if calling `msg.send_back()`,
* `msg.send_reply()`, or `msg.send_request()`. For example:
*
* auto send = msg.send_later();
* // ... later, perhaps in a lambda or scheduled job:
* send.reply("content");
*
* is equivalent to
*
* msg.send_reply("content");
*
* except that it is valid even if `msg` is no longer valid.
*/
class DeferredSend {
public:
OxenMQ& oxenmq; ///< The owning OxenMQ object
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `reply()`.
explicit DeferredSend(Message& m) : oxenmq{m.oxenmq}, conn{m.conn}, reply_tag{m.reply_tag} {}
template <typename... Args>
void operator()(Args &&...args) const {
if (reply_tag.empty())
back(std::forward<Args>(args)...);
else
reply(std::forward<Args>(args)...);
};
/// Equivalent to msg.send_back(...), but can be invoked later.
template <typename... Args>
void back(std::string_view command, Args&&... args) const;
/// Equivalent to msg.send_reply(...), but can be invoked later.
template <typename... Args>
void reply(Args&&... args) const;
/// Equivalent to msg.send_request(...), but can be invoked later.
template <typename ReplyCallback, typename... Args>
void request(std::string_view command, ReplyCallback&& callback, Args&&... args) const;
};
/// Returns a DeferredSend object that can be used to send replies to this message even if the
/// message expires. Typically this is used when sending a reply requires waiting on another
/// task to complete without needing to block the handler thread.
DeferredSend send_later() { return DeferredSend{*this}; }
};
}

File diff suppressed because it is too large Load diff

View file

@ -1,843 +0,0 @@
#include "oxenmq.h"
#include "oxenmq-internal.h"
#include <oxenc/hex.h>
#include <exception>
#include <future>
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
extern "C" {
#include <pthread.h>
#include <pthread_np.h>
}
#endif
#ifdef OXENMQ_USE_EPOLL
#include <sys/epoll.h>
#endif
#ifndef _WIN32
extern "C" {
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
}
#endif
namespace oxenmq {
void OxenMQ::proxy_quit() {
OMQ_LOG(debug, "Received quit command, shutting down proxy thread");
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
assert(std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto& worker) { return std::get<0>(worker).worker_thread.joinable(); }));
command.set(zmq::sockopt::linger, 0);
command.close();
{
std::lock_guard lock{control_sockets_mutex};
proxy_shutting_down = true; // To prevent threads from opening new control sockets
}
workers_socket.close();
int linger = std::chrono::milliseconds{CLOSE_LINGER}.count();
for (auto& [id, s] : connections)
s.set(zmq::sockopt::linger, linger);
connections.clear();
peers.clear();
OMQ_LOG(debug, "Proxy thread teardown complete");
}
void OxenMQ::proxy_send(oxenc::bt_dict_consumer data) {
// NB: bt_dict_consumer goes in alphabetical order
std::string_view hint;
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
bool optional = false;
bool outgoing = false;
bool incoming = false;
bool request = false;
bool have_conn_id = false;
ConnectionID conn_id;
std::string request_tag;
ReplyCallback request_callback;
if (data.skip_until("conn_id")) {
conn_id.id = data.consume_integer<long long>();
if (conn_id.id == -1)
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
have_conn_id = true;
}
if (data.skip_until("conn_pubkey")) {
if (have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy send command; conn_id and conn_pubkey are exclusive");
conn_id.pk = data.consume_string();
conn_id.id = ConnectionID::SN_ID;
} else if (!have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy send command; conn_pubkey or conn_id missing");
if (data.skip_until("conn_route"))
conn_id.route = data.consume_string();
if (data.skip_until("hint"))
hint = data.consume_string_view();
if (data.skip_until("incoming"))
incoming = data.consume_integer<bool>();
if (data.skip_until("keep_alive"))
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
if (data.skip_until("optional"))
optional = data.consume_integer<bool>();
if (data.skip_until("outgoing"))
outgoing = data.consume_integer<bool>();
if (data.skip_until("request"))
request = data.consume_integer<bool>();
if (request) {
if (!data.skip_until("request_callback"))
throw std::runtime_error("Internal error: received request without request_callback");
request_callback = detail::deserialize_object<ReplyCallback>(data.consume_integer<uintptr_t>());
if (!data.skip_until("request_tag"))
throw std::runtime_error("Internal error: received request without request_name");
request_tag = data.consume_string();
if (data.skip_until("request_timeout"))
request_timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
}
if (!data.skip_until("send"))
throw std::runtime_error("Internal error: Invalid proxy send command; send parts missing");
oxenc::bt_list_consumer send = data.consume_list_consumer();
send_option::queue_failure::callback_t callback_nosend;
if (data.skip_until("send_fail"))
callback_nosend = detail::deserialize_object<decltype(callback_nosend)>(data.consume_integer<uintptr_t>());
send_option::queue_full::callback_t callback_noqueue;
if (data.skip_until("send_full_q"))
callback_noqueue = detail::deserialize_object<decltype(callback_noqueue)>(data.consume_integer<uintptr_t>());
// Now figure out which socket to send to and do the actual sending. We can repeat this loop
// multiple times, if we're sending to a SN, because it's possible that we have multiple
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
// connection and try the next one.
bool retry = true, sent = false, nowarn = false;
while (retry) {
retry = false;
zmq::socket_t *send_to;
if (conn_id.sn()) {
auto sock_route = proxy_connect_sn(conn_id.pk, hint, optional, incoming, outgoing, EPHEMERAL_ROUTING_ID, keep_alive);
if (!sock_route.first) {
nowarn = true;
if (optional)
OMQ_LOG(debug, "Not sending: send is optional and no connection to ",
oxenc::to_hex(conn_id.pk), " is currently established");
else
OMQ_LOG(error, "Unable to send to ", oxenc::to_hex(conn_id.pk), ": no valid connection address found");
break;
}
send_to = sock_route.first;
conn_id.route = std::move(sock_route.second);
} else if (!conn_id.route.empty()) { // incoming non-SN connection
auto it = connections.find(conn_id.id);
if (it == connections.end()) {
OMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
break;
}
send_to = &it->second;
} else {
auto pr = peers.equal_range(conn_id);
if (pr.first == peers.end()) {
OMQ_LOG(warn, "Unable to send: connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
break;
}
auto& peer = pr.first->second;
auto it = connections.find(peer.conn_id);
if (it == connections.end()) {
OMQ_LOG(warn, "Unable to send: peer connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
break;
}
send_to = &it->second;
}
try {
sent = send_message_parts(*send_to, build_send_parts(send, conn_id.route));
} catch (const zmq::error_t &e) {
if (e.num() == EHOSTUNREACH && !conn_id.route.empty() /*= incoming conn*/) {
OMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
auto pr = peers.equal_range(conn_id);
if (pr.first != peers.end()) {
if (!conn_id.sn()) {
peers.erase(pr.first);
} else {
bool removed;
for (auto it = pr.first; it != pr.second; ) {
auto& peer = it->second;
if (peer.route == conn_id.route) {
peers.erase(it);
removed = true;
break;
}
}
// The incoming connection to the SN is no longer good, but we can retry because
// we may have another active connection with the SN (or may want to open one).
if (removed) {
OMQ_LOG(debug, "Retrying sending to SN ", oxenc::to_hex(conn_id.pk), " using other sockets");
retry = true;
}
}
}
}
if (!retry) {
if (!conn_id.sn() && !conn_id.route.empty()) { // incoming non-SN connection
OMQ_LOG(debug, "Unable to send message to incoming connection ", conn_id, ": ", e.what(),
"; remote has probably disconnected");
} else {
OMQ_LOG(warn, "Unable to send message to ", conn_id, ": ", e.what());
}
nowarn = true;
if (callback_nosend) {
job([callback = std::move(callback_nosend), error = e] { callback(&error); });
callback_nosend = nullptr;
}
}
}
}
if (request) {
if (sent) {
OMQ_LOG(debug, "Added new pending request ", oxenc::to_hex(request_tag));
pending_requests.insert({ request_tag, {
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
} else {
OMQ_LOG(debug, "Could not send request, scheduling request callback failure");
job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); });
}
}
if (!sent) {
if (callback_nosend)
job([callback = std::move(callback_nosend)] { callback(nullptr); });
else if (callback_noqueue)
job(std::move(callback_noqueue));
else if (!nowarn)
OMQ_LOG(warn, "Unable to send message to ", conn_id, ": sending would block");
}
}
void OxenMQ::proxy_reply(oxenc::bt_dict_consumer data) {
bool have_conn_id = false;
ConnectionID conn_id{0};
if (data.skip_until("conn_id")) {
conn_id.id = data.consume_integer<long long>();
if (conn_id.id == -1)
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
have_conn_id = true;
}
if (data.skip_until("conn_pubkey")) {
if (have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_id and conn_pubkey are exclusive");
conn_id.pk = data.consume_string();
conn_id.id = ConnectionID::SN_ID;
} else if (!have_conn_id)
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_pubkey or conn_id missing");
if (!data.skip_until("send"))
throw std::runtime_error("Internal error: Invalid proxy reply command; send parts missing");
oxenc::bt_list_consumer send = data.consume_list_consumer();
auto pr = peers.equal_range(conn_id);
if (pr.first == pr.second) {
OMQ_LOG(warn, "Unable to send tagged reply: the connection is no longer valid");
return;
}
// We try any connections until one works (for ordinary remotes there will be just one, but for
// SNs there might be one incoming and one outgoing).
for (auto it = pr.first; it != pr.second; ) {
try {
send_message_parts(connections[it->second.conn_id], build_send_parts(send, it->second.route));
break;
} catch (const zmq::error_t &err) {
if (err.num() == EHOSTUNREACH) {
if (it->second.outgoing()) {
OMQ_LOG(debug, "Unable to send reply to non-SN request on outgoing socket: "
"remote is no longer connected; closing connection");
proxy_close_connection(it->second.conn_id, CLOSE_LINGER);
it = peers.erase(it);
++it;
} else {
OMQ_LOG(debug, "Unable to send reply to non-SN request on incoming socket: "
"remote is no longer connected; removing peer details");
it = peers.erase(it);
}
} else {
OMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what());
++it;
}
}
}
}
void OxenMQ::proxy_control_message(OxenMQ::control_message_array& parts, size_t len) {
// We throw an uncaught exception here because we only generate control messages internally in
// oxenmq code: if one of these condition fail it's a oxenmq bug.
if (len < 2)
throw std::logic_error("OxenMQ bug: Expected 2-3 message parts for a proxy control message");
auto route = view(parts[0]), cmd = view(parts[1]);
OMQ_TRACE("control message: ", cmd);
if (len == 3) {
OMQ_TRACE("...: ", parts[2]);
auto data = view(parts[2]);
if (cmd == "SEND") {
OMQ_TRACE("proxying message");
return proxy_send(data);
} else if (cmd == "REPLY") {
OMQ_TRACE("proxying reply to non-SN incoming message");
return proxy_reply(data);
} else if (cmd == "BATCH") {
OMQ_TRACE("proxy batch jobs");
auto ptrval = oxenc::bt_deserialize<uintptr_t>(data);
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
} else if (cmd == "INJECT") {
OMQ_TRACE("proxy inject");
return proxy_inject_task(detail::deserialize_object<injected_task>(oxenc::bt_deserialize<uintptr_t>(data)));
} else if (cmd == "SET_SNS") {
return proxy_set_active_sns(data);
} else if (cmd == "UPDATE_SNS") {
return proxy_update_active_sns(data);
} else if (cmd == "CONNECT_SN") {
proxy_connect_sn(data);
return;
} else if (cmd == "CONNECT_REMOTE") {
return proxy_connect_remote(data);
} else if (cmd == "DISCONNECT") {
return proxy_disconnect(data);
} else if (cmd == "TIMER") {
return proxy_timer(data);
} else if (cmd == "TIMER_DEL") {
return proxy_timer_del(oxenc::bt_deserialize<int>(data));
} else if (cmd == "BIND") {
auto b = detail::deserialize_object<bind_data>(oxenc::bt_deserialize<uintptr_t>(data));
if (proxy_bind(b, bind.size()))
bind.push_back(std::move(b));
return;
}
} else if (len == 2) {
if (cmd == "START") {
// Command send by the owning thread during startup; we send back a simple READY reply to
// let it know we are running.
return route_control(command, route, "READY");
} else if (cmd == "QUIT") {
// Asked to quit: set max_workers to zero and tell any idle ones to quit. We will
// close workers as they come back to READY status, and then close external
// connections once all workers are done.
max_workers = 0;
for (size_t i = 0; i < idle_worker_count; i++)
route_control(workers_socket, workers[idle_workers[i]].worker_routing_id, "QUIT");
idle_worker_count = 0;
for (auto& [run, busy, queue] : tagged_workers)
if (!busy)
route_control(workers_socket, run.worker_routing_id, "QUIT");
return;
}
}
throw std::runtime_error("OxenMQ bug: Proxy received invalid control command: " +
std::string{cmd} + " (" + std::to_string(len) + ")");
}
bool OxenMQ::proxy_bind(bind_data& b, size_t bind_index) {
zmq::socket_t listener{context, zmq::socket_type::router};
setup_incoming_socket(listener, b.curve, pubkey, privkey, bind_index);
bool good = true;
try {
listener.bind(b.address);
} catch (const zmq::error_t&) {
good = false;
}
if (b.on_bind) {
b.on_bind(good);
b.on_bind = nullptr;
}
if (!good) {
OMQ_LOG(warn, "OxenMQ failed to listen on ", b.address);
return false;
}
OMQ_LOG(info, "OxenMQ listening on ", b.address);
b.conn_id = next_conn_id++;
connections.emplace_hint(connections.end(), b.conn_id, std::move(listener));
connections_updated = true;
return true;
}
void OxenMQ::proxy_loop_init() {
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
pthread_setname_np(pthread_self(), "omq-proxy");
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
pthread_set_name_np(pthread_self(), "omq-proxy");
#elif defined(__MACH__)
pthread_setname_np("omq-proxy");
#endif
zap_auth = zmq::socket_t{context, zmq::socket_type::rep};
zap_auth.set(zmq::sockopt::linger, 0);
zap_auth.bind(ZMQ_ADDR_ZAP);
workers_socket = zmq::socket_t{context, zmq::socket_type::router};
workers_socket.set(zmq::sockopt::router_mandatory, true);
workers_socket.bind(SN_ADDR_WORKERS);
workers.reserve(max_workers);
idle_workers.resize(max_workers);
if (!workers.empty() || !worker_sockets.empty())
throw std::logic_error("Internal error: proxy thread started with active worker threads");
worker_sockets.reserve(max_workers);
// Pre-initialize these worker sockets rather than creating during thread initialization so that
// we can't hit the zmq socket limit during worker thread startup.
for (int i = 0; i < max_workers; i++)
worker_sockets.emplace_back(context, zmq::socket_type::dealer);
#ifndef _WIN32
int saved_umask = -1;
if (STARTUP_UMASK >= 0)
saved_umask = umask(STARTUP_UMASK);
#endif
{
zmq::socket_t inproc_listener{context, zmq::socket_type::router};
inproc_listener.bind(SN_ADDR_SELF);
inproc_listener_connid = next_conn_id++;
connections.emplace_hint(connections.end(), inproc_listener_connid, std::move(inproc_listener));
}
for (size_t i = 0; i < bind.size(); i++) {
if (!proxy_bind(bind[i], i)) {
OMQ_LOG(fatal, "OxenMQ failed to listen on ", bind[i].address);
throw zmq::error_t{};
}
}
#ifndef _WIN32
if (saved_umask != -1)
umask(saved_umask);
// set socket gid / uid if it is provided
if (SOCKET_GID != -1 or SOCKET_UID != -1) {
for (auto& b : bind) {
const address addr(b.address);
if (addr.ipc())
if (chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1)
throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno));
}
}
#endif
connections_updated = true;
// Also add an internal connection to self so that calling code can avoid needing to
// special-case rare situations where we are supposed to talk to a quorum member that happens to
// be ourselves (which can happen, for example, with cross-quoum Blink communication)
// FIXME: not working
//listener.bind(SN_ADDR_SELF);
if (!timers)
timers.reset(zmq_timers_new());
if (-1 == zmq_timers_add(timers.get(),
std::chrono::milliseconds{CONN_CHECK_INTERVAL}.count(),
[](int /*timer_id*/, void* self) { static_cast<OxenMQ*>(self)->proxy_conn_cleanup(); },
this)) {
throw zmq::error_t{};
}
// Wait for tagged worker threads to get ready and connect to us (we get a "STARTING" message)
// and send them back a "START" to let them know to go ahead with startup. We need this
// synchronization dance to guarantee that the workers are routable before we can proceed.
if (!tagged_workers.empty()) {
OMQ_LOG(debug, "Waiting for tagged workers");
{
std::unique_lock lock{tagged_startup_mutex};
tagged_go = tagged_go_mode::GO;
}
tagged_cv.notify_all();
std::unordered_set<std::string_view> waiting_on;
for (auto& w : tagged_workers)
waiting_on.emplace(std::get<run_info>(w).worker_routing_id);
for (std::vector<zmq::message_t> parts; !waiting_on.empty(); parts.clear()) {
recv_message_parts(workers_socket, parts);
if (parts.size() != 2 || view(parts[1]) != "STARTING"sv) {
OMQ_LOG(error, "Received invalid message on worker socket while waiting for tagged thread startup");
continue;
}
OMQ_LOG(debug, "Received STARTING message from ", view(parts[0]));
if (auto it = waiting_on.find(view(parts[0])); it != waiting_on.end())
waiting_on.erase(it);
else
OMQ_LOG(error, "Received STARTING message from unknown worker ", view(parts[0]));
}
for (auto&w : tagged_workers) {
OMQ_LOG(debug, "Telling tagged thread worker ", std::get<run_info>(w).worker_routing_name, " to finish startup");
route_control(workers_socket, std::get<run_info>(w).worker_routing_id, "START");
}
}
}
void OxenMQ::proxy_loop(std::promise<void> startup) {
try {
proxy_loop_init();
} catch (...) {
startup.set_exception(std::current_exception());
return;
}
startup.set_value();
// Fixed array used for worker and control messages: these are never longer than 3 parts:
std::array<zmq::message_t, 3> control_parts;
// General vector for handling incoming messages:
std::vector<zmq::message_t> parts;
std::vector<std::pair<const int64_t, zmq::socket_t>*> queue; // Used as a circular buffer
#ifdef OXENMQ_USE_EPOLL
std::vector<struct epoll_event> evs;
#endif
while (true) {
std::chrono::milliseconds poll_timeout;
if (max_workers == 0) { // Will be 0 only if we are quitting
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); }) &&
std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto &w) { return std::get<0>(w).worker_thread.joinable(); })) {
// All the workers have finished, so we can finish shutting down
return proxy_quit();
}
poll_timeout = 1s; // We don't keep running timers when we're quitting, so don't have a timer to check
} else {
poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())};
}
if (connections_updated) {
rebuild_pollitems();
// If we just rebuilt the queue then do a full check of everything, because we might
// have sockets that already edge-triggered that we need to fully drain before we start
// polling.
proxy_skip_one_poll = true;
}
// We round-robin connections when pulling off pending messages one-by-one rather than
// pulling off all messages from one connection before moving to the next; thus in cases of
// contention we end up fairly distributing.
queue.reserve(connections.size() + 1);
#ifdef OXENMQ_USE_EPOLL
bool process_command = false, process_worker = false, process_zap = false, process_all = false;
if (proxy_skip_one_poll) {
proxy_skip_one_poll = false;
process_command = command.get(zmq::sockopt::events) & ZMQ_POLLIN;
process_worker = workers_socket.get(zmq::sockopt::events) & ZMQ_POLLIN;
process_zap = zap_auth.get(zmq::sockopt::events) & ZMQ_POLLIN;
process_all = true;
}
else {
OMQ_TRACE("polling for new messages via epoll");
evs.resize(3 + connections.size());
const int max = epoll_wait(epoll_fd, evs.data(), evs.size(), poll_timeout.count());
queue.clear();
for (int i = 0; i < max; i++) {
const auto conn_id = evs[i].data.u64;
if (conn_id == EPOLL_COMMAND_ID)
process_command = true;
else if (conn_id == EPOLL_WORKER_ID)
process_worker = true;
else if (conn_id == EPOLL_ZAP_ID)
process_zap = true;
else if (auto it = connections.find(conn_id); it != connections.end())
queue.push_back(&*it);
}
queue.push_back(nullptr);
}
#else
if (proxy_skip_one_poll)
proxy_skip_one_poll = false;
else {
OMQ_TRACE("polling for new messages");
// We poll the control socket and worker socket for any incoming messages. If we have
// available worker room then also poll incoming connections and outgoing connections
// for messages to forward to a worker. Otherwise, we just look for a control message
// or a worker coming back with a ready message.
zmq::poll(pollitems.data(), pollitems.size(), poll_timeout);
}
constexpr bool process_command = true, process_worker = true, process_zap = true, process_all = true;
#endif
if (process_command) {
OMQ_TRACE("processing control messages");
while (size_t len = recv_message_parts(command, control_parts, zmq::recv_flags::dontwait))
proxy_control_message(control_parts, len);
}
if (process_worker) {
OMQ_TRACE("processing worker messages");
while (size_t len = recv_message_parts(workers_socket, control_parts, zmq::recv_flags::dontwait))
proxy_worker_message(control_parts, len);
}
OMQ_TRACE("processing timers");
zmq_timers_execute(timers.get());
if (process_zap) {
// Handle any zap authentication
OMQ_TRACE("processing zap requests");
process_zap_requests();
}
// See if we can drain anything from the current queue before we potentially add to it
// below.
OMQ_TRACE("processing queued jobs and messages");
proxy_process_queue();
OMQ_TRACE("processing new incoming messages");
if (process_all) {
queue.clear();
for (auto& id_sock : connections)
if (id_sock.second.get(zmq::sockopt::events) & ZMQ_POLLIN)
queue.push_back(&id_sock);
queue.push_back(nullptr);
}
size_t end = queue.size() - 1;
for (size_t pos = 0; pos != end; ++pos %= queue.size()) {
parts.clear();
auto& [id, sock] = *queue[pos];
if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait))
continue;
// We only pull this one message now but then requeue the socket so that after we check
// all other sockets we come back to this one to check again.
queue[end] = queue[pos];
++end %= queue.size();
if (parts.empty()) {
OMQ_LOG(warn, "Ignoring empty (0-part) incoming message");
continue;
}
if (!proxy_handle_builtin(id, sock, parts))
proxy_to_worker(id, sock, parts);
if (connections_updated) {
// If connections got updated then our points are stale, so restart the proxy loop;
// we'll immediately end up right back here at least once before we resume polling.
OMQ_TRACE("connections became stale; short-circuiting incoming message loop");
break;
}
}
#ifdef OXENMQ_USE_EPOLL
// If any socket still has ZMQ_POLLIN (which is possible if something we did above changed
// state on another socket, perhaps by writing to it) then we need to repeat the loop
// *without* going back to epoll again, until we get through everything without any
// ZMQ_POLLIN sockets. If we didn't, we could miss it and might end up deadlocked because
// of ZMQ's edge-triggered notifications on zmq fd's.
//
// More info on the complexities here at https://github.com/zeromq/libzmq/issues/3641 and
// https://funcptr.net/2012/09/10/zeromq---edge-triggered-notification/
if (!connections_updated && !proxy_skip_one_poll) {
for (auto* s : {&command, &workers_socket, &zap_auth}) {
if (s->get(zmq::sockopt::events) & ZMQ_POLLIN) {
proxy_skip_one_poll = true;
break;
}
}
if (!proxy_skip_one_poll) {
for (auto& [id, sock] : connections) {
if (sock.get(zmq::sockopt::events) & ZMQ_POLLIN) {
proxy_skip_one_poll = true;
break;
}
}
}
}
#endif
OMQ_TRACE("done proxy loop");
}
}
static bool is_error_response(std::string_view cmd) {
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
}
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
// reason)
bool OxenMQ::proxy_handle_builtin(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
// Doubling as a bool and an offset:
size_t incoming = sock.get(zmq::sockopt::type) == ZMQ_ROUTER;
std::string_view route, cmd;
if (parts.size() < 1 + incoming) {
OMQ_LOG(warn, "Received empty message; ignoring");
return true;
}
if (incoming) {
route = view(parts[0]);
cmd = view(parts[1]);
} else {
cmd = view(parts[0]);
}
OMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back()));
if (cmd == "REPLY") {
size_t tag_pos = 1 + incoming;
if (parts.size() <= tag_pos) {
OMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
return true;
}
std::string reply_tag{view(parts[tag_pos])};
auto it = pending_requests.find(reply_tag);
if (it != pending_requests.end()) {
OMQ_LOG(debug, "Received REPLY for pending command ", oxenc::to_hex(reply_tag), "; scheduling callback");
std::vector<std::string> data;
data.reserve(parts.size() - (tag_pos + 1));
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
data.emplace_back(view(*it));
proxy_schedule_reply_job([callback=std::move(it->second.second), data=std::move(data)] {
callback(true, std::move(data));
});
pending_requests.erase(it);
} else {
OMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", oxenc::to_hex(reply_tag), "); ignoring");
}
return true;
} else if (cmd == "HI") {
if (!incoming) {
OMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
return true;
}
OMQ_LOG(debug, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO");
try {
send_routed_message(sock, std::string{route}, "HELLO");
} catch (const std::exception &e) { OMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
return true;
} else if (cmd == "HELLO") {
if (incoming) {
OMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
return true;
}
auto it = std::find_if(pending_connects.begin(), pending_connects.end(),
[&](auto& pc) { return std::get<int64_t>(pc) == conn_id; });
if (it == pending_connects.end()) {
OMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring");
return true;
}
auto& pc = *it;
auto pit = peers.find(std::get<int64_t>(pc));
if (pit == peers.end()) {
OMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring");
return true;
}
OMQ_LOG(debug, "Got initial HELLO server response from ", peer_address(parts.back()));
proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)),
conn=pit->first] {
on_success(conn);
});
pending_connects.erase(it);
return true;
} else if (cmd == "BYE") {
if (!incoming) {
OMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
proxy_close_connection(conn_id, 0s);
} else {
OMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
}
return true;
}
else if (is_error_response(cmd)) {
// These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to
// invoke something that doesn't exist or we don't have permission to access. These have
// two forms (the latter is only sent by remotes running 1.1.0+).
// - ["XXX", "whatever.command"]
// - ["XXX", "REPLY", replytag]
// (ignoring the routing prefix on incoming commands).
// For the former, we log; for the latter we trigger the reply callback with a failure
if (parts.size() == (1 + incoming) && cmd == "UNKNOWNCOMMAND") {
// pre-1.1.0 sent just a plain UNKNOWNCOMMAND (without the actual command); this was not
// useful, but also this response is *expected* for things 1.0.5 didn't understand, like
// FORBIDDEN_SN: so log it only at debug level and move on.
OMQ_LOG(debug, "Received plain UNKNOWNCOMMAND; remote is probably an older oxenmq. Ignoring.");
return true;
}
if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") {
std::string reply_tag{view(parts[2 + incoming])};
auto it = pending_requests.find(reply_tag);
if (it != pending_requests.end()) {
OMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", oxenc::to_hex(reply_tag), "; scheduling failure callback");
proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] {
callback(false, {{std::move(cmd)}});
});
pending_requests.erase(it);
} else {
OMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", oxenc::to_hex(reply_tag), "); ignoring");
}
} else {
OMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"sv),
" from ", peer_address(parts.back()));
}
return true;
}
return false;
}
void OxenMQ::proxy_process_queue() {
if (max_workers == 0) // shutting down
return;
// First: send any tagged thread tasks to the tagged threads, if idle
for (auto& [run, busy, queue] : tagged_workers) {
if (!busy && !queue.empty()) {
busy = true;
proxy_run_worker(run.load(std::move(queue.front()), false, run.worker_id));
queue.pop_front();
}
}
// Second: process any batch jobs; since these are internal they are given higher priority.
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
// Next any reply batch jobs (which are a bit different from the above, since they are
// externally triggered but for things we initiated locally).
proxy_run_batch_jobs(reply_jobs, reply_jobs_reserved, reply_jobs_active, true);
// Finally general incoming commands
for (auto it = pending_commands.begin(); it != pending_commands.end() && active_workers() < max_workers; ) {
auto& pending = *it;
if (pending.cat.active_threads < pending.cat.reserved_threads
|| active_workers() < general_workers) {
proxy_run_worker(get_idle_worker().load(std::move(pending)));
pending.cat.queued--;
pending.cat.active_threads++;
assert(pending.cat.queued >= 0);
it = pending_commands.erase(it);
} else {
++it; // no available general or reserved worker spots for this job right now
}
}
}
}

View file

@ -1,172 +0,0 @@
#pragma once
#include "connections.h"
#include <chrono>
#include <mutex>
#include <shared_mutex>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include <optional>
namespace oxenmq {
using namespace std::chrono_literals;
namespace detail {
struct no_data_t final {};
inline constexpr no_data_t no_data{};
template <typename UserData>
struct SubData {
std::chrono::steady_clock::time_point expiry;
UserData user_data;
explicit SubData(std::chrono::steady_clock::time_point _exp)
: expiry{_exp}, user_data{} {}
};
template <>
struct SubData<void> {
std::chrono::steady_clock::time_point expiry;
};
}
/**
* OMQ Subscription class. Handles pub/sub connections such that the user only needs to call
* methods to subscribe and publish.
*
* FIXME: do we want an unsubscribe, or is expiry / conn management sufficient?
*
* Type UserData can contain whatever information the user may need at publish time, for example if
* the subscription is for logs the subscriber can specify log levels or categories, and the
* publisher can choose to send or not based on those. The UserData type, if provided and non-void,
* must be default constructible, must be comparable with ==, and must be movable.
*/
template <typename UserData = void>
class Subscription {
static constexpr bool have_user_data = !std::is_void_v<UserData>;
using UserData_if_present = std::conditional_t<have_user_data, UserData, detail::no_data_t>;
using subdata_t = detail::SubData<UserData>;
std::unordered_map<ConnectionID, subdata_t> subs;
std::shared_mutex _mutex;
const std::string description; // description of the sub for logging
const std::chrono::milliseconds sub_duration; // extended by re-subscribe
public:
Subscription() = delete;
Subscription(std::string description, std::chrono::milliseconds sub_duration = 30min)
: description{std::move(description)}, sub_duration{sub_duration} {}
// returns true if new sub, false if refresh sub. throws on error. `data` will be checked
// against the existing data: if there is existing data and it compares `==` to the given value,
// false is returned (and the existing data is not replaced). Otherwise the given data gets
// stored for this connection (replacing existing data, if present), and true is returned.
bool subscribe(const ConnectionID& conn, UserData_if_present data) {
std::unique_lock lock{_mutex};
auto expiry = std::chrono::steady_clock::now() + sub_duration;
auto [value, added] = subs.emplace(conn, subdata_t{expiry});
if (added) {
if constexpr (have_user_data)
value->second.user_data = std::move(data);
return true;
}
value->second.expiry = expiry;
if constexpr (have_user_data) {
// if user_data changed, consider it a new sub rather than refresh, and update
// user_data in the mapped value.
if (!(value->second.user_data == data)) {
value->second.user_data = std::move(data);
return true;
}
}
return false;
}
// no-user-data version, only available for Subscription<void> (== Subscription without a
// UserData type).
template <bool enable = !have_user_data, std::enable_if_t<enable, int> = 0>
bool subscribe(const ConnectionID& conn) {
return subscribe(conn, detail::no_data);
}
// unsubscribe a connection ID. return the user data, if a sub was present.
template <bool enable = have_user_data, std::enable_if_t<enable, int> = 0>
std::optional<UserData> unsubscribe(const ConnectionID& conn) {
std::unique_lock lock{_mutex};
auto node = subs.extract(conn);
if (!node.empty())
return node.mapped().user_data;
return std::nullopt;
}
// no-user-data version, only available for Subscription<void> (== Subscription without a
// UserData type).
template <bool enable = !have_user_data, std::enable_if_t<enable, int> = 0>
bool unsubscribe(const ConnectionID& conn) {
std::unique_lock lock{_mutex};
auto node = subs.extract(conn);
return !node.empty(); // true if removed, false if wasn't present
}
// force removal of expired subscriptions. removal will otherwise only happen on publish
void remove_expired() {
std::unique_lock lock{_mutex};
auto now = std::chrono::steady_clock::now();
for (auto itr = subs.begin(); itr != subs.end();) {
if (itr->second.expiry < now)
itr = subs.erase(itr);
else
itr++;
}
}
// Func is any callable which takes:
// - (const ConnectionID&, const UserData&) for Subscription<UserData> with non-void UserData
// - (const ConnectionID&) for Subscription<void>.
template <typename Func>
void publish(Func&& func) {
std::vector<ConnectionID> to_remove;
{
std::shared_lock lock(_mutex);
if (subs.empty())
return;
auto now = std::chrono::steady_clock::now();
for (const auto& [conn, sub] : subs) {
if (sub.expiry < now)
to_remove.push_back(conn);
else if constexpr (have_user_data)
func(conn, sub.user_data);
else
func(conn);
}
}
if (to_remove.empty())
return;
std::unique_lock lock{_mutex};
auto now = std::chrono::steady_clock::now();
for (auto& conn : to_remove) {
auto it = subs.find(conn);
if (it != subs.end() && it->second.expiry < now /* recheck: client might have resubscribed in between locks */) {
subs.erase(it);
}
}
}
};
} // namespace oxenmq
// vim:sw=4:et

View file

@ -1,5 +0,0 @@
namespace oxenmq {
constexpr int VERSION_MAJOR = @PROJECT_VERSION_MAJOR@;
constexpr int VERSION_MINOR = @PROJECT_VERSION_MINOR@;
constexpr int VERSION_PATCH = @PROJECT_VERSION_PATCH@;
}

View file

@ -1,431 +0,0 @@
#include "oxenmq.h"
#include "batch.h"
#include "oxenmq-internal.h"
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
extern "C" {
#include <pthread.h>
#include <pthread_np.h>
}
#endif
#include <oxenc/variant.h>
namespace oxenmq {
namespace {
// Waits for a specific command or "QUIT" on the given socket. Returns true if the command was
// received. If "QUIT" was received, replies with "QUITTING" on the socket and closes it, then
// returns false.
[[gnu::always_inline]] inline
bool worker_wait_for(OxenMQ& omq, zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const std::string_view worker_id, const std::string_view expect) {
while (true) {
omq.log(LogLevel::trace, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect);
parts.clear();
recv_message_parts(sock, parts);
if (parts.size() != 1) {
omq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part control msg");
continue;
}
auto command = view(parts[0]);
if (command == expect) {
#ifndef NDEBUG
omq.log(LogLevel::trace, __FILE__, __LINE__, "Worker ", worker_id, " received waited-for ", expect, " command");
#endif
return true;
} else if (command == "QUIT"sv) {
omq.log(LogLevel::debug, __FILE__, __LINE__, "Worker ", worker_id, " received QUIT command, shutting down");
detail::send_control(sock, "QUITTING");
sock.set(zmq::sockopt::linger, 1000);
sock.close();
return false;
} else {
omq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
}
}
}
}
void OxenMQ::worker_thread(unsigned int index, std::optional<std::string> tagged, std::function<void()> start) {
std::string routing_id = (tagged ? "t" : "w") +
std::string(reinterpret_cast<const char*>(&index), sizeof(index)); // for routing
std::string worker_id{tagged ? *tagged : "w" + std::to_string(index)}; // for debug
[[maybe_unused]] std::string thread_name = tagged.value_or("omq-" + worker_id);
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
if (thread_name.size() > 15) thread_name.resize(15);
pthread_setname_np(pthread_self(), thread_name.c_str());
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
pthread_set_name_np(pthread_self(), thread_name.c_str());
#elif defined(__MACH__)
pthread_setname_np(thread_name.c_str());
#endif
std::optional<zmq::socket_t> tagged_socket;
if (tagged) {
// If we're a tagged worker then we got started up before OxenMQ started, so we need to wait
// for an all-clear signal from OxenMQ first, then we fire our `start` callback, then we can
// start waiting for commands in the main loop further down. (We also can't get the
// reference to our `tagged_workers` element or create a socket until the main proxy thread
// is running).
{
std::unique_lock lock{tagged_startup_mutex};
tagged_cv.wait(lock, [this] { return tagged_go != tagged_go_mode::WAIT; });
}
if (tagged_go == tagged_go_mode::SHUTDOWN) // OxenMQ destroyed without starting
return;
tagged_socket.emplace(context, zmq::socket_type::dealer);
}
auto& sock = tagged ? *tagged_socket : worker_sockets[index];
sock.set(zmq::sockopt::routing_id, routing_id);
OMQ_LOG(debug, "New worker thread ", worker_id, " (", routing_id, ") started");
sock.connect(SN_ADDR_WORKERS);
if (tagged)
detail::send_control(sock, "STARTING");
Message message{*this, 0, AuthLevel::none, ""s};
std::vector<zmq::message_t> parts;
bool waiting_for_command;
if (tagged) {
waiting_for_command = true;
if (!worker_wait_for(*this, sock, parts, worker_id, "START"sv))
return;
if (start) start();
} else {
// Otherwise for a regular worker we can only be started by an active main proxy thread
// which will have preloaded our first job so we can start off right away.
waiting_for_command = false;
}
// This will always contains the current job, and is guaranteed to never be invalidated.
run_info& run = tagged ? std::get<run_info>(tagged_workers[index - 1]) : workers[index];
while (true) {
if (waiting_for_command) {
if (!worker_wait_for(*this, sock, parts, worker_id, "RUN"sv))
return;
}
try {
if (run.is_batch_job) {
auto* batch = var::get<detail::Batch*>(run.to_run);
if (run.batch_jobno >= 0) {
OMQ_TRACE("worker thread ", worker_id, " running batch ", batch, "#", run.batch_jobno);
batch->run_job(run.batch_jobno);
} else if (run.batch_jobno == -1) {
OMQ_TRACE("worker thread ", worker_id, " running batch ", batch, " completion");
batch->job_completion();
}
} else if (run.is_injected) {
auto& func = var::get<std::function<void()>>(run.to_run);
OMQ_TRACE("worker thread ", worker_id, " invoking injected command ", run.command);
func();
func = nullptr;
} else {
message.conn = run.conn;
message.access = run.access;
message.remote = std::move(run.remote);
message.data.clear();
OMQ_TRACE("Got incoming command from ", message.remote, "/", message.conn, message.conn.route.empty() ? " (outgoing)" : " (incoming)");
auto& [callback, is_request] = *var::get<const std::pair<CommandCallback, bool>*>(run.to_run);
if (is_request) {
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
message.data.emplace_back(it->data<char>(), it->size());
} else {
for (auto& m : run.data_parts)
message.data.emplace_back(m.data<char>(), m.size());
}
OMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
callback(message);
}
}
catch (const oxenc::bt_deserialize_invalid& e) {
OMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
}
#ifndef BROKEN_APPLE_VARIANT
catch (const std::bad_variant_access& e) {
OMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
}
#endif
catch (const std::out_of_range& e) {
OMQ_LOG(warn, worker_id, " deserialization failed: invalid data - required field missing (", e.what(), "); ignoring request");
}
catch (const std::exception& e) {
OMQ_LOG(warn, worker_id, " caught exception when processing command: ", e.what());
}
catch (...) {
OMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
}
// Tell the proxy thread that we are ready for another job
detail::send_control(sock, "RAN");
waiting_for_command = true;
}
}
OxenMQ::run_info& OxenMQ::get_idle_worker() {
if (idle_worker_count == 0) {
uint32_t id = workers.size();
workers.emplace_back();
auto& r = workers.back();
r.worker_id = id;
r.worker_routing_id = "w" + std::string(reinterpret_cast<const char*>(&id), sizeof(id));
r.worker_routing_name = "w" + std::to_string(id);
return r;
}
size_t id = idle_workers[--idle_worker_count];
return workers[id];
}
void OxenMQ::proxy_worker_message(OxenMQ::control_message_array& parts, size_t len) {
// Process messages sent by workers
if (len != 2) {
OMQ_LOG(error, "Received send invalid ", len, "-part message");
return;
}
auto route = view(parts[0]), cmd = view(parts[1]);
if (route.size() != 5 || (route[0] != 'w' && route[0] != 't')) {
OMQ_LOG(error, "Received malformed worker id in worker message; unable to process worker command");
return;
}
bool tagged_worker = route[0] == 't';
uint32_t worker_id;
std::memcpy(&worker_id, route.data() + 1, 4);
if (tagged_worker
? 0 == worker_id || worker_id > tagged_workers.size() // tagged worker ids are indexed from 1 to N (0 means untagged)
: worker_id >= workers.size()) { // regular worker ids are indexed from 0 to N-1
OMQ_LOG(error, "Received invalid worker id w" + std::to_string(worker_id) + " in worker message; unable to process worker command");
return;
}
auto& run = tagged_worker ? std::get<run_info>(tagged_workers[worker_id - 1]) : workers[worker_id];
OMQ_TRACE("received ", cmd, " command from ", route);
if (cmd == "RAN"sv) {
OMQ_TRACE("Worker ", route, " finished ", run.is_batch_job ? "batch job" : run.command);
if (run.is_batch_job) {
if (tagged_worker) {
std::get<bool>(tagged_workers[worker_id - 1]) = false;
} else {
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
assert(active > 0);
active--;
}
bool clear_job = false;
auto* batch = var::get<detail::Batch*>(run.to_run);
if (run.batch_jobno == -1) {
// Returned from the completion function
clear_job = true;
} else {
auto [state, thread] = batch->job_finished();
if (state == detail::BatchState::complete) {
if (thread == -1) { // run directly in proxy
OMQ_TRACE("Completion job running directly in proxy");
try {
batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
} catch (const std::exception &e) {
// Raise these to error levels: the caller really shouldn't be doing
// anything non-trivial in an in-proxy completion function!
OMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
} catch (...) {
OMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
}
clear_job = true;
} else {
auto& jobs =
thread > 0
? std::get<batch_queue>(tagged_workers[thread - 1]) // run in tagged thread
: run.is_reply_job
? reply_jobs
: batch_jobs;
jobs.emplace_back(batch, -1);
}
} else if (state == detail::BatchState::done) {
// No completion job
clear_job = true;
}
// else the job is still running
}
if (clear_job) {
delete batch;
}
} else {
assert(run.cat->active_threads > 0);
run.cat->active_threads--;
}
if (max_workers == 0) { // Shutting down
OMQ_TRACE("Telling worker ", route, " to quit");
route_control(workers_socket, route, "QUIT");
} else if (!tagged_worker) {
idle_workers[idle_worker_count++] = worker_id;
}
} else if (cmd == "QUITTING"sv) {
run.worker_thread.join();
OMQ_LOG(debug, "Worker ", route, " exited normally");
} else {
OMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
}
}
void OxenMQ::proxy_run_worker(run_info& run) {
if (!run.worker_thread.joinable())
run.worker_thread = std::thread{&OxenMQ::worker_thread, this, run.worker_id, std::nullopt, nullptr};
else
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
}
void OxenMQ::proxy_to_worker(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
bool outgoing = sock.get(zmq::sockopt::type) == ZMQ_DEALER;
peer_info tmp_peer;
tmp_peer.conn_id = conn_id;
if (!outgoing) tmp_peer.route = parts[0].to_string();
peer_info* peer = nullptr;
if (outgoing) {
auto snit = outgoing_sn_conns.find(conn_id);
auto it = snit != outgoing_sn_conns.end()
? peers.find(snit->second)
: peers.find(conn_id);
if (it == peers.end()) {
OMQ_LOG(warn, "Internal error: connection id ", conn_id, " not found");
return;
}
peer = &it->second;
} else if (conn_id == inproc_listener_connid) {
tmp_peer.auth_level = AuthLevel::admin;
tmp_peer.pubkey = pubkey;
tmp_peer.service_node = active_service_nodes.count(pubkey);
peer = &tmp_peer;
} else {
std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey);
if (tmp_peer.service_node) {
// It's a service node so we should have a peer_info entry; see if we can find one with
// the same route, and if not, add one.
auto pr = peers.equal_range(tmp_peer.pubkey);
for (auto it = pr.first; it != pr.second; ++it) {
if (it->second.conn_id == tmp_peer.conn_id && it->second.route == tmp_peer.route) {
peer = &it->second;
// Update the stored auth level just in case the peer reconnected
peer->auth_level = tmp_peer.auth_level;
break;
}
}
if (!peer) {
// We don't have a record: this is either a new SN connection or a new message on a
// connection that recently gained SN status.
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
}
} else {
// Incoming, non-SN connection: we don't store a peer_info for this, so just use the
// temporary one
peer = &tmp_peer;
}
}
size_t command_part_index = outgoing ? 0 : 1;
std::string command = parts[command_part_index].to_string();
// Steal any data message parts
size_t data_part_index = command_part_index + 1;
std::vector<zmq::message_t> data_parts;
data_parts.reserve(parts.size() - data_part_index);
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
data_parts.push_back(std::move(*it));
auto cat_call = get_command(command);
// Check that command is valid, that we have permission, etc.
if (!proxy_check_auth(conn_id, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
return;
auto& category = *cat_call.first;
Access access{peer->auth_level, peer->service_node, local_service_node};
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
// No free reserved or general spots, try to queue it for later
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
OMQ_LOG(warn, "No space to queue incoming command ", command, "; already have ", category.queued,
"commands queued in that category (max ", category.max_queue, "); dropping message");
return;
}
OMQ_LOG(debug, "No available free workers, queuing ", command, " for later");
ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey, std::move(tmp_peer.route)};
pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second,
std::move(conn), std::move(access), peer_address(parts[command_part_index]));
category.queued++;
return;
}
if (cat_call.second->second /*is_request*/ && data_parts.empty()) {
OMQ_LOG(warn, "Received an invalid request command with no reply tag; dropping message");
return;
}
auto& run = get_idle_worker();
{
ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey};
c.route = std::move(tmp_peer.route);
if (outgoing || peer->service_node)
tmp_peer.route.clear();
run.load(&category, std::move(command), std::move(c), std::move(access), peer_address(parts[command_part_index]),
std::move(data_parts), cat_call.second);
}
if (outgoing)
peer->activity(); // outgoing connection activity, pump the activity timer
OMQ_TRACE("Forwarding incoming ", run.command, " from ", run.conn, " @ ", peer_address(parts[command_part_index]),
" to worker ", run.worker_routing_name);
proxy_run_worker(run);
category.active_threads++;
}
void OxenMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function<void()> callback) {
if (!callback) return;
auto it = categories.find(category);
if (it == categories.end())
throw std::out_of_range{"Invalid category `" + category + "': category does not exist"};
detail::send_control(get_control_socket(), "INJECT", oxenc::bt_serialize(detail::serialize_object(
injected_task{it->second, std::move(command), std::move(remote), std::move(callback)})));
}
void OxenMQ::proxy_inject_task(injected_task task) {
auto& category = task.cat;
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
// No free worker slot, queue for later
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
OMQ_LOG(warn, "No space to queue injected task ", task.command, "; already have ", category.queued,
"commands queued in that category (max ", category.max_queue, "); dropping task");
return;
}
OMQ_LOG(debug, "No available free workers for injected task ", task.command, "; queuing for later");
pending_commands.emplace_back(category, std::move(task.command), std::move(task.callback), std::move(task.remote));
category.queued++;
return;
}
auto& run = get_idle_worker();
OMQ_TRACE("Forwarding incoming injected task ", task.command, " from ", task.remote, " to worker ", run.worker_routing_name);
run.load(&category, std::move(task.command), std::move(task.remote), std::move(task.callback));
proxy_run_worker(run);
category.active_threads++;
}
}

View file

@ -1,27 +1,25 @@
add_subdirectory(Catch2)
add_executable(tests
set(LMQ_TEST_SRC
main.cpp
test_address.cpp
test_batch.cpp
test_connect.cpp
test_commands.cpp
test_failures.cpp
test_inject.cpp
test_pubsub.cpp
test_requests.cpp
test_socket_limit.cpp
test_tagged_threads.cpp
test_timer.cpp
)
test_string_view.cpp
)
add_executable(tests ${LMQ_TEST_SRC})
find_package(Threads)
find_package(PkgConfig REQUIRED)
pkg_check_modules(SODIUM REQUIRED libsodium)
target_link_libraries(tests Catch2::Catch2 oxenmq Threads::Threads)
target_link_libraries(tests Catch2::Catch2 lokimq ${SODIUM_LIBRARIES} Threads::Threads)
set_target_properties(tests PROPERTIES
CXX_STANDARD 17
CXX_STANDARD 14
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
)

@ -1 +1 @@
Subproject commit dba29b60d639bf8d206a9a12c223e6ed4284fb13
Subproject commit b3b07215d1ca2224aea6ff3e21d87ad0f7750df2

View file

@ -1,61 +1,18 @@
#pragma once
#include "oxenmq/oxenmq.h"
#include "lokimq/lokimq.h"
#include <catch2/catch.hpp>
#include <chrono>
using namespace oxenmq;
// Apple's mutexes, thread scheduling, and IO handling are garbage and it shows up with lots of
// spurious failures in this test suite (because it expects a system to not suck that badly), so we
// multiply the time-sensitive bits by this factor as a hack to make the test suite work.
constexpr int TIME_DILATION =
#ifdef __APPLE__
5;
#else
1;
#endif
using namespace lokimq;
static auto startup = std::chrono::steady_clock::now();
/// Returns a localhost connection string to listen on. It can be considered random, though in
/// practice in the current implementation is sequential starting at 25432.
inline std::string random_localhost() {
static std::atomic<uint16_t> last = 25432;
last++;
assert(last); // We should never call this enough to overflow
return "tcp://127.0.0.1:" + std::to_string(last);
}
// Catch2 macros aren't thread safe, so guard with a mutex
inline std::unique_lock<std::mutex> catch_lock() {
static std::mutex mutex;
return std::unique_lock<std::mutex>{mutex};
}
/// Waits up to 200ms for something to happen.
template <typename Func>
inline void wait_for(Func f) {
auto start = std::chrono::steady_clock::now();
for (int i = 0; i < 20; i++) {
if (f())
break;
std::this_thread::sleep_for(10ms * TIME_DILATION);
}
auto lock = catch_lock();
UNSCOPED_INFO("done waiting after " << (std::chrono::steady_clock::now() - start).count() << "ns");
}
/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough
/// time for an initial connection + request.
inline void wait_for_conn(std::atomic<bool> &c) {
wait_for([&c] { return c.load(); });
}
/// Waits enough time for us to receive a reply from a localhost remote.
inline void reply_sleep() { std::this_thread::sleep_for(10ms * TIME_DILATION); }
inline OxenMQ::Logger get_logger(std::string prefix = "") {
inline LokiMQ::Logger get_logger(std::string prefix = "") {
std::string me = "tests/common.h";
std::string strip = __FILE__;
if (strip.substr(strip.size() - me.size()) == me)

View file

@ -1,162 +0,0 @@
#include "oxenmq/address.h"
#include "common.h"
const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88";
const std::string pk_hex = "f16ba5591039f089b42a834175093094074d0d937a79e53e5ce730f946e14b88";
const std::string pk_HEX = "F16BA5591039F089B42A834175093094074D0D937A79E53E5CE730F946E14B88";
const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqry";
const std::string pk_B32Z = "6FI4KSEO88AEUPBKOPYZKNJO1ODW4DCUXJH6KX1HHHAX1TZBJQRY";
const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g"; // NB: padding '=' omitted
TEST_CASE("tcp addresses", "[address][tcp]") {
address a{"tcp://1.2.3.4:5678"};
REQUIRE( a.host == "1.2.3.4" );
REQUIRE( a.port == 5678 );
REQUIRE_FALSE( a.curve() );
REQUIRE( a.tcp() );
REQUIRE( a.zmq_address() == "tcp://1.2.3.4:5678" );
REQUIRE( a.full_address() == "tcp://1.2.3.4:5678" );
REQUIRE( a.qr_address() == "TCP://1.2.3.4:5678" );
REQUIRE_THROWS_AS( address{"tcp://1:1:1"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcpz://localhost:123"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcp://abc"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcpz://localhost:0"}, std::invalid_argument );
REQUIRE_THROWS_AS( address{"tcpz://[::1:1080"}, std::invalid_argument );
address b = address::tcp("example.com", 80);
REQUIRE( b.host == "example.com" );
REQUIRE( b.port == 80 );
REQUIRE_FALSE( b.curve() );
REQUIRE( b.tcp() );
REQUIRE( b.zmq_address() == "tcp://example.com:80" );
REQUIRE( b.full_address() == "tcp://example.com:80" );
REQUIRE( b.qr_address() == "TCP://EXAMPLE.COM:80" );
address c{"tcp://[::1]:1111"};
REQUIRE( c.host == "[::1]" );
REQUIRE( c.port == 1111 );
}
TEST_CASE("unix sockets", "[address][ipc]") {
address a{"ipc:///path/to/foo"};
REQUIRE( a.socket == "/path/to/foo" );
REQUIRE_FALSE( a.curve() );
REQUIRE_FALSE( a.tcp() );
REQUIRE( a.zmq_address() == "ipc:///path/to/foo" );
REQUIRE( a.full_address() == "ipc:///path/to/foo" );
address b = address::ipc("../foo");
REQUIRE( b.socket == "../foo" );
REQUIRE_FALSE( b.curve() );
REQUIRE_FALSE( b.tcp() );
REQUIRE( b.zmq_address() == "ipc://../foo" );
REQUIRE( b.full_address() == "ipc://../foo" );
}
TEST_CASE("pubkey formats", "[address][curve][pubkey]") {
address a{"tcp+curve://a:1/" + pk_hex};
address b{"curve://a:1/" + pk_b32z};
address c{"curve://a:1/" + pk_b64};
address d{"CURVE://A:1/" + pk_B32Z};
REQUIRE( a.curve() );
REQUIRE( a.host == "a" );
REQUIRE( a.port == 1 );
REQUIRE((b.curve() && c.curve() && d.curve()));
REQUIRE( a.pubkey == pk );
REQUIRE( b.pubkey == pk );
REQUIRE( c.pubkey == pk );
REQUIRE( d.pubkey == pk );
address e{"ipc+curve://my.sock/" + pk_hex};
address f{"ipc+curve://../my.sock/" + pk_b32z};
address g{"ipc+curve:///my.sock/" + pk_B32Z};
address h{"ipc+curve://./my.sock/" + pk_b64};
REQUIRE( e.curve() );
REQUIRE( e.ipc() );
REQUIRE_FALSE( e.tcp() );
REQUIRE((f.curve() && g.curve() && h.curve()));
REQUIRE( e.socket == "my.sock" );
REQUIRE( f.socket == "../my.sock" );
REQUIRE( g.socket == "/my.sock" );
REQUIRE( h.socket == "./my.sock" );
REQUIRE( e.pubkey == pk );
REQUIRE( f.pubkey == pk );
REQUIRE( g.pubkey == pk );
REQUIRE( h.pubkey == pk );
REQUIRE( d.full_address(address::encoding::hex) == "curve://a:1/" + pk_hex );
REQUIRE( c.full_address(address::encoding::base32z) == "curve://a:1/" + pk_b32z );
REQUIRE( b.full_address(address::encoding::BASE32Z) == "curve://a:1/" + pk_B32Z );
REQUIRE( a.full_address(address::encoding::base64) == "curve://a:1/" + pk_b64 );
REQUIRE( h.full_address(address::encoding::hex) == "ipc+curve://./my.sock/" + pk_hex );
REQUIRE( g.full_address(address::encoding::base32z) == "ipc+curve:///my.sock/" + pk_b32z );
REQUIRE( f.full_address(address::encoding::BASE32Z) == "ipc+curve://../my.sock/" + pk_B32Z );
REQUIRE( e.full_address(address::encoding::base64) == "ipc+curve://my.sock/" + pk_b64 );
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_hex.substr(0, 63)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b32z.substr(0, 51)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_B32Z.substr(0, 51)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b64.substr(0, 42)}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock"}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/"}, std::invalid_argument);
}
TEST_CASE("tcp QR-code friendly addresses", "[address][tcp][qr]") {
address a{"tcp://public.loki.foundation:12345"};
address a_qr{"TCP://PUBLIC.LOKI.FOUNDATION:12345"};
address b{"tcp://PUBLIC.LOKI.FOUNDATION:12345"};
REQUIRE( a == a_qr );
REQUIRE( a != b );
REQUIRE( a.host == "public.loki.foundation" );
REQUIRE( a.qr_address() == "TCP://PUBLIC.LOKI.FOUNDATION:12345" );
address c = address::tcp_curve("public.loki.foundation", 12345, pk);
REQUIRE( c.qr_address() == "CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z );
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z} == c );
// We don't produce with upper-case hex, but we accept it:
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_HEX} == c );
// lower case not permitted: ▾
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATiON:12345/" + pk_B32Z}, std::invalid_argument);
// also only accept upper-base base32z and hex:
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_b32z}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_hex}, std::invalid_argument);
// don't accept base64 even if it's upper-case (because case-converting it changes the value)
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}, std::invalid_argument);
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}, std::invalid_argument);
}
TEST_CASE("address hashing", "[address][hash]") {
address a{"tcp://public.loki.foundation:12345"};
address b{"tcp+curve://public.loki.foundation:12345/" + pk_hex};
address c{"ipc:///tmp/some.sock"};
address d{"ipc:///tmp/some.other.sock"};
std::hash<oxenmq::address> hasher{};
REQUIRE( hasher(a) != hasher(b) );
REQUIRE( hasher(a) != hasher(c) );
REQUIRE( hasher(a) != hasher(d) );
REQUIRE( hasher(b) != hasher(c) );
REQUIRE( hasher(b) != hasher(d) );
REQUIRE( hasher(c) != hasher(d) );
std::unordered_set<oxenmq::address> set;
set.insert(a);
set.insert(b);
set.insert(c);
set.insert(d);
CHECK( set.size() == 4 );
std::unordered_map<oxenmq::address, int> count;
for (const auto& addr : set)
count[addr]++;
REQUIRE( count.size() == 4 );
CHECK( count[a] == 1 );
CHECK( count[b] == 1 );
CHECK( count[c] == 1 );
CHECK( count[d] == 1 );
}

View file

@ -1,4 +1,4 @@
#include "oxenmq/batch.h"
#include "lokimq/batch.h"
#include "common.h"
#include <future>
@ -12,7 +12,7 @@ double do_my_task(int input) {
std::promise<std::pair<double, int>> done;
void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
void continue_big_task(std::vector<lokimq::job_result<double>> results) {
double sum = 0;
int exc_count = 0;
for (auto& r : results) {
@ -25,10 +25,10 @@ void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
done.set_value({sum, exc_count});
}
void start_big_task(oxenmq::OxenMQ& omq) {
void start_big_task(lokimq::LokiMQ& lmq) {
size_t num_jobs = 32;
oxenmq::Batch<double /*return type*/> batch;
lokimq::Batch<double /*return type*/> batch;
batch.reserve(num_jobs);
for (size_t i = 0; i < num_jobs; i++)
@ -36,21 +36,21 @@ void start_big_task(oxenmq::OxenMQ& omq) {
batch.completion(&continue_big_task);
omq.batch(std::move(batch));
lmq.batch(std::move(batch));
}
TEST_CASE("batching many small jobs", "[batch-many]") {
oxenmq::OxenMQ omq{
lokimq::LokiMQ lmq{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
omq.set_general_threads(4);
omq.set_batch_threads(4);
omq.start();
lmq.set_general_threads(4);
lmq.set_batch_threads(4);
lmq.start();
start_big_task(omq);
start_big_task(lmq);
auto sum = done.get_future().get();
auto lock = catch_lock();
REQUIRE( sum.first == 1337.0 );
@ -58,14 +58,14 @@ TEST_CASE("batching many small jobs", "[batch-many]") {
}
TEST_CASE("batch exception propagation", "[batch-exceptions]") {
oxenmq::OxenMQ omq{
lokimq::LokiMQ lmq{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
omq.set_general_threads(4);
omq.set_batch_threads(4);
omq.start();
lmq.set_general_threads(4);
lmq.set_batch_threads(4);
lmq.start();
std::promise<void> done_promise;
std::future<void> done_future = done_promise.get_future();
@ -73,7 +73,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
using Catch::Matchers::Message;
SECTION( "value return" ) {
oxenmq::Batch<int> batch;
lokimq::Batch<int> batch;
for (int i : {1, 2})
batch.add_job([i]() { if (i == 1) return 42; throw std::domain_error("bad value " + std::to_string(i)); });
batch.completion([&done_promise](auto results) {
@ -83,12 +83,12 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
REQUIRE_THROWS_MATCHES( results[1].get() == 0, std::domain_error, Message("bad value 2") );
done_promise.set_value();
});
omq.batch(std::move(batch));
lmq.batch(std::move(batch));
done_future.get();
}
SECTION( "lvalue return" ) {
oxenmq::Batch<int&> batch;
lokimq::Batch<int&> batch;
int forty_two = 42;
for (int i : {1, 2})
batch.add_job([i,&forty_two]() -> int& {
@ -105,12 +105,12 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
REQUIRE_THROWS_MATCHES( results[1].get(), std::domain_error, Message("bad value 2") );
done_promise.set_value();
});
omq.batch(std::move(batch));
lmq.batch(std::move(batch));
done_future.get();
}
SECTION( "void return" ) {
oxenmq::Batch<void> batch;
lokimq::Batch<void> batch;
for (int i : {1, 2})
batch.add_job([i]() { if (i != 1) throw std::domain_error("bad value " + std::to_string(i)); });
batch.completion([&done_promise](auto results) {
@ -120,7 +120,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
REQUIRE_THROWS_MATCHES( results[1].get(), std::domain_error, Message("bad value 2") );
done_promise.set_value();
});
omq.batch(std::move(batch));
lmq.batch(std::move(batch));
done_future.get();
}
}

View file

@ -1,20 +1,21 @@
#include "common.h"
#include <oxenc/hex.h>
#include <future>
#include <lokimq/hex.h>
#include <map>
#include <set>
using namespace oxenmq;
using namespace lokimq;
TEST_CASE("basic commands", "[commands]") {
std::string listen = random_localhost();
OxenMQ server{
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
get_logger("")
};
server.listen_curve(listen);
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
std::atomic<int> hellos{0}, his{0};
@ -31,43 +32,52 @@ TEST_CASE("basic commands", "[commands]") {
server.start();
OxenMQ client{get_logger(""), LogLevel::trace};
LokiMQ client{
get_logger("")
};
client.log_level(LogLevel::trace);
client.add_category("public", Access{AuthLevel::none});
client.add_command("public", "hi", [&](auto&) { his++; });
client.start();
std::atomic<bool> got{false};
bool success = false, failed = false;
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
[&](auto conn, std::string_view) { failed = true; got = true; });
auto c = client.connect_remote(listen,
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto conn, string_view) { failed = true; },
server.get_pubkey());
wait_for_conn(got);
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
REQUIRE( connected.load() );
REQUIRE( i <= 1 ); // should be fast
REQUIRE( !failed.load() );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
client.send(c, "public.hello");
client.send(c, "public.client.pubkey");
reply_sleep();
std::this_thread::sleep_for(50ms);
{
auto lock = catch_lock();
REQUIRE( hellos == 1 );
REQUIRE( his == 1 );
REQUIRE( oxenc::to_hex(client_pubkey) == oxenc::to_hex(client.get_pubkey()) );
REQUIRE( to_hex(client_pubkey) == to_hex(client.get_pubkey()) );
}
for (int i = 0; i < 50; i++)
client.send(c, "public.hello");
wait_for([&] { return his == 26; });
std::this_thread::sleep_for(100ms);
{
auto lock = catch_lock();
REQUIRE( hellos == 51 );
@ -76,15 +86,15 @@ TEST_CASE("basic commands", "[commands]") {
}
TEST_CASE("outgoing auth level", "[commands][auth]") {
std::string listen = random_localhost();
OxenMQ server{
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
get_logger("")
};
server.listen_curve(listen);
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
std::atomic<int> hellos{0};
@ -93,7 +103,10 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
server.start();
OxenMQ client{get_logger(""), LogLevel::trace};
LokiMQ client{
get_logger("")
};
client.log_level(LogLevel::trace);
std::atomic<int> public_hi{0}, basic_hi{0}, admin_hi{0};
client.add_category("public", Access{AuthLevel::none});
@ -104,15 +117,15 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.add_command("admin", "hi", [&](auto&) { admin_hi++; });
client.start();
client.EPHEMERAL_ROUTING_ID = true; // establishing multiple connections below, so we need unique routing ids
client.PUBKEY_BASED_ROUTING_ID = false; // establishing multiple connections below, so we need unique routing ids
address server_addr{listen, server.get_pubkey()};
auto public_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {});
auto basic_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::basic);
auto admin_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::admin);
auto public_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey());
auto basic_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::basic);
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
client.send(public_c, "public.reflect", "public.hi");
wait_for([&] { return public_hi == 1; });
std::this_thread::sleep_for(50ms);
{
auto lock = catch_lock();
REQUIRE( public_hi == 1 );
@ -125,7 +138,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.send(admin_c, "public.reflect", "admin.hi");
client.send(basic_c, "public.reflect", "basic.hi");
wait_for([&] { return basic_hi == 2; });
std::this_thread::sleep_for(50ms);
{
auto lock = catch_lock();
REQUIRE( admin_hi == 3 );
@ -147,7 +160,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.send(admin_c, "public.reflect", "basic.hi");
client.send(admin_c, "public.reflect", "public.hi");
wait_for([&] { return public_hi == 3; });
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( admin_hi == 1 );
REQUIRE( basic_hi == 2 );
@ -158,41 +171,33 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
// Tests that the ConnectionID from a Message can be stored and reused later to contact the
// original node.
std::string listen = random_localhost();
OxenMQ server{
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
get_logger("")
};
server.listen_curve(listen);
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
std::vector<std::pair<ConnectionID, std::string>> subscribers;
ConnectionID backdoor;
server.add_category("hey google", Access{AuthLevel::none});
server.add_request_command("hey google", "remember", [&](Message& m) {
bool bd;
{
auto l = catch_lock();
subscribers.emplace_back(m.conn, std::string{m.data[0]});
bd = (bool) backdoor;
}
auto l = catch_lock();
subscribers.emplace_back(m.conn, std::string{m.data[0]});
m.send_reply("Okay, I'll remember that.");
if (bd)
m.oxenmq.send(backdoor, "backdoor.data", m.data[0]);
if (backdoor)
m.lokimq.send(backdoor, "backdoor.data", m.data[0]);
});
server.add_command("hey google", "recall", [&](Message& m) {
decltype(subscribers) subs;
{
auto l = catch_lock();
subs = subscribers;
}
for (auto& s : subs)
auto l = catch_lock();
for (auto& s : subscribers) {
server.send(s.first, "personal.detail", s.second);
}
});
server.add_command("hey google", "install backdoor", [&](Message& m) {
auto l = catch_lock();
@ -201,29 +206,29 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
server.start();
auto connect_success = [&](auto&&...) { auto l = catch_lock(); REQUIRE(true); };
auto connect_failure = [&](auto&&...) { auto l = catch_lock(); REQUIRE(false); };
auto connect_success = [&](...) { auto l = catch_lock(); REQUIRE(true); };
auto connect_failure = [&](...) { auto l = catch_lock(); REQUIRE(false); };
std::set<std::string> backdoor_details;
OxenMQ nsa{get_logger("NSA» ")};
LokiMQ nsa{get_logger("NSA» ")};
nsa.add_category("backdoor", Access{AuthLevel::admin});
nsa.add_command("backdoor", "data", [&](Message& m) {
auto l = catch_lock();
backdoor_details.emplace(m.data[0]);
backdoor_details.emplace(m.data[0]);
});
nsa.start();
auto nsa_c = nsa.connect_remote(address{listen, server.get_pubkey()}, connect_success, connect_failure, AuthLevel::admin);
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
nsa.send(nsa_c, "hey google.install backdoor");
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
std::this_thread::sleep_for(50ms);
{
auto l = catch_lock();
REQUIRE( backdoor );
}
std::vector<std::unique_ptr<OxenMQ>> clients;
std::vector<std::unique_ptr<LokiMQ>> clients;
std::vector<ConnectionID> conns;
std::map<int, std::set<std::string>> personal_details{
{0, {"Loretta"s, "photos"s}},
@ -235,14 +240,12 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
std::set<std::string> all_the_things;
for (auto& pd : personal_details) all_the_things.insert(pd.second.begin(), pd.second.end());
address server_addr{listen, server.get_pubkey()};
std::map<int, std::set<std::string>> google_knows;
int things_remembered{0};
for (int i = 0; i < 5; i++) {
clients.push_back(std::make_unique<OxenMQ>(
get_logger("C" + std::to_string(i) + "» "), LogLevel::trace
));
clients.push_back(std::make_unique<LokiMQ>(get_logger("C" + std::to_string(i) + "» ")));
auto& c = clients.back();
c->log_level(LogLevel::trace);
c->add_category("personal", Access{AuthLevel::basic});
c->add_command("personal", "detail", [&,i](Message& m) {
auto l = catch_lock();
@ -250,7 +253,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
});
c->start();
conns.push_back(
c->connect_remote(server_addr, connect_success, connect_failure, AuthLevel::basic));
c->connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::basic));
for (auto& personal_detail : personal_details[i])
c->request(conns.back(), "hey google.remember",
[&](bool success, std::vector<std::string> data) {
@ -262,7 +265,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
},
personal_detail);
}
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size() && things_remembered == backdoor_details.size(); });
std::this_thread::sleep_for(50ms);
{
auto l = catch_lock();
REQUIRE( things_remembered == all_the_things.size() );
@ -270,256 +273,9 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
}
clients[0]->send(conns[0], "hey google.recall");
reply_sleep();
std::this_thread::sleep_for(50ms);
{
auto l = catch_lock();
REQUIRE( google_knows == personal_details );
}
}
TEST_CASE("send failure callbacks", "[commands][queue_full]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::debug // This test traces so much that it takes 2.5-3s of CPU time at trace level, so don't do that.
};
server.listen_plain(listen);
std::atomic<int> send_attempts{0};
std::atomic<int> send_failures{0};
// ZMQ TCP sockets' HWM is complicated and OS dependent; sender and receiver (probably) each
// have 1000 message queues, but there is also the TCP queue to worry about which means we can
// have more queued before we fill up, so we send 4kiB of null with each message so that we
// don't get too much TCP queuing.
std::string junk(4096, '0');
server.add_category("x", Access{AuthLevel::none})
.add_command("x", [&](Message& m) {
for (int x = 0; x < 500; x++) {
++send_attempts;
m.send_back("y.y", junk, send_option::queue_full{[&]() { ++send_failures; }});
}
});
server.start();
// Use a raw socket here because I want to stall it by not reading from it at all, and that is
// hard with OxenMQ.
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
zmq::message_t hello;
auto recvd = client.recv(hello);
std::string_view hello_sv{hello.data<char>(), hello.size()};
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello_sv == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
// Tell the remote to queue up a batch of messages
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
int i;
for (i = 0; i < 20; i++) {
if (send_attempts.load() >= 500)
break;
std::this_thread::sleep_for(10ms);
}
{
auto lock = catch_lock();
REQUIRE( i < 20 ); // should be not too slow
// We have two buffers here: 1000 on the receiver, and 1000 on the client, which means we
// should be able to get 2000 out before we hit HWM. We should only have been sent 501 so
// far (the "HELLO" handshake + 500 "y.y" messages).
REQUIRE( send_attempts.load() == 500 );
REQUIRE( send_failures.load() == 0 );
}
// Now we want to tell the server to send enough to fill the outgoing queue and start stalling.
// This is complicated as it depends on ZMQ internals *and* OS-level TCP buffers, so we really
// don't know precisely where this will start failing.
//
// In practice, I seem to reach HWM (for this test, with this amount of data being sent, on my
// Debian desktop) after 2499 messages (that is, queuing 2500 gives 1 failure).
int expected_attempts = 500;
for (int i = 0; i < 10; i++) {
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
expected_attempts += 500;
if (i >= 4) {
if (send_failures.load() > 0)
break;
std::this_thread::sleep_for(25ms);
}
}
for (i = 0; i < 100; i++) {
if (send_attempts.load() >= expected_attempts)
break;
std::this_thread::sleep_for(10ms);
}
{
auto lock = catch_lock();
REQUIRE( i < 100 );
REQUIRE( send_attempts.load() == expected_attempts );
REQUIRE( send_failures.load() > 0 );
}
}
TEST_CASE("data parts", "[commands][send][data_parts]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_curve(listen);
std::mutex mut;
std::vector<std::string> r;
server.add_category("public", Access{AuthLevel::none});
server.add_command("public", "hello", [&](Message& m) {
std::lock_guard l{mut};
for (const auto& s : m.data)
r.emplace_back(s);
});
server.start();
OxenMQ client{get_logger(""), LogLevel::trace};
client.start();
std::atomic<bool> got{false};
bool success = false, failed = false;
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
[&](auto conn, std::string_view) { failed = true; got = true; });
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
}
std::vector some_data{{"abc"s, "def"s, "omg123\0zzz"s}};
client.send(c, "public.hello", oxenmq::send_option::data_parts(some_data.begin(), some_data.end()));
reply_sleep();
{
auto lock = catch_lock();
std::lock_guard l{mut};
REQUIRE( r == some_data );
r.clear();
}
std::optional<std::string_view> opt1, opt2;
std::optional<std::string> opt3, opt4;
opt1 = "o1"sv;
opt4 = "o4"s;
std::vector some_data2{{"a"sv, "b"sv, "\0"sv}};
client.send(c, "public.hello",
"hi",
oxenmq::send_option::data_parts(some_data2.begin(), some_data2.end()),
"another",
"string"sv,
oxenmq::send_option::data_parts(some_data.begin(), some_data.end()),
opt1, opt2, opt3, opt4
);
std::vector<std::string> expected;
expected.push_back("hi");
expected.insert(expected.end(), some_data2.begin(), some_data2.end());
expected.push_back("another");
expected.push_back("string");
expected.insert(expected.end(), some_data.begin(), some_data.end());
expected.push_back("o1");
expected.push_back("o4");
reply_sleep();
{
auto lock = catch_lock();
std::lock_guard l{mut};
REQUIRE( r == expected );
}
}
TEST_CASE("deferred replies", "[commands][send][deferred]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
std::atomic<int> hellos{0}, his{0};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "echo", [&](Message& m) {
std::string msg = m.data.empty() ? ""s : std::string{m.data.front()};
std::thread t{[send=m.send_later(), msg=std::move(msg)] {
{ auto lock = catch_lock(); UNSCOPED_INFO("sleeping"); }
std::this_thread::sleep_for(50ms * TIME_DILATION);
{ auto lock = catch_lock(); UNSCOPED_INFO("sending"); }
send(msg);
}};
t.detach();
});
server.set_general_threads(1);
server.start();
OxenMQ client(
get_logger(""),
LogLevel::trace);
//client.log_level(LogLevel::trace);
client.start();
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
}
std::unordered_set<std::string> replies;
std::mutex reply_mut;
std::vector<std::string> data;
for (auto str : {"hello", "world", "omg"})
client.request(c, "public.echo", [&](bool ok, std::vector<std::string> data_) {
std::lock_guard lock{reply_mut};
replies.insert(std::move(data_[0]));
}, str);
reply_sleep();
{
std::lock_guard lq{reply_mut};
auto lock = catch_lock();
REQUIRE( replies.size() == 0 ); // The server waits 50ms before sending, so we shouldn't have any reply yet
}
std::this_thread::sleep_for(60ms * TIME_DILATION); // We're at least 70ms in now so the 50ms-delayed server responses should have arrived
{
std::lock_guard lq{reply_mut};
auto lock = catch_lock();
REQUIRE( replies == std::unordered_set<std::string>{{"hello", "world", "omg"}} );
}
}

View file

@ -1,75 +1,74 @@
#include "common.h"
#include <future>
#include <lokimq/hex.h>
extern "C" {
#include <sodium.h>
}
TEST_CASE("connections with curve authentication", "[curve][connect]") {
std::string listen = random_localhost();
OxenMQ server{
std::string listen = "tcp://127.0.0.1:4455";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
get_logger("")
};
server.log_level(LogLevel::trace);
server.listen_curve(listen);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
server.start();
OxenMQ client{get_logger(""), LogLevel::trace};
LokiMQ client{get_logger("")};
client.log_level(LogLevel::trace);
client.start();
auto pubkey = server.get_pubkey();
std::atomic<bool> got{false};
bool success = false;
auto server_conn = client.connect_remote(address{listen, pubkey},
[&](auto conn) { success = true; got = true; },
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; });
std::atomic<int> connected{0};
auto server_conn = client.connect_remote(listen,
[&](auto conn) { connected = 1; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); },
pubkey);
wait_for_conn(got);
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( i <= 1 );
REQUIRE( connected.load() );
}
success = false;
bool success = false;
std::vector<std::string> parts;
client.request(server_conn, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( success );
}
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( success );
}
TEST_CASE("self-connection SN optimization", "[connect][self]") {
std::string pubkey, privkey;
pubkey.resize(crypto_box_PUBLICKEYBYTES);
privkey.resize(crypto_box_SECRETKEYBYTES);
REQUIRE(sodium_init() != -1);
auto listen_addr = random_localhost();
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
OxenMQ sn{
LokiMQ sn{
pubkey, privkey,
true,
[&](auto pk) { if (pk == pubkey) return listen_addr; else return ""s; },
get_logger(""),
LogLevel::trace
[&](auto pk) { if (pk == pubkey) return "tcp://127.0.0.1:5544"; else return ""; },
get_logger("")
};
sn.listen_curve(listen_addr, [&](auto ip, auto pk, auto sn) {
auto lock = catch_lock();
REQUIRE(ip == "127.0.0.1");
REQUIRE(sn == (pk == pubkey));
return AuthLevel::none;
});
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk) { REQUIRE(ip == "127.0.0.1"); return Allow{AuthLevel::none, pk == pubkey}; });
sn.add_category("a", Access{AuthLevel::none});
std::atomic<bool> invoked{false};
bool invoked = false;
sn.add_command("a", "b", [&](const Message& m) {
invoked = true;
auto lock = catch_lock();
@ -78,532 +77,98 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
REQUIRE(!m.data.empty());
REQUIRE(m.data[0] == "my data");
});
sn.set_active_sns({{pubkey}});
sn.log_level(LogLevel::trace);
sn.start();
std::this_thread::sleep_for(50ms);
sn.send(pubkey, "a.b", "my data");
wait_for_conn(invoked);
{
auto lock = catch_lock();
REQUIRE(invoked);
}
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE(invoked);
}
TEST_CASE("plain-text connections", "[plaintext][connect]") {
std::string listen = random_localhost();
OxenMQ server{get_logger(""), LogLevel::trace};
std::string listen = "tcp://127.0.0.1:4455";
LokiMQ server{get_logger("")};
server.log_level(LogLevel::trace);
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
server.listen_plain(listen);
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.start();
OxenMQ client{get_logger(""), LogLevel::trace};
LokiMQ client{get_logger("")};
client.log_level(LogLevel::trace);
client.start();
std::atomic<bool> got{false};
bool success = false;
auto c = client.connect_remote(address{listen},
[&](auto conn) { success = true; got = true; },
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
std::atomic<int> connected{0};
auto c = client.connect_remote(listen,
[&](auto conn) { connected = 1; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
);
wait_for_conn(got);
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( i <= 1 );
REQUIRE( connected.load() );
}
success = false;
bool success = false;
std::vector<std::string> parts;
client.request(c, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( success );
}
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( success );
}
TEST_CASE("post-start listening", "[connect][listen]") {
OxenMQ server{get_logger(""), LogLevel::trace};
server.add_category("x", AuthLevel::none)
.add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); });
server.start();
std::atomic<int> listens = 0;
auto listen_curve = random_localhost();
server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; });
auto listen_plain = random_localhost();
server.listen_plain(listen_plain, nullptr, [&](bool success) { if (success) listens += 10; });
wait_for([&] { return listens.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( listens == 11 );
}
// This should fail since we're already listening on it:
server.listen_curve(listen_plain, nullptr, [&](bool success) { if (!success) listens++; });
wait_for([&] { return listens.load() >= 12; });
{
auto lock = catch_lock();
REQUIRE( listens == 12 );
}
OxenMQ client{get_logger("C1» "), LogLevel::trace};
client.start();
std::atomic<int> conns = 0;
auto c1 = client.connect_remote(address{listen_curve, server.get_pubkey()},
[&](auto) { conns++; },
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
auto c2 = client.connect_remote(address{listen_plain},
[&](auto) { conns += 10; },
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
wait_for([&] { return conns.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( conns == 11 );
}
std::atomic<int> replies = 0;
std::string reply1, reply2;
client.request(c1, "x.y", [&](auto success, auto parts) { replies++; for (auto& p : parts) reply1 += p; }, " world");
client.request(c2, "x.y", [&](auto success, auto parts) { replies += 10; for (auto& p : parts) reply2 += p; }, " cat");
wait_for([&] { return replies.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( replies == 11 );
REQUIRE( reply1 == "hi world" );
REQUIRE( reply2 == "hi cat" );
}
}
TEST_CASE("unique connection IDs", "[connect][id]") {
std::string listen = random_localhost();
OxenMQ server{get_logger(""), LogLevel::trace};
ConnectionID first, second;
server.add_category("x", Access{AuthLevel::none})
.add_request_command("x", [&](Message& m) { first = m.conn; m.send_reply("hi"); })
.add_request_command("y", [&](Message& m) { second = m.conn; m.send_reply("hi"); })
;
server.listen_plain(listen);
server.start();
OxenMQ client1{get_logger("C1» "), LogLevel::trace};
OxenMQ client2{get_logger("C2» "), LogLevel::trace};
client1.start();
client2.start();
std::atomic<bool> good1{false}, good2{false};
auto r1 = client1.connect_remote(address{listen},
[&](auto conn) { good1 = true; },
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
);
auto r2 = client2.connect_remote(address{listen},
[&](auto conn) { good2 = true; },
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
);
wait_for_conn(good1);
wait_for_conn(good2);
{
auto lock = catch_lock();
REQUIRE( good1 );
REQUIRE( good2 );
REQUIRE( first == second );
REQUIRE_FALSE( first );
REQUIRE_FALSE( second );
}
good1 = false;
good2 = false;
client1.request(r1, "x.x", [&](auto success_, auto parts_) { good1 = true; });
client2.request(r2, "x.y", [&](auto success_, auto parts_) { good2 = true; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( good1 );
REQUIRE( good2 );
REQUIRE_FALSE( first == second );
REQUIRE_FALSE( std::hash<ConnectionID>{}(first) == std::hash<ConnectionID>{}(second) );
}
}
TEST_CASE("SN disconnections", "[connect][disconnect]") {
std::vector<std::unique_ptr<OxenMQ>> omq;
std::vector<std::unique_ptr<LokiMQ>> lmq;
std::vector<std::string> pubkey, privkey;
std::unordered_map<std::string, std::string> conn;
REQUIRE(sodium_init() != -1);
for (int i = 0; i < 3; i++) {
pubkey.emplace_back();
privkey.emplace_back();
pubkey[i].resize(crypto_box_PUBLICKEYBYTES);
privkey[i].resize(crypto_box_SECRETKEYBYTES);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[i][0]), reinterpret_cast<unsigned char*>(&privkey[i][0]));
conn.emplace(pubkey[i], random_localhost());
conn.emplace(pubkey[i], "tcp://127.0.0.1:" + std::to_string(4450 + i));
}
std::atomic<int> his{0};
for (int i = 0; i < pubkey.size(); i++) {
omq.push_back(std::make_unique<OxenMQ>(
lmq.push_back(std::make_unique<LokiMQ>(
pubkey[i], privkey[i], true,
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
get_logger("S" + std::to_string(i) + "» "),
LogLevel::trace
get_logger("S" + std::to_string(i) + "» ")
));
auto& server = *omq.back();
auto& server = *lmq.back();
server.log_level(LogLevel::debug);
server.listen_curve(conn[pubkey[i]]);
server.listen_curve(conn[pubkey[i]], [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, true}; });
server.add_category("sn", Access{AuthLevel::none, true})
.add_command("hi", [&](Message& m) { his++; });
server.set_active_sns({pubkey.begin(), pubkey.end()});
server.start();
}
std::this_thread::sleep_for(50ms);
omq[0]->send(pubkey[1], "sn.hi");
omq[0]->send(pubkey[2], "sn.hi");
omq[2]->send(pubkey[0], "sn.hi");
omq[2]->send(pubkey[1], "sn.hi");
omq[1]->send(pubkey[0], "BYE");
omq[0]->send(pubkey[2], "sn.hi");
std::this_thread::sleep_for(50ms * TIME_DILATION);
lmq[0]->send(pubkey[1], "sn.hi");
lmq[0]->send(pubkey[2], "sn.hi");
std::this_thread::sleep_for(50ms);
lmq[2]->send(pubkey[0], "sn.hi");
lmq[2]->send(pubkey[1], "sn.hi");
lmq[1]->send(pubkey[0], "BYE");
std::this_thread::sleep_for(50ms);
lmq[0]->send(pubkey[2], "sn.hi");
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE(his == 5);
}
TEST_CASE("SN auth checks", "[sandwich][auth]") {
// When a remote connects, we check its authentication level; if at the time of connection it
// isn't recognized as a SN but tries to invoke a SN command it'll be told to disconnect; if it
// tries to send again it should reconnect and reauthenticate. This test is meant to test this
// pattern where the reconnection/reauthentication now authenticates it as a SN.
std::string listen = random_localhost();
std::string pubkey, privkey;
pubkey.resize(crypto_box_PUBLICKEYBYTES);
privkey.resize(crypto_box_SECRETKEYBYTES);
REQUIRE(sodium_init() != -1);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
OxenMQ server{
pubkey, privkey,
true, // service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
std::atomic<bool> incoming_is_sn{false};
server.listen_curve(listen);
server.add_category("public", Access{AuthLevel::none})
.add_request_command("hello", [&](Message& m) { m.send_reply("hi"); })
.add_request_command("sudo", [&](Message& m) {
server.update_active_sns({{m.conn.pubkey()}}, {});
m.send_reply("making sandwiches");
})
.add_request_command("nosudo", [&](Message& m) {
// Send the reply *first* because if we do it the other way we'll have just removed
// ourselves from the list of SNs and thus would try to open an outbound connection
// to deliver it since it's still queued as a message to a SN.
m.send_reply("make them yourself");
server.update_active_sns({}, {{m.conn.pubkey()}});
});
server.add_category("sandwich", Access{AuthLevel::none, true})
.add_request_command("make", [&](Message& m) { m.send_reply("okay"); });
server.start();
OxenMQ client{
"", "", false,
[&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; },
get_logger(""), LogLevel::trace};
client.start();
std::atomic<bool> got{false};
bool success;
client.request(pubkey, "public.hello", [&](auto success_, auto) { success = success_; got = true; });
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
}
got = false;
using dvec = std::vector<std::string>;
dvec data;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE_FALSE( success );
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
}
// Somebody set up us the bomb. Main sudo turn on.
got = false;
client.request(pubkey, "public.sudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"making sandwiches"}} );
}
got = false;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"okay"}} );
}
// Take off every 'SUDO', You [not] know what you doing
got = false;
client.request(pubkey, "public.nosudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"make them yourself"}} );
}
got = false;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE_FALSE( success );
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
}
}
TEST_CASE("SN single worker test", "[connect][worker]") {
// Tests a failure case that could trigger when all workers are allocated (here we make that
// simpler by just having one worker).
std::string listen = random_localhost();
OxenMQ server{
"", "",
false, // service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.set_general_threads(1);
server.set_batch_threads(0);
server.set_reply_threads(0);
server.listen_plain(listen);
server.add_category("c", Access{AuthLevel::none})
.add_request_command("x", [&](Message& m) { m.send_reply(); })
;
server.start();
OxenMQ client{get_logger(""), LogLevel::trace};
client.start();
auto conn = client.connect_remote(address{listen}, [](auto) {}, [](auto, auto) {});
std::atomic<int> got{0};
std::atomic<int> success{0};
client.request(conn, "c.x", [&](auto success_, auto) { if (success_) ++success; ++got; });
wait_for([&] { return got.load() >= 1; });
{
auto lock = catch_lock();
REQUIRE( success == 1 );
}
client.request(conn, "c.x", [&](auto success_, auto) { if (success_) ++success; ++got; });
wait_for([&] { return got.load() >= 2; });
{
auto lock = catch_lock();
REQUIRE( success == 2 );
}
}
TEST_CASE("SN backchatter", "[connect][sn]") {
// When we have a SN connection A -> B and then B sends a message to A on that existing
// connection, A should see it as coming from B.
std::vector<std::unique_ptr<OxenMQ>> omq;
std::vector<std::string> pubkey, privkey;
std::unordered_map<std::string, std::string> conn;
REQUIRE(sodium_init() != -1);
for (int i = 0; i < 2; i++) {
pubkey.emplace_back();
privkey.emplace_back();
pubkey[i].resize(crypto_box_PUBLICKEYBYTES);
privkey[i].resize(crypto_box_SECRETKEYBYTES);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[i][0]), reinterpret_cast<unsigned char*>(&privkey[i][0]));
conn.emplace(pubkey[i], random_localhost());
}
for (int i = 0; i < pubkey.size(); i++) {
omq.push_back(std::make_unique<OxenMQ>(
pubkey[i], privkey[i], true,
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
get_logger("S" + std::to_string(i) + "» "),
LogLevel::trace
));
auto& server = *omq.back();
server.listen_curve(conn[pubkey[i]]);
server.set_active_sns({pubkey.begin(), pubkey.end()});
}
std::string f;
omq[0]->add_category("a", Access{AuthLevel::none, true})
.add_command("a", [&](Message& m) {
m.oxenmq.send(m.conn, "b.b", "abc");
//m.send_back("b.b", "abc");
})
.add_command("z", [&](Message& m) {
auto lock = catch_lock();
f = m.data[0];
});
omq[1]->add_category("b", Access{AuthLevel::none, true})
.add_command("b", [&](Message& m) {
{
auto lock = catch_lock();
UNSCOPED_INFO("b.b from conn " << m.conn);
}
m.send_back("a.z", m.data[0]);
});
for (auto& server : omq)
server->start();
auto c = omq[1]->connect_sn(pubkey[0]);
omq[1]->send(c, "a.a");
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE(f == "abc");
}
TEST_CASE("inproc connections", "[connect][inproc]") {
std::string inproc_name = "foo";
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
omq.add_category("public", Access{AuthLevel::none});
omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
omq.start();
std::atomic<int> got{0};
bool success = false;
auto c_inproc = omq.connect_inproc(
[&](auto conn) { success = true; got++; },
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; }
);
wait_for([&got] { return got.load() > 0; });
{
auto lock = catch_lock();
REQUIRE( success );
REQUIRE( got == 1 );
}
got = 0;
success = false;
omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) {
success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got == 1 );
REQUIRE( success );
}
}
TEST_CASE("no explicit inproc listening", "[connect][inproc]") {
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
REQUIRE_THROWS_AS(omq.listen_plain("inproc://foo"), std::logic_error);
REQUIRE_THROWS_AS(omq.listen_curve("inproc://foo"), std::logic_error);
}
TEST_CASE("inproc connection permissions", "[connect][inproc]") {
std::string listen = random_localhost();
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
omq.add_category("public", Access{AuthLevel::none});
omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
omq.add_category("private", Access{AuthLevel::admin});
omq.add_request_command("private", "handshake", [&](Message& m) { m.send_reply("yo dude"); });
omq.listen_plain(listen);
omq.start();
std::atomic<int> got{0};
bool success = false;
auto c_inproc = omq.connect_inproc(
[&](auto conn) { success = true; got++; },
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; }
);
bool pub_success = false;
auto c_pub = omq.connect_remote(address{listen},
[&](auto conn) { pub_success = true; got++; },
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("tcp connection failed: " << reason); got++; }
);
wait_for([&got] { return got.load() == 2; });
{
auto lock = catch_lock();
REQUIRE( got == 2 );
REQUIRE( success );
REQUIRE( pub_success );
}
got = 0;
success = false;
pub_success = false;
bool success_private = false;
bool pub_success_private = false;
omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) {
success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
});
omq.request(c_pub, "public.hello", [&](auto success_, auto parts_) {
pub_success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
});
omq.request(c_inproc, "private.handshake", [&](auto success_, auto parts_) {
success_private = success_ && parts_.size() == 1 && parts_.front() == "yo dude"; got++;
});
omq.request(c_pub, "private.handshake", [&](auto success_, auto parts_) {
pub_success_private = success_; got++;
});
wait_for([&got] { return got.load() == 4; });
{
auto lock = catch_lock();
REQUIRE( got == 4 );
REQUIRE( success );
REQUIRE( pub_success );
REQUIRE( success_private );
REQUIRE_FALSE( pub_success_private );
}
}

View file

@ -1,324 +0,0 @@
#include "common.h"
#include <map>
#include <set>
using namespace oxenmq;
TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen);
server.start();
// Use a raw socket here because I want to see the raw commands coming on the wire
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
zmq::message_t resp;
auto recvd = client.recv(resp);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "a.a" );
REQUIRE_FALSE( resp.more() );
}
TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen);
server.add_category("x", AuthLevel::none)
.add_request_command("r", [] (auto& m) { m.send_reply("a"); });
server.start();
// Use a raw socket here because I want to see the raw commands coming on the wire
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
zmq::message_t resp;
auto recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "x.r" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "foo" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", AuthLevel::basic)
.add_command("x", [] (auto& m) { m.send_back("a"); });
server.add_category("y", AuthLevel::admin)
.add_command("x", [] (auto& m) { m.send_back("b"); });
server.start();
zmq::context_t client_ctx;
std::array<zmq::socket_t, 3> clients;
// Client 0 should get none auth level, client 1 should get basic, client 2 should get admin
for (auto& client : clients) {
client = {client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
}
for (auto& c : clients)
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
auto recvd = clients[0].recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
REQUIRE( clients[0].recv(resp) );
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
for (int i : {1, 2}) {
recvd = clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
for (auto& c : clients)
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
for (int i : {0, 1}) {
recvd = clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
REQUIRE( clients[i].recv(resp) );
REQUIRE( resp.to_string() == "y.x" );
REQUIRE_FALSE( resp.more() );
}
recvd = clients[2].recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "b" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", Access{AuthLevel::none, false, true})
.add_command("x", [] (auto&) {})
.add_request_command("r", [] (auto& m) { m.send_reply(); })
;
server.start();
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
auto recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", Access{AuthLevel::none, true, false})
.add_command("x", [] (auto&) {})
.add_request_command("r", [] (auto& m) { m.send_reply(); })
;
server.start();
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
auto recvd = client.recv(hello);
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
auto recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
recvd = client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( recvd );
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
REQUIRE( client.recv(resp) );
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}
}

View file

@ -1,96 +0,0 @@
#include "common.h"
using namespace oxenmq;
TEST_CASE("injected external commands", "[injected]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.set_general_threads(1);
server.listen_curve(listen);
std::atomic<int> hellos = 0;
std::atomic<bool> done = false;
server.add_category("public", AuthLevel::none, 3);
server.add_command("public", "hello", [&](Message& m) {
hellos++;
while (!done) std::this_thread::sleep_for(10ms);
});
server.start();
OxenMQ client{get_logger(""), LogLevel::trace};
client.start();
std::atomic<bool> got{false};
bool success = false;
// Deliberately using a deprecated command here, disable -Wdeprecated-declarations
#ifdef __GNUG__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
auto c = client.connect_remote(listen,
[&](auto conn) { success = true; got = true; },
[&](auto conn, std::string_view) { got = true; },
server.get_pubkey());
#ifdef __GNUG__
#pragma GCC diagnostic pop
#endif
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
}
// First make sure that basic message respects the 3 thread limit
client.send(c, "public.hello");
client.send(c, "public.hello");
client.send(c, "public.hello");
client.send(c, "public.hello");
wait_for([&] { return hellos >= 3; });
std::this_thread::sleep_for(20ms);
{
auto lock = catch_lock();
REQUIRE( hellos == 3 );
}
done = true;
wait_for([&] { return hellos >= 4; });
{
auto lock = catch_lock();
REQUIRE( hellos == 4 );
}
// Now try injecting external commands
done = false;
hellos = 0;
client.send(c, "public.hello");
wait_for([&] { return hellos >= 1; });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
wait_for([&] { return hellos >= 11; });
client.send(c, "public.hello");
wait_for([&] { return hellos >= 12; });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
wait_for([&] { return hellos >= 12; });
std::this_thread::sleep_for(20ms);
{
auto lock = catch_lock();
REQUIRE( hellos == 12 );
}
done = true;
wait_for([&] { return hellos >= 42; });
{
auto lock = catch_lock();
REQUIRE( hellos == 42 );
}
}

View file

@ -1,611 +0,0 @@
#include "common.h"
#include "oxenmq/pubsub.h"
#include <oxenc/hex.h>
using namespace oxenmq;
using namespace std::chrono_literals;
TEST_CASE("sub OK", "[pubsub]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
Subscription<> greetings{"greetings"};
std::atomic<bool> is_new{false};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "greetings", [&](Message& m) {
is_new = greetings.subscribe(m.conn);
m.send_reply("OK");
});
server.start();
OxenMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
std::atomic<int> reply_count{0};
client.add_category("notify", Access{AuthLevel::none});
client.add_command("notify", "greetings", [&](Message& m) {
const auto& data = m.data;
if (!data.size())
{
std::cerr << "client received public.greetings with empty data\n";
return;
}
if (data[0] == "hello")
reply_count++;
});
client.start();
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
}
std::atomic<bool> got_reply{false};
bool success;
std::vector<std::string> data;
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
}
greetings.publish([&](auto& conn) {
server.send(conn, "notify.greetings", "hello");
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count == 1 );
}
greetings.publish([&](auto& conn) {
server.send(conn, "notify.greetings", "hello");
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count == 2 );
}
}
TEST_CASE("user data", "[pubsub]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
Subscription<std::string> greetings{"greetings"};
std::atomic<bool> is_new{false};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "greetings", [&](Message& m) {
is_new = greetings.subscribe(m.conn, std::string{m.data[0]});
m.send_reply("OK");
});
server.start();
OxenMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
std::string response{"foo"};
std::atomic<int> reply_count{0};
std::atomic<int> foo_count{0};
client.add_category("notify", Access{AuthLevel::none});
client.add_command("notify", "greetings", [&](Message& m) {
const auto& data = m.data;
if (!data.size())
{
std::cerr << "client received public.greetings with empty data\n";
return;
}
if (data[0] == response)
reply_count++;
if (data[0] == "foo")
foo_count++;
});
client.start();
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
}
std::atomic<bool> got_reply{false};
std::atomic<bool> success;
std::vector<std::string> data;
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
}, response);
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( is_new );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
}
got_reply = false;
success = false;
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
}, response);
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE_FALSE( is_new );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
}
greetings.publish([&](auto& conn, std::string user) {
server.send(conn, "notify.greetings", user);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count == 1 );
REQUIRE( foo_count == 1 );
}
got_reply = false;
success = false;
response = "bar";
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
}, response);
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( is_new );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
}
greetings.publish([&](auto& conn, std::string user) {
server.send(conn, "notify.greetings", user);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count == 2 );
REQUIRE( foo_count == 1 );
}
}
TEST_CASE("unsubscribe", "[pubsub]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
Subscription<> greetings{"greetings"};
std::atomic<bool> was_subbed{false};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "greetings", [&](Message& m) {
greetings.subscribe(m.conn);
m.send_reply("OK");
});
server.add_request_command("public", "goodbye", [&](Message& m) {
was_subbed = greetings.unsubscribe(m.conn);
m.send_reply("OK");
});
server.start();
OxenMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
std::atomic<int> reply_count{0};
client.add_category("notify", Access{AuthLevel::none});
client.add_command("notify", "greetings", [&](Message& m) {
const auto& data = m.data;
if (!data.size())
{
std::cerr << "client received public.greetings with empty data\n";
return;
}
if (data[0] == "hello")
reply_count++;
});
client.start();
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
}
std::atomic<bool> got_reply{false};
std::atomic<bool> success;
std::vector<std::string> data;
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
}
greetings.publish([&](auto& conn) {
server.send(conn, "notify.greetings", "hello");
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count == 1 );
}
got_reply = false;
success = false;
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
REQUIRE( was_subbed );
}
greetings.publish([&](auto& conn) {
server.send(conn, "notify.greetings", "hello");
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count == 1 );
}
got_reply = false;
success = false;
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
REQUIRE( was_subbed == false);
}
}
TEST_CASE("expire", "[pubsub]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
Subscription<> greetings{"greetings", 250ms};
std::atomic<bool> was_subbed{false};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "greetings", [&](Message& m) {
greetings.subscribe(m.conn);
m.send_reply("OK");
});
server.add_request_command("public", "goodbye", [&](Message& m) {
was_subbed = greetings.unsubscribe(m.conn);
m.send_reply("OK");
});
server.start();
OxenMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
std::atomic<int> reply_count{0};
client.add_category("notify", Access{AuthLevel::none});
client.add_command("notify", "greetings", [&](Message& m) {
const auto& data = m.data;
if (!data.size())
{
std::cerr << "client received public.greetings with empty data\n";
return;
}
if (data[0] == "hello")
reply_count++;
});
client.start();
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
}
std::atomic<bool> got_reply{false};
bool success;
std::vector<std::string> data;
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
}
// should be expired by now
std::this_thread::sleep_for(500ms);
greetings.remove_expired();
got_reply = false;
success = false;
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"OK"}} );
REQUIRE( was_subbed == false);
}
}
TEST_CASE("multiple subs", "[pubsub]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
Subscription<> greetings{"greetings"};
std::atomic<bool> is_new{false};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "greetings", [&](Message& m) {
is_new = greetings.subscribe(m.conn);
m.send_reply("OK");
});
server.start();
/* client 1 */
std::atomic<int> reply_count_c1{0};
std::atomic<bool> connected_c1{false}, failed_c1{false};
std::atomic<bool> got_reply_c1{false};
bool success_c1;
std::vector<std::string> data_c1;
std::string pubkey_c1;
OxenMQ client1(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
client1.add_category("notify", Access{AuthLevel::none});
client1.add_command("notify", "greetings", [&](Message& m) {
const auto& data = m.data;
if (!data.size())
{
std::cerr << "client received public.greetings with empty data\n";
return;
}
if (data[0] == "hello")
reply_count_c1++;
});
client1.start();
auto c1 = client1.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey_c1 = conn.pubkey(); connected_c1 = true; },
[&](auto, auto) { failed_c1 = true; });
wait_for([&] { return connected_c1 || failed_c1; });
{
auto lock = catch_lock();
REQUIRE( connected_c1 );
REQUIRE_FALSE( failed_c1 );
REQUIRE( oxenc::to_hex(pubkey_c1) == oxenc::to_hex(server.get_pubkey()) );
}
client1.request(c1, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply_c1 = true;
success_c1 = ok;
data_c1 = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply_c1.load() );
REQUIRE( success_c1 );
REQUIRE( data_c1 == std::vector<std::string>{{"OK"}} );
}
/* end client 1 */
/* client 2 */
std::atomic<int> reply_count_c2{0};
std::atomic<bool> connected_c2{false}, failed_c2{false};
std::atomic<bool> got_reply_c2{false};
bool success_c2;
std::vector<std::string> data_c2;
std::string pubkey_c2;
OxenMQ client2(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
client2.add_category("notify", Access{AuthLevel::none});
client2.add_command("notify", "greetings", [&](Message& m) {
const auto& data = m.data;
if (!data.size())
{
std::cerr << "client received public.greetings with empty data\n";
return;
}
if (data[0] == "hello")
reply_count_c2++;
});
client2.start();
auto c2 = client2.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey_c2 = conn.pubkey(); connected_c2 = true; },
[&](auto, auto) { failed_c2 = true; });
wait_for([&] { return connected_c2 || failed_c2; });
{
auto lock = catch_lock();
REQUIRE( connected_c2 );
REQUIRE_FALSE( failed_c2 );
REQUIRE( oxenc::to_hex(pubkey_c2) == oxenc::to_hex(server.get_pubkey()) );
}
client2.request(c2, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
got_reply_c2 = true;
success_c2 = ok;
data_c2 = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply_c2.load() );
REQUIRE( success_c2 );
REQUIRE( data_c2 == std::vector<std::string>{{"OK"}} );
}
/* end client2 */
greetings.publish([&](auto& conn) {
server.send(conn, "notify.greetings", "hello");
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count_c1 == 1 );
REQUIRE( reply_count_c2 == 1 );
}
greetings.publish([&](auto& conn) {
server.send(conn, "notify.greetings", "hello");
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( reply_count_c1 == 2 );
REQUIRE( reply_count_c2 == 2 );
}
}
// vim:sw=4:et

View file

@ -1,11 +1,12 @@
#include "common.h"
#include <oxenc/hex.h>
#include <future>
#include <lokimq/hex.h>
using namespace oxenmq;
using namespace lokimq;
TEST_CASE("basic requests", "[requests]") {
std::string listen = random_localhost();
OxenMQ server{
std::string listen = "tcp://127.0.0.1:5678";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
@ -20,7 +21,7 @@ TEST_CASE("basic requests", "[requests]") {
});
server.start();
OxenMQ client(
LokiMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
//client.log_level(LogLevel::trace);
@ -30,66 +31,10 @@ TEST_CASE("basic requests", "[requests]") {
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
auto c = client.connect_remote(listen,
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
}
std::atomic<bool> got_reply{false};
bool success;
std::vector<std::string> data;
client.request(c, "public.hello", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"123"}} );
}
}
TEST_CASE("request from server to client", "[requests]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
std::atomic<int> hellos{0}, his{0};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) {
m.send_reply("123");
});
server.start();
OxenMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
//client.log_level(LogLevel::trace);
client.start();
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
[&](auto, auto) { failed = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
@ -102,7 +47,69 @@ TEST_CASE("request from server to client", "[requests]") {
REQUIRE( connected.load() );
REQUIRE( !failed.load() );
REQUIRE( i <= 1 );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
std::atomic<bool> got_reply{false};
bool success;
std::vector<std::string> data;
client.request(c, "public.hello", [&](bool ok, std::vector<std::string> data_) {
got_reply = true;
success = ok;
data = std::move(data_);
});
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"123"}} );
}
TEST_CASE("request from server to client", "[requests]") {
std::string listen = "tcp://127.0.0.1:5678";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_curve(listen);
std::atomic<int> hellos{0}, his{0};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) {
m.send_reply("123");
});
server.start();
LokiMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
//client.log_level(LogLevel::trace);
client.start();
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(listen,
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
{
auto lock = catch_lock();
REQUIRE( connected.load() );
REQUIRE( !failed.load() );
REQUIRE( i <= 1 );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
std::atomic<bool> got_reply{false};
@ -124,8 +131,8 @@ TEST_CASE("request from server to client", "[requests]") {
}
TEST_CASE("request timeouts", "[requests][timeout]") {
std::string listen = random_localhost();
OxenMQ server{
std::string listen = "tcp://127.0.0.1:5678";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
@ -138,7 +145,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
server.add_request_command("public", "blackhole", [&](Message& m) { /* doesn't reply */ });
server.start();
OxenMQ client(
LokiMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
//client.log_level(LogLevel::trace);
@ -149,15 +156,21 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
std::atomic<bool> connected{false}, failed{false};
std::string pubkey;
auto c = client.connect_remote(address{listen, server.get_pubkey()},
auto c = client.connect_remote(listen,
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto, auto) { failed = true; });
[&](auto, auto) { failed = true; },
server.get_pubkey());
wait_for([&] { return connected || failed; });
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
REQUIRE( connected.load() );
REQUIRE( !failed.load() );
REQUIRE( i <= 1 );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
std::atomic<bool> got_triggered{false};
bool success;
@ -167,7 +180,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
success = ok;
data = std::move(data_);
},
oxenmq::send_option::request_timeout{10ms}
lokimq::send_option::request_timeout{30ms}
);
std::atomic<bool> got_triggered2{false};
@ -176,13 +189,13 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
success = ok;
data = std::move(data_);
},
oxenmq::send_option::request_timeout{200ms}
lokimq::send_option::request_timeout{100ms}
);
std::this_thread::sleep_for(100ms);
REQUIRE( got_triggered );
REQUIRE_FALSE( got_triggered2 );
std::this_thread::sleep_for(50ms);
REQUIRE( got_triggered.load() );
REQUIRE_FALSE( success );
REQUIRE( data == std::vector<std::string>{{"TIMEOUT"}} );
REQUIRE( data.size() == 0 );
REQUIRE_FALSE( got_triggered2.load() );
}

View file

@ -1,43 +0,0 @@
#include "common.h"
#include <oxenc/hex.h>
using namespace oxenmq;
TEST_CASE("zmq socket limit", "[zmq][socket-limit]") {
// Make sure setting .MAX_SOCKETS works as expected. (This test was added when a bug was fixed
// that was causing it not to be applied).
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
};
server.listen_plain(listen);
server.start();
std::atomic<int> failed = 0, good = 0, failed_toomany = 0;
OxenMQ client;
client.MAX_SOCKETS = 15;
client.start();
std::vector<ConnectionID> conns;
address server_addr{listen};
for (int i = 0; i < 16; i++)
client.connect_remote(server_addr,
[&](auto) { good++; },
[&](auto cid, auto msg) {
if (msg == "connect() failed: Too many open files")
failed_toomany++;
else
failed++;
});
wait_for([&] { return good > 0 && failed_toomany > 0; });
{
auto lock = catch_lock();
REQUIRE( good > 0 );
REQUIRE( failed == 0 );
REQUIRE( failed_toomany > 0 );
}
}

187
tests/test_string_view.cpp Normal file
View file

@ -0,0 +1,187 @@
#include <catch2/catch.hpp>
#include "lokimq/string_view.h"
#include <future>
using namespace lokimq;
using namespace std::literals;
TEST_CASE("string view", "[string_view]") {
std::string foo = "abc 123 xyz";
string_view f1{foo};
string_view f2{"def 789 uvw"};
string_view f3{"nu\0ll", 5};
REQUIRE( f1 == "abc 123 xyz" );
REQUIRE( f2 == "def 789 uvw" );
REQUIRE( f3.size() == 5 );
REQUIRE( f3 == std::string{"nu\0ll", 5} );
REQUIRE( f3 != "nu" );
REQUIRE( f3.data() == "nu"s );
REQUIRE( string_view(f3) == f3 );
auto f4 = f3;
REQUIRE( f4 == f3 );
f4 = f2;
REQUIRE( f4 == "def 789 uvw" );
REQUIRE( f1.size() == 11 );
REQUIRE( f3.length() == 5 );
string_view f5{""};
REQUIRE( !f3.empty() );
REQUIRE( f5.empty() );
REQUIRE( f1[5] == '2' );
size_t i = 0;
for (auto c : f3)
REQUIRE(c == f3[i++]);
std::string backwards;
for (auto it = std::rbegin(f2); it != f2.crend(); ++it)
backwards += *it;
REQUIRE( backwards == "wvu 987 fed" );
REQUIRE( f1.at(10) == 'z' );
REQUIRE_THROWS_AS( f1.at(15), std::out_of_range );
REQUIRE_THROWS_AS( f1.at(11), std::out_of_range );
f4 = f1;
f4.remove_prefix(2);
REQUIRE( f4 == "c 123 xyz" );
f4.remove_prefix(2);
f4.remove_suffix(4);
REQUIRE( f4 == "123" );
f4.remove_prefix(1);
REQUIRE( f4 == "23" );
REQUIRE( f1 == "abc 123 xyz" );
f4.swap(f1);
REQUIRE( f1 == "23" );
REQUIRE( f4 == "abc 123 xyz" );
f1.remove_suffix(2);
REQUIRE( f1.empty() );
REQUIRE( f4 == "abc 123 xyz" );
f1.swap(f4);
REQUIRE( f4.empty() );
REQUIRE( f1 == "abc 123 xyz" );
REQUIRE( f1.front() == 'a' );
REQUIRE( f1.back() == 'z' );
REQUIRE( f1.compare("abc") > 0 );
REQUIRE( f1.compare("abd") < 0 );
REQUIRE( f1.compare("abc 123 xyz") == 0 );
REQUIRE( f1.compare("abc 123 xyza") < 0 );
REQUIRE( f1.compare("abc 123 xy") > 0 );
std::string buf;
buf.resize(5);
f1.copy(&buf[0], 5, 2);
REQUIRE( buf == "c 123" );
buf.resize(100, 'X');
REQUIRE( f1.copy(&buf[0], 100) == 11 );
REQUIRE( buf.substr(0, 11) == f1 );
REQUIRE( buf.substr(11) == std::string(89, 'X') );
REQUIRE( f1.substr(4) == "123 xyz" );
REQUIRE( f1.substr(4, 3) == "123" );
REQUIRE_THROWS_AS( f1.substr(500, 3), std::out_of_range );
REQUIRE( f1.substr(11, 2) == "" );
REQUIRE( f1.substr(8, 500) == "xyz" );
REQUIRE( f1.find("123") == 4 );
REQUIRE( f1.find("abc") == 0 );
REQUIRE( f1.find("xyz") == 8 );
REQUIRE( f1.find("abc 123 xyz 7") == string_view::npos );
REQUIRE( f1.find("23") == 5 );
REQUIRE( f1.find("234") == string_view::npos );
string_view f6{"zz abc abcd abcde abcdef"};
REQUIRE( f6.find("abc") == 3 );
REQUIRE( f6.find("abc", 3) == 3 );
REQUIRE( f6.find("abc", 4) == 7 );
REQUIRE( f6.find("abc", 7) == 7 );
REQUIRE( f6.find("abc", 8) == 12 );
REQUIRE( f6.find("abc", 18) == 18 );
REQUIRE( f6.find("abc", 19) == string_view::npos );
REQUIRE( f6.find("abcd") == 7 );
REQUIRE( f6.rfind("abc") == 18 );
REQUIRE( f6.rfind("abcd") == 18 );
REQUIRE( f6.rfind("bcd") == 19 );
REQUIRE( f6.rfind("abc", 19) == 18 );
REQUIRE( f6.rfind("abc", 18) == 18 );
REQUIRE( f6.rfind("abc", 17) == 12 );
REQUIRE( f6.rfind("abc", 17) == 12 );
REQUIRE( f6.rfind("abc", 8) == 7 );
REQUIRE( f6.rfind("abc", 7) == 7 );
REQUIRE( f6.rfind("abc", 6) == 3 );
REQUIRE( f6.rfind("abc", 3) == 3 );
REQUIRE( f6.rfind("abc", 2) == string_view::npos );
REQUIRE( f6.find('a') == 3 );
REQUIRE( f6.find('a', 17) == 18 );
REQUIRE( f6.find('a', 20) == string_view::npos );
REQUIRE( f6.rfind('a') == 18 );
REQUIRE( f6.rfind('a', 17) == 12 );
REQUIRE( f6.rfind('a', 2) == string_view::npos );
string_view f7{"abc\0def", 7};
REQUIRE( f7.find("c\0d", 0, 3) == 2 );
REQUIRE( f7.find("c\0e", 0, 3) == string_view::npos );
REQUIRE( f7.rfind("c\0d", string_view::npos, 3) == 2 );
REQUIRE( f7.rfind("c\0e", 0, 3) == string_view::npos );
REQUIRE( f6.find_first_of("c789b") == 4 );
REQUIRE( f6.find_first_of("c789") == 5 );
REQUIRE( f2.find_first_of("c789b") == 4 );
REQUIRE( f6.find_first_of("c789b", 6) == 8 );
REQUIRE( f6.find_last_of("c789b") == 20 );
REQUIRE( f6.find_last_of("789b") == 19 );
REQUIRE( f2.find_last_of("c789b") == 6 );
REQUIRE( f6.find_last_of("c789b", 6) == 5 );
REQUIRE( f6.find_last_of("c789b", 5) == 5 );
REQUIRE( f6.find_last_of("c789b", 4) == 4 );
REQUIRE( f6.find_last_of("c789b", 3) == string_view::npos );
REQUIRE( f2.find_first_of(f7) == 0 );
REQUIRE( f3.find_first_of(f7) == 2 );
REQUIRE( f3.find_first_of('\0') == 2 );
REQUIRE( f3.find_first_of("jk\0", 0, 3) == 2 );
REQUIRE( f1.find_first_not_of("abc") == 3 );
REQUIRE( f1.find_first_not_of("abc ", 3) == 4 );
REQUIRE( f1.find_first_not_of(" 123", 3) == 8 );
REQUIRE( f1.find_last_not_of("abc") == 10 );
REQUIRE( f1.find_last_not_of("xyz") == 7 );
REQUIRE( f1.find_last_not_of("xyz 321") == 2 );
REQUIRE( f1.find_last_not_of("xay z1b2c3") == string_view::npos );
REQUIRE( f6.find_last_not_of("def") == 20 );
REQUIRE( f6.find_last_not_of("abcdef") == 17 );
REQUIRE( f6.find_last_not_of("abcdef ") == 1 );
REQUIRE( f6.find_first_not_of('z') == 2 );
REQUIRE( f6.find_first_not_of("z ") == 3 );
REQUIRE( f6.find_first_not_of("a ", 2) == 4 );
REQUIRE( f6.find_last_not_of("abc ", 9) == 1 );
REQUIRE( string_view{"abc"} == string_view{"abc"} );
REQUIRE_FALSE( string_view{"abc"} == string_view{"abd"} );
REQUIRE_FALSE( string_view{"abc"} == string_view{"abcd"} );
REQUIRE( string_view{"abc"} != string_view{"abd"} );
REQUIRE( string_view{"abc"} != string_view{"abcd"} );
REQUIRE( string_view{"abc"} < string_view{"abcd"} );
REQUIRE( string_view{"abc"} < string_view{"abd"} );
REQUIRE_FALSE( string_view{"abd"} < string_view{"abc"} );
REQUIRE_FALSE( string_view{"abcd"} < string_view{"abc"} );
REQUIRE_FALSE( string_view{"abc"} < string_view{"abc"} );
REQUIRE( string_view{"abd"} > string_view{"abc"} );
REQUIRE( string_view{"abcd"} > string_view{"abc"} );
REQUIRE_FALSE( string_view{"abc"} > string_view{"abd"} );
REQUIRE_FALSE( string_view{"abc"} > string_view{"abcd"} );
REQUIRE_FALSE( string_view{"abc"} > string_view{"abc"} );
REQUIRE( string_view{"abc"} <= string_view{"abcd"} );
REQUIRE( string_view{"abc"} <= string_view{"abc"} );
REQUIRE( string_view{"abc"} <= string_view{"abd"} );
REQUIRE( string_view{"abd"} >= string_view{"abc"} );
REQUIRE( string_view{"abc"} >= string_view{"abc"} );
REQUIRE( string_view{"abcd"} >= string_view{"abc"} );
}

View file

@ -1,165 +0,0 @@
#include "oxenmq/batch.h"
#include "common.h"
#include <future>
TEST_CASE("tagged thread start functions", "[tagged][start]") {
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
omq.set_general_threads(2);
omq.set_batch_threads(2);
auto t_abc = omq.add_tagged_thread("abc");
std::atomic<bool> start_called = false;
auto t_def = omq.add_tagged_thread("def", [&] { start_called = true; });
std::this_thread::sleep_for(20ms);
{
auto lock = catch_lock();
REQUIRE_FALSE( start_called );
}
omq.start();
wait_for([&] { return start_called.load(); });
{
auto lock = catch_lock();
REQUIRE( start_called );
}
}
TEST_CASE("tagged threads quit-before-start", "[tagged][quit]") {
auto omq = std::make_unique<oxenmq::OxenMQ>(get_logger(""), LogLevel::trace);
auto t_abc = omq->add_tagged_thread("abc");
REQUIRE_NOTHROW(omq.reset());
}
TEST_CASE("batch jobs to tagged threads", "[tagged][batch]") {
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
omq.set_general_threads(2);
omq.set_batch_threads(2);
std::thread::id id_abc, id_def;
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
auto t_def = omq.add_tagged_thread("def", [&] { id_def = std::this_thread::get_id(); });
omq.start();
std::atomic<bool> done = false;
std::thread::id id;
omq.job([&] { id = std::this_thread::get_id(); done = true; });
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( id != id_abc );
REQUIRE( id != id_def );
}
done = false;
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( id == id_abc );
}
done = false;
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_def);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( id == id_def );
}
std::atomic<bool> sleep = true;
auto sleeper = [&] { for (int i = 0; sleep && i < 10; i++) { std::this_thread::sleep_for(25ms); } };
omq.job(sleeper);
omq.job(sleeper);
// This one should stall:
std::atomic<bool> bad = false;
omq.job([&] { bad = true; });
std::this_thread::sleep_for(50ms);
done = false;
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( done.load() );
REQUIRE_FALSE( bad.load() );
}
done = false;
// We can queue up a bunch of jobs which should all happen in order, and all on the abc thread.
std::vector<int> v;
for (int i = 0; i < 100; i++) {
omq.job([&] { if (std::this_thread::get_id() == id_abc) v.push_back(v.size()); }, t_abc);
}
omq.job([&] { done = true; }, t_abc);
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( done.load() );
REQUIRE_FALSE( bad.load() );
REQUIRE( v.size() == 100 );
for (int i = 0; i < 100; i++)
REQUIRE( v[i] == i );
}
sleep = false;
wait_for([&] { return bad.load(); });
{
auto lock = catch_lock();
REQUIRE( bad.load() );
}
}
TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]") {
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
omq.set_general_threads(4);
omq.set_batch_threads(4);
std::thread::id id_abc;
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
omq.start();
oxenmq::Batch<int> batch;
for (int i = 1; i < 10; i++)
batch.add_job([i, &id_abc]() { if (std::this_thread::get_id() == id_abc) return 0; return i; });
std::atomic<int> result_sum = -1;
batch.completion([&](auto result) {
int sum = 0;
for (auto& r : result)
sum += r.get();
result_sum = std::this_thread::get_id() == id_abc ? sum : -sum;
}, t_abc);
omq.batch(std::move(batch));
wait_for([&] { return result_sum.load() != -1; });
{
auto lock = catch_lock();
REQUIRE( result_sum == 45 );
}
}
TEST_CASE("timer job completion on tagged threads", "[tagged][timer]") {
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
omq.set_general_threads(4);
omq.set_batch_threads(4);
std::thread::id id_abc;
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
omq.start();
std::atomic<int> ticks = 0;
std::atomic<int> abc_ticks = 0;
omq.add_timer([&] { ticks++; }, 10ms);
omq.add_timer([&] { if (std::this_thread::get_id() == id_abc) abc_ticks++; }, 10ms, true, t_abc);
wait_for([&] { return ticks.load() > 2 && abc_ticks > 2; });
{
auto lock = catch_lock();
REQUIRE( ticks.load() > 2 );
REQUIRE( abc_ticks.load() > 2 );
}
}

View file

@ -1,132 +0,0 @@
#include "oxenmq/oxenmq.h"
#include "common.h"
#include <chrono>
#include <future>
TEST_CASE("timer test", "[timer][basic]") {
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
omq.set_general_threads(1);
omq.set_batch_threads(1);
std::atomic<int> ticks = 0;
auto timer = omq.add_timer([&] { ticks++; }, 5ms);
omq.start();
auto start = std::chrono::steady_clock::now();
wait_for([&] { return ticks.load() > 3; });
{
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start).count();
auto lock = catch_lock();
REQUIRE( ticks.load() > 3 );
REQUIRE( elapsed_ms < 50 * TIME_DILATION );
}
}
TEST_CASE("timer squelch", "[timer][squelch]") {
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
omq.set_general_threads(3);
omq.set_batch_threads(3);
std::atomic<bool> first = true;
std::atomic<bool> done = false;
std::atomic<int> ticks = 0;
// Set up a timer with squelch on; the job shouldn't get rescheduled until the first call
// finishes, by which point we set `done` and so should get exactly 1 tick.
auto timer = omq.add_timer([&] {
if (first.exchange(false)) {
std::this_thread::sleep_for(30ms * TIME_DILATION);
ticks++;
done = true;
} else if (!done) {
ticks++;
}
}, 5ms * TIME_DILATION, true /* squelch */);
omq.start();
wait_for([&] { return done.load(); });
{
auto lock = catch_lock();
REQUIRE( done.load() );
REQUIRE( ticks.load() == 1 );
}
// Start another timer with squelch *off*; the subsequent jobs should get scheduled even while
// the first one blocks
std::atomic<bool> first2 = true;
std::atomic<bool> done2 = false;
std::atomic<int> ticks2 = 0;
auto timer2 = omq.add_timer([&] {
if (first2.exchange(false)) {
std::this_thread::sleep_for(40ms * TIME_DILATION);
done2 = true;
} else if (!done2) {
ticks2++;
}
}, 5ms, false /* squelch */);
wait_for([&] { return done2.load(); });
{
auto lock = catch_lock();
REQUIRE( ticks2.load() > 2 );
REQUIRE( done2.load() );
}
}
TEST_CASE("timer cancel", "[timer][cancel]") {
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
omq.set_general_threads(1);
omq.set_batch_threads(1);
std::atomic<int> ticks = 0;
// We set up *and cancel* this timer before omq starts, so it should never fire
auto notimer = omq.add_timer([&] { ticks += 1000; }, 5ms * TIME_DILATION);
omq.cancel_timer(notimer);
TimerID timer = omq.add_timer([&] {
if (++ticks == 3)
omq.cancel_timer(timer);
}, 5ms * TIME_DILATION);
omq.start();
wait_for([&] { return ticks.load() >= 3; });
{
auto lock = catch_lock();
REQUIRE( ticks.load() == 3 );
}
// Test the alternative taking an lvalue reference instead of returning by value (see oxenmq.h
// for why this is sometimes needed).
std::atomic<int> ticks3 = 0;
std::weak_ptr<TimerID> w_timer3;
{
auto timer3 = std::make_shared<TimerID>();
auto& t3ref = *timer3; // Get this reference *before* we move the shared pointer into the lambda
omq.add_timer(t3ref, [&ticks3, &omq, timer3=std::move(timer3)] {
if (ticks3 == 0)
ticks3++;
else if (ticks3 > 1) {
omq.cancel_timer(*timer3);
ticks3++;
}
}, 1ms);
}
wait_for([&] { return ticks3.load() >= 1; });
{
auto lock = catch_lock();
REQUIRE( ticks3.load() == 1 );
}
ticks3++;
wait_for([&] { return ticks3.load() >= 3 && w_timer3.expired(); });
{
auto lock = catch_lock();
REQUIRE( ticks3.load() == 3 );
REQUIRE( w_timer3.expired() );
}
}