mirror of https://github.com/oxen-io/oxen-mq.git
Compare commits
214 Commits
Author | SHA1 | Date |
---|---|---|
Jason Rhinelander | a27961d787 | |
Jason Rhinelander | 5878473f67 | |
Jason Rhinelander | 68b3420bad | |
Jason Rhinelander | dc7fb35493 | |
Jason Rhinelander | caadd35052 | |
Jason Rhinelander | fd58ab9cac | |
Jason Rhinelander | 8f97add30f | |
Jason Rhinelander | e1b66ced48 | |
Jason Rhinelander | 4f3ee28784 | |
Jason Rhinelander | bd3e2cdfb0 | |
Jason Rhinelander | b8bb10eac5 | |
Jason Rhinelander | ff0e515c51 | |
Jason Rhinelander | 2e308d4f43 | |
Jason Rhinelander | 445f214840 | |
Jason Rhinelander | 358005df06 | |
Thomas Winget | 85437d167b | |
Jason Rhinelander | b26fe8cb04 | |
Jason Rhinelander | df19d1dd94 | |
Jason Rhinelander | 25f714371b | |
Jason Rhinelander | 0858dd278b | |
Jason Rhinelander | 057685b7c0 | |
Jason Rhinelander | 3a3ffa7d23 | |
Jason Rhinelander | edcde9246a | |
Sean | c854046684 | |
Sean Darcy | c91e56cf2d | |
Jason Rhinelander | 61b7505304 | |
Jason Rhinelander | b0c3bd4ee9 | |
Jason Rhinelander | fd95919704 | |
Jason Rhinelander | 4671af3ca0 | |
Jason Rhinelander | c4b7aa9b23 | |
Jason Rhinelander | 115c5550ca | |
Jason Rhinelander | ace6ea9d8e | |
Jason Rhinelander | 62a803f371 | |
Jason Rhinelander | d86ecb3a70 | |
Jason Rhinelander | 45791d3a19 | |
Jason Rhinelander | b8e4eb148f | |
Jason Rhinelander | fa6de369b2 | |
Jason Rhinelander | 371606cde0 | |
Jason Rhinelander | 3a51713396 | |
Jason Rhinelander | 5c7f6504d2 | |
Jason Rhinelander | 5a3c12e721 | |
xutaxkamay | f0c2222d6e | |
Jason Rhinelander | 320a85ac0c | |
Jason Rhinelander | 7fca36b3a9 | |
Jason Rhinelander | bbdf4af98f | |
Jason Rhinelander | 77c4840273 | |
Jason Rhinelander | d7f5efebc1 | |
Jason Rhinelander | a0a54ed461 | |
Jason Rhinelander | 045df9cb9b | |
Jason Rhinelander | 3d178ce3ea | |
Jason Rhinelander | fe8a1f4306 | |
Jason Rhinelander | 3b634329ac | |
Jason Rhinelander | f88691b7e9 | |
Jason Rhinelander | 9c022b29de | |
Jason Rhinelander | 4d68868482 | |
Jason Rhinelander | 430951bf3c | |
Jason Rhinelander | 03749c87f0 | |
Jason Rhinelander | 85d35fa505 | |
Jason Rhinelander | e180187746 | |
Jason Rhinelander | e382373f2e | |
Jason Rhinelander | 375cfab4ce | |
Jason Rhinelander | f04bd72a4c | |
Jason Rhinelander | 31f64821f8 | |
Jason Rhinelander | a53e1f1786 | |
Jason Rhinelander | 39b6d89037 | |
Jason Rhinelander | f0bb2c3d3f | |
Jason Rhinelander | 09f3de2232 | |
Jason Rhinelander | 519a107542 | |
Jason Rhinelander | 23c2d537a3 | |
Jason Rhinelander | 6a386b7d4a | |
Jason Rhinelander | 5e9b8c0948 | |
Jason Rhinelander | 560d38d069 | |
Jason Rhinelander | 504d0d10ea | |
Jason Rhinelander | 7695e770a7 | |
Jason Rhinelander | 0d0ed8efa9 | |
Jason Rhinelander | 02a542b9c6 | |
Jason Rhinelander | 9a8adb5bfd | |
Jason Rhinelander | ee1d69f333 | |
Jason Rhinelander | 24dd7a3854 | |
Jason Rhinelander | cd56ad8e08 | |
Jason Rhinelander | 6100802f82 | |
Jason Rhinelander | 7cb7c2fd6d | |
Jeff Becker | 5a41e84378 | |
Jason Rhinelander | 377932607c | |
Jason Rhinelander | cdd21a9e81 | |
Jason Rhinelander | 977bced84e | |
Jason Rhinelander | 9e3469d968 | |
Jason Rhinelander | f12a48a195 | |
Jason Rhinelander | e1b1a84c4b | |
Jason Rhinelander | 2ac4379fa6 | |
Jason Rhinelander | ae884d2f13 | |
Jason Rhinelander | 45f358ab5f | |
Jason Rhinelander | c6ae1faefa | |
Jason Rhinelander | 719d33f1cc | |
Jason Rhinelander | f553085558 | |
Jason Rhinelander | bae71ec6a8 | |
Jason Rhinelander | 29cd543af9 | |
Jason Rhinelander | 917c7d64c5 | |
Jason Rhinelander | 4a24ac9baa | |
Jason Rhinelander | e1d21d3faf | |
Jason Rhinelander | 1d2246cda8 | |
Jason Rhinelander | 3bb32a81ff | |
Jason Rhinelander | 9e0d2e24f6 | |
Jason Rhinelander | 4a6bb3f702 | |
Jason Rhinelander | ad04c53c0e | |
Jason Rhinelander | 7ba81a7d50 | |
Jason Rhinelander | 45db87f712 | |
Jason Rhinelander | a0642a894e | |
Jason Rhinelander | 5dd7c12219 | |
Jason Rhinelander | dccbd1e8cd | |
Jason Rhinelander | 780246858f | |
Jason Rhinelander | 0287f7834e | |
Jason Rhinelander | cdc6a9709c | |
Jason Rhinelander | 3991f50547 | |
Jason Rhinelander | 26745299ed | |
Jason Rhinelander | 4ef1060e3f | |
Jason Rhinelander | 5ccacafdb1 | |
Jason Rhinelander | 6d20a3614a | |
Jason Rhinelander | 39dce56e14 | |
Jason Rhinelander | ac58e5b574 | |
Jason Rhinelander | 99a3f1d840 | |
Jason Rhinelander | dc40ebd428 | |
Jason Rhinelander | e3e79e1fb7 | |
Jason Rhinelander | f9ef827075 | |
Jason Rhinelander | 506bd65b05 | |
Jason Rhinelander | 86247bc5c7 | |
Jason Rhinelander | 396f591fae | |
Jason Rhinelander | b49a94fb83 | |
Jason Rhinelander | 0738695eb9 | |
Jason Rhinelander | 2ae6b96016 | |
Jason Rhinelander | bd9313bf19 | |
Jason Rhinelander | 1959f8747d | |
Jason Rhinelander | 90701e5d62 | |
Jason Rhinelander | 178bd4f674 | |
Jason Rhinelander | b1543513bb | |
Jason Rhinelander | 253f1ee66e | |
Jason Rhinelander | d889f308ae | |
Jason Rhinelander | 768a639dea | |
Jason Rhinelander | ec0d44e143 | |
Jason Rhinelander | ea484729c7 | |
Jason Rhinelander | 7049d3cb5a | |
Jason Rhinelander | 8ed529200b | |
Jason Rhinelander | 318781a6d4 | |
Thomas Winget | f37e619d7b | |
Jason Rhinelander | 0ac1d48bc8 | |
Jeff Becker | 0938e1fc53 | |
Jeff Becker | 0c9eeeea43 | |
Jason Rhinelander | 9467c4682c | |
Jason Rhinelander | 8c28c52d41 | |
Jason Rhinelander | faeeaa86d4 | |
Jason Rhinelander | 8d3ed4606f | |
Jason Rhinelander | 30faadf01a | |
Jason Rhinelander | d8d1d8677c | |
Jason Rhinelander | e5cf174b83 | |
Jason Rhinelander | af189a8d72 | |
Jason Rhinelander | d2f852c217 | |
Jason Rhinelander | ee080e0550 | |
Jason Rhinelander | 7cd58e4677 | |
Jason Rhinelander | 9c54264321 | |
Jason Rhinelander | 932bbb33d7 | |
Jason Rhinelander | 07b31bd8a1 | |
Jason Rhinelander | 8a56b18cc6 | |
Jason Rhinelander | 1d56c3d44c | |
Jason Rhinelander | 66176d44d7 | |
Jason Rhinelander | 4e89dce5b6 | |
Jason Rhinelander | 0493f615b9 | |
Jason Rhinelander | d0a73e5e68 | |
Jason Rhinelander | 278909db77 | |
Jason Rhinelander | 3edcab9344 | |
Jason Rhinelander | ae8dd27cdd | |
Jason Rhinelander | 8caab97355 | |
Jason Rhinelander | 44b91534c2 | |
Jason Rhinelander | 29380922bf | |
Jason Rhinelander | 6356421488 | |
Jason Rhinelander | d28e39ffeb | |
Jason Rhinelander | 9a103f1bf6 | |
Jason Rhinelander | 211d5211b0 | |
Jason Rhinelander | 9a283a148c | |
Jason Rhinelander | 65aa5940be | |
Jason Rhinelander | ec9c58ea34 | |
Jason Rhinelander | af59d58797 | |
Jason Rhinelander | e072e68d84 | |
Jason Rhinelander | e5a8d09127 | |
Jason Rhinelander | a24e87d4d0 | |
Jason Rhinelander | 9ac47ec419 | |
Jason Rhinelander | d0a07f7c08 | |
Jason Rhinelander | 86f5b463e9 | |
Jason Rhinelander | 68c1899cda | |
Jason Rhinelander | 1479a030d7 | |
Stephen Shelton | f296b82ba5 | |
Jason Rhinelander | 1f60abf50e | |
Jason Rhinelander | de395af872 | |
Jason Rhinelander | e970f14e55 | |
Jason Rhinelander | 1e38f3b1d1 | |
Jason Rhinelander | c9cf833861 | |
Jason Rhinelander | 7b42537801 | |
Jason Rhinelander | 8984dfc4ea | |
Jason Rhinelander | be4cbc6641 | |
Jason Rhinelander | 46d007e1ac | |
Jason Rhinelander | 59a41943d4 | |
Jason Rhinelander | 719a9b0b58 | |
Jason Rhinelander | 22559548fc | |
Jason Rhinelander | 7b552007df | |
Jason Rhinelander | b905a8a4ff | |
Jason Rhinelander | 08a11bb9ba | |
Jason Rhinelander | 3a0508fdce | |
Jason Rhinelander | f4f1506df0 | |
Jason Rhinelander | a812abd422 | |
Jason Rhinelander | 730633bbae | |
Jason Rhinelander | 99bbf8dea9 | |
Jason Rhinelander | 1a65d7f5e5 | |
Jason Rhinelander | e7cd2dedc2 | |
Jason Rhinelander | 6ddf033674 | |
Jason Rhinelander | 0ebfef2164 |
|
@ -0,0 +1,117 @@
|
|||
local docker_base = 'registry.oxen.rocks/lokinet-ci-';
|
||||
|
||||
local default_deps_nocxx = ['libsodium-dev', 'libzmq3-dev', 'liboxenc-dev'];
|
||||
|
||||
local submodule_commands = ['git fetch --tags', 'git submodule update --init --recursive --depth=1'];
|
||||
|
||||
local submodules = {
|
||||
name: 'submodules',
|
||||
image: 'drone/git',
|
||||
commands: submodule_commands,
|
||||
};
|
||||
|
||||
local apt_get_quiet = 'apt-get -o=Dpkg::Use-Pty=0 -q ';
|
||||
|
||||
local debian_pipeline(name,
|
||||
image,
|
||||
arch='amd64',
|
||||
deps=['g++'] + default_deps_nocxx,
|
||||
cmake_extra='',
|
||||
build_type='Release',
|
||||
extra_cmds=[],
|
||||
distro='$$(lsb_release -sc)',
|
||||
allow_fail=false) = {
|
||||
kind: 'pipeline',
|
||||
type: 'docker',
|
||||
name: name,
|
||||
platform: { arch: arch },
|
||||
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
|
||||
steps: [
|
||||
submodules,
|
||||
{
|
||||
name: 'build',
|
||||
image: image,
|
||||
pull: 'always',
|
||||
[if allow_fail then 'failure']: 'ignore',
|
||||
commands: [
|
||||
'echo "Building on ${DRONE_STAGE_MACHINE}"',
|
||||
'echo "man-db man-db/auto-update boolean false" | debconf-set-selections',
|
||||
apt_get_quiet + 'update',
|
||||
apt_get_quiet + 'install -y eatmydata',
|
||||
'eatmydata ' + apt_get_quiet + ' install --no-install-recommends -y lsb-release',
|
||||
'cp contrib/deb.oxen.io.gpg /etc/apt/trusted.gpg.d',
|
||||
'echo deb http://deb.oxen.io ' + distro + ' main >/etc/apt/sources.list.d/oxen.list',
|
||||
'eatmydata ' + apt_get_quiet + ' update',
|
||||
'eatmydata ' + apt_get_quiet + 'dist-upgrade -y',
|
||||
'eatmydata ' + apt_get_quiet + 'install -y cmake git ninja-build pkg-config ccache ' + std.join(' ', deps),
|
||||
'mkdir build',
|
||||
'cd build',
|
||||
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fdiagnostics-color=always -DCMAKE_BUILD_TYPE=' + build_type + ' -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ' + cmake_extra,
|
||||
'ninja -v',
|
||||
'./tests/tests --use-colour yes',
|
||||
] + extra_cmds,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
local clang(version) = debian_pipeline(
|
||||
'Debian sid/clang-' + version + ' (amd64)',
|
||||
docker_base + 'debian-sid-clang',
|
||||
distro='sid',
|
||||
deps=['clang-' + version] + default_deps_nocxx,
|
||||
cmake_extra='-DCMAKE_C_COMPILER=clang-' + version + ' -DCMAKE_CXX_COMPILER=clang++-' + version + ' '
|
||||
);
|
||||
|
||||
local full_llvm(version) = debian_pipeline(
|
||||
'Debian sid/llvm-' + version + ' (amd64)',
|
||||
docker_base + 'debian-sid-clang',
|
||||
distro='sid',
|
||||
deps=['clang-' + version, 'lld-' + version, 'libc++-' + version + '-dev', 'libc++abi-' + version + '-dev']
|
||||
+ default_deps_nocxx,
|
||||
cmake_extra='-DCMAKE_C_COMPILER=clang-' + version +
|
||||
' -DCMAKE_CXX_COMPILER=clang++-' + version +
|
||||
' -DCMAKE_CXX_FLAGS=-stdlib=libc++ ' +
|
||||
std.join(' ', [
|
||||
'-DCMAKE_' + type + '_LINKER_FLAGS=-fuse-ld=lld-' + version
|
||||
for type in ['EXE', 'MODULE', 'SHARED', 'STATIC']
|
||||
])
|
||||
);
|
||||
|
||||
|
||||
[
|
||||
debian_pipeline('Debian sid (amd64)', docker_base + 'debian-sid', distro='sid'),
|
||||
debian_pipeline('Debian sid/Debug (amd64)', docker_base + 'debian-sid', build_type='Debug', distro='sid'),
|
||||
clang(16),
|
||||
full_llvm(16),
|
||||
debian_pipeline('Debian buster (amd64)', docker_base + 'debian-buster'),
|
||||
debian_pipeline('Debian stable (i386)', docker_base + 'debian-stable/i386'),
|
||||
debian_pipeline('Debian sid (ARM64)', docker_base + 'debian-sid', arch='arm64', distro='sid'),
|
||||
debian_pipeline('Debian stable (armhf)', docker_base + 'debian-stable/arm32v7', arch='arm64'),
|
||||
debian_pipeline('Debian buster (armhf)', docker_base + 'debian-buster/arm32v7', arch='arm64'),
|
||||
debian_pipeline('Ubuntu focal (amd64)', docker_base + 'ubuntu-focal'),
|
||||
debian_pipeline('Ubuntu bionic (amd64)',
|
||||
docker_base + 'ubuntu-bionic',
|
||||
deps=default_deps_nocxx,
|
||||
cmake_extra='-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8'),
|
||||
{
|
||||
kind: 'pipeline',
|
||||
type: 'exec',
|
||||
name: 'macOS (w/macports)',
|
||||
platform: { os: 'darwin', arch: 'amd64' },
|
||||
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
|
||||
steps: [
|
||||
{ name: 'submodules', commands: submodule_commands },
|
||||
{
|
||||
name: 'build',
|
||||
commands: [
|
||||
'mkdir build',
|
||||
'cd build',
|
||||
'ulimit -n 1024', // Because macOS has a stupid tiny default ulimit
|
||||
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fcolor-diagnostics -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
|
||||
'ninja -v',
|
||||
'./tests/tests --use-colour yes',
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
|
@ -1,9 +1,9 @@
|
|||
[submodule "mapbox-variant"]
|
||||
path = mapbox-variant
|
||||
url = https://github.com/mapbox/variant.git
|
||||
[submodule "cppzmq"]
|
||||
path = cppzmq
|
||||
url = https://github.com/zeromq/cppzmq.git
|
||||
[submodule "Catch2"]
|
||||
path = tests/Catch2
|
||||
url = https://github.com/catchorg/Catch2.git
|
||||
[submodule "oxen-encoding"]
|
||||
path = oxen-encoding
|
||||
url = https://github.com/oxen-io/oxen-encoding.git
|
||||
|
|
251
CMakeLists.txt
251
CMakeLists.txt
|
@ -1,87 +1,163 @@
|
|||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
foreach(lang C CXX)
|
||||
if(NOT DEFINED CMAKE_${lang}_COMPILER_LAUNCHER AND NOT CMAKE_${lang}_COMPILER MATCHES ".*/ccache")
|
||||
message(STATUS "Enabling ccache for ${lang}")
|
||||
set(CMAKE_${lang}_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE STRING "")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
cmake_minimum_required(VERSION 3.7)
|
||||
|
||||
project(liblokimq CXX)
|
||||
# Has to be set before `project()`, and ignored on non-macos:
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET 10.12 CACHE STRING "macOS deployment target (Apple clang only)")
|
||||
|
||||
project(liboxenmq
|
||||
VERSION 1.2.16
|
||||
LANGUAGES CXX C)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
set(LOKIMQ_VERSION_MAJOR 1)
|
||||
set(LOKIMQ_VERSION_MINOR 1)
|
||||
set(LOKIMQ_VERSION_PATCH 1)
|
||||
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
|
||||
message(STATUS "lokimq v${LOKIMQ_VERSION}")
|
||||
message(STATUS "oxenmq v${PROJECT_VERSION}")
|
||||
|
||||
set(LOKIMQ_LIBVERSION 0)
|
||||
set(OXENMQ_LIBVERSION 0)
|
||||
|
||||
|
||||
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
|
||||
set(oxenmq_IS_TOPLEVEL_PROJECT TRUE)
|
||||
else()
|
||||
set(oxenmq_IS_TOPLEVEL_PROJECT FALSE)
|
||||
endif()
|
||||
|
||||
|
||||
option(BUILD_SHARED_LIBS "Build shared libraries instead of static ones" ON)
|
||||
set(oxenmq_INSTALL_DEFAULT OFF)
|
||||
if(BUILD_SHARED_LIBS OR oxenmq_IS_TOPLEVEL_PROJECT)
|
||||
set(oxenmq_INSTALL_DEFAULT ON)
|
||||
endif()
|
||||
set(oxenmq_EPOLL_DEFAULT OFF)
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND NOT CMAKE_CROSSCOMPILING)
|
||||
set(oxenmq_EPOLL_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
option(OXENMQ_BUILD_TESTS "Building and perform oxenmq tests" ${oxenmq_IS_TOPLEVEL_PROJECT})
|
||||
option(OXENMQ_INSTALL "Add oxenmq libraries and headers to cmake install target; defaults to ON if BUILD_SHARED_LIBS is enabled or we are the top-level project; OFF for a static subdirectory build" ${oxenmq_INSTALL_DEFAULT})
|
||||
option(OXENMQ_INSTALL_CPPZMQ "Install cppzmq header with oxenmq/ headers (requires OXENMQ_INSTALL)" ON)
|
||||
option(OXENMQ_USE_EPOLL "Use epoll for socket polling (requires Linux)" ${oxenmq_EPOLL_DEFAULT})
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
configure_file(lokimq/version.h.in lokimq/version.h @ONLY)
|
||||
configure_file(liblokimq.pc.in liblokimq.pc @ONLY)
|
||||
configure_file(oxenmq/version.h.in oxenmq/version.h @ONLY)
|
||||
configure_file(liboxenmq.pc.in liboxenmq.pc @ONLY)
|
||||
|
||||
add_library(lokimq
|
||||
lokimq/auth.cpp
|
||||
lokimq/bt_serialize.cpp
|
||||
lokimq/connections.cpp
|
||||
lokimq/jobs.cpp
|
||||
lokimq/lokimq.cpp
|
||||
lokimq/proxy.cpp
|
||||
lokimq/worker.cpp
|
||||
|
||||
add_library(oxenmq
|
||||
oxenmq/address.cpp
|
||||
oxenmq/auth.cpp
|
||||
oxenmq/connections.cpp
|
||||
oxenmq/jobs.cpp
|
||||
oxenmq/oxenmq.cpp
|
||||
oxenmq/proxy.cpp
|
||||
oxenmq/worker.cpp
|
||||
)
|
||||
set_target_properties(lokimq PROPERTIES SOVERSION ${LOKIMQ_LIBVERSION})
|
||||
set_target_properties(oxenmq PROPERTIES SOVERSION ${OXENMQ_LIBVERSION})
|
||||
if(OXENMQ_USE_EPOLL)
|
||||
target_compile_definitions(oxenmq PRIVATE OXENMQ_USE_EPOLL)
|
||||
endif()
|
||||
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
target_link_libraries(lokimq PRIVATE Threads::Threads)
|
||||
target_link_libraries(oxenmq PRIVATE Threads::Threads)
|
||||
|
||||
|
||||
if(TARGET oxenc::oxenc)
|
||||
add_library(_oxenmq_external_oxenc INTERFACE IMPORTED)
|
||||
target_link_libraries(_oxenmq_external_oxenc INTERFACE oxenc::oxenc)
|
||||
target_link_libraries(oxenmq PUBLIC _oxenmq_external_oxenc)
|
||||
message(STATUS "using pre-existing oxenc::oxenc target")
|
||||
elseif(BUILD_SHARED_LIBS)
|
||||
include(FindPkgConfig)
|
||||
pkg_check_modules(oxenc liboxenc IMPORTED_TARGET)
|
||||
|
||||
if(oxenc_FOUND)
|
||||
# Work around cmake bug 22180 (PkgConfig::tgt not set if no flags needed)
|
||||
if(TARGET PkgConfig::oxenc OR CMAKE_VERSION VERSION_GREATER_EQUAL "3.21")
|
||||
target_link_libraries(oxenmq PUBLIC PkgConfig::oxenc)
|
||||
endif()
|
||||
else()
|
||||
add_subdirectory(oxen-encoding)
|
||||
target_link_libraries(oxenmq PUBLIC oxenc::oxenc)
|
||||
endif()
|
||||
else()
|
||||
add_subdirectory(oxen-encoding)
|
||||
target_link_libraries(oxenmq PUBLIC oxenc::oxenc)
|
||||
endif()
|
||||
|
||||
# libzmq is nearly impossible to link statically from a system-installed static library: it depends
|
||||
# on a ton of other libraries, some of which are not all statically available. If the caller wants
|
||||
# to mess with this, so be it: they can set up a libzmq target and we'll use it. Otherwise if they
|
||||
# asked us to do things statically, don't even try to find a system lib and just build it.
|
||||
set(lokimq_build_static_libzmq OFF)
|
||||
set(oxenmq_build_static_libzmq OFF)
|
||||
if(TARGET libzmq)
|
||||
target_link_libraries(lokimq PUBLIC libzmq)
|
||||
target_link_libraries(oxenmq PUBLIC libzmq)
|
||||
elseif(BUILD_SHARED_LIBS)
|
||||
include(FindPkgConfig)
|
||||
pkg_check_modules(libzmq libzmq>=4.3 IMPORTED_TARGET)
|
||||
|
||||
if(libzmq_FOUND)
|
||||
target_link_libraries(lokimq PUBLIC PkgConfig::libzmq)
|
||||
# Debian sid includes a -isystem in the mit-krb package that, starting with pkg-config 0.29.2,
|
||||
# breaks cmake's pkgconfig module because it stupidly thinks "-isystem" is a path, so if we find
|
||||
# -isystem in the include dirs then hack it out.
|
||||
get_property(zmq_inc TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
|
||||
list(FIND zmq_inc "-isystem" broken_isystem)
|
||||
if(NOT broken_isystem EQUAL -1)
|
||||
list(REMOVE_AT zmq_inc ${broken_isystem})
|
||||
set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc})
|
||||
endif()
|
||||
|
||||
target_link_libraries(oxenmq PUBLIC PkgConfig::libzmq)
|
||||
else()
|
||||
set(lokimq_build_static_libzmq ON)
|
||||
set(oxenmq_build_static_libzmq ON)
|
||||
endif()
|
||||
else()
|
||||
set(lokimq_build_static_libzmq ON)
|
||||
set(oxenmq_build_static_libzmq ON)
|
||||
endif()
|
||||
|
||||
if(lokimq_build_static_libzmq)
|
||||
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled 4.3.2")
|
||||
if(oxenmq_build_static_libzmq)
|
||||
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled version")
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/local-libzmq")
|
||||
include(LocalLibzmq)
|
||||
target_link_libraries(lokimq PUBLIC libzmq_vendor)
|
||||
target_link_libraries(oxenmq PUBLIC libzmq_vendor)
|
||||
endif()
|
||||
|
||||
target_include_directories(lokimq
|
||||
target_include_directories(oxenmq
|
||||
PUBLIC
|
||||
$<INSTALL_INTERFACE:>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/cppzmq>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/mapbox-variant/include>
|
||||
)
|
||||
|
||||
target_compile_options(lokimq PRIVATE -Wall -Wextra -Werror)
|
||||
set_target_properties(lokimq PROPERTIES
|
||||
CXX_STANDARD 14
|
||||
target_compile_options(oxenmq PRIVATE -Wall -Wextra)
|
||||
|
||||
option(WARNINGS_AS_ERRORS "treat all warnings as errors" ON)
|
||||
if(WARNINGS_AS_ERRORS)
|
||||
target_compile_options(oxenmq PRIVATE -Werror)
|
||||
endif()
|
||||
|
||||
set_target_properties(oxenmq PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD 17
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CXX_EXTENSIONS OFF
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
)
|
||||
|
||||
function(link_dep_libs target linktype libdirs)
|
||||
foreach(lib ${ARGN})
|
||||
find_library(link_lib-${lib} NAMES ${lib} PATHS ${libdirs})
|
||||
message(STATUS "FIND ${lib} FOUND ${link_lib-${lib}}")
|
||||
if(link_lib-${lib})
|
||||
target_link_libraries(${target} ${linktype} ${link_lib-${lib}})
|
||||
endif()
|
||||
|
@ -91,87 +167,72 @@ endfunction()
|
|||
# If the caller has already set up a sodium target then we will just link to it, otherwise we go
|
||||
# looking for it.
|
||||
if(TARGET sodium)
|
||||
target_link_libraries(lokimq PRIVATE sodium)
|
||||
if(lokimq_build_static_libzmq)
|
||||
target_link_libraries(oxenmq PUBLIC sodium)
|
||||
if(oxenmq_build_static_libzmq)
|
||||
target_link_libraries(libzmq_vendor INTERFACE sodium)
|
||||
endif()
|
||||
else()
|
||||
include(FindPkgConfig)
|
||||
pkg_check_modules(sodium REQUIRED libsodium IMPORTED_TARGET)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_libraries(lokimq PRIVATE PkgConfig::sodium)
|
||||
if(lokimq_build_static_libzmq)
|
||||
target_link_libraries(oxenmq PUBLIC PkgConfig::sodium)
|
||||
if(oxenmq_build_static_libzmq)
|
||||
target_link_libraries(libzmq_vendor INTERFACE PkgConfig::sodium)
|
||||
endif()
|
||||
else()
|
||||
link_dep_libs(lokimq PRIVATE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_include_directories(lokimq PRIVATE ${sodium_STATIC_INCLUDE_DIRS})
|
||||
if(lokimq_build_static_libzmq)
|
||||
link_dep_libs(oxenmq PUBLIC "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_include_directories(oxenmq PUBLIC ${sodium_STATIC_INCLUDE_DIRS})
|
||||
if(oxenmq_build_static_libzmq)
|
||||
link_dep_libs(libzmq_vendor INTERFACE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_link_libraries(libzmq_vendor INTERFACE ${sodium_STATIC_INCLUDE_DIRS})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(lokimq::lokimq ALIAS lokimq)
|
||||
add_library(oxenmq::oxenmq ALIAS oxenmq)
|
||||
|
||||
export(
|
||||
TARGETS lokimq
|
||||
NAMESPACE lokimq::
|
||||
FILE lokimqTargets.cmake
|
||||
)
|
||||
install(
|
||||
TARGETS lokimq
|
||||
EXPORT lokimqConfig
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
TARGETS oxenmq
|
||||
NAMESPACE oxenmq::
|
||||
FILE oxenmqTargets.cmake
|
||||
)
|
||||
|
||||
install(
|
||||
FILES lokimq/auth.h
|
||||
lokimq/batch.h
|
||||
lokimq/bt_serialize.h
|
||||
lokimq/connections.h
|
||||
lokimq/hex.h
|
||||
lokimq/lokimq.h
|
||||
lokimq/message.h
|
||||
lokimq/string_view.h
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lokimq/version.h
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
|
||||
)
|
||||
option(LOKIMQ_INSTALL_MAPBOX_VARIANT "Install mapbox-variant headers with lokimq/ headers" ON)
|
||||
if(LOKIMQ_INSTALL_MAPBOX_VARIANT)
|
||||
install(
|
||||
FILES mapbox-variant/include/mapbox/variant.hpp
|
||||
mapbox-variant/include/mapbox/variant_cast.hpp
|
||||
mapbox-variant/include/mapbox/variant_io.hpp
|
||||
mapbox-variant/include/mapbox/variant_visitor.hpp
|
||||
mapbox-variant/include/mapbox/recursive_wrapper.hpp
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq/mapbox
|
||||
)
|
||||
if(OXENMQ_INSTALL)
|
||||
install(
|
||||
TARGETS oxenmq
|
||||
EXPORT oxenmqConfig
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES oxenmq/address.h
|
||||
oxenmq/auth.h
|
||||
oxenmq/batch.h
|
||||
oxenmq/connections.h
|
||||
oxenmq/fmt.h
|
||||
oxenmq/message.h
|
||||
oxenmq/oxenmq.h
|
||||
oxenmq/pubsub.h
|
||||
${CMAKE_CURRENT_BINARY_DIR}/oxenmq/version.h
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq
|
||||
)
|
||||
|
||||
if(OXENMQ_INSTALL_CPPZMQ)
|
||||
install(
|
||||
FILES cppzmq/zmq.hpp
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/liboxenmq.pc
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
option(LOKIMQ_INSTALL_CPPZMQ "Install cppzmq header with lokimq/ headers" ON)
|
||||
if(LOKIMQ_INSTALL_CPPZMQ)
|
||||
install(
|
||||
FILES cppzmq/zmq.hpp
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/liblokimq.pc
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig
|
||||
)
|
||||
|
||||
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
|
||||
set(lokimq_IS_TOPLEVEL_PROJECT TRUE)
|
||||
else()
|
||||
set(lokimq_IS_TOPLEVEL_PROJECT FALSE)
|
||||
endif()
|
||||
|
||||
option(LOKIMQ_BUILD_TESTS "Building and perform lokimq tests" ${lokimq_IS_TOPLEVEL_PROJECT})
|
||||
if(LOKIMQ_BUILD_TESTS)
|
||||
if(OXENMQ_BUILD_TESTS)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
|
|
83
README.md
83
README.md
|
@ -1,14 +1,15 @@
|
|||
# LokiMQ - zeromq-based message passing for Loki projects
|
||||
# OxenMQ - high-level zeromq-based message passing for network-based projects
|
||||
|
||||
This C++14 library contains an abstraction layer around ZeroMQ to support integration with Loki
|
||||
authentication, RPC, and message passing. It is designed to be usable as the underlying
|
||||
communication mechanism of SN-to-SN communication ("quorumnet"), the RPC interface used by wallets
|
||||
and local daemon commands, communication channels between lokid and auxiliary services (storage
|
||||
server, lokinet), and also provides a local multithreaded job scheduling within a process.
|
||||
This C++17 library contains an abstraction layer around ZeroMQ to provide a high-level interface to
|
||||
authentication, RPC, and message passing. It is used extensively within Oxen projects (hence the
|
||||
name) as the underlying communication mechanism of SN-to-SN communication ("quorumnet"), the RPC
|
||||
interface used by wallets and local daemon commands, communication channels between oxend and
|
||||
auxiliary services (storage server, lokinet), and also provides local multithreaded job scheduling
|
||||
within a process.
|
||||
|
||||
Messages channels can be encrypted (using x25519) or not -- however opening an encrypted channel
|
||||
requires knowing the server pubkey. All SN-to-SN traffic is encrypted, and other traffic can be
|
||||
encrypted as needed.
|
||||
requires knowing the server pubkey. Within Oxen, all SN-to-SN traffic is encrypted, and other
|
||||
traffic can be encrypted as needed.
|
||||
|
||||
This library makes minimal use of mutexes, and none in the hot paths of the code, instead mostly
|
||||
relying on ZMQ sockets for synchronization; for more information on this (and why this is generally
|
||||
|
@ -16,20 +17,20 @@ much better performing and more scalable) see the ZMQ guide documentation on the
|
|||
|
||||
## Basic message structure
|
||||
|
||||
LokiMQ messages come in two fundamental forms: "commands", consisting of a command named and
|
||||
OxenMQ messages come in two fundamental forms: "commands", consisting of a command named and
|
||||
optional arguments, and "requests", consisting of a request name, a request tag, and optional
|
||||
arguments.
|
||||
|
||||
All channels are capable of bidirectional communication, and multiple messages can be in transit in
|
||||
either direction at any time. LokiMQ sets up a "listener" and "client" connections, but these only
|
||||
either direction at any time. OxenMQ sets up a "listener" and "client" connections, but these only
|
||||
determine how connections are established: once established, commands can be issued by either party.
|
||||
|
||||
The command/request string is one of two types:
|
||||
|
||||
`category.command` - for commands/requests registered by the LokiMQ caller (e.g. lokid). Here
|
||||
`category.command` - for commands/requests registered by the OxenMQ caller (e.g. oxend). Here
|
||||
`category` must be at least one character not containing a `.` and `command` may be anything. These
|
||||
categories and commands are registered according to general function and authentication level (more
|
||||
on this below). For example, for lokid categories are:
|
||||
on this below). For example, for oxend categories are:
|
||||
|
||||
- `system` - is for RPC commands related to the system administration such as mining, getting
|
||||
sensitive statistics, accessing SN private keys, remote shutdown, etc.
|
||||
|
@ -42,14 +43,14 @@ on this below). For example, for lokid categories are:
|
|||
The difference between a request and a command is that a request includes an additional opaque tag
|
||||
value which is used to identify a reply. For example you could register a `general.backwards`
|
||||
request that takes a string that receives a reply containing that string reversed. When invoking
|
||||
the request via LokiMQ you provide a callback to be invoked when the reply arrives. On the wire
|
||||
the request via OxenMQ you provide a callback to be invoked when the reply arrives. On the wire
|
||||
this looks like:
|
||||
|
||||
<<< [general.backwards] [v71.&a] [hello world]
|
||||
>>> [REPLY] [v71.&a] [dlrow olleh]
|
||||
|
||||
where each [] denotes a message part and `v71.&a` is a unique randomly generated identifier handled
|
||||
by LokiMQ (both the invoker and the recipient code only see the `hello world`/`dlrow olleh` message
|
||||
by OxenMQ (both the invoker and the recipient code only see the `hello world`/`dlrow olleh` message
|
||||
parts).
|
||||
|
||||
In contrast, regular registered commands have no identifier or expected reply callback. For example
|
||||
|
@ -92,7 +93,7 @@ handled for you transparently.
|
|||
|
||||
## Command arguments
|
||||
|
||||
Optional command/request arguments are always strings on the wire. The LokiMQ-using developer is
|
||||
Optional command/request arguments are always strings on the wire. The OxenMQ-using developer is
|
||||
free to create whatever encoding she wants, and these can vary across commands. For example
|
||||
`wallet.tx` might be a request that returns a transaction in binary, while `wallet.tx_info` might
|
||||
return tx metadata in JSON, and `p2p.send_tx` might encode tx data and metadata in a bt-encoded
|
||||
|
@ -101,47 +102,49 @@ data string.
|
|||
No structure at all is imposed on message data to allow maximum flexibility; it is entirely up to
|
||||
the calling code to handle all encoding/decoding duties.
|
||||
|
||||
Internal commands passed between LokiMQ-managed threads use either plain strings or bt-encoded
|
||||
dictionaries. See `lokimq/bt_serialize.h` if you want a bt serializer/deserializer.
|
||||
Internal commands passed between OxenMQ-managed threads use either plain strings or bt-encoded
|
||||
dictionaries. See `oxenmq/bt_serialize.h` if you want a bt serializer/deserializer.
|
||||
|
||||
## Sending commands
|
||||
|
||||
Sending a command to a peer is done by using a connection ID, and generally falls into either a
|
||||
`send()` method or a `request()` method.
|
||||
|
||||
lmq.send(conn, "category.command", "some data");
|
||||
lmq.request(conn, "category.command", [](bool success, std::vector<std::string> data) {
|
||||
omq.send(conn, "category.command", "some data");
|
||||
omq.request(conn, "category.command", [](bool success, std::vector<std::string> data) {
|
||||
if (success) { std::cout << "Remote replied: " << data.at(0) << "\n"; } });
|
||||
|
||||
The connection ID generally has two possible values:
|
||||
|
||||
- a string containing a service node pubkey. In this mode LokiMQ will look for the given SN in
|
||||
- a string containing a service node pubkey. In this mode OxenMQ will look for the given SN in
|
||||
already-established connections, reusing a connection if one exists. If no connection already
|
||||
exists, a new connection to the given SN is attempted (this requires constructing the LokiMQ
|
||||
exists, a new connection to the given SN is attempted (this requires constructing the OxenMQ
|
||||
object with a callback to determine SN remote addresses).
|
||||
- a ConnectionID object, typically returned by the `connect_remote` method (although there are other
|
||||
places to get one, such as from the `Message` object passed to a command: see the following
|
||||
section).
|
||||
|
||||
```C++
|
||||
// Send to a service node, establishing a connection if necessary:
|
||||
std::string my_sn = ...; // 32-byte pubkey of a known SN
|
||||
lmq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
|
||||
omq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
|
||||
|
||||
// Connect to a remote by address then send it something
|
||||
auto conn = lmq.connect_remote("tcp://127.0.0.1:4567",
|
||||
auto conn = omq.connect_remote("tcp://127.0.0.1:4567",
|
||||
[](ConnectionID c) { std::cout << "Connected!\n"; },
|
||||
[](ConnectionID c, string_view f) { std::cout << "Connect failed: " << f << "\n" });
|
||||
lmq.request(conn, "rpc.get_height", [](bool s, std::vector<std::string> d) {
|
||||
omq.request(conn, "rpc.get_height", [](bool s, std::vector<std::string> d) {
|
||||
if (s && d.size() == 1)
|
||||
std::cout << "Current height: " << d[0] << "\n";
|
||||
else
|
||||
std::cout << "Timeout fetching height!";
|
||||
});
|
||||
```
|
||||
|
||||
## Command invocation
|
||||
|
||||
The application registers categories and registers commands within these categories with callbacks.
|
||||
The callbacks are passed a LokiMQ::Message object from which the message (plus various connection
|
||||
The callbacks are passed a OxenMQ::Message object from which the message (plus various connection
|
||||
information) can be obtained. There is no structure imposed at all on the data passed in subsequent
|
||||
message parts: it is up to the command itself to deserialize however it wishes (e.g. JSON,
|
||||
bt-encoded, or any other encoding).
|
||||
|
@ -149,13 +152,13 @@ bt-encoded, or any other encoding).
|
|||
The Message object also provides methods for replying to the caller. Simple replies queue a reply
|
||||
if the client is still connected. Replies to service nodes can also be "strong" replies: when
|
||||
replying to a SN that has closed connection with a strong reply we will attempt to reestablish a
|
||||
connection to deliver the message. In order for this to work the LokiMQ caller must provide a
|
||||
connection to deliver the message. In order for this to work the OxenMQ caller must provide a
|
||||
lookup function to retrieve the remote address given a SN x25519 pubkey.
|
||||
|
||||
### Callbacks
|
||||
|
||||
Invoked command functions are always invoked with exactly one arguments: a non-const LokiMQ::Message
|
||||
reference from which the connection info, LokiMQ object, and message data can be obtained.
|
||||
Invoked command functions are always invoked with exactly one arguments: a non-const OxenMQ::Message
|
||||
reference from which the connection info, OxenMQ object, and message data can be obtained.
|
||||
|
||||
The Message object also contains a `ConnectionID` object as the public `conn` member; it is safe to
|
||||
take a copy of this and then use it later to send commands to this peer. (For example, a wallet
|
||||
|
@ -185,7 +188,7 @@ logins.
|
|||
Configuration defaults allows controlling the default access for an incoming connection based on its
|
||||
remote address. Typically this is used to allow connections from localhost (or a unix domain
|
||||
socket) to automatically be an Admin connection without requiring explicit authentication. This
|
||||
also allows configuration of how public connections should be treated: for example, a lokid running
|
||||
also allows configuration of how public connections should be treated: for example, an oxend running
|
||||
as a public RPC server would do so by granting Basic access to all incoming connections.
|
||||
|
||||
Explicit logins allow the daemon to specify username/passwords with mapping to Basic or Admin
|
||||
|
@ -194,7 +197,7 @@ authentication levels.
|
|||
Thus, for example, a daemon could be configured to be allow Basic remote access with authentication
|
||||
(i.e. requiring a username/password login given out to people who should be able to access).
|
||||
|
||||
For example, in lokid the categories described above have authentication levels of:
|
||||
For example, in oxend the categories described above have authentication levels of:
|
||||
|
||||
- `system` - Admin
|
||||
- `sn` - ServiceNode
|
||||
|
@ -203,7 +206,7 @@ For example, in lokid the categories described above have authentication levels
|
|||
|
||||
### Service Node authentication
|
||||
|
||||
In order to handle ServiceNode authentication, LokiMQ uses an Allow callback invoked during
|
||||
In order to handle ServiceNode authentication, OxenMQ uses an Allow callback invoked during
|
||||
connection to determine both whether to allow the connection, and to determine whether the incoming
|
||||
connection is an active service node.
|
||||
|
||||
|
@ -224,7 +227,7 @@ such aliases be used only temporarily for version transitions.
|
|||
|
||||
## Threads
|
||||
|
||||
LokiMQ operates a pool of worker threads to handle jobs. The simplest use just allocates new jobs
|
||||
OxenMQ operates a pool of worker threads to handle jobs. The simplest use just allocates new jobs
|
||||
to a free worker thread, and we have a "general threads" value to configure how many such threads
|
||||
are available.
|
||||
|
||||
|
@ -239,7 +242,7 @@ Note that these actual reserved threads are not exclusive: reserving M of N tota
|
|||
category simply ensures that no more than (N-M) threads are being used for other categories at any
|
||||
given time, but the actual jobs may run on any worker thread.
|
||||
|
||||
As mentioned above, LokiMQ tries to avoid exceeding the configured general threads value (G)
|
||||
As mentioned above, OxenMQ tries to avoid exceeding the configured general threads value (G)
|
||||
whenever possible: the only time we will dispatch a job to a worker thread when we have >= G threads
|
||||
already running is when a new command arrives, the category reserves M threads, and the thread pool
|
||||
is currently processing fewer than M jobs for that category.
|
||||
|
@ -275,7 +278,7 @@ when a command with reserve threads arrived.
|
|||
A common pattern is one where a single thread suddenly has some work that can be be parallelized.
|
||||
You could employ some blocking, locking, mutex + condition variable monstrosity, but you shouldn't.
|
||||
|
||||
Instead LokiMQ provides a mechanism for this by allowing you to submit a batch of jobs with a
|
||||
Instead OxenMQ provides a mechanism for this by allowing you to submit a batch of jobs with a
|
||||
completion callback. All jobs will be queued and, when the last one finishes, the finalization
|
||||
callback will be queued to continue with the task.
|
||||
|
||||
|
@ -300,7 +303,7 @@ double do_my_task(int input) {
|
|||
return 3.0 * input;
|
||||
}
|
||||
|
||||
void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
||||
void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
|
||||
double sum = 0;
|
||||
for (auto& r : results) {
|
||||
try {
|
||||
|
@ -321,7 +324,7 @@ void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
|||
void start_big_task() {
|
||||
size_t num_jobs = 32;
|
||||
|
||||
lokimq::Batch<double /*return type*/> batch;
|
||||
oxenmq::Batch<double /*return type*/> batch;
|
||||
batch.reserve(num_jobs);
|
||||
|
||||
for (size_t i = 0; i < num_jobs; i++)
|
||||
|
@ -329,7 +332,7 @@ void start_big_task() {
|
|||
|
||||
batch.completion(&continue_big_task);
|
||||
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
// ... to be continued in `continue_big_task` after all the jobs finish
|
||||
|
||||
// Can do other things here, but note that continue_big_task could run
|
||||
|
@ -339,12 +342,12 @@ void start_big_task() {
|
|||
|
||||
This code deliberately does not support blocking to wait for the tasks to finish: if you want such a
|
||||
poor design (which is a recipe for deadlocks: imagine jobs that queuing other jobs that can end up
|
||||
exhausting the worker threads with waiting jobs) then you can implement it yourself; LokiMQ isn't
|
||||
exhausting the worker threads with waiting jobs) then you can implement it yourself; OxenMQ isn't
|
||||
going to help you hurt yourself like that.
|
||||
|
||||
### Single-job queuing
|
||||
|
||||
As a shortcut there is a `lmq.job(...)` method that schedules a single task (with no return value)
|
||||
As a shortcut there is a `omq.job(...)` method that schedules a single task (with no return value)
|
||||
in the batch job queue. This is useful when some event requires triggering some other event, but
|
||||
you don't need to wait for or collect its result. (Internally this is just a convenience method
|
||||
around creating a single-job, no-completion Batch job).
|
||||
|
@ -356,7 +359,7 @@ either using your own thread or a periodic timer (see below) to shepherd those o
|
|||
|
||||
## Timers
|
||||
|
||||
LokiMQ supports scheduling periodic tasks via the `add_timer()` function. These timers have an
|
||||
OxenMQ supports scheduling periodic tasks via the `add_timer()` function. These timers have an
|
||||
interval and are scheduled as (single-job) batches when the timer fires. They also support
|
||||
"squelching" (enabled by default) that supresses the job being scheduled if a previously scheduled
|
||||
job is already scheduled or running.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
set(LIBZMQ_PREFIX ${CMAKE_BINARY_DIR}/libzmq)
|
||||
set(ZeroMQ_VERSION 4.3.2)
|
||||
set(ZeroMQ_VERSION 4.3.5)
|
||||
set(LIBZMQ_URL https://github.com/zeromq/libzmq/releases/download/v${ZeroMQ_VERSION}/zeromq-${ZeroMQ_VERSION}.tar.gz)
|
||||
set(LIBZMQ_HASH SHA512=b6251641e884181db9e6b0b705cced7ea4038d404bdae812ff47bdd0eed12510b6af6846b85cb96898e253ccbac71eca7fe588673300ddb9c3109c973250c8e4)
|
||||
set(LIBZMQ_HASH SHA512=a71d48aa977ad8941c1609947d8db2679fc7a951e4cd0c3a1127ae026d883c11bd4203cf315de87f95f5031aec459a731aec34e5ce5b667b8d0559b157952541)
|
||||
|
||||
message(${LIBZMQ_URL})
|
||||
|
||||
|
@ -13,19 +13,30 @@ endif()
|
|||
|
||||
file(MAKE_DIRECTORY ${LIBZMQ_PREFIX}/include)
|
||||
|
||||
set(libzmq_compiler_args)
|
||||
foreach(lang C CXX)
|
||||
foreach(thing COMPILER FLAGS COMPILER_LAUNCHER)
|
||||
if(DEFINED CMAKE_${lang}_${thing})
|
||||
list(APPEND libzmq_compiler_args "-DCMAKE_${lang}_${thing}=${CMAKE_${lang}_${thing}}")
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
include(ExternalProject)
|
||||
include(ProcessorCount)
|
||||
ExternalProject_Add(libzmq_external
|
||||
PREFIX ${LIBZMQ_PREFIX}
|
||||
URL ${LIBZMQ_URL}
|
||||
URL_HASH ${LIBZMQ_HASH}
|
||||
CMAKE_ARGS -DWITH_LIBSODIUM=ON -DZMQ_BUILD_TESTS=OFF -DWITH_PERF_TOOL=OFF -DENABLE_DRAFTS=OFF
|
||||
CMAKE_ARGS ${libzmq_compiler_args}
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DWITH_LIBSODIUM=ON -DZMQ_BUILD_TESTS=OFF -DWITH_PERF_TOOL=OFF -DENABLE_DRAFTS=OFF
|
||||
-DBUILD_SHARED=OFF -DBUILD_STATIC=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX=${LIBZMQ_PREFIX}
|
||||
BUILD_BYPRODUCTS ${LIBZMQ_PREFIX}/lib/libzmq.a
|
||||
BUILD_BYPRODUCTS ${LIBZMQ_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libzmq.a
|
||||
)
|
||||
|
||||
add_library(libzmq_vendor STATIC IMPORTED GLOBAL)
|
||||
add_dependencies(libzmq_vendor libzmq_external)
|
||||
set_target_properties(libzmq_vendor PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES ${LIBZMQ_PREFIX}/include
|
||||
IMPORTED_LOCATION ${LIBZMQ_PREFIX}/lib/libzmq.a)
|
||||
IMPORTED_LOCATION ${LIBZMQ_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libzmq.a)
|
||||
|
|
Binary file not shown.
2
cppzmq
2
cppzmq
|
@ -1 +1 @@
|
|||
Subproject commit 8d5c9a88988dcbebb72939ca0939d432230ffde1
|
||||
Subproject commit 76bf169fd67b8e99c1b0e6490029d9cd5ef97666
|
|
@ -3,11 +3,12 @@ exec_prefix=${prefix}
|
|||
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
|
||||
includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
|
||||
|
||||
Name: liblokimq
|
||||
Description: ZeroMQ-based communication library for Loki
|
||||
Version: @LOKIMQ_VERSION@
|
||||
Name: liboxenmq
|
||||
Description: ZeroMQ-based communication library
|
||||
Version: @PROJECT_VERSION@
|
||||
|
||||
Libs: -L${libdir} -llokimq
|
||||
Libs: -L${libdir} -loxenmq
|
||||
Libs.private: @PRIVATE_LIBS@
|
||||
Requires: liboxenc
|
||||
Requires.private: libzmq libsodium
|
||||
Cflags: -I${includedir}
|
|
@ -1,233 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include "bt_serialize.h"
|
||||
|
||||
namespace lokimq {
|
||||
namespace detail {
|
||||
|
||||
/// Reads digits into an unsigned 64-bit int.
|
||||
uint64_t extract_unsigned(string_view& s) {
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid{"Expected 0-9 but found end of string"};
|
||||
if (s[0] < '0' || s[0] > '9')
|
||||
throw bt_deserialize_invalid("Expected 0-9 but found '"s + s[0]);
|
||||
uint64_t uval = 0;
|
||||
while (!s.empty() && (s[0] >= '0' && s[0] <= '9')) {
|
||||
uint64_t bigger = uval * 10 + (s[0] - '0');
|
||||
s.remove_prefix(1);
|
||||
if (bigger < uval) // overflow
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: value is too large for a 64-bit int");
|
||||
uval = bigger;
|
||||
}
|
||||
return uval;
|
||||
}
|
||||
|
||||
void bt_deserialize<string_view>::operator()(string_view& s, string_view& val) {
|
||||
if (s.size() < 2) throw bt_deserialize_invalid{"Deserialize failed: given data is not an bt-encoded string"};
|
||||
if (s[0] < '0' || s[0] > '9')
|
||||
throw bt_deserialize_invalid_type{"Expected 0-9 but found '"s + s[0] + "'"};
|
||||
auto len = static_cast<size_t>(extract_unsigned(s));
|
||||
if (s.empty() || s[0] != ':')
|
||||
throw bt_deserialize_invalid{"Did not find expected ':' during string deserialization"};
|
||||
s.remove_prefix(1);
|
||||
|
||||
if (len > s.size())
|
||||
throw bt_deserialize_invalid{"String deserialization failed: encoded string length is longer than the serialized data"};
|
||||
|
||||
val = {s.data(), len};
|
||||
s.remove_prefix(len);
|
||||
}
|
||||
|
||||
// Check that we are on a 2's complement architecture. It's highly unlikely that this code ever
|
||||
// runs on a non-2s-complement architecture (especially since C++20 requires a two's complement
|
||||
// signed value behaviour), but check at compile time anyway because we rely on these relations
|
||||
// below.
|
||||
static_assert(std::numeric_limits<int64_t>::min() + std::numeric_limits<int64_t>::max() == -1 &&
|
||||
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()) + uint64_t{1} == (uint64_t{1} << 63),
|
||||
"Non 2s-complement architecture not supported!");
|
||||
|
||||
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s) {
|
||||
// Smallest possible encoded integer is 3 chars: "i0e"
|
||||
if (s.size() < 3) throw bt_deserialize_invalid("Deserialization failed: end of string found where integer expected");
|
||||
if (s[0] != 'i') throw bt_deserialize_invalid_type("Deserialization failed: expected 'i', found '"s + s[0] + '\'');
|
||||
s.remove_prefix(1);
|
||||
std::pair<maybe_signed_int64_t, bool> result;
|
||||
if (s[0] == '-') {
|
||||
result.second = true;
|
||||
s.remove_prefix(1);
|
||||
}
|
||||
|
||||
uint64_t uval = extract_unsigned(s);
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: encountered end of string before integer was finished");
|
||||
if (s[0] != 'e')
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: expected digit or 'e', found '"s + s[0] + '\'');
|
||||
s.remove_prefix(1);
|
||||
if (result.second) { // negative
|
||||
if (uval > (uint64_t{1} << 63))
|
||||
throw bt_deserialize_invalid("Deserialization of integer failed: negative integer value is too large for a 64-bit signed int");
|
||||
result.first.i64 = -uval;
|
||||
} else {
|
||||
result.first.u64 = uval;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template struct bt_deserialize<int64_t>;
|
||||
template struct bt_deserialize<uint64_t>;
|
||||
|
||||
void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where bt-encoded value expected");
|
||||
|
||||
switch (s[0]) {
|
||||
case 'd': {
|
||||
bt_dict dict;
|
||||
bt_deserialize<bt_dict>{}(s, dict);
|
||||
val = std::move(dict);
|
||||
break;
|
||||
}
|
||||
case 'l': {
|
||||
bt_list list;
|
||||
bt_deserialize<bt_list>{}(s, list);
|
||||
val = std::move(list);
|
||||
break;
|
||||
}
|
||||
case 'i': {
|
||||
auto read = bt_deserialize_integer(s);
|
||||
val = read.first.i64; // We only store an i64, but can get a u64 out of it via get<uint64_t>(val)
|
||||
break;
|
||||
}
|
||||
case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': {
|
||||
std::string str;
|
||||
bt_deserialize<std::string>{}(s, str);
|
||||
val = std::move(str);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw bt_deserialize_invalid("Deserialize failed: encountered invalid value '"s + s[0] + "'; expected one of [0-9idl]");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
||||
bt_list_consumer::bt_list_consumer(string_view data_) : data{std::move(data_)} {
|
||||
if (data.empty()) throw std::runtime_error{"Cannot create a bt_list_consumer with an empty string_view"};
|
||||
if (data[0] != 'l') throw std::runtime_error{"Cannot create a bt_list_consumer with non-list data"};
|
||||
data.remove_prefix(1);
|
||||
}
|
||||
|
||||
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
|
||||
/// value is not a string.
|
||||
string_view bt_list_consumer::consume_string_view() {
|
||||
if (data.empty())
|
||||
throw bt_deserialize_invalid{"expected a string, but reached end of data"};
|
||||
else if (!is_string())
|
||||
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
|
||||
string_view next{data}, result;
|
||||
detail::bt_deserialize<string_view>{}(next, result);
|
||||
data = next;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string bt_list_consumer::consume_string() {
|
||||
return std::string{consume_string_view()};
|
||||
}
|
||||
|
||||
/// Consumes a value without returning it.
|
||||
void bt_list_consumer::skip_value() {
|
||||
if (is_string())
|
||||
consume_string_view();
|
||||
else if (is_integer())
|
||||
detail::bt_deserialize_integer(data);
|
||||
else if (is_list())
|
||||
consume_list_data();
|
||||
else if (is_dict())
|
||||
consume_dict_data();
|
||||
else
|
||||
throw bt_deserialize_invalid_type{"next bt value has unknown type"};
|
||||
}
|
||||
|
||||
string_view bt_list_consumer::consume_list_data() {
|
||||
auto start = data.begin();
|
||||
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
data.remove_prefix(1); // Descend into the sublist, consume the "l"
|
||||
while (!is_finished()) {
|
||||
skip_value();
|
||||
if (data.empty())
|
||||
throw bt_deserialize_invalid{"bt list consumption failed: hit the end of string before the list was done"};
|
||||
}
|
||||
data.remove_prefix(1); // Back out from the sublist, consume the "e"
|
||||
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
|
||||
}
|
||||
|
||||
string_view bt_list_consumer::consume_dict_data() {
|
||||
auto start = data.begin();
|
||||
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
data.remove_prefix(1); // Descent into the dict, consumer the "d"
|
||||
while (!is_finished()) {
|
||||
consume_string_view(); // Key is always a string
|
||||
if (!data.empty())
|
||||
skip_value();
|
||||
if (data.empty())
|
||||
throw bt_deserialize_invalid{"bt dict consumption failed: hit the end of string before the dict was done"};
|
||||
}
|
||||
data.remove_prefix(1); // Back out of the dict, consume the "e"
|
||||
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
|
||||
}
|
||||
|
||||
bt_dict_consumer::bt_dict_consumer(string_view data_) {
|
||||
data = std::move(data_);
|
||||
if (data.empty()) throw std::runtime_error{"Cannot create a bt_dict_consumer with an empty string_view"};
|
||||
if (data.size() < 2 || data[0] != 'd') throw std::runtime_error{"Cannot create a bt_dict_consumer with non-dict data"};
|
||||
data.remove_prefix(1);
|
||||
}
|
||||
|
||||
bool bt_dict_consumer::consume_key() {
|
||||
if (key_.data())
|
||||
return true;
|
||||
if (data.empty()) throw bt_deserialize_invalid_type{"expected a key or dict end, found end of string"};
|
||||
if (data[0] == 'e') return false;
|
||||
key_ = bt_list_consumer::consume_string_view();
|
||||
if (data.empty() || data[0] == 'e')
|
||||
throw bt_deserialize_invalid{"dict key isn't followed by a value"};
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<string_view, string_view> bt_dict_consumer::next_string() {
|
||||
if (!is_string())
|
||||
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
|
||||
std::pair<string_view, string_view> ret;
|
||||
ret.second = bt_list_consumer::consume_string_view();
|
||||
ret.first = flush_key();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
} // namespace lokimq
|
|
@ -1,829 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <cstring>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include "string_view.h"
|
||||
#include "mapbox/variant.hpp"
|
||||
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
/** \file
|
||||
* LokiMQ serialization for internal commands is very simple: we support two primitive types,
|
||||
* strings and integers, and two container types, lists and dicts with string keys. On the wire
|
||||
* these go in BitTorrent byte encoding as described in BEP-0003
|
||||
* (https://www.bittorrent.org/beps/bep_0003.html#bencoding).
|
||||
*
|
||||
* On the C++ side, on input we allow strings, integral types, STL-like containers of these types,
|
||||
* and STL-like containers of pairs with a string first value and any of these types as second
|
||||
* value. We also accept std::variants (if compiled with std::variant support, i.e. in C++17 mode)
|
||||
* that contain any of these, and mapbox::util::variants (the internal type used for its recursive
|
||||
* support).
|
||||
*
|
||||
* One minor deviation from BEP-0003 is that we don't support serializing values that don't fit in a
|
||||
* 64-bit integer (BEP-0003 specifies arbitrary precision integers).
|
||||
*
|
||||
* On deserialization we can either deserialize into a mapbox::util::variant that supports everything, or
|
||||
* we can fill a container of your given type (though this fails if the container isn't compatible
|
||||
* with the deserialized data).
|
||||
*/
|
||||
|
||||
/// Exception throw if deserialization fails
|
||||
class bt_deserialize_invalid : public std::invalid_argument {
|
||||
using std::invalid_argument::invalid_argument;
|
||||
};
|
||||
|
||||
/// A more specific subclass that is thown if the serialization type is an initial mismatch: for
|
||||
/// example, trying deserializing an int but the next thing in input is a list. This is not,
|
||||
/// however, thrown if the type initially looks fine but, say, a nested serialization fails. This
|
||||
/// error will only be thrown when the input stream has not been advanced (and so can be tried for a
|
||||
/// different type).
|
||||
class bt_deserialize_invalid_type : public bt_deserialize_invalid {
|
||||
using bt_deserialize_invalid::bt_deserialize_invalid;
|
||||
};
|
||||
|
||||
class bt_list;
|
||||
class bt_dict;
|
||||
|
||||
/// Recursive generic type that can fully represent everything valid for a BT serialization.
|
||||
using bt_value = mapbox::util::variant<
|
||||
std::string,
|
||||
string_view,
|
||||
int64_t,
|
||||
mapbox::util::recursive_wrapper<bt_list>,
|
||||
mapbox::util::recursive_wrapper<bt_dict>
|
||||
>;
|
||||
|
||||
/// Very thin wrapper around a std::list<bt_value> that holds a list of generic values (though *any*
|
||||
/// compatible data type can be used).
|
||||
class bt_list : public std::list<bt_value> {
|
||||
using std::list<bt_value>::list;
|
||||
};
|
||||
/// Very thin wrapper around a std::unordered_map<bt_value> that holds a list of string -> generic
|
||||
/// value pairs (though *any* compatible data type can be used).
|
||||
class bt_dict : public std::unordered_map<std::string, bt_value> {
|
||||
using std::unordered_map<std::string, bt_value>::unordered_map;
|
||||
};
|
||||
|
||||
#ifdef __cpp_lib_void_t
|
||||
using std::void_t;
|
||||
#else
|
||||
/// C++17 void_t backport
|
||||
template <typename... Ts> struct void_t_impl { using type = void; };
|
||||
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Reads digits into an unsigned 64-bit int.
|
||||
uint64_t extract_unsigned(string_view& s);
|
||||
inline uint64_t extract_unsigned(string_view&& s) { return extract_unsigned(s); }
|
||||
|
||||
// Fallback base case; we only get here if none of the partial specializations below work
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct bt_serialize { static_assert(!std::is_same<T, T>::value, "Cannot serialize T: unsupported type for bt serialization"); };
|
||||
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct bt_deserialize { static_assert(!std::is_same<T, T>::value, "Cannot deserialize T: unsupported type for bt deserialization"); };
|
||||
|
||||
/// Checks that we aren't at the end of a string view and throws if we are.
|
||||
inline void bt_need_more(const string_view &s) {
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid{"Unexpected end of string while deserializing"};
|
||||
}
|
||||
|
||||
union maybe_signed_int64_t { int64_t i64; uint64_t u64; };
|
||||
|
||||
/// Deserializes a signed or unsigned 64-bit integer from a string. Sets the second bool to true
|
||||
/// iff the value is int64_t because a negative value was read. Throws an exception if the read
|
||||
/// value doesn't fit in a int64_t (if negative) or a uint64_t (if positive). Removes consumed
|
||||
/// characters from the string_view.
|
||||
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s);
|
||||
|
||||
/// Integer specializations
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<std::is_integral<T>::value>> {
|
||||
static_assert(sizeof(T) <= sizeof(uint64_t), "Serialization of integers larger than uint64_t is not supported");
|
||||
void operator()(std::ostream &os, const T &val) {
|
||||
// Cast 1-byte types to a larger type to avoid iostream interpreting them as single characters
|
||||
using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t<std::is_signed<T>::value, int, unsigned>>;
|
||||
os << 'i' << static_cast<output_type>(val) << 'e';
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<std::is_integral<T>::value>> {
|
||||
void operator()(string_view& s, T &val) {
|
||||
constexpr uint64_t umax = static_cast<uint64_t>(std::numeric_limits<T>::max());
|
||||
constexpr int64_t smin = static_cast<int64_t>(std::numeric_limits<T>::min()),
|
||||
smax = static_cast<int64_t>(std::numeric_limits<T>::max());
|
||||
|
||||
auto read = bt_deserialize_integer(s);
|
||||
if (std::is_signed<T>::value) {
|
||||
if (!read.second) { // read a positive value
|
||||
if (read.first.u64 > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(read.first.u64);
|
||||
} else {
|
||||
bool oob = read.first.i64 < smin || read.first.i64 > smax;
|
||||
if (sizeof(T) < sizeof(int64_t) && oob)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found out-of-range value " + std::to_string(read.first.i64) + " not in [" + std::to_string(smin) + "," + std::to_string(smax) + "]");
|
||||
val = static_cast<T>(read.first.i64);
|
||||
}
|
||||
} else {
|
||||
if (read.second)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found negative value " + std::to_string(read.first.i64) + " but type is unsigned");
|
||||
if (sizeof(T) < sizeof(uint64_t) && read.first.u64 > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(read.first.u64);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
extern template struct bt_deserialize<int64_t>;
|
||||
extern template struct bt_deserialize<uint64_t>;
|
||||
|
||||
template <>
|
||||
struct bt_serialize<string_view> {
|
||||
void operator()(std::ostream &os, const string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); }
|
||||
};
|
||||
template <>
|
||||
struct bt_deserialize<string_view> {
|
||||
void operator()(string_view& s, string_view& val);
|
||||
};
|
||||
|
||||
/// String specialization
|
||||
template <>
|
||||
struct bt_serialize<std::string> {
|
||||
void operator()(std::ostream &os, const std::string &val) { bt_serialize<string_view>{}(os, val); }
|
||||
};
|
||||
template <>
|
||||
struct bt_deserialize<std::string> {
|
||||
void operator()(string_view& s, std::string& val) { string_view view; bt_deserialize<string_view>{}(s, view); val = {view.data(), view.size()}; }
|
||||
};
|
||||
|
||||
/// char * and string literals -- we allow serialization for convenience, but not deserialization
|
||||
template <>
|
||||
struct bt_serialize<char *> {
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, std::strlen(str)}); }
|
||||
};
|
||||
template <size_t N>
|
||||
struct bt_serialize<char[N]> {
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, N-1}); }
|
||||
};
|
||||
|
||||
/// Partial dict validity; we don't check the second type for serializability, that will be handled
|
||||
/// via the base case static_assert if invalid.
|
||||
template <typename T, typename = void> struct is_bt_input_dict_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_input_dict_container<T, std::enable_if_t<
|
||||
std::is_same<std::string, std::remove_cv_t<typename T::value_type::first_type>>::value,
|
||||
void_t<typename T::const_iterator /* is const iterable */,
|
||||
typename T::value_type::second_type /* has a second type */>>>
|
||||
: std::true_type {};
|
||||
|
||||
/// Determines whether the type looks like something we can insert into (using `v.insert(v.end(), x)`)
|
||||
template <typename T, typename = void> struct is_bt_insertable : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_insertable<T,
|
||||
void_t<decltype(std::declval<T>().insert(std::declval<T>().end(), std::declval<typename T::value_type>()))>>
|
||||
: std::true_type {};
|
||||
|
||||
/// Determines whether the given type looks like a compatible map (i.e. has std::string keys) that
|
||||
/// we can insert into.
|
||||
template <typename T, typename = void> struct is_bt_output_dict_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_output_dict_container<T, std::enable_if_t<
|
||||
std::is_same<std::string, std::remove_cv_t<typename T::key_type>>::value &&
|
||||
is_bt_insertable<T>::value,
|
||||
void_t<typename T::value_type::second_type /* has a second type */>>>
|
||||
: std::true_type {};
|
||||
|
||||
|
||||
/// Specialization for a dict-like container (such as an unordered_map). We accept anything for a
|
||||
/// dict that is const iterable over something that looks like a pair with std::string for first
|
||||
/// value type. The value (i.e. second element of the pair) also must be serializable.
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>::value>> {
|
||||
using second_type = typename T::value_type::second_type;
|
||||
using ref_pair = std::reference_wrapper<const typename T::value_type>;
|
||||
void operator()(std::ostream &os, const T &dict) {
|
||||
os << 'd';
|
||||
std::vector<ref_pair> pairs;
|
||||
pairs.reserve(dict.size());
|
||||
for (const auto &pair : dict)
|
||||
pairs.emplace(pairs.end(), pair);
|
||||
std::sort(pairs.begin(), pairs.end(), [](ref_pair a, ref_pair b) { return a.get().first < b.get().first; });
|
||||
for (auto &ref : pairs) {
|
||||
bt_serialize<std::string>{}(os, ref.get().first);
|
||||
bt_serialize<second_type>{}(os, ref.get().second);
|
||||
}
|
||||
os << 'e';
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>::value>> {
|
||||
using second_type = typename T::value_type::second_type;
|
||||
void operator()(string_view& s, T& dict) {
|
||||
// Smallest dict is 2 bytes "de", for an empty dict.
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where dict expected");
|
||||
if (s[0] != 'd') throw bt_deserialize_invalid_type("Deserialization failed: expected 'd', found '"s + s[0] + "'"s);
|
||||
s.remove_prefix(1);
|
||||
dict.clear();
|
||||
bt_deserialize<std::string> key_deserializer;
|
||||
bt_deserialize<second_type> val_deserializer;
|
||||
|
||||
while (!s.empty() && s[0] != 'e') {
|
||||
std::string key;
|
||||
second_type val;
|
||||
key_deserializer(s, key);
|
||||
val_deserializer(s, val);
|
||||
dict.insert(dict.end(), typename T::value_type{std::move(key), std::move(val)});
|
||||
}
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before dict was finished");
|
||||
s.remove_prefix(1); // Consume the 'e'
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Accept anything that looks iterable; value serialization validity isn't checked here (it fails
|
||||
/// via the base case static assert).
|
||||
template <typename T, typename = void> struct is_bt_input_list_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_input_list_container<T, std::enable_if_t<
|
||||
!std::is_same<T, std::string>::value &&
|
||||
!is_bt_input_dict_container<T>::value,
|
||||
void_t<typename T::const_iterator, typename T::value_type>>>
|
||||
: std::true_type {};
|
||||
|
||||
template <typename T, typename = void> struct is_bt_output_list_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_output_list_container<T, std::enable_if_t<
|
||||
!std::is_same<T, std::string>::value &&
|
||||
!is_bt_output_dict_container<T>::value &&
|
||||
is_bt_insertable<T>::value>>
|
||||
: std::true_type {};
|
||||
|
||||
|
||||
/// List specialization
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>::value>> {
|
||||
void operator()(std::ostream& os, const T& list) {
|
||||
os << 'l';
|
||||
for (const auto &v : list)
|
||||
bt_serialize<std::remove_cv_t<typename T::value_type>>{}(os, v);
|
||||
os << 'e';
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>::value>> {
|
||||
using value_type = typename T::value_type;
|
||||
void operator()(string_view& s, T& list) {
|
||||
// Smallest list is 2 bytes "le", for an empty list.
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where list expected");
|
||||
if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization failed: expected 'l', found '"s + s[0] + "'"s);
|
||||
s.remove_prefix(1);
|
||||
list.clear();
|
||||
bt_deserialize<value_type> deserializer;
|
||||
while (!s.empty() && s[0] != 'e') {
|
||||
value_type v;
|
||||
deserializer(s, v);
|
||||
list.insert(list.end(), std::move(v));
|
||||
}
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before list was finished");
|
||||
s.remove_prefix(1); // Consume the 'e'
|
||||
}
|
||||
};
|
||||
|
||||
/// variant visitor; serializes whatever is contained
|
||||
class bt_serialize_visitor {
|
||||
std::ostream &os;
|
||||
public:
|
||||
using result_type = void;
|
||||
bt_serialize_visitor(std::ostream &os) : os{os} {}
|
||||
template <typename T> void operator()(const T &val) const {
|
||||
bt_serialize<T>{}(os, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using is_bt_deserializable = std::integral_constant<bool,
|
||||
std::is_same<T, std::string>::value || std::is_integral<T>::value ||
|
||||
is_bt_output_dict_container<T>::value || is_bt_output_list_container<T>::value>;
|
||||
|
||||
// General template and base case; this base will only actually be invoked when Ts... is empty,
|
||||
// which means we reached the end without finding any variant type capable of holding the value.
|
||||
template <typename SFINAE, typename Variant, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl {
|
||||
void operator()(string_view&, Variant&) {
|
||||
throw bt_deserialize_invalid("Deserialization failed: could not deserialize value into any variant type");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts, typename Variant>
|
||||
void bt_deserialize_try_variant(string_view& s, Variant& variant) {
|
||||
bt_deserialize_try_variant_impl<void, Variant, Ts...>{}(s, variant);
|
||||
}
|
||||
|
||||
|
||||
template <typename Variant, typename T, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>::value>, Variant, T, Ts...> {
|
||||
void operator()(string_view& s, Variant& variant) {
|
||||
if ( is_bt_output_list_container<T>::value ? s[0] == 'l' :
|
||||
is_bt_output_dict_container<T>::value ? s[0] == 'd' :
|
||||
std::is_integral<T>::value ? s[0] == 'i' :
|
||||
std::is_same<T, std::string>::value ? s[0] >= '0' && s[0] <= '9' :
|
||||
false) {
|
||||
T val;
|
||||
bt_deserialize<T>{}(s, val);
|
||||
variant = std::move(val);
|
||||
} else {
|
||||
bt_deserialize_try_variant<Ts...>(s, variant);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Variant, typename T, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<!is_bt_deserializable<T>::value>, Variant, T, Ts...> {
|
||||
void operator()(string_view& s, Variant& variant) {
|
||||
// Unsupported deserialization type, skip it
|
||||
bt_deserialize_try_variant<Ts...>(s, variant);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct bt_deserialize<bt_value, void> {
|
||||
void operator()(string_view& s, bt_value& val);
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct bt_serialize<mapbox::util::variant<Ts...>> {
|
||||
void operator()(std::ostream& os, const mapbox::util::variant<Ts...>& val) {
|
||||
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct bt_deserialize<mapbox::util::variant<Ts...>> {
|
||||
void operator()(string_view& s, mapbox::util::variant<Ts...>& val) {
|
||||
bt_deserialize_try_variant<Ts...>(s, val);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#ifdef __cpp_lib_variant
|
||||
/// C++17 std::variant support
|
||||
template <typename... Ts>
|
||||
struct bt_serialize<std::variant<Ts...>> {
|
||||
void operator()(std::ostream &os, const std::variant<Ts...>& val) {
|
||||
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct bt_deserialize<std::variant<Ts...>> {
|
||||
void operator()(string_view& s, std::variant<Ts...>& val) {
|
||||
bt_deserialize_try_variant<Ts...>(s, val);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct bt_stream_serializer {
|
||||
const T &val;
|
||||
explicit bt_stream_serializer(const T &val) : val{val} {}
|
||||
operator std::string() const {
|
||||
std::ostringstream oss;
|
||||
oss << *this;
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
std::ostream &operator<<(std::ostream &os, const bt_stream_serializer<T> &s) {
|
||||
bt_serialize<T>{}(os, s.val);
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
||||
/// Returns a wrapper around a value reference that can serialize the value directly to an output
|
||||
/// stream. This class is intended to be used inline (i.e. without being stored) as in:
|
||||
///
|
||||
/// std::list<int> my_list{{1,2,3}};
|
||||
/// std::cout << bt_serializer(my_list);
|
||||
///
|
||||
/// While it is possible to store the returned object and use it, such as:
|
||||
///
|
||||
/// auto encoded = bt_serializer(42);
|
||||
/// std::cout << encoded;
|
||||
///
|
||||
/// this approach is not generally recommended: the returned object stores a reference to the
|
||||
/// passed-in type, which may not survive. If doing this note that it is the caller's
|
||||
/// responsibility to ensure the serializer is not used past the end of the lifetime of the value
|
||||
/// being serialized.
|
||||
///
|
||||
/// Also note that serializing directly to an output stream is more efficient as no intermediate
|
||||
/// string containing the entire serialization has to be constructed.
|
||||
///
|
||||
template <typename T>
|
||||
detail::bt_stream_serializer<T> bt_serializer(const T &val) { return detail::bt_stream_serializer<T>{val}; }
|
||||
|
||||
/// Serializes the given value into a std::string.
|
||||
///
|
||||
/// int number = 42;
|
||||
/// std::string encoded = bt_serialize(number);
|
||||
/// // Equivalent:
|
||||
/// //auto encoded = (std::string) bt_serialize(number);
|
||||
///
|
||||
/// This takes any serializable type: integral types, strings, lists of serializable types, and
|
||||
/// string->value maps of serializable types.
|
||||
template <typename T>
|
||||
std::string bt_serialize(const T &val) { return bt_serializer(val); }
|
||||
|
||||
/// Deserializes the given string view directly into `val`. Usage:
|
||||
///
|
||||
/// std::string encoded = "i42e";
|
||||
/// int value;
|
||||
/// bt_deserialize(encoded, value); // Sets value to 42
|
||||
///
|
||||
template <typename T, std::enable_if_t<!std::is_const<T>::value, int> = 0>
|
||||
void bt_deserialize(string_view s, T& val) {
|
||||
return detail::bt_deserialize<T>{}(s, val);
|
||||
}
|
||||
|
||||
|
||||
/// Deserializes the given string_view into a `T`, which is returned.
|
||||
///
|
||||
/// std::string encoded = "li1ei2ei3ee"; // bt-encoded list of ints: [1,2,3]
|
||||
/// auto mylist = bt_deserialize<std::list<int>>(encoded);
|
||||
///
|
||||
template <typename T>
|
||||
T bt_deserialize(string_view s) {
|
||||
T val;
|
||||
bt_deserialize(s, val);
|
||||
return val;
|
||||
}
|
||||
|
||||
/// Deserializes the given value into a generic `bt_value` type (mapbox::util::variant) which is capable
|
||||
/// of holding all possible BT-encoded values (including recursion).
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// std::string encoded = "i42e";
|
||||
/// auto val = bt_get(encoded);
|
||||
/// int v = get_int<int>(val); // fails unless the encoded value was actually an integer that
|
||||
/// // fits into an `int`
|
||||
///
|
||||
inline bt_value bt_get(string_view s) {
|
||||
return bt_deserialize<bt_value>(s);
|
||||
}
|
||||
|
||||
/// Helper functions to extract a value of some integral type from a bt_value which contains an
|
||||
/// integer. Does range checking, throwing std::overflow_error if the stored value is outside the
|
||||
/// range of the target type.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// std::string encoded = "i123456789e";
|
||||
/// auto val = bt_get(encoded);
|
||||
/// auto v = get_int<uint32_t>(val); // throws if the decoded value doesn't fit in a uint32_t
|
||||
template <typename IntType, std::enable_if_t<std::is_integral<IntType>::value, int> = 0>
|
||||
IntType get_int(const bt_value &v) {
|
||||
// It's highly unlikely that this code ever runs on a non-2s-complement architecture, but check
|
||||
// at compile time if converting to a uint64_t (because while int64_t -> uint64_t is
|
||||
// well-defined, uint64_t -> int64_t only does the right thing under 2's complement).
|
||||
static_assert(!std::is_unsigned<IntType>::value || sizeof(IntType) != sizeof(int64_t) || -1 == ~0,
|
||||
"Non 2s-complement architecture not supported!");
|
||||
int64_t value = mapbox::util::get<int64_t>(v);
|
||||
if (sizeof(IntType) < sizeof(int64_t)) {
|
||||
if (value > static_cast<int64_t>(std::numeric_limits<IntType>::max())
|
||||
|| value < static_cast<int64_t>(std::numeric_limits<IntType>::min()))
|
||||
throw std::overflow_error("Unable to extract integer value: stored value is outside the range of the requested type");
|
||||
}
|
||||
return static_cast<IntType>(value);
|
||||
}
|
||||
|
||||
/// Class that allows you to walk through a bt-encoded list in memory without copying or allocating
|
||||
/// memory. It accesses existing memory directly and so the caller must ensure that the referenced
|
||||
/// memory stays valid for the lifetime of the bt_list_consumer object.
|
||||
class bt_list_consumer {
|
||||
protected:
|
||||
string_view data;
|
||||
bt_list_consumer() = default;
|
||||
public:
|
||||
bt_list_consumer(string_view data_);
|
||||
|
||||
/// Copy constructor. Making a copy copies the current position so can be used for multipass
|
||||
/// iteration through a list.
|
||||
bt_list_consumer(const bt_list_consumer&) = default;
|
||||
bt_list_consumer& operator=(const bt_list_consumer&) = default;
|
||||
|
||||
/// Returns true if the next value indicates the end of the list
|
||||
bool is_finished() const { return data.front() == 'e'; }
|
||||
/// Returns true if the next element looks like an encoded string
|
||||
bool is_string() const { return data.front() >= '0' && data.front() <= '9'; }
|
||||
/// Returns true if the next element looks like an encoded integer
|
||||
bool is_integer() const { return data.front() == 'i'; }
|
||||
/// Returns true if the next element looks like an encoded list
|
||||
bool is_list() const { return data.front() == 'l'; }
|
||||
/// Returns true if the next element looks like an encoded dict
|
||||
bool is_dict() const { return data.front() == 'd'; }
|
||||
|
||||
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
|
||||
/// value is not a string.
|
||||
std::string consume_string();
|
||||
string_view consume_string_view();
|
||||
|
||||
/// Attempts to parse the next value as an integer (and advance just past it). Throws if the
|
||||
/// next value is not an integer.
|
||||
template <typename IntType>
|
||||
IntType consume_integer() {
|
||||
if (!is_integer()) throw bt_deserialize_invalid_type{"next value is not an integer"};
|
||||
string_view next{data};
|
||||
IntType ret;
|
||||
detail::bt_deserialize<IntType>{}(next, ret);
|
||||
data = next;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Consumes a list, return it as a list-like type. This typically requires dynamic allocation,
|
||||
/// but only has to parse the data once. Compare with consume_list_data() which allows
|
||||
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
|
||||
template <typename T = bt_list>
|
||||
T consume_list() {
|
||||
T list;
|
||||
consume_list(list);
|
||||
return list;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing list-like data type.
|
||||
template <typename T>
|
||||
void consume_list(T& list) {
|
||||
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
string_view n{data};
|
||||
detail::bt_deserialize<T>{}(n, list);
|
||||
data = n;
|
||||
}
|
||||
|
||||
/// Consumes a dict, return it as a dict-like type. This typically requires dynamic allocation,
|
||||
/// but only has to parse the data once. Compare with consume_dict_data() which allows
|
||||
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
|
||||
template <typename T = bt_dict>
|
||||
T consume_dict() {
|
||||
T dict;
|
||||
consume_dict(dict);
|
||||
return dict;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing dict-like data type.
|
||||
template <typename T>
|
||||
void consume_dict(T& dict) {
|
||||
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
string_view n{data};
|
||||
detail::bt_deserialize<T>{}(n, dict);
|
||||
data = n;
|
||||
}
|
||||
|
||||
/// Consumes a value without returning it.
|
||||
void skip_value();
|
||||
|
||||
/// Attempts to parse the next value as a list and returns the string_view that contains the
|
||||
/// entire thing. This is recursive into both lists and dicts and likely to be quite
|
||||
/// inefficient for large, nested structures (unless the values only need to be skipped but
|
||||
/// aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
string_view consume_list_data();
|
||||
|
||||
/// Attempts to parse the next value as a dict and returns the string_view that contains the
|
||||
/// entire thing. This is recursive into both lists and dicts and likely to be quite
|
||||
/// inefficient for large, nested structures (unless the values only need to be skipped but
|
||||
/// aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
string_view consume_dict_data();
|
||||
};
|
||||
|
||||
|
||||
/// Class that allows you to walk through key-value pairs of a bt-encoded dict in memory without
|
||||
/// copying or allocating memory. It accesses existing memory directly and so the caller must
|
||||
/// ensure that the referenced memory stays valid for the lifetime of the bt_dict_consumer object.
|
||||
class bt_dict_consumer : private bt_list_consumer {
|
||||
string_view key_;
|
||||
|
||||
/// Consume the key if not already consumed and there is a key present (rather than 'e').
|
||||
/// Throws exception if what should be a key isn't a string, or if the key consumes the entire
|
||||
/// data (i.e. requires that it be followed by something). Returns true if the key was consumed
|
||||
/// (either now or previously and cached).
|
||||
bool consume_key();
|
||||
|
||||
/// Clears the cached key and returns it. Must have already called consume_key directly or
|
||||
/// indirectly via one of the `is_{...}` methods.
|
||||
string_view flush_key() {
|
||||
string_view k;
|
||||
k.swap(key_);
|
||||
return k;
|
||||
}
|
||||
|
||||
public:
|
||||
bt_dict_consumer(string_view data_);
|
||||
|
||||
/// Copy constructor. Making a copy copies the current position so can be used for multipass
|
||||
/// iteration through a list.
|
||||
bt_dict_consumer(const bt_dict_consumer&) = default;
|
||||
bt_dict_consumer& operator=(const bt_dict_consumer&) = default;
|
||||
|
||||
/// Returns true if the next value indicates the end of the dict
|
||||
bool is_finished() { return !consume_key() && data.front() == 'e'; }
|
||||
/// Operator bool is an alias for `!is_finished()`
|
||||
operator bool() { return !is_finished(); }
|
||||
/// Returns true if the next value looks like an encoded string
|
||||
bool is_string() { return consume_key() && data.front() >= '0' && data.front() <= '9'; }
|
||||
/// Returns true if the next element looks like an encoded integer
|
||||
bool is_integer() { return consume_key() && data.front() == 'i'; }
|
||||
/// Returns true if the next element looks like an encoded list
|
||||
bool is_list() { return consume_key() && data.front() == 'l'; }
|
||||
/// Returns true if the next element looks like an encoded dict
|
||||
bool is_dict() { return consume_key() && data.front() == 'd'; }
|
||||
/// Returns the key of the next pair. This does not have to be called; it is also returned by
|
||||
/// all of the other consume_* methods. The value is cached whether called here or by some
|
||||
/// other method; accessing it multiple times simple accesses the cache until the next value is
|
||||
/// consumed.
|
||||
string_view key() {
|
||||
if (!consume_key())
|
||||
throw bt_deserialize_invalid{"Cannot access next key: at the end of the dict"};
|
||||
return key_;
|
||||
}
|
||||
|
||||
/// Attempt to parse the next value as a string->string pair (and advance just past it). Throws
|
||||
/// if the next value is not a string.
|
||||
std::pair<string_view, string_view> next_string();
|
||||
|
||||
/// Attempts to parse the next value as an string->integer pair (and advance just past it).
|
||||
/// Throws if the next value is not an integer.
|
||||
template <typename IntType>
|
||||
std::pair<string_view, IntType> next_integer() {
|
||||
if (!is_integer()) throw bt_deserialize_invalid_type{"next bt dict value is not an integer"};
|
||||
std::pair<string_view, IntType> ret;
|
||||
ret.second = bt_list_consumer::consume_integer<IntType>();
|
||||
ret.first = flush_key();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Consumes a string->list pair, return it as a list-like type. This typically requires
|
||||
/// dynamic allocation, but only has to parse the data once. Compare with consume_list_data()
|
||||
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
|
||||
/// used).
|
||||
template <typename T = bt_list>
|
||||
std::pair<string_view, T> next_list() {
|
||||
std::pair<string_view, T> pair;
|
||||
pair.first = consume_list(pair.second);
|
||||
return pair;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing list-like data type. Returns the key.
|
||||
template <typename T>
|
||||
string_view next_list(T& list) {
|
||||
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
bt_list_consumer::consume_list(list);
|
||||
return flush_key();
|
||||
}
|
||||
|
||||
/// Consumes a string->dict pair, return it as a dict-like type. This typically requires
|
||||
/// dynamic allocation, but only has to parse the data once. Compare with consume_dict_data()
|
||||
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
|
||||
/// used).
|
||||
template <typename T = bt_dict>
|
||||
std::pair<string_view, T> next_dict() {
|
||||
std::pair<string_view, T> pair;
|
||||
pair.first = consume_dict(pair.second);
|
||||
return pair;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing dict-like data type. Returns the key.
|
||||
template <typename T>
|
||||
string_view next_dict(T& dict) {
|
||||
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
bt_list_consumer::consume_dict(dict);
|
||||
return flush_key();
|
||||
}
|
||||
|
||||
/// Attempts to parse the next value as a string->list pair and returns the string_view that
|
||||
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
|
||||
/// quite inefficient for large, nested structures (unless the values only need to be skipped
|
||||
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
std::pair<string_view, string_view> next_list_data() {
|
||||
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt dict value is not a list"};
|
||||
return {flush_key(), bt_list_consumer::consume_list_data()};
|
||||
}
|
||||
|
||||
/// Same as next_list_data(), but wraps the value in a bt_list_consumer for convenience
|
||||
std::pair<string_view, bt_list_consumer> next_list_consumer() { return next_list_data(); }
|
||||
|
||||
/// Attempts to parse the next value as a string->dict pair and returns the string_view that
|
||||
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
|
||||
/// quite inefficient for large, nested structures (unless the values only need to be skipped
|
||||
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
std::pair<string_view, string_view> next_dict_data() {
|
||||
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt dict value is not a dict"};
|
||||
return {flush_key(), bt_list_consumer::consume_dict_data()};
|
||||
}
|
||||
|
||||
/// Same as next_dict_data(), but wraps the value in a bt_dict_consumer for convenience
|
||||
std::pair<string_view, bt_dict_consumer> next_dict_consumer() { return next_dict_data(); }
|
||||
|
||||
/// Skips ahead until we find the first key >= the given key or reach the end of the dict.
|
||||
/// Returns true if we found an exact match, false if we reached some greater value or the end.
|
||||
/// If we didn't hit the end, the next `consumer_*()` call will return the key-value pair we
|
||||
/// found (either the exact match or the first key greater than the requested key).
|
||||
///
|
||||
/// Two important notes:
|
||||
///
|
||||
/// - properly encoded bt dicts must have lexicographically sorted keys, and this method assumes
|
||||
/// that the input is correctly sorted (and thus if we find a greater value then your key does
|
||||
/// not exist).
|
||||
/// - this is irreversible; you cannot returned to skipped values without reparsing. (You *can*
|
||||
/// however, make a copy of the bt_dict_consumer before calling and use the copy to return to
|
||||
/// the pre-skipped position).
|
||||
bool skip_until(string_view find) {
|
||||
while (consume_key() && key_ < find) {
|
||||
flush_key();
|
||||
skip_value();
|
||||
}
|
||||
return key_ == find;
|
||||
}
|
||||
|
||||
/// The `consume_*` functions are wrappers around next_whatever that discard the returned key.
|
||||
///
|
||||
/// Intended for use with skip_until such as:
|
||||
///
|
||||
/// std::string value;
|
||||
/// if (d.skip_until("key"))
|
||||
/// value = d.consume_string();
|
||||
///
|
||||
|
||||
auto consume_string_view() { return next_string().second; }
|
||||
auto consume_string() { return std::string{consume_string_view()}; }
|
||||
|
||||
template <typename IntType>
|
||||
auto consume_integer() { return next_integer<IntType>().second; }
|
||||
|
||||
template <typename T = bt_list>
|
||||
auto consume_list() { return next_list<T>().second; }
|
||||
|
||||
template <typename T>
|
||||
void consume_list(T& list) { next_list(list); }
|
||||
|
||||
template <typename T = bt_dict>
|
||||
auto consume_dict() { return next_dict<T>().second; }
|
||||
|
||||
template <typename T>
|
||||
void consume_dict(T& dict) { next_dict(dict); }
|
||||
|
||||
string_view consume_list_data() { return next_list_data().second; }
|
||||
string_view consume_dict_data() { return next_dict_data().second; }
|
||||
|
||||
bt_list_consumer consume_list_consumer() { return consume_list_data(); }
|
||||
bt_dict_consumer consume_dict_consumer() { return consume_dict_data(); }
|
||||
};
|
||||
|
||||
|
||||
} // namespace lokimq
|
137
lokimq/hex.h
137
lokimq/hex.h
|
@ -1,137 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include "string_view.h"
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
#include <cassert>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Compile-time generated lookup tables hex conversion
|
||||
struct hex_table {
|
||||
char from_hex_lut[256];
|
||||
char to_hex_lut[16];
|
||||
constexpr hex_table() noexcept : from_hex_lut{}, to_hex_lut{} {
|
||||
for (unsigned char c = 0; c < 10; c++) {
|
||||
from_hex_lut[(unsigned char)('0' + c)] = 0 + c;
|
||||
to_hex_lut[ (unsigned char)( 0 + c)] = '0' + c;
|
||||
}
|
||||
for (unsigned char c = 0; c < 6; c++) {
|
||||
from_hex_lut[(unsigned char)('a' + c)] = 10 + c;
|
||||
from_hex_lut[(unsigned char)('A' + c)] = 10 + c;
|
||||
to_hex_lut[ (unsigned char)(10 + c)] = 'a' + c;
|
||||
}
|
||||
}
|
||||
constexpr char from_hex(unsigned char c) const noexcept { return from_hex_lut[c]; }
|
||||
constexpr char to_hex(unsigned char b) const noexcept { return to_hex_lut[b]; }
|
||||
} constexpr hex_lut;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Creates hex digits from a character sequence.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void to_hex(InputIt begin, InputIt end, OutputIt out) {
|
||||
for (; begin != end; ++begin) {
|
||||
auto c = *begin;
|
||||
*out++ = detail::hex_lut.to_hex((c & 0xf0) >> 4);
|
||||
*out++ = detail::hex_lut.to_hex(c & 0x0f);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a hex string from an iterable, std::string-like object
|
||||
inline std::string to_hex(string_view s) {
|
||||
std::string hex;
|
||||
hex.reserve(s.size() * 2);
|
||||
to_hex(s.begin(), s.end(), std::back_inserter(hex));
|
||||
return hex;
|
||||
}
|
||||
|
||||
inline std::string to_hex(ustring_view s) {
|
||||
std::string hex;
|
||||
hex.reserve(s.size() * 2);
|
||||
to_hex(s.begin(), s.end(), std::back_inserter(hex));
|
||||
return hex;
|
||||
}
|
||||
|
||||
/// Returns true if all elements in the range are hex characters
|
||||
template <typename It>
|
||||
constexpr bool is_hex(It begin, It end) {
|
||||
for (; begin != end; ++begin) {
|
||||
if (detail::hex_lut.from_hex(*begin) == 0 && *begin != '0')
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if all elements in the string-like value are hex characters
|
||||
constexpr bool is_hex(string_view s) { return is_hex(s.begin(), s.end()); }
|
||||
constexpr bool is_hex(ustring_view s) { return is_hex(s.begin(), s.end()); }
|
||||
|
||||
/// Convert a hex digit into its numeric (0-15) value
|
||||
constexpr char from_hex_digit(unsigned char x) noexcept {
|
||||
return detail::hex_lut.from_hex(x);
|
||||
}
|
||||
|
||||
/// Constructs a byte value from a pair of hex digits
|
||||
constexpr char from_hex_pair(unsigned char a, unsigned char b) noexcept { return (from_hex_digit(a) << 4) | from_hex_digit(b); }
|
||||
|
||||
/// Converts a sequence of hex digits to bytes. Undefined behaviour if any characters are not in
|
||||
/// [0-9a-fA-F] or if the input sequence length is not even. It is permitted for the input and
|
||||
/// output ranges to overlap as long as out is no earlier than begin.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void from_hex(InputIt begin, InputIt end, OutputIt out) {
|
||||
using std::distance;
|
||||
assert(distance(begin, end) % 2 == 0);
|
||||
while (begin != end) {
|
||||
auto a = *begin++;
|
||||
auto b = *begin++;
|
||||
*out++ = from_hex_pair(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts hex digits from a std::string-like object into a std::string of bytes. Undefined
|
||||
/// behaviour if any characters are not in [0-9a-fA-F] or if the input sequence length is not even.
|
||||
inline std::string from_hex(string_view s) {
|
||||
std::string bytes;
|
||||
bytes.reserve(s.size() / 2);
|
||||
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
|
||||
inline std::string from_hex(ustring_view s) {
|
||||
std::string bytes;
|
||||
bytes.reserve(s.size() / 2);
|
||||
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
|
||||
}
|
108
lokimq/jobs.cpp
108
lokimq/jobs.cpp
|
@ -1,108 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "batch.h"
|
||||
#include "lokimq-internal.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
void LokiMQ::proxy_batch(detail::Batch* batch) {
|
||||
batches.insert(batch);
|
||||
const int jobs = batch->size();
|
||||
for (int i = 0; i < jobs; i++)
|
||||
batch_jobs.emplace(batch, i);
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void LokiMQ::job(std::function<void()> f) {
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::move(f));
|
||||
auto* baseptr = static_cast<detail::Batch*>(b);
|
||||
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_schedule_reply_job(std::function<void()> f) {
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::move(f));
|
||||
batches.insert(b);
|
||||
reply_jobs.emplace(static_cast<detail::Batch*>(b), 0);
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_run_batch_jobs(std::queue<batch_job>& jobs, const int reserved, int& active, bool reply) {
|
||||
while (!jobs.empty() && static_cast<int>(workers.size()) < max_workers &&
|
||||
(active < reserved || active_workers() < general_workers)) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(jobs.front()), reply));
|
||||
jobs.pop();
|
||||
active++;
|
||||
}
|
||||
}
|
||||
|
||||
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
|
||||
// the timer. If the timer object hasn't been set up yet it gets set up here.
|
||||
void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
int timer_id = zmq_timers_add(timers.get(),
|
||||
interval.count(),
|
||||
[](int timer_id, void* self) { static_cast<LokiMQ*>(self)->_queue_timer_job(timer_id); },
|
||||
this);
|
||||
if (timer_id == -1)
|
||||
throw zmq::error_t{};
|
||||
timer_jobs[timer_id] = std::make_tuple(std::move(job), squelch, false);
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_timer(bt_list_consumer timer_data) {
|
||||
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
|
||||
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
|
||||
auto squelch = timer_data.consume_integer<bool>();
|
||||
if (!timer_data.is_finished())
|
||||
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
|
||||
proxy_timer(std::move(*func), interval, squelch);
|
||||
}
|
||||
|
||||
void LokiMQ::_queue_timer_job(int timer_id) {
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it == timer_jobs.end()) {
|
||||
LMQ_LOG(warn, "Could not find timer job ", timer_id);
|
||||
return;
|
||||
}
|
||||
auto& timer = it->second;
|
||||
auto& squelch = std::get<1>(timer);
|
||||
auto& running = std::get<2>(timer);
|
||||
if (squelch && running) {
|
||||
LMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
|
||||
return;
|
||||
}
|
||||
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::get<0>(timer));
|
||||
if (squelch) {
|
||||
running = true;
|
||||
b->completion_proxy([this,timer_id](auto results) {
|
||||
try { results[0].get(); }
|
||||
catch (const std::exception &e) { LMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { LMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it != timer_jobs.end())
|
||||
std::get<2>(it->second)/*running*/ = false;
|
||||
});
|
||||
}
|
||||
batches.insert(b);
|
||||
batch_jobs.emplace(static_cast<detail::Batch*>(b), 0);
|
||||
assert(b->size() == 1);
|
||||
}
|
||||
|
||||
void LokiMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
|
||||
if (proxy_thread.joinable()) {
|
||||
detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{
|
||||
detail::serialize_object(std::move(job)),
|
||||
interval.count(),
|
||||
squelch}}));
|
||||
} else {
|
||||
proxy_timer(std::move(job), interval, squelch);
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
|
||||
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
#pragma once
|
||||
#include <vector>
|
||||
#include "connections.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
class LokiMQ;
|
||||
|
||||
/// Encapsulates an incoming message from a remote connection with message details plus extra
|
||||
/// info need to send a reply back through the proxy thread via the `reply()` method. Note that
|
||||
/// this object gets reused: callbacks should use but not store any reference beyond the callback.
|
||||
class Message {
|
||||
public:
|
||||
LokiMQ& lokimq; ///< The owning LokiMQ object
|
||||
std::vector<string_view> data; ///< The provided command data parts, if any.
|
||||
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
|
||||
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
|
||||
|
||||
/// Constructor
|
||||
Message(LokiMQ& lmq, ConnectionID cid) : lokimq{lmq}, conn{std::move(cid)} {}
|
||||
|
||||
// Non-copyable
|
||||
Message(const Message&) = delete;
|
||||
Message& operator=(const Message&) = delete;
|
||||
|
||||
/// Sends a command back to whomever sent this message. Arguments are forwarded to send() but
|
||||
/// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e.
|
||||
/// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to
|
||||
/// establish a new connection to the SN if no longer connected. For non-SN messages the reply
|
||||
/// will be attempted using the available routing information, but if the connection has already
|
||||
/// been closed the reply will be dropped.
|
||||
///
|
||||
/// If you want to send a non-strong reply even when the remote is a service node then add
|
||||
/// an explicit `send_option::optional()` argument.
|
||||
template <typename... Args>
|
||||
void send_back(string_view, Args&&... args);
|
||||
|
||||
/// Sends a reply to a request. This takes no command: the command is always the built-in
|
||||
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
|
||||
/// arguments are as in `send_back()`. You should only send one reply for a command expecting
|
||||
/// replies, though this is not enforced: attempting to send multiple replies will simply be
|
||||
/// dropped when received by the remote. (Note, however, that it is possible to send multiple
|
||||
/// messages -- e.g. you could send a reply and then also call send_back() and/or send_request()
|
||||
/// to send more requests back to the sender).
|
||||
template <typename... Args>
|
||||
void send_reply(Args&&... args);
|
||||
|
||||
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
|
||||
/// lmq.request() that takes care of setting up the recipient arguments.
|
||||
template <typename ReplyCallback, typename... Args>
|
||||
void send_request(string_view cmd, ReplyCallback&& callback, Args&&... args);
|
||||
};
|
||||
|
||||
}
|
617
lokimq/proxy.cpp
617
lokimq/proxy.cpp
|
@ -1,617 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include "hex.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
void LokiMQ::proxy_quit() {
|
||||
LMQ_LOG(debug, "Received quit command, shutting down proxy thread");
|
||||
|
||||
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
|
||||
|
||||
command.setsockopt<int>(ZMQ_LINGER, 0);
|
||||
command.close();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{control_sockets_mutex};
|
||||
for (auto &control : thread_control_sockets)
|
||||
control->close();
|
||||
proxy_shutting_down = true; // To prevent threads from opening new control sockets
|
||||
}
|
||||
workers_socket.close();
|
||||
int linger = std::chrono::milliseconds{CLOSE_LINGER}.count();
|
||||
for (auto& s : connections)
|
||||
s.setsockopt(ZMQ_LINGER, linger);
|
||||
connections.clear();
|
||||
peers.clear();
|
||||
|
||||
LMQ_LOG(debug, "Proxy thread teardown complete");
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_send(bt_dict_consumer data) {
|
||||
// NB: bt_dict_consumer goes in alphabetical order
|
||||
string_view hint;
|
||||
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
|
||||
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
|
||||
bool optional = false;
|
||||
bool outgoing = false;
|
||||
bool incoming = false;
|
||||
bool request = false;
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id;
|
||||
|
||||
std::string request_tag;
|
||||
ReplyCallback request_callback;
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_pubkey or conn_id missing");
|
||||
if (data.skip_until("conn_route"))
|
||||
conn_id.route = data.consume_string();
|
||||
if (data.skip_until("hint"))
|
||||
hint = data.consume_string_view();
|
||||
if (data.skip_until("incoming"))
|
||||
incoming = data.consume_integer<bool>();
|
||||
if (data.skip_until("keep_alive"))
|
||||
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
if (data.skip_until("optional"))
|
||||
optional = data.consume_integer<bool>();
|
||||
if (data.skip_until("outgoing"))
|
||||
outgoing = data.consume_integer<bool>();
|
||||
|
||||
if (data.skip_until("request"))
|
||||
request = data.consume_integer<bool>();
|
||||
if (request) {
|
||||
if (!data.skip_until("request_callback"))
|
||||
throw std::runtime_error("Internal error: received request without request_callback");
|
||||
|
||||
request_callback = detail::deserialize_object<ReplyCallback>(data.consume_integer<uintptr_t>());
|
||||
|
||||
if (!data.skip_until("request_tag"))
|
||||
throw std::runtime_error("Internal error: received request without request_name");
|
||||
request_tag = data.consume_string();
|
||||
if (data.skip_until("request_timeout"))
|
||||
request_timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
}
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; send parts missing");
|
||||
bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
send_option::queue_failure::callback_t callback_nosend;
|
||||
if (data.skip_until("send_fail"))
|
||||
callback_nosend = detail::deserialize_object<decltype(callback_nosend)>(data.consume_integer<uintptr_t>());
|
||||
|
||||
send_option::queue_full::callback_t callback_noqueue;
|
||||
if (data.skip_until("send_full_q"))
|
||||
callback_noqueue = detail::deserialize_object<decltype(callback_noqueue)>(data.consume_integer<uintptr_t>());
|
||||
|
||||
// Now figure out which socket to send to and do the actual sending. We can repeat this loop
|
||||
// multiple times, if we're sending to a SN, because it's possible that we have multiple
|
||||
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
|
||||
// connection and try the next one.
|
||||
bool retry = true, sent = false, warned = false;
|
||||
std::unique_ptr<zmq::error_t> send_error;
|
||||
while (retry) {
|
||||
retry = false;
|
||||
zmq::socket_t *send_to;
|
||||
if (conn_id.sn()) {
|
||||
auto sock_route = proxy_connect_sn(conn_id.pk, hint, optional, incoming, outgoing, keep_alive);
|
||||
if (!sock_route.first) {
|
||||
if (optional)
|
||||
LMQ_LOG(debug, "Not sending: send is optional and no connection to ",
|
||||
to_hex(conn_id.pk), " is currently established");
|
||||
else
|
||||
LMQ_LOG(error, "Unable to send to ", to_hex(conn_id.pk), ": no valid connection address found");
|
||||
break;
|
||||
}
|
||||
send_to = sock_route.first;
|
||||
conn_id.route = std::move(sock_route.second);
|
||||
} else if (!conn_id.route.empty()) { // incoming non-SN connection
|
||||
auto it = incoming_conn_index.find(conn_id);
|
||||
if (it == incoming_conn_index.end()) {
|
||||
LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
|
||||
break;
|
||||
}
|
||||
send_to = &connections[it->second];
|
||||
} else {
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == peers.end()) {
|
||||
LMQ_LOG(warn, "Unable to send: connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
|
||||
break;
|
||||
}
|
||||
auto& peer = pr.first->second;
|
||||
send_to = &connections[peer.conn_index];
|
||||
}
|
||||
|
||||
try {
|
||||
sent = send_message_parts(*send_to, build_send_parts(send, conn_id.route));
|
||||
} catch (const zmq::error_t &e) {
|
||||
if (e.num() == EHOSTUNREACH && !conn_id.route.empty() /*= incoming conn*/) {
|
||||
|
||||
LMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first != peers.end()) {
|
||||
if (!conn_id.sn()) {
|
||||
peers.erase(pr.first);
|
||||
} else {
|
||||
bool removed;
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
auto& peer = it->second;
|
||||
if (peer.route == conn_id.route) {
|
||||
peers.erase(it);
|
||||
removed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The incoming connection to the SN is no longer good, but we can retry because
|
||||
// we may have another active connection with the SN (or may want to open one).
|
||||
if (removed) {
|
||||
LMQ_LOG(debug, "Retrying sending to SN ", to_hex(conn_id.pk), " using other sockets");
|
||||
retry = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!retry) {
|
||||
LMQ_LOG(warn, "Unable to send message to ", conn_id, ": ", e.what());
|
||||
warned = true;
|
||||
if (callback_nosend) {
|
||||
job([callback = std::move(callback_nosend), error = e] { callback(&error); });
|
||||
callback_nosend = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (request) {
|
||||
if (sent) {
|
||||
LMQ_LOG(debug, "Added new pending request ", to_hex(request_tag));
|
||||
pending_requests.insert({ request_tag, {
|
||||
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
|
||||
} else {
|
||||
LMQ_LOG(debug, "Could not send request, scheduling request callback failure");
|
||||
job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); });
|
||||
}
|
||||
}
|
||||
if (!sent) {
|
||||
if (callback_nosend)
|
||||
job([callback = std::move(callback_nosend)] { callback(nullptr); });
|
||||
else if (callback_noqueue)
|
||||
job(std::move(callback_noqueue));
|
||||
else if (!warned)
|
||||
LMQ_LOG(warn, "Unable to send message to ", conn_id, ": sending would block");
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_reply(bt_dict_consumer data) {
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id{0};
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_pubkey or conn_id missing");
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; send parts missing");
|
||||
|
||||
bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == pr.second) {
|
||||
LMQ_LOG(warn, "Unable to send tagged reply: the connection is no longer valid");
|
||||
return;
|
||||
}
|
||||
|
||||
// We try any connections until one works (for ordinary remotes there will be just one, but for
|
||||
// SNs there might be one incoming and one outgoing).
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
try {
|
||||
send_message_parts(connections[it->second.conn_index], build_send_parts(send, it->second.route));
|
||||
break;
|
||||
} catch (const zmq::error_t &err) {
|
||||
if (err.num() == EHOSTUNREACH) {
|
||||
LMQ_LOG(debug, "Unable to send reply to incoming non-SN request: remote is no longer connected; removing peer details");
|
||||
it = peers.erase(it);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what());
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
|
||||
// We throw an uncaught exception here because we only generate control messages internally in
|
||||
// lokimq code: if one of these condition fail it's a lokimq bug.
|
||||
if (parts.size() < 2)
|
||||
throw std::logic_error("LokiMQ bug: Expected 2-3 message parts for a proxy control message");
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
LMQ_TRACE("control message: ", cmd);
|
||||
if (parts.size() == 3) {
|
||||
LMQ_TRACE("...: ", parts[2]);
|
||||
auto data = view(parts[2]);
|
||||
if (cmd == "SEND") {
|
||||
LMQ_TRACE("proxying message");
|
||||
return proxy_send(data);
|
||||
} else if (cmd == "REPLY") {
|
||||
LMQ_TRACE("proxying reply to non-SN incoming message");
|
||||
return proxy_reply(data);
|
||||
} else if (cmd == "BATCH") {
|
||||
LMQ_TRACE("proxy batch jobs");
|
||||
auto ptrval = bt_deserialize<uintptr_t>(data);
|
||||
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
|
||||
} else if (cmd == "SET_SNS") {
|
||||
return proxy_set_active_sns(data);
|
||||
} else if (cmd == "UPDATE_SNS") {
|
||||
return proxy_update_active_sns(data);
|
||||
} else if (cmd == "CONNECT_SN") {
|
||||
proxy_connect_sn(data);
|
||||
return;
|
||||
} else if (cmd == "CONNECT_REMOTE") {
|
||||
return proxy_connect_remote(data);
|
||||
} else if (cmd == "DISCONNECT") {
|
||||
return proxy_disconnect(data);
|
||||
} else if (cmd == "TIMER") {
|
||||
return proxy_timer(data);
|
||||
}
|
||||
} else if (parts.size() == 2) {
|
||||
if (cmd == "START") {
|
||||
// Command send by the owning thread during startup; we send back a simple READY reply to
|
||||
// let it know we are running.
|
||||
return route_control(command, route, "READY");
|
||||
} else if (cmd == "QUIT") {
|
||||
// Asked to quit: set max_workers to zero and tell any idle ones to quit. We will
|
||||
// close workers as they come back to READY status, and then close external
|
||||
// connections once all workers are done.
|
||||
max_workers = 0;
|
||||
for (const auto &route : idle_workers)
|
||||
route_control(workers_socket, workers[route].worker_routing_id, "QUIT");
|
||||
idle_workers.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("LokiMQ bug: Proxy received invalid control command: " +
|
||||
std::string{cmd} + " (" + std::to_string(parts.size()) + ")");
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_loop() {
|
||||
|
||||
zap_auth.setsockopt<int>(ZMQ_LINGER, 0);
|
||||
zap_auth.bind(ZMQ_ADDR_ZAP);
|
||||
|
||||
workers_socket.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
|
||||
workers_socket.bind(SN_ADDR_WORKERS);
|
||||
|
||||
assert(general_workers > 0);
|
||||
if (batch_jobs_reserved < 0)
|
||||
batch_jobs_reserved = (general_workers + 1) / 2;
|
||||
if (reply_jobs_reserved < 0)
|
||||
reply_jobs_reserved = (general_workers + 7) / 8;
|
||||
|
||||
max_workers = general_workers + batch_jobs_reserved + reply_jobs_reserved;
|
||||
for (const auto& cat : categories) {
|
||||
max_workers += cat.second.reserved_threads;
|
||||
}
|
||||
|
||||
if (log_level() >= LogLevel::debug) {
|
||||
LMQ_LOG(debug, "Reserving space for ", max_workers, " max workers = ", general_workers, " general plus reservations for:");
|
||||
for (const auto& cat : categories)
|
||||
LMQ_LOG(debug, " - ", cat.first, ": ", cat.second.reserved_threads);
|
||||
LMQ_LOG(debug, " - (batch jobs): ", batch_jobs_reserved);
|
||||
LMQ_LOG(debug, " - (reply jobs): ", reply_jobs_reserved);
|
||||
}
|
||||
|
||||
workers.reserve(max_workers);
|
||||
if (!workers.empty())
|
||||
throw std::logic_error("Internal error: proxy thread started with active worker threads");
|
||||
|
||||
for (size_t i = 0; i < bind.size(); i++) {
|
||||
auto& b = bind[i].second;
|
||||
zmq::socket_t listener{context, zmq::socket_type::router};
|
||||
|
||||
std::string auth_domain = bt_serialize(i);
|
||||
setup_external_socket(listener);
|
||||
listener.setsockopt(ZMQ_ZAP_DOMAIN, auth_domain.c_str(), auth_domain.size());
|
||||
if (b.curve) {
|
||||
listener.setsockopt<int>(ZMQ_CURVE_SERVER, 1);
|
||||
listener.setsockopt(ZMQ_CURVE_PUBLICKEY, pubkey.data(), pubkey.size());
|
||||
listener.setsockopt(ZMQ_CURVE_SECRETKEY, privkey.data(), privkey.size());
|
||||
}
|
||||
listener.setsockopt<int>(ZMQ_ROUTER_HANDOVER, 1);
|
||||
listener.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
|
||||
|
||||
listener.bind(bind[i].first);
|
||||
LMQ_LOG(info, "LokiMQ listening on ", bind[i].first);
|
||||
|
||||
connections.push_back(std::move(listener));
|
||||
auto conn_id = next_conn_id++;
|
||||
conn_index_to_id.push_back(conn_id);
|
||||
incoming_conn_index[conn_id] = connections.size() - 1;
|
||||
b.index = connections.size() - 1;
|
||||
}
|
||||
pollitems_stale = true;
|
||||
|
||||
// Also add an internal connection to self so that calling code can avoid needing to
|
||||
// special-case rare situations where we are supposed to talk to a quorum member that happens to
|
||||
// be ourselves (which can happen, for example, with cross-quoum Blink communication)
|
||||
// FIXME: not working
|
||||
//listener.bind(SN_ADDR_SELF);
|
||||
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
auto do_conn_cleanup = [this] { proxy_conn_cleanup(); };
|
||||
using CleanupLambda = decltype(do_conn_cleanup);
|
||||
if (-1 == zmq_timers_add(timers.get(),
|
||||
std::chrono::milliseconds{CONN_CHECK_INTERVAL}.count(),
|
||||
// Wrap our lambda into a C function pointer where we pass in the lambda pointer as extra arg
|
||||
[](int /*timer_id*/, void* cleanup) { (*static_cast<CleanupLambda*>(cleanup))(); },
|
||||
&do_conn_cleanup)) {
|
||||
throw zmq::error_t{};
|
||||
}
|
||||
|
||||
std::vector<zmq::message_t> parts;
|
||||
|
||||
while (true) {
|
||||
std::chrono::milliseconds poll_timeout;
|
||||
if (max_workers == 0) { // Will be 0 only if we are quitting
|
||||
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); })) {
|
||||
// All the workers have finished, so we can finish shutting down
|
||||
return proxy_quit();
|
||||
}
|
||||
poll_timeout = 1s; // We don't keep running timers when we're quitting, so don't have a timer to check
|
||||
} else {
|
||||
poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())};
|
||||
}
|
||||
|
||||
if (proxy_skip_one_poll)
|
||||
proxy_skip_one_poll = false;
|
||||
else {
|
||||
LMQ_TRACE("polling for new messages");
|
||||
|
||||
if (pollitems_stale)
|
||||
rebuild_pollitems();
|
||||
|
||||
// We poll the control socket and worker socket for any incoming messages. If we have
|
||||
// available worker room then also poll incoming connections and outgoing connections
|
||||
// for messages to forward to a worker. Otherwise, we just look for a control message
|
||||
// or a worker coming back with a ready message.
|
||||
zmq::poll(pollitems.data(), pollitems.size(), poll_timeout);
|
||||
}
|
||||
|
||||
LMQ_TRACE("processing control messages");
|
||||
// Retrieve any waiting incoming control messages
|
||||
for (parts.clear(); recv_message_parts(command, parts, zmq::recv_flags::dontwait); parts.clear()) {
|
||||
proxy_control_message(parts);
|
||||
}
|
||||
|
||||
LMQ_TRACE("processing worker messages");
|
||||
for (parts.clear(); recv_message_parts(workers_socket, parts, zmq::recv_flags::dontwait); parts.clear()) {
|
||||
proxy_worker_message(parts);
|
||||
}
|
||||
|
||||
LMQ_TRACE("processing timers");
|
||||
zmq_timers_execute(timers.get());
|
||||
|
||||
// Handle any zap authentication
|
||||
LMQ_TRACE("processing zap requests");
|
||||
process_zap_requests();
|
||||
|
||||
// See if we can drain anything from the current queue before we potentially add to it
|
||||
// below.
|
||||
LMQ_TRACE("processing queued jobs and messages");
|
||||
proxy_process_queue();
|
||||
|
||||
LMQ_TRACE("processing new incoming messages");
|
||||
|
||||
// We round-robin connections when pulling off pending messages one-by-one rather than
|
||||
// pulling off all messages from one connection before moving to the next; thus in cases of
|
||||
// contention we end up fairly distributing.
|
||||
const int num_sockets = connections.size();
|
||||
std::queue<int> queue_index;
|
||||
for (int i = 0; i < num_sockets; i++)
|
||||
queue_index.push(i);
|
||||
|
||||
for (parts.clear(); !queue_index.empty() && static_cast<int>(workers.size()) < max_workers; parts.clear()) {
|
||||
size_t i = queue_index.front();
|
||||
queue_index.pop();
|
||||
auto& sock = connections[i];
|
||||
|
||||
if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait))
|
||||
continue;
|
||||
|
||||
// We only pull this one message now but then requeue the socket so that after we check
|
||||
// all other sockets we come back to this one to check again.
|
||||
queue_index.push(i);
|
||||
|
||||
if (parts.empty()) {
|
||||
LMQ_LOG(warn, "Ignoring empty (0-part) incoming message");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!proxy_handle_builtin(i, parts))
|
||||
proxy_to_worker(i, parts);
|
||||
|
||||
if (pollitems_stale) {
|
||||
// If our items became stale then we may have just closed a connection and so our
|
||||
// queue index maybe also be stale, so restart the proxy loop (so that we rebuild
|
||||
// pollitems).
|
||||
LMQ_TRACE("pollitems became stale; short-circuiting incoming message loop");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("done proxy loop");
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_error_response(string_view cmd) {
|
||||
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
|
||||
}
|
||||
|
||||
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
|
||||
// reason)
|
||||
bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts) {
|
||||
// Doubling as a bool and an offset:
|
||||
size_t incoming = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_ROUTER;
|
||||
|
||||
string_view route, cmd;
|
||||
if (parts.size() < 1 + incoming) {
|
||||
LMQ_LOG(warn, "Received empty message; ignoring");
|
||||
return true;
|
||||
}
|
||||
if (incoming) {
|
||||
route = view(parts[0]);
|
||||
cmd = view(parts[1]);
|
||||
} else {
|
||||
cmd = view(parts[0]);
|
||||
}
|
||||
LMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back()));
|
||||
|
||||
if (cmd == "REPLY") {
|
||||
size_t tag_pos = 1 + incoming;
|
||||
if (parts.size() <= tag_pos) {
|
||||
LMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
|
||||
return true;
|
||||
}
|
||||
std::string reply_tag{view(parts[tag_pos])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
LMQ_LOG(debug, "Received REPLY for pending command ", to_hex(reply_tag), "; scheduling callback");
|
||||
std::vector<std::string> data;
|
||||
data.reserve(parts.size() - (tag_pos + 1));
|
||||
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
|
||||
data.emplace_back(view(*it));
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), data=std::move(data)] {
|
||||
callback(true, std::move(data));
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
return true;
|
||||
} else if (cmd == "HI") {
|
||||
if (!incoming) {
|
||||
LMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
LMQ_LOG(debug, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO");
|
||||
try {
|
||||
send_routed_message(connections[conn_index], std::string{route}, "HELLO");
|
||||
} catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
|
||||
return true;
|
||||
} else if (cmd == "HELLO") {
|
||||
if (incoming) {
|
||||
LMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto it = std::find_if(pending_connects.begin(), pending_connects.end(),
|
||||
[&](auto& pc) { return std::get<size_t>(pc) == conn_index; });
|
||||
if (it == pending_connects.end()) {
|
||||
LMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto& pc = *it;
|
||||
auto pit = peers.find(std::get<long long>(pc));
|
||||
if (pit == peers.end()) {
|
||||
LMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring");
|
||||
return true;
|
||||
}
|
||||
|
||||
LMQ_LOG(debug, "Got initial HELLO server response from ", peer_address(parts.back()));
|
||||
proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)),
|
||||
conn=conn_index_to_id[conn_index]] {
|
||||
on_success(conn);
|
||||
});
|
||||
pending_connects.erase(it);
|
||||
return true;
|
||||
} else if (cmd == "BYE") {
|
||||
if (!incoming) {
|
||||
LMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
|
||||
proxy_close_connection(conn_index, 0s);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
else if (is_error_response(cmd)) {
|
||||
// These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to
|
||||
// invoke something that doesn't exist or we don't have permission to access. These have
|
||||
// two forms (the latter is only sent by remotes running 1.1.0+).
|
||||
// - ["XXX", "whatever.command"]
|
||||
// - ["XXX", "REPLY", replytag]
|
||||
// (ignoring the routing prefix on incoming commands).
|
||||
// For the former, we log; for the latter we trigger the reply callback with a failure
|
||||
|
||||
if (parts.size() == (1 + incoming) && cmd == "UNKNOWNCOMMAND") {
|
||||
// pre-1.1.0 sent just a plain UNKNOWNCOMMAND (without the actual command); this was not
|
||||
// useful, but also this response is *expected* for things 1.0.5 didn't understand, like
|
||||
// FORBIDDEN_SN: so log it only at debug level and move on.
|
||||
LMQ_LOG(debug, "Received plain UNKNOWNCOMMAND; remote is probably an older lokimq. Ignoring.");
|
||||
return true;
|
||||
}
|
||||
|
||||
if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") {
|
||||
std::string reply_tag{view(parts[2 + incoming])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
LMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", to_hex(reply_tag), "; scheduling failure callback");
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] {
|
||||
callback(false, {{std::move(cmd)}});
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
} else {
|
||||
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"_sv),
|
||||
" from ", peer_address(parts.back()));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_process_queue() {
|
||||
// First up: process any batch jobs; since these are internal they are given higher priority.
|
||||
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
|
||||
|
||||
// Next any reply batch jobs (which are a bit different from the above, since they are
|
||||
// externally triggered but for things we initiated locally).
|
||||
proxy_run_batch_jobs(reply_jobs, reply_jobs_reserved, reply_jobs_active, true);
|
||||
|
||||
// Finally general incoming commands
|
||||
for (auto it = pending_commands.begin(); it != pending_commands.end() && active_workers() < max_workers; ) {
|
||||
auto& pending = *it;
|
||||
if (pending.cat.active_threads < pending.cat.reserved_threads
|
||||
|| active_workers() < general_workers) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(pending)));
|
||||
pending.cat.queued--;
|
||||
pending.cat.active_threads++;
|
||||
assert(pending.cat.queued >= 0);
|
||||
it = pending_commands.erase(it);
|
||||
} else {
|
||||
++it; // no available general or reserved worker spots for this job right now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,310 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#ifdef __cpp_lib_string_view
|
||||
|
||||
#include <string_view>
|
||||
namespace lokimq {
|
||||
using string_view = std::string_view;
|
||||
using ustring_view = std::basic_string_view<unsigned char>;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#include <ostream>
|
||||
#include <limits>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
/// Basic implementation of std::string_view (except for std::hash support).
|
||||
template <typename CharT>
|
||||
class simple_string_view {
|
||||
const CharT *data_;
|
||||
size_t size_;
|
||||
public:
|
||||
using traits_type = std::char_traits<CharT>;
|
||||
using value_type = CharT;
|
||||
using pointer = CharT*;
|
||||
using const_pointer = const CharT*;
|
||||
using reference = CharT&;
|
||||
using const_reference = const CharT&;
|
||||
using const_iterator = const_pointer;
|
||||
using iterator = const_iterator;
|
||||
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
|
||||
using reverse_iterator = const_reverse_iterator;
|
||||
using size_type = std::size_t;
|
||||
using different_type = std::ptrdiff_t;
|
||||
|
||||
static constexpr auto& npos = std::string::npos;
|
||||
|
||||
constexpr simple_string_view() noexcept : data_{nullptr}, size_{0} {}
|
||||
constexpr simple_string_view(const simple_string_view&) noexcept = default;
|
||||
simple_string_view(const std::basic_string<CharT>& str) : data_{str.data()}, size_{str.size()} {}
|
||||
constexpr simple_string_view(const CharT* data, size_t size) noexcept : data_{data}, size_{size} {}
|
||||
simple_string_view(const CharT* data) : data_{data}, size_{traits_type::length(data)} {}
|
||||
simple_string_view& operator=(const simple_string_view&) = default;
|
||||
constexpr const CharT* data() const noexcept { return data_; }
|
||||
constexpr size_t size() const noexcept { return size_; }
|
||||
constexpr size_t length() const noexcept { return size_; }
|
||||
constexpr size_t max_size() const noexcept { return std::numeric_limits<size_t>::max(); }
|
||||
constexpr bool empty() const noexcept { return size_ == 0; }
|
||||
explicit operator std::basic_string<CharT>() const { return {data_, size_}; }
|
||||
constexpr const CharT* begin() const noexcept { return data_; }
|
||||
constexpr const CharT* cbegin() const noexcept { return data_; }
|
||||
constexpr const CharT* end() const noexcept { return data_ + size_; }
|
||||
constexpr const CharT* cend() const noexcept { return data_ + size_; }
|
||||
reverse_iterator rbegin() const { return reverse_iterator{end()}; }
|
||||
reverse_iterator crbegin() const { return reverse_iterator{end()}; }
|
||||
reverse_iterator rend() const { return reverse_iterator{begin()}; }
|
||||
reverse_iterator crend() const { return reverse_iterator{begin()}; }
|
||||
constexpr const CharT& operator[](size_t pos) const { return data_[pos]; }
|
||||
constexpr const CharT& front() const { return *data_; }
|
||||
constexpr const CharT& back() const { return data_[size_ - 1]; }
|
||||
int compare(simple_string_view s) const;
|
||||
constexpr void remove_prefix(size_t n) { data_ += n; size_ -= n; }
|
||||
constexpr void remove_suffix(size_t n) { size_ -= n; }
|
||||
void swap(simple_string_view &s) noexcept { std::swap(data_, s.data_); std::swap(size_, s.size_); }
|
||||
|
||||
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
|
||||
constexpr // GCC 5.x is buggy wrt constexpr throwing
|
||||
#endif
|
||||
const CharT& at(size_t pos) const {
|
||||
if (pos >= size())
|
||||
throw std::out_of_range{"invalid string_view index"};
|
||||
return data_[pos];
|
||||
};
|
||||
|
||||
size_t copy(CharT* dest, size_t count, size_t pos = 0) const {
|
||||
if (pos > size()) throw std::out_of_range{"invalid copy pos"};
|
||||
size_t rcount = std::min(count, size_ - pos);
|
||||
traits_type::copy(dest, data_ + pos, rcount);
|
||||
return rcount;
|
||||
}
|
||||
|
||||
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
|
||||
constexpr // GCC 5.x is buggy wrt constexpr throwing
|
||||
#endif
|
||||
simple_string_view substr(size_t pos = 0, size_t count = npos) const {
|
||||
if (pos > size()) throw std::out_of_range{"invalid substr range"};
|
||||
simple_string_view result = *this;
|
||||
if (pos > 0) result.remove_prefix(pos);
|
||||
if (count < result.size()) result.remove_suffix(result.size() - count);
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t find(simple_string_view v, size_t pos = 0) const {
|
||||
if (pos > size_ || v.size_ > size_) return npos;
|
||||
for (const size_t max_pos = size_ - v.size_; pos <= max_pos; ++pos) {
|
||||
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
|
||||
return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
size_t find(CharT c, size_t pos = 0) const { return find({&c, 1}, pos); }
|
||||
size_t find(const CharT* c, size_t pos, size_t count) const { return find({c, count}, pos); }
|
||||
size_t find(const CharT* c, size_t pos = 0) const { return find(simple_string_view(c), pos); }
|
||||
|
||||
size_t rfind(simple_string_view v, size_t pos = npos) const {
|
||||
if (v.size_ > size_) return npos;
|
||||
const size_t max_pos = size_ - v.size_;
|
||||
for (pos = std::min(pos, max_pos); pos <= max_pos; --pos) {
|
||||
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
|
||||
return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
size_t rfind(CharT c, size_t pos = npos) const { return rfind({&c, 1}, pos); }
|
||||
size_t rfind(const CharT* c, size_t pos, size_t count) const { return rfind({c, count}, pos); }
|
||||
size_t rfind(const CharT* c, size_t pos = npos) const { return rfind(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_first_of(simple_string_view v, size_t pos = 0) const noexcept {
|
||||
for (; pos < size_; ++pos)
|
||||
for (CharT c : v)
|
||||
if (data_[pos] == c)
|
||||
return pos;
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_first_of(CharT c, size_t pos = 0) const noexcept { return find_first_of({&c, 1}, pos); }
|
||||
constexpr size_t find_first_of(const CharT* c, size_t pos, size_t count) const { return find_first_of({c, count}, pos); }
|
||||
size_t find_first_of(const CharT* c, size_t pos = 0) const { return find_first_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_last_of(simple_string_view v, const size_t pos = npos) const noexcept {
|
||||
if (size_ == 0) return npos;
|
||||
const size_t last_pos = std::min(pos, size_-1);
|
||||
for (size_t i = last_pos; i <= last_pos; --i)
|
||||
for (CharT c : v)
|
||||
if (data_[i] == c)
|
||||
return i;
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_last_of(CharT c, size_t pos = npos) const noexcept { return find_last_of({&c, 1}, pos); }
|
||||
constexpr size_t find_last_of(const CharT* c, size_t pos, size_t count) const { return find_last_of({c, count}, pos); }
|
||||
size_t find_last_of(const CharT* c, size_t pos = npos) const { return find_last_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_first_not_of(simple_string_view v, size_t pos = 0) const noexcept {
|
||||
for (; pos < size_; ++pos) {
|
||||
bool none = true;
|
||||
for (CharT c : v) {
|
||||
if (data_[pos] == c) {
|
||||
none = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (none) return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_first_not_of(CharT c, size_t pos = 0) const noexcept { return find_first_not_of({&c, 1}, pos); }
|
||||
constexpr size_t find_first_not_of(const CharT* c, size_t pos, size_t count) const { return find_first_not_of({c, count}, pos); }
|
||||
size_t find_first_not_of(const CharT* c, size_t pos = 0) const { return find_first_not_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_last_not_of(simple_string_view v, const size_t pos = npos) const noexcept {
|
||||
if (size_ == 0) return npos;
|
||||
const size_t last_pos = std::min(pos, size_-1);
|
||||
for (size_t i = last_pos; i <= last_pos; --i) {
|
||||
bool none = true;
|
||||
for (CharT c : v) {
|
||||
if (data_[i] == c) {
|
||||
none = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (none) return i;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_last_not_of(CharT c, size_t pos = npos) const noexcept { return find_last_not_of({&c, 1}, pos); }
|
||||
constexpr size_t find_last_not_of(const CharT* c, size_t pos, size_t count) const { return find_last_not_of({c, count}, pos); }
|
||||
size_t find_last_not_of(const CharT* c, size_t pos = npos) const { return find_last_not_of(simple_string_view(c), pos); }
|
||||
};
|
||||
/// We have three of each of these: one with two string views, one with RHS argument deduction, and
|
||||
/// one with LHS argument deduction, so that you can do (sv == sv), (sv == "foo"), and ("foo" == sv)
|
||||
template <typename CharT>
|
||||
inline bool operator==(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator==(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator==(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.size() == rhs.size() && 0 == std::char_traits<CharT>::compare(lhs.data(), rhs.data(), lhs.size());
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator!=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename CharT>
|
||||
inline bool operator!=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename CharT>
|
||||
inline bool operator!=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename CharT>
|
||||
inline int simple_string_view<CharT>::compare(simple_string_view s) const {
|
||||
int cmp = std::char_traits<CharT>::compare(data_, s.data(), std::min(size_, s.size()));
|
||||
if (cmp) return cmp;
|
||||
if (size_ < s.size()) return -1;
|
||||
else if (size_ > s.size()) return 1;
|
||||
return 0;
|
||||
}
|
||||
template <typename CharT>
|
||||
inline bool operator<(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) < 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) < 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) < 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) <= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) <= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator<=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) <= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) > 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) > 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) > 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>=(simple_string_view<CharT> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) >= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>=(simple_string_view<CharT> lhs, std::common_type_t<simple_string_view<CharT>> rhs) {
|
||||
return lhs.compare(rhs) >= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline bool operator>=(std::common_type_t<simple_string_view<CharT>> lhs, simple_string_view<CharT> rhs) {
|
||||
return lhs.compare(rhs) >= 0;
|
||||
};
|
||||
template <typename CharT>
|
||||
inline std::basic_ostream<CharT>& operator<<(std::basic_ostream<CharT>& os, const simple_string_view<CharT>& s) {
|
||||
os.write(s.data(), s.size());
|
||||
return os;
|
||||
}
|
||||
|
||||
using string_view = simple_string_view<char>;
|
||||
using ustring_view = simple_string_view<unsigned char>;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Add a "foo"_sv literal that works exactly like the C++17 "foo"sv literal, but works with our
|
||||
// implementation in pre-C++17.
|
||||
namespace lokimq {
|
||||
inline namespace literals {
|
||||
inline string_view operator""_sv(const char* str, size_t len) { return {str, len}; }
|
||||
}
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
namespace lokimq {
|
||||
constexpr int VERSION_MAJOR = @LOKIMQ_VERSION_MAJOR@;
|
||||
constexpr int VERSION_MINOR = @LOKIMQ_VERSION_MINOR@;
|
||||
constexpr int VERSION_PATCH = @LOKIMQ_VERSION_PATCH@;
|
||||
}
|
|
@ -1,289 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "batch.h"
|
||||
#include "hex.h"
|
||||
#include "lokimq-internal.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
void LokiMQ::worker_thread(unsigned int index) {
|
||||
std::string worker_id = "w" + std::to_string(index);
|
||||
zmq::socket_t sock{context, zmq::socket_type::dealer};
|
||||
sock.setsockopt(ZMQ_ROUTING_ID, worker_id.data(), worker_id.size());
|
||||
LMQ_LOG(debug, "New worker thread ", worker_id, " started");
|
||||
sock.connect(SN_ADDR_WORKERS);
|
||||
|
||||
Message message{*this, 0};
|
||||
std::vector<zmq::message_t> parts;
|
||||
run_info& run = workers[index]; // This contains our first job, and will be updated later with subsequent jobs
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
if (run.is_batch_job) {
|
||||
if (run.batch_jobno >= 0) {
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, "#", run.batch_jobno);
|
||||
run.batch->run_job(run.batch_jobno);
|
||||
} else if (run.batch_jobno == -1) {
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, " completion");
|
||||
run.batch->job_completion();
|
||||
}
|
||||
} else {
|
||||
message.conn = run.conn;
|
||||
message.data.clear();
|
||||
|
||||
LMQ_TRACE("Got incoming command from ", message.conn, message.conn.route.empty() ? "(outgoing)" : "(incoming)");
|
||||
|
||||
if (run.callback->second /*is_request*/) {
|
||||
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
|
||||
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
|
||||
message.data.emplace_back(it->data<char>(), it->size());
|
||||
} else {
|
||||
for (auto& m : run.data_parts)
|
||||
message.data.emplace_back(m.data<char>(), m.size());
|
||||
}
|
||||
|
||||
LMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
|
||||
run.callback->first(message);
|
||||
}
|
||||
}
|
||||
catch (const bt_deserialize_invalid& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
|
||||
}
|
||||
catch (const mapbox::util::bad_variant_access& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
|
||||
}
|
||||
catch (const std::out_of_range& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: invalid data - required field missing (", e.what(), "); ignoring request");
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
LMQ_LOG(warn, worker_id, " caught exception when processing command: ", e.what());
|
||||
}
|
||||
catch (...) {
|
||||
LMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
// Signal that we are ready for another job and wait for it. (We do this down here
|
||||
// because our first job gets set up when the thread is started).
|
||||
detail::send_control(sock, "RAN");
|
||||
LMQ_TRACE("worker ", worker_id, " waiting for requests");
|
||||
parts.clear();
|
||||
recv_message_parts(sock, parts);
|
||||
|
||||
if (parts.size() != 1) {
|
||||
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part worker instruction");
|
||||
continue;
|
||||
}
|
||||
auto command = view(parts[0]);
|
||||
if (command == "RUN") {
|
||||
LMQ_LOG(debug, "worker ", worker_id, " running command ", run.command);
|
||||
break; // proxy has set up a command for us, go back and run it.
|
||||
} else if (command == "QUIT") {
|
||||
LMQ_LOG(debug, "worker ", worker_id, " shutting down");
|
||||
detail::send_control(sock, "QUITTING");
|
||||
sock.setsockopt<int>(ZMQ_LINGER, 1000);
|
||||
sock.close();
|
||||
return;
|
||||
} else {
|
||||
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
LokiMQ::run_info& LokiMQ::get_idle_worker() {
|
||||
if (idle_workers.empty()) {
|
||||
size_t id = workers.size();
|
||||
assert(workers.capacity() > id);
|
||||
workers.emplace_back();
|
||||
auto& r = workers.back();
|
||||
r.worker_id = id;
|
||||
r.worker_routing_id = "w" + std::to_string(id);
|
||||
return r;
|
||||
}
|
||||
size_t id = idle_workers.back();
|
||||
idle_workers.pop_back();
|
||||
return workers[id];
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
|
||||
// Process messages sent by workers
|
||||
if (parts.size() != 2) {
|
||||
LMQ_LOG(error, "Received send invalid ", parts.size(), "-part message");
|
||||
return;
|
||||
}
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
LMQ_TRACE("worker message from ", route);
|
||||
assert(route.size() >= 2 && route[0] == 'w' && route[1] >= '0' && route[1] <= '9');
|
||||
string_view worker_id_str{&route[1], route.size()-1}; // Chop off the leading "w"
|
||||
unsigned int worker_id = detail::extract_unsigned(worker_id_str);
|
||||
if (!worker_id_str.empty() /* didn't consume everything */ || worker_id >= workers.size()) {
|
||||
LMQ_LOG(error, "Worker id '", route, "' is invalid, unable to process worker command");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = workers[worker_id];
|
||||
|
||||
LMQ_TRACE("received ", cmd, " command from ", route);
|
||||
if (cmd == "RAN") {
|
||||
LMQ_LOG(debug, "Worker ", route, " finished ", run.command);
|
||||
if (run.is_batch_job) {
|
||||
auto& jobs = run.is_reply_job ? reply_jobs : batch_jobs;
|
||||
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
|
||||
assert(active > 0);
|
||||
active--;
|
||||
bool clear_job = false;
|
||||
if (run.batch_jobno == -1) {
|
||||
// Returned from the completion function
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto status = run.batch->job_finished();
|
||||
if (status == detail::BatchStatus::complete) {
|
||||
jobs.emplace(run.batch, -1);
|
||||
} else if (status == detail::BatchStatus::complete_proxy) {
|
||||
try {
|
||||
run.batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
|
||||
} catch (const std::exception &e) {
|
||||
// Raise these to error levels: the caller really shouldn't be doing
|
||||
// anything non-trivial in an in-proxy completion function!
|
||||
LMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
|
||||
} catch (...) {
|
||||
LMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
|
||||
}
|
||||
clear_job = true;
|
||||
} else if (status == detail::BatchStatus::done) {
|
||||
clear_job = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (clear_job) {
|
||||
batches.erase(run.batch);
|
||||
delete run.batch;
|
||||
run.batch = nullptr;
|
||||
}
|
||||
} else {
|
||||
assert(run.cat->active_threads > 0);
|
||||
run.cat->active_threads--;
|
||||
}
|
||||
if (max_workers == 0) { // Shutting down
|
||||
LMQ_TRACE("Telling worker ", route, " to quit");
|
||||
route_control(workers_socket, route, "QUIT");
|
||||
} else {
|
||||
idle_workers.push_back(worker_id);
|
||||
}
|
||||
} else if (cmd == "QUITTING") {
|
||||
workers[worker_id].worker_thread.join();
|
||||
LMQ_LOG(debug, "Worker ", route, " exited normally");
|
||||
} else {
|
||||
LMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_run_worker(run_info& run) {
|
||||
if (!run.worker_thread.joinable())
|
||||
run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id};
|
||||
else
|
||||
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& parts) {
|
||||
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
|
||||
|
||||
peer_info tmp_peer;
|
||||
tmp_peer.conn_index = conn_index;
|
||||
if (!outgoing) tmp_peer.route = parts[0].to_string();
|
||||
peer_info* peer = nullptr;
|
||||
if (outgoing) {
|
||||
auto it = peers.find(conn_index_to_id[conn_index]);
|
||||
if (it == peers.end()) {
|
||||
LMQ_LOG(warn, "Internal error: connection index ", conn_index, " not found");
|
||||
return;
|
||||
}
|
||||
peer = &it->second;
|
||||
} else {
|
||||
std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
|
||||
tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey);
|
||||
|
||||
if (tmp_peer.service_node) {
|
||||
// It's a service node so we should have a peer_info entry; see if we can find one with
|
||||
// the same route, and if not, add one.
|
||||
auto pr = peers.equal_range(tmp_peer.pubkey);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
if (it->second.conn_index == tmp_peer.conn_index && it->second.route == tmp_peer.route) {
|
||||
peer = &it->second;
|
||||
// Update the stored auth level just in case the peer reconnected
|
||||
peer->auth_level = tmp_peer.auth_level;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!peer) {
|
||||
// We don't have a record: this is either a new SN connection or a new message on a
|
||||
// connection that recently gained SN status.
|
||||
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
|
||||
}
|
||||
} else {
|
||||
// Incoming, non-SN connection: we don't store a peer_info for this, so just use the
|
||||
// temporary one
|
||||
peer = &tmp_peer;
|
||||
}
|
||||
}
|
||||
|
||||
size_t command_part_index = outgoing ? 0 : 1;
|
||||
std::string command = parts[command_part_index].to_string();
|
||||
|
||||
// Steal any data message parts
|
||||
size_t data_part_index = command_part_index + 1;
|
||||
std::vector<zmq::message_t> data_parts;
|
||||
data_parts.reserve(parts.size() - data_part_index);
|
||||
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
|
||||
data_parts.push_back(std::move(*it));
|
||||
|
||||
auto cat_call = get_command(command);
|
||||
|
||||
// Check that command is valid, that we have permission, etc.
|
||||
if (!proxy_check_auth(conn_index, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
|
||||
return;
|
||||
|
||||
auto& category = *cat_call.first;
|
||||
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free reserved or general spots, try to queue it for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
LMQ_LOG(warn, "No space to queue incoming command ", command, "; already have ", category.queued,
|
||||
"commands queued in that category (max ", category.max_queue, "); dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
LMQ_LOG(debug, "No available free workers, queuing ", command, " for later");
|
||||
ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey, std::move(tmp_peer.route)};
|
||||
pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second, std::move(conn));
|
||||
category.queued++;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cat_call.second->second /*is_request*/ && data_parts.empty()) {
|
||||
LMQ_LOG(warn, "Received an invalid request command with no reply tag; dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = get_idle_worker();
|
||||
{
|
||||
ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey};
|
||||
c.route = std::move(tmp_peer.route);
|
||||
if (outgoing || peer->service_node)
|
||||
tmp_peer.route.clear();
|
||||
run.load(&category, std::move(command), std::move(c), std::move(data_parts), cat_call.second);
|
||||
}
|
||||
|
||||
if (outgoing)
|
||||
peer->activity(); // outgoing connection activity, pump the activity timer
|
||||
|
||||
LMQ_TRACE("Forwarding incoming ", run.command, " from ", run.conn, " @ ", peer_address(parts[command_part_index]),
|
||||
" to worker ", run.worker_routing_id);
|
||||
|
||||
proxy_run_worker(run);
|
||||
category.active_threads++;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
Subproject commit c94634bbd294204c9ba3f5b267a39582a52e8e5a
|
|
@ -0,0 +1 @@
|
|||
Subproject commit d6f300d7d250ae0a9708090c0011c0f495377e6a
|
|
@ -0,0 +1,351 @@
|
|||
#include "address.h"
|
||||
#include <tuple>
|
||||
#include <limits>
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <ostream>
|
||||
#include <oxenc/hex.h>
|
||||
#include <oxenc/base32z.h>
|
||||
#include <oxenc/base64.h>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
constexpr size_t enc_length(address::encoding enc) {
|
||||
return enc == address::encoding::hex ? 64 :
|
||||
enc == address::encoding::base64 ? 43 : // this can be 44 with a padding byte, but we don't need it
|
||||
52 /*base32z*/;
|
||||
};
|
||||
|
||||
// Parses an encoding pubkey from the given string_view. Advanced the string_view beyond the
|
||||
// consumed pubkey data, and returns the pubkey (as a 32-byte string). Throws if no valid pubkey
|
||||
// was found at the beginning of addr. We look for hex, base32z, or base64 pubkeys *unless* qr is
|
||||
// given: for QR-friendly we only accept hex or base32z (since QR cannot handle base64's alphabet).
|
||||
std::string decode_pubkey(std::string_view& in, bool qr) {
|
||||
std::string pubkey;
|
||||
if (in.size() >= 64 && oxenc::is_hex(in.substr(0, 64))) {
|
||||
pubkey = oxenc::from_hex(in.substr(0, 64));
|
||||
in.remove_prefix(64);
|
||||
} else if (in.size() >= 52 && oxenc::is_base32z(in.substr(0, 52))) {
|
||||
pubkey = oxenc::from_base32z(in.substr(0, 52));
|
||||
in.remove_prefix(52);
|
||||
} else if (!qr && in.size() >= 43 && oxenc::is_base64(in.substr(0, 43))) {
|
||||
pubkey = oxenc::from_base64(in.substr(0, 43));
|
||||
in.remove_prefix(43);
|
||||
if (!in.empty() && in.front() == '=')
|
||||
in.remove_prefix(1); // allow (and eat) a padding byte at the end
|
||||
} else {
|
||||
throw std::invalid_argument{"No pubkey found"};
|
||||
}
|
||||
return pubkey;
|
||||
}
|
||||
|
||||
// Parse the host, port, and optionally pubkey from a string view, mutating it to remove the parsed
|
||||
// sections. qr should be true if we should accept $IPv6$ as a QR-encoding-friendly alternative to
|
||||
// [IPv6] (the returned host will have the $ replaced, i.e. [IPv6]).
|
||||
std::tuple<std::string, uint16_t, std::string> parse_tcp(std::string_view& addr, bool qr, bool expect_pubkey) {
|
||||
std::tuple<std::string, uint16_t, std::string> result;
|
||||
auto& host = std::get<0>(result);
|
||||
if (addr.front() == '[' || (qr && addr.front() == '$')) { // IPv6 addr (though this is far from complete validation)
|
||||
auto pos = addr.find_first_not_of(":.1234567890abcdefABCDEF", 1);
|
||||
if (pos == std::string_view::npos)
|
||||
throw std::invalid_argument("Could not find terminating ] while parsing an IPv6 address");
|
||||
if (!(addr[pos] == ']' || (qr && addr[pos] == '$')))
|
||||
throw std::invalid_argument{"Expected " + (qr ? "$"s : "]"s) + " to close IPv6 address but found " + std::string(1, addr[pos])};
|
||||
host = std::string{addr.substr(0, pos+1)};
|
||||
if (qr) {
|
||||
if (host.front() == '$')
|
||||
host.front() = '[';
|
||||
if (host.back() == '$')
|
||||
host.back() = ']';
|
||||
}
|
||||
addr.remove_prefix(pos+1);
|
||||
} else {
|
||||
auto port_pos = addr.find(':');
|
||||
if (port_pos == std::string_view::npos)
|
||||
throw std::invalid_argument{"Could not determine host (no following ':port' found)"};
|
||||
if (port_pos == 0)
|
||||
throw std::invalid_argument{"Host cannot be empty"};
|
||||
host = std::string{addr.substr(0, port_pos)};
|
||||
addr.remove_prefix(port_pos);
|
||||
}
|
||||
|
||||
if (qr)
|
||||
// Lower-case the host because upper case hostnames are ugly
|
||||
for (char& c : host)
|
||||
if (c >= 'A' && c <= 'Z')
|
||||
c = c - 'A' + 'a';
|
||||
|
||||
if (addr.size() < 2 || addr[0] != ':')
|
||||
throw std::invalid_argument{"Could not find :port in address string"};
|
||||
addr.remove_prefix(1);
|
||||
auto pos = addr.find_first_not_of("1234567890");
|
||||
if (pos == 0)
|
||||
throw std::invalid_argument{"Could not find numeric port in address string"};
|
||||
if (pos == std::string_view::npos)
|
||||
pos = addr.size();
|
||||
size_t processed;
|
||||
int port_int = std::stoi(std::string{addr.substr(0, pos)}, &processed);
|
||||
if (port_int == 0 || processed != pos)
|
||||
throw std::invalid_argument{"Could not parse numeric port in address string"};
|
||||
if (port_int < 0 || port_int > std::numeric_limits<uint16_t>::max())
|
||||
throw std::invalid_argument{"Invalid port: port must be in range 1-65535"};
|
||||
std::get<1>(result) = static_cast<uint16_t>(port_int);
|
||||
addr.remove_prefix(pos);
|
||||
|
||||
if (expect_pubkey) {
|
||||
if (addr.size() < 1 + enc_length(qr ? address::encoding::base32z : address::encoding::base64)
|
||||
|| addr.front() != '/')
|
||||
throw std::invalid_argument{"Invalid address: expected /PUBKEY after port"};
|
||||
addr.remove_prefix(1);
|
||||
|
||||
std::get<2>(result) = decode_pubkey(addr, qr);
|
||||
if (!addr.empty())
|
||||
throw std::invalid_argument{"Invalid address: found unexpected trailing data after pubkey"};
|
||||
} else if (!addr.empty()) {
|
||||
throw std::invalid_argument{"Invalid address: found unexpected trailing data after port"};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Parse the socket path and (possibly) pubkey, mutating it to remove the parsed sections.
|
||||
// Currently the /pubkey *must* be at the end of the string, but this might not always be the case
|
||||
// (e.g. we could in the future support query string-like arguments).
|
||||
std::pair<std::string, std::string> parse_unix(std::string_view& addr, bool expect_pubkey) {
|
||||
std::pair<std::string, std::string> result;
|
||||
if (expect_pubkey) {
|
||||
size_t b64_len = addr.size() > 0 && addr.back() == '=' ? 44 : 43;
|
||||
if (addr.size() > 64 && addr[addr.size() - 65] == '/' && oxenc::is_hex(addr.substr(addr.size() - 64))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - 65)};
|
||||
result.second = oxenc::from_hex(addr.substr(addr.size() - 64));
|
||||
} else if (addr.size() > 52 && addr[addr.size() - 53] == '/' && oxenc::is_base32z(addr.substr(addr.size() - 52))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - 53)};
|
||||
result.second = oxenc::from_base32z(addr.substr(addr.size() - 52));
|
||||
} else if (addr.size() > b64_len && addr[addr.size() - b64_len - 1] == '/' && oxenc::is_base64(addr.substr(addr.size() - b64_len))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - b64_len - 1)};
|
||||
result.second = oxenc::from_base64(addr.substr(addr.size() - b64_len));
|
||||
} else {
|
||||
throw std::invalid_argument{"icp+curve:// requires a trailing /PUBKEY value, got: " + std::string{addr}};
|
||||
}
|
||||
} else {
|
||||
// Anything goes
|
||||
result.first = std::string{addr};
|
||||
}
|
||||
|
||||
// Any path above consumes everything:
|
||||
addr.remove_prefix(addr.size());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
address::address(std::string_view addr) {
|
||||
auto protoend = addr.find("://"sv);
|
||||
if (protoend == std::string_view::npos || protoend == 0)
|
||||
throw std::invalid_argument("Invalid address: no protocol found");
|
||||
auto pro = addr.substr(0, protoend);
|
||||
addr.remove_prefix(protoend + 3);
|
||||
if (addr.empty())
|
||||
throw std::invalid_argument("Invalid address: no value specified after protocol");
|
||||
bool qr = false;
|
||||
if (pro == "tcp") protocol = proto::tcp;
|
||||
else if (pro == "tcp+curve" || pro == "curve") protocol = proto::tcp_curve;
|
||||
else if (pro == "ipc") protocol = proto::ipc;
|
||||
else if (pro == "ipc+curve") protocol = proto::ipc_curve;
|
||||
else if (pro == "TCP") {
|
||||
protocol = proto::tcp;
|
||||
qr = true;
|
||||
} else if (pro == "CURVE") {
|
||||
protocol = proto::tcp_curve;
|
||||
qr = true;
|
||||
} else {
|
||||
throw std::invalid_argument("Invalid protocol '" + std::string{pro} + "'");
|
||||
}
|
||||
|
||||
if (qr) {
|
||||
// The QR variations only allow QR-alphanumeric characters (upper-case letters, numbers, and
|
||||
// a few symbols):
|
||||
for (char c : addr) {
|
||||
// QR alphanumeric also allows space, %, *, +, but we don't need or allow any of those here.
|
||||
if (!((c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '$' || c == ':' || c == '/' || c == '.' || c == '-'))
|
||||
throw std::invalid_argument("Found non-QR-alphanumeric value in QR TCP:// or CURVE:// address");
|
||||
}
|
||||
}
|
||||
|
||||
if (tcp())
|
||||
std::tie(host, port, pubkey) = parse_tcp(addr, qr, curve());
|
||||
else
|
||||
std::tie(socket, pubkey) = parse_unix(addr, curve());
|
||||
|
||||
if (!addr.empty())
|
||||
throw std::invalid_argument{"Invalid trailing garbage '" + std::string{addr} + "' in address"};
|
||||
}
|
||||
|
||||
address& address::set_pubkey(std::string_view pk) {
|
||||
if (pk.size() == 0) {
|
||||
if (protocol == proto::tcp_curve) protocol = proto::tcp;
|
||||
else if (protocol == proto::ipc_curve) protocol = proto::ipc;
|
||||
} else if (pk.size() == 32) {
|
||||
if (protocol == proto::tcp) protocol = proto::tcp_curve;
|
||||
else if (protocol == proto::ipc) protocol = proto::ipc_curve;
|
||||
} else {
|
||||
throw std::invalid_argument{"Invalid pubkey passed to set_pubkey(): require 0- or 32-byte pubkey"};
|
||||
}
|
||||
pubkey = pk;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string address::encode_pubkey(encoding enc) const {
|
||||
std::string pk;
|
||||
if (enc == encoding::hex)
|
||||
pk = oxenc::to_hex(pubkey);
|
||||
else if (enc == encoding::base32z)
|
||||
pk = oxenc::to_base32z(pubkey);
|
||||
else if (enc == encoding::BASE32Z) {
|
||||
pk = oxenc::to_base32z(pubkey);
|
||||
for (char& c : pk)
|
||||
if (c >= 'a' && c <= 'z')
|
||||
c = c - 'a' + 'A';
|
||||
} else if (enc == encoding::base64) {
|
||||
pk = oxenc::to_base64(pubkey);
|
||||
if (pk.size() == 44 && pk.back() == '=')
|
||||
pk.resize(43);
|
||||
} else {
|
||||
throw std::logic_error{"Invalid encoding"};
|
||||
}
|
||||
return pk;
|
||||
}
|
||||
|
||||
std::string address::full_address(encoding enc) const {
|
||||
std::string result;
|
||||
std::string pk;
|
||||
if (curve())
|
||||
pk = encode_pubkey(enc);
|
||||
|
||||
if (protocol == proto::tcp) {
|
||||
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
|
||||
result += "tcp://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
} else if (protocol == proto::tcp_curve) {
|
||||
result.reserve(8 /*curve:// */ + host.size() + 6 /*:port*/ + 1 /* / */ + pk.size());
|
||||
result += "curve://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
result += '/';
|
||||
result += pk;
|
||||
} else if (protocol == proto::ipc) {
|
||||
result.reserve(6 /*ipc:// */ + socket.size());
|
||||
result += "ipc://";
|
||||
result += socket;
|
||||
} else if (protocol == proto::ipc_curve) {
|
||||
result.reserve(12 /*ipc+curve:// */ + socket.size() + 1 /* / */ + pk.size());
|
||||
result += "ipc+curve://";
|
||||
result += socket;
|
||||
result += '/';
|
||||
result += pk;
|
||||
} else {
|
||||
throw std::logic_error{"Invalid protocol"};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string address::zmq_address() const {
|
||||
std::string result;
|
||||
if (tcp()) {
|
||||
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
|
||||
result += "tcp://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
} else {
|
||||
result.reserve(6 /*ipc:// */ + socket.size());
|
||||
result += "ipc://";
|
||||
result += socket;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string address::qr_address() const {
|
||||
if (protocol != proto::tcp && protocol != proto::tcp_curve)
|
||||
throw std::logic_error("Cannot construct a QR-friendly address for a non-TCP address");
|
||||
if (host.empty())
|
||||
throw std::logic_error("Cannot construct a QR-friendly address with an empty TCP host");
|
||||
std::string result;
|
||||
result.reserve((curve() ? 8 /*CURVE:// */ : 6 /*TCP:// */) + host.size() + 6 /*:port*/ +
|
||||
(curve() ? 1 + enc_length(encoding::BASE32Z) : 0));
|
||||
result += curve() ? "CURVE://" : "TCP://";
|
||||
std::string uc_host = host;
|
||||
for (auto& c : uc_host)
|
||||
if (c >= 'a' && c <= 'z')
|
||||
c = c - 'a' + 'A';
|
||||
|
||||
if (uc_host.front() == '[' && uc_host.back() == ']') {
|
||||
uc_host.front() = '$';
|
||||
uc_host.back() = '$';
|
||||
}
|
||||
result += uc_host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
|
||||
if (curve()) {
|
||||
result += '/';
|
||||
result += encode_pubkey(encoding::BASE32Z);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool address::operator==(const address& other) const {
|
||||
if (protocol != other.protocol)
|
||||
return false;
|
||||
if (tcp())
|
||||
if (host != other.host || port != other.port)
|
||||
return false;
|
||||
if (ipc())
|
||||
if (socket != other.socket)
|
||||
return false;
|
||||
if (curve())
|
||||
if (pubkey != other.pubkey)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
address address::tcp(std::string host, uint16_t port) {
|
||||
address a;
|
||||
a.protocol = proto::tcp;
|
||||
a.host = std::move(host);
|
||||
a.port = port;
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::tcp_curve(std::string host, uint16_t port, std::string pubkey) {
|
||||
address a;
|
||||
a.protocol = proto::tcp_curve;
|
||||
a.host = std::move(host);
|
||||
a.port = port;
|
||||
a.pubkey = std::move(pubkey);
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::ipc(std::string path) {
|
||||
address a;
|
||||
a.protocol = proto::ipc;
|
||||
a.socket = std::move(path);
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::ipc_curve(std::string path, std::string pubkey) {
|
||||
address a;
|
||||
a.protocol = proto::ipc_curve;
|
||||
a.socket = std::move(path);
|
||||
a.pubkey = std::move(pubkey);
|
||||
return a;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, const address& a) { return o << a.full_address(); }
|
||||
|
||||
}
|
|
@ -0,0 +1,218 @@
|
|||
// Copyright (c) 2020-2021, The Oxen Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <cstdint>
|
||||
#include <iosfwd>
|
||||
#include <functional>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
/** OxenMQ address abstraction class. This class uses and extends standard ZMQ addresses allowing
|
||||
* extra parameters to be passed in in a relative standard way.
|
||||
*
|
||||
* External ZMQ addresses generally have two forms that we are concerned with: one for TCP and one
|
||||
* for Unix sockets:
|
||||
*
|
||||
* tcp://HOST:PORT -- HOST can be a hostname, IPv4 address, or IPv6 address in [...]
|
||||
* ipc://PATH -- PATH can be absolute (ipc:///path/to/some.sock) or relative (ipc://some.sock)
|
||||
*
|
||||
* but this doesn't carry enough info: in particular, we can connect with two very different
|
||||
* protocols: curve25519-encrypted, or plaintext, but for curve25519-encrypted we require the
|
||||
* remote's public key as well to verify the connection.
|
||||
*
|
||||
* This class, then, handles this by allowing addresses of:
|
||||
*
|
||||
* Standard ZMQ address: these carry no pubkey and so the connection will be unencrypted:
|
||||
*
|
||||
* tcp://HOSTNAME:PORT
|
||||
* ipc://PATH
|
||||
*
|
||||
* Non-ZMQ address formats that specify that the connection shall be x25519 encrypted:
|
||||
*
|
||||
* curve://HOSTNAME:PORT/PUBKEY -- PUBKEY must be specified in hex (64 characters), base32z (52)
|
||||
* or base64 (43 or 44 with one '=' trailing padding)
|
||||
* ipc+curve:///path/to/my.sock/PUBKEY -- same requirements on PUBKEY as above.
|
||||
* tcp+curve://(whatever) -- alias for curve://(whatever)
|
||||
*
|
||||
* We also accept special upper-case TCP-only variants which *only* accept uppercase characters and
|
||||
* a few required symbols (:, /, $, ., and -) in the string:
|
||||
*
|
||||
* TCP://HOSTNAME:PORT
|
||||
* CURVE://HOSTNAME:PORT/B32ZPUBKEY
|
||||
*
|
||||
* These versions are explicitly meant to be used with QR codes; the upper-case-only requirement
|
||||
* allows a smaller QR code by allowing QR's alphanumeric mode (which allows only [A-Z0-9 $%*+./:-])
|
||||
* to be used. Such a QR-friendly address can be created from the qr_address() method. To support
|
||||
* literal IPv6 addresses we surround the address with $...$ instead of the usual [...].
|
||||
*
|
||||
* Note that this class does very little validate the host argument at all, and no socket path
|
||||
* validation whatsoever. The only constraint on host is when parsing an encoded address: we check
|
||||
* that it contains no : at all, or must be a [bracketed] expression that contains only hex
|
||||
* characters, :'s, or .'s. Otherwise, if you pass broken crap into the hostname, expect broken
|
||||
* crap out.
|
||||
*/
|
||||
struct address {
|
||||
/// Supported address protocols: TCP connections (tcp), or unix sockets (ipc).
|
||||
enum class proto {
|
||||
tcp,
|
||||
tcp_curve,
|
||||
ipc,
|
||||
ipc_curve
|
||||
};
|
||||
/// Supported public key encodings (used when regenerating an augmented address).
|
||||
enum class encoding {
|
||||
hex, ///< hexadecimal encoded
|
||||
base32z, ///< base32z encoded
|
||||
base64, ///< base64 encoded (*without* trailing = padding)
|
||||
BASE32Z ///< upper-case base32z encoding, meant for QR encoding
|
||||
};
|
||||
|
||||
/// The protocol: one of the `protocol` enum values for tcp or ipc (unix sockets), with or
|
||||
/// without _curve encryption.
|
||||
proto protocol = proto::tcp;
|
||||
/// The host for tcp connections; can be a hostname or IP address. If this is an IPv6 it must be surrounded with [ ].
|
||||
std::string host;
|
||||
/// The port (for tcp connections)
|
||||
uint16_t port = 0;
|
||||
/// The socket path (for unix socket connections)
|
||||
std::string socket;
|
||||
/// If a curve connection, this is the required remote public key (in bytes)
|
||||
std::string pubkey;
|
||||
|
||||
/// Default constructor; this gives you an unusable address.
|
||||
address() = default;
|
||||
|
||||
/**
|
||||
* Constructs an address by parsing a string_view containing one of the formats listed in the
|
||||
* class description. This is intentionally implicitly constructible so that you can pass a
|
||||
* string_view into anything expecting an `address`.
|
||||
*
|
||||
* Throw std::invalid_argument if the given address is not parseable.
|
||||
*/
|
||||
address(std::string_view addr);
|
||||
|
||||
/** Constructs an address from a remote string and a separate pubkey. Typically `remote` is a
|
||||
* basic ZMQ connect string, though this is not enforced. Any pubkey information embedded in
|
||||
* the remote string will be discarded and replaced with the given pubkey string. The result
|
||||
* will be curve encrypted if `pubkey` is non-empty, plaintext if `pubkey` is empty.
|
||||
*
|
||||
* Throws an exception if either addr or pubkey is invalid.
|
||||
*
|
||||
* Exactly equivalent to `address a{remote}; a.set_pubkey(pubkey);`
|
||||
*/
|
||||
address(std::string_view addr, std::string_view pubkey) : address(addr) { set_pubkey(pubkey); }
|
||||
|
||||
/// Replaces the address's pubkey (if any) with the given pubkey (or no pubkey if empty). If
|
||||
/// changing from pubkey to no-pubkey or no-pubkey to pubkey then the protocol is update to
|
||||
/// switch to or from curve encryption.
|
||||
///
|
||||
/// pubkey should be the 32-byte binary pubkey, or an empty string to remove an existing pubkey.
|
||||
///
|
||||
/// Returns the object itself, so that you can chain it.
|
||||
address& set_pubkey(std::string_view pubkey);
|
||||
|
||||
/// Constructs and builds the ZMQ connection address from the stored connection details. This
|
||||
/// does not contain any of the curve-related details; those must be specified separately when
|
||||
/// interfacing with ZMQ.
|
||||
std::string zmq_address() const;
|
||||
|
||||
/// Returns true if the connection was specified as a curve-encryption-enabled connection, false
|
||||
/// otherwise.
|
||||
bool curve() const { return protocol == proto::tcp_curve || protocol == proto::ipc_curve; }
|
||||
|
||||
/// True if the protocol is TCP (either with or without curve)
|
||||
bool tcp() const { return protocol == proto::tcp || protocol == proto::tcp_curve; }
|
||||
|
||||
/// True if the protocol is unix socket (either with or without curve)
|
||||
bool ipc() const { return !tcp(); }
|
||||
|
||||
/// Returns the full "augmented" address string (i.e. that could be passed in to the
|
||||
/// constructor). This will be equivalent (but not necessarily identical) to an augmented
|
||||
/// string passed into the constructor. Takes an optional encoding format for the pubkey (if
|
||||
/// any), which defaults to base32z.
|
||||
std::string full_address(encoding enc = encoding::base32z) const;
|
||||
|
||||
/// Returns a QR-code friendly address string. This returns an all-uppercase version of the
|
||||
/// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z
|
||||
/// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we surround the
|
||||
/// address with $ instead of [...]
|
||||
///
|
||||
/// \throws std::logic_error if called on a unix socket address.
|
||||
std::string qr_address() const;
|
||||
|
||||
/// Returns `.pubkey` but encoded in the given format
|
||||
std::string encode_pubkey(encoding enc) const;
|
||||
|
||||
/// Returns true if two addresses are identical (i.e. same protocol and relevant protocol
|
||||
/// arguments).
|
||||
///
|
||||
/// Note that it is possible for addresses to connect to the same socket without being
|
||||
/// identical: for example, using "foo.sock" and "./foo.sock", or writing IPv6 addresses (or
|
||||
/// even IPv4 addresses) in slightly different ways). Such equivalent but non-equal values will
|
||||
/// result in a false return here.
|
||||
///
|
||||
/// Note also that we ignore irrelevant arguments: for example, we don't care whether pubkeys
|
||||
/// match when comparing two non-curve TCP addresses.
|
||||
bool operator==(const address& other) const;
|
||||
/// Negation of ==
|
||||
bool operator!=(const address& other) const { return !operator==(other); }
|
||||
|
||||
/// Factory function that constructs a TCP address from a host and port. The connection will be
|
||||
/// plaintext. If the host is an IPv6 address it *must* be surrounded with [ and ].
|
||||
static address tcp(std::string host, uint16_t port);
|
||||
|
||||
/// Factory function that constructs a curve-encrypted TCP address from a host, port, and remote
|
||||
/// pubkey. The pubkey must be 32 bytes. As above, IPv6 addresses must be specified as [addr].
|
||||
static address tcp_curve(std::string host, uint16_t, std::string pubkey);
|
||||
|
||||
/// Factory function that constructs a unix socket address from a path. The connection will be
|
||||
/// plaintext (which is usually fine for a socket since unix sockets are local machine).
|
||||
static address ipc(std::string path);
|
||||
|
||||
/// Factory function that constructs a unix socket address from a path and remote pubkey. The
|
||||
/// connection will be curve25519 encrypted; the remote pubkey must be 32 bytes.
|
||||
static address ipc_curve(std::string path, std::string pubkey);
|
||||
};
|
||||
|
||||
// Outputs address.full_address() when sent to an ostream.
|
||||
std::ostream& operator<<(std::ostream& o, const address& a);
|
||||
|
||||
} // namespace oxenmq
|
||||
|
||||
namespace std {
|
||||
template<> struct hash<oxenmq::address> {
|
||||
std::size_t operator()(const oxenmq::address& a) const noexcept {
|
||||
return std::hash<std::string>{}(a.full_address(oxenmq::address::encoding::hex));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
|
@ -1,8 +1,10 @@
|
|||
#include "lokimq.h"
|
||||
#include "hex.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include "oxenmq.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include "oxenmq-internal.h"
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, AuthLevel a) {
|
||||
return o << to_string(a);
|
||||
|
@ -12,7 +14,7 @@ namespace {
|
|||
|
||||
// Builds a ZMTP metadata key-value pair. These will be available on every message from that peer.
|
||||
// Keys must start with X- and be <= 255 characters.
|
||||
std::string zmtp_metadata(string_view key, string_view value) {
|
||||
std::string zmtp_metadata(std::string_view key, std::string_view value) {
|
||||
assert(key.size() > 2 && key.size() <= 255 && key[0] == 'X' && key[1] == '-');
|
||||
|
||||
std::string result;
|
||||
|
@ -29,29 +31,29 @@ std::string zmtp_metadata(string_view key, string_view value) {
|
|||
}
|
||||
|
||||
|
||||
bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
|
||||
bool OxenMQ::proxy_check_auth(int64_t conn_id, bool outgoing, const peer_info& peer,
|
||||
zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) {
|
||||
auto command = view(cmd);
|
||||
std::string reply;
|
||||
|
||||
if (!cat_call.first) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
OMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
reply = "UNKNOWNCOMMAND";
|
||||
} else if (peer.auth_level < cat_call.first->access.auth) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": peer auth level ", peer.auth_level, " < ", cat_call.first->access.auth);
|
||||
reply = "FORBIDDEN";
|
||||
} else if (cat_call.first->access.local_sn && !local_service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": that command is only available when this LokiMQ is running in service node mode");
|
||||
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": that command is only available when this OxenMQ is running in service node mode");
|
||||
reply = "NOT_A_SERVICE_NODE";
|
||||
} else if (cat_call.first->access.remote_sn && !peer.service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": remote is not recognized as a service node");
|
||||
reply = "FORBIDDEN_SN";
|
||||
} else if (cat_call.second->second /*is_request*/ && data.empty()) {
|
||||
LMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [",
|
||||
to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
OMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [",
|
||||
oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
reply = "NO_REPLY_TAG";
|
||||
} else {
|
||||
return true;
|
||||
|
@ -63,39 +65,39 @@ bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info&
|
|||
msgs.push_back(create_message(peer.route));
|
||||
msgs.push_back(create_message(reply));
|
||||
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
|
||||
msgs.push_back(create_message("REPLY"_sv));
|
||||
msgs.push_back(create_message("REPLY"sv));
|
||||
msgs.push_back(create_message(view(data.front()))); // reply tag
|
||||
} else {
|
||||
msgs.push_back(create_message(view(cmd)));
|
||||
}
|
||||
|
||||
try {
|
||||
send_message_parts(connections[conn_index], msgs);
|
||||
send_message_parts(connections.at(conn_id), msgs);
|
||||
} catch (const zmq::error_t& err) {
|
||||
/* can't send: possibly already disconnected. Ignore. */
|
||||
LMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
|
||||
OMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void LokiMQ::set_active_sns(pubkey_set pubkeys) {
|
||||
void OxenMQ::set_active_sns(pubkey_set pubkeys) {
|
||||
if (proxy_thread.joinable()) {
|
||||
auto data = bt_serialize(detail::serialize_object(std::move(pubkeys)));
|
||||
auto data = oxenc::bt_serialize(detail::serialize_object(std::move(pubkeys)));
|
||||
detail::send_control(get_control_socket(), "SET_SNS", data);
|
||||
} else {
|
||||
proxy_set_active_sns(std::move(pubkeys));
|
||||
}
|
||||
}
|
||||
void LokiMQ::proxy_set_active_sns(string_view data) {
|
||||
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(bt_deserialize<uintptr_t>(data)));
|
||||
void OxenMQ::proxy_set_active_sns(std::string_view data) {
|
||||
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(oxenc::bt_deserialize<uintptr_t>(data)));
|
||||
}
|
||||
void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
|
||||
void OxenMQ::proxy_set_active_sns(pubkey_set pubkeys) {
|
||||
pubkey_set added, removed;
|
||||
for (auto it = pubkeys.begin(); it != pubkeys.end(); ) {
|
||||
auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to set_active_sns");
|
||||
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to set_active_sns");
|
||||
it = pubkeys.erase(it);
|
||||
continue;
|
||||
}
|
||||
|
@ -104,7 +106,7 @@ void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
|
|||
++it;
|
||||
}
|
||||
if (added.empty() && active_service_nodes.size() == pubkeys.size()) {
|
||||
LMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update");
|
||||
OMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update");
|
||||
return;
|
||||
}
|
||||
for (const auto& pk : active_service_nodes) {
|
||||
|
@ -116,32 +118,30 @@ void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
|
|||
proxy_update_active_sns_clean(std::move(added), std::move(removed));
|
||||
}
|
||||
|
||||
void LokiMQ::update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
LMQ_LOG(info, "uh, ", added.size());
|
||||
void OxenMQ::update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
if (proxy_thread.joinable()) {
|
||||
std::array<uintptr_t, 2> data;
|
||||
data[0] = detail::serialize_object(std::move(added));
|
||||
data[1] = detail::serialize_object(std::move(removed));
|
||||
detail::send_control(get_control_socket(), "UPDATE_SNS", bt_serialize(data));
|
||||
detail::send_control(get_control_socket(), "UPDATE_SNS", oxenc::bt_serialize(data));
|
||||
} else {
|
||||
proxy_update_active_sns(std::move(added), std::move(removed));
|
||||
}
|
||||
}
|
||||
void LokiMQ::proxy_update_active_sns(bt_list_consumer data) {
|
||||
void OxenMQ::proxy_update_active_sns(oxenc::bt_list_consumer data) {
|
||||
auto added = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
|
||||
auto remed = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
|
||||
proxy_update_active_sns(std::move(added), std::move(remed));
|
||||
}
|
||||
void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
void OxenMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
// We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting
|
||||
// values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists
|
||||
// to the _clean version.
|
||||
|
||||
LMQ_LOG(info, "uh, ", added.size(), ", ", removed.size());
|
||||
for (auto it = removed.begin(); it != removed.end(); ) {
|
||||
const auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (removed)");
|
||||
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to update_active_sns (removed)");
|
||||
it = removed.erase(it);
|
||||
} else if (!active_service_nodes.count(pk) || added.count(pk) /* added wins if in both */) {
|
||||
it = removed.erase(it);
|
||||
|
@ -153,7 +153,7 @@ void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
|
|||
for (auto it = added.begin(); it != added.end(); ) {
|
||||
const auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (added)");
|
||||
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to update_active_sns (added)");
|
||||
it = added.erase(it);
|
||||
} else if (active_service_nodes.count(pk)) {
|
||||
it = added.erase(it);
|
||||
|
@ -165,8 +165,8 @@ void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
|
|||
proxy_update_active_sns_clean(std::move(added), std::move(removed));
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) {
|
||||
LMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys");
|
||||
void OxenMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) {
|
||||
OMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys");
|
||||
|
||||
// For anything we remove we want close the connection to the SN (if outgoing), and remove the
|
||||
// stored peer_info (incoming or outgoing).
|
||||
|
@ -176,11 +176,11 @@ void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed)
|
|||
auto range = peers.equal_range(c);
|
||||
for (auto it = range.first; it != range.second; ) {
|
||||
bool outgoing = it->second.outgoing();
|
||||
size_t conn_index = it->second.conn_index;
|
||||
auto conn_id = it->second.conn_id;
|
||||
it = peers.erase(it);
|
||||
if (outgoing) {
|
||||
LMQ_LOG(debug, "Closing outgoing connection to ", c);
|
||||
proxy_close_connection(conn_index, CLOSE_LINGER);
|
||||
OMQ_LOG(debug, "Closing outgoing connection to ", c);
|
||||
proxy_close_connection(conn_id, CLOSE_LINGER);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -190,7 +190,7 @@ void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed)
|
|||
active_service_nodes.insert(std::move(pk));
|
||||
}
|
||||
|
||||
void LokiMQ::process_zap_requests() {
|
||||
void OxenMQ::process_zap_requests() {
|
||||
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
|
||||
#ifndef NDEBUG
|
||||
if (log_level() >= LogLevel::trace) {
|
||||
|
@ -200,14 +200,14 @@ void LokiMQ::process_zap_requests() {
|
|||
o << "\n[" << i << "]: ";
|
||||
auto v = view(frames[i]);
|
||||
if (i == 1 || i == 6)
|
||||
o << to_hex(v);
|
||||
o << oxenc::to_hex(v);
|
||||
else
|
||||
o << v;
|
||||
}
|
||||
log_(LogLevel::trace, __FILE__, __LINE__, o.str());
|
||||
log(LogLevel::trace, __FILE__, __LINE__, o.str());
|
||||
} else
|
||||
#endif
|
||||
LMQ_LOG(debug, "Processing ZAP authentication request");
|
||||
OMQ_LOG(debug, "Processing ZAP authentication request");
|
||||
|
||||
// https://rfc.zeromq.org/spec:27/ZAP/
|
||||
//
|
||||
|
@ -240,55 +240,60 @@ void LokiMQ::process_zap_requests() {
|
|||
std::string &status_code = response_vals[2], &status_text = response_vals[3];
|
||||
|
||||
if (frames.size() < 6 || view(frames[0]) != "1.0") {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: version != 1.0 or invalid ZAP message parts");
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: version != 1.0 or invalid ZAP message parts");
|
||||
status_code = "500";
|
||||
status_text = "Internal error: invalid auth request";
|
||||
} else {
|
||||
auto auth_domain = view(frames[2]);
|
||||
size_t bind_id = (size_t) -1;
|
||||
try {
|
||||
bind_id = bt_deserialize<size_t>(view(frames[2]));
|
||||
bind_id = oxenc::bt_deserialize<size_t>(view(frames[2]));
|
||||
} catch (...) {}
|
||||
|
||||
if (bind_id >= bind.size()) {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
|
||||
status_code = "400";
|
||||
status_text = "Unknown authentication domain: " + std::string{auth_domain};
|
||||
} else if (bind[bind_id].second.curve
|
||||
} else if (bind[bind_id].curve
|
||||
? !(frames.size() == 7 && view(frames[5]) == "CURVE")
|
||||
: !(frames.size() == 6 && view(frames[5]) == "NULL")) {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: invalid ",
|
||||
bind[bind_id].second.curve ? "CURVE" : "NULL", " authentication request");
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: invalid ",
|
||||
bind[bind_id].curve ? "CURVE" : "NULL", " authentication request");
|
||||
status_code = "500";
|
||||
status_text = "Invalid authentication request mechanism";
|
||||
} else if (bind[bind_id].second.curve && frames[6].size() != 32) {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
|
||||
} else if (bind[bind_id].curve && frames[6].size() != 32) {
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
|
||||
status_code = "500";
|
||||
status_text = "Invalid public key size for CURVE authentication";
|
||||
} else {
|
||||
auto ip = view(frames[3]);
|
||||
string_view pubkey;
|
||||
// If we're in dual stack mode IPv4 address might be IPv4-mapped IPv6 address (e.g.
|
||||
// ::ffff:192.168.0.1); if so, remove the prefix to get a proper IPv4 address:
|
||||
if (ip.size() >= 14 && ip.substr(0, 7) == "::ffff:"sv && ip.find_last_not_of("0123456789."sv) == 6)
|
||||
ip = ip.substr(7);
|
||||
|
||||
std::string_view pubkey;
|
||||
bool sn = false;
|
||||
if (bind[bind_id].second.curve) {
|
||||
if (bind[bind_id].curve) {
|
||||
pubkey = view(frames[6]);
|
||||
sn = active_service_nodes.count(std::string{pubkey});
|
||||
}
|
||||
auto auth = bind[bind_id].second.allow(ip, pubkey, sn);
|
||||
auto auth = bind[bind_id].allow(ip, pubkey, sn);
|
||||
auto& user_id = response_vals[4];
|
||||
if (bind[bind_id].second.curve) {
|
||||
if (bind[bind_id].curve) {
|
||||
user_id.reserve(64);
|
||||
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
|
||||
oxenc::to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
|
||||
}
|
||||
|
||||
if (auth <= AuthLevel::denied || auth > AuthLevel::admin) {
|
||||
LMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
OMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
|
||||
" with initial auth level ", auth);
|
||||
status_code = "400";
|
||||
status_text = "Access denied";
|
||||
user_id.clear();
|
||||
} else {
|
||||
LMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
OMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection with authentication level ", auth,
|
||||
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
|
||||
|
||||
|
@ -301,7 +306,7 @@ void LokiMQ::process_zap_requests() {
|
|||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("ZAP request result: ", status_code, " ", status_text);
|
||||
OMQ_TRACE("ZAP request result: ", status_code, " ", status_text);
|
||||
|
||||
std::vector<zmq::message_t> response;
|
||||
response.reserve(response_vals.size());
|
|
@ -1,10 +1,10 @@
|
|||
#pragma once
|
||||
#include <iostream>
|
||||
#include <iosfwd>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
/// Authentication levels for command categories and connections
|
||||
enum class AuthLevel {
|
||||
|
@ -19,16 +19,16 @@ std::ostream& operator<<(std::ostream& os, AuthLevel a);
|
|||
/// The access level for a command category
|
||||
struct Access {
|
||||
/// Minimum access level required
|
||||
AuthLevel auth = AuthLevel::none;
|
||||
AuthLevel auth;
|
||||
/// If true only remote SNs may call the category commands
|
||||
bool remote_sn = false;
|
||||
bool remote_sn;
|
||||
/// If true the category requires that the local node is a SN
|
||||
bool local_sn = false;
|
||||
bool local_sn;
|
||||
|
||||
/// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an
|
||||
/// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both
|
||||
/// remote and local sn set to false).
|
||||
Access(AuthLevel auth, bool remote_sn = false, bool local_sn = false)
|
||||
Access(AuthLevel auth = AuthLevel::none, bool remote_sn = false, bool local_sn = false)
|
||||
: auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {}
|
||||
};
|
||||
|
||||
|
@ -45,11 +45,29 @@ struct already_hashed {
|
|||
};
|
||||
|
||||
/// std::unordered_set specialization for specifying pubkeys (used, in particular, by
|
||||
/// LokiMQ::set_active_sns and LokiMQ::update_active_sns); this is a std::string unordered_set that
|
||||
/// OxenMQ::set_active_sns and OxenMQ::update_active_sns); this is a std::string unordered_set that
|
||||
/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the
|
||||
/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like
|
||||
/// pubkeys and a terrible hash choice for anything else).
|
||||
using pubkey_set = std::unordered_set<std::string, already_hashed>;
|
||||
|
||||
inline constexpr std::string_view to_string(AuthLevel a) {
|
||||
switch (a) {
|
||||
case AuthLevel::denied: return "denied";
|
||||
case AuthLevel::none: return "none";
|
||||
case AuthLevel::basic: return "basic";
|
||||
case AuthLevel::admin: return "admin";
|
||||
default: return "(unknown)";
|
||||
}
|
||||
}
|
||||
|
||||
inline AuthLevel auth_from_string(std::string_view a) {
|
||||
if (a == "none") return AuthLevel::none;
|
||||
if (a == "basic") return AuthLevel::basic;
|
||||
if (a == "admin") return AuthLevel::admin;
|
||||
return AuthLevel::denied;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2020, The Loki Project
|
||||
// Copyright (c) 2020-2021, The Oxen Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
@ -30,24 +30,31 @@
|
|||
#include <exception>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "lokimq.h"
|
||||
#include "oxenmq.h"
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
namespace detail {
|
||||
|
||||
enum class BatchStatus {
|
||||
enum class BatchState {
|
||||
running, // there are still jobs to run (or running)
|
||||
complete, // the batch is complete but still has a completion job to call
|
||||
complete_proxy, // same as `complete`, but the completion job should be invoked immediately in the proxy thread (be very careful)
|
||||
done // the batch is complete and has no completion function
|
||||
};
|
||||
|
||||
struct BatchStatus {
|
||||
BatchState state;
|
||||
int thread;
|
||||
};
|
||||
|
||||
// Virtual base class for Batch<R>
|
||||
class Batch {
|
||||
public:
|
||||
// Returns the number of jobs in this batch
|
||||
virtual size_t size() const = 0;
|
||||
// Returns the number of jobs in this batch and whether any of them are thread-specific
|
||||
virtual std::pair<size_t, bool> size() const = 0;
|
||||
// Returns a vector of exactly the same length of size().first containing the tagged thread ids
|
||||
// of the batch jobs or 0 for general jobs.
|
||||
virtual std::vector<int> threads() const = 0;
|
||||
// Called in a worker thread to run the job
|
||||
virtual void run_job(int i) = 0;
|
||||
// Called in the main proxy thread when the worker returns from finishing a job. The return
|
||||
|
@ -71,7 +78,7 @@ public:
|
|||
* This is designed to be like a very stripped down version of a std::promise/std::future pair. We
|
||||
* reimplemented it, however, because by ditching all the thread synchronization that promise/future
|
||||
* guarantees we can substantially reduce call overhead (by a factor of ~8 according to benchmarking
|
||||
* code). Since LokiMQ's proxy<->worker communication channel already gives us thread that overhead
|
||||
* code). Since OxenMQ's proxy<->worker communication channel already gives us thread that overhead
|
||||
* would just be wasted.
|
||||
*
|
||||
* @tparam R the value type held by the result; must be default constructible. Note, however, that
|
||||
|
@ -128,13 +135,13 @@ public:
|
|||
void get() { if (exc) std::rethrow_exception(exc); }
|
||||
};
|
||||
|
||||
/// Helper class used to set up batches of jobs to be scheduled via the lokimq job handler.
|
||||
/// Helper class used to set up batches of jobs to be scheduled via the oxenmq job handler.
|
||||
///
|
||||
/// @tparam R - the return type of the individual jobs
|
||||
///
|
||||
template <typename R>
|
||||
class Batch final : private detail::Batch {
|
||||
friend class LokiMQ;
|
||||
friend class OxenMQ;
|
||||
public:
|
||||
/// The completion function type, called after all jobs have finished.
|
||||
using CompletionFunc = std::function<void(std::vector<job_result<R>> results)>;
|
||||
|
@ -151,16 +158,17 @@ public:
|
|||
Batch &operator=(const Batch&) = delete;
|
||||
|
||||
private:
|
||||
std::vector<std::function<R()>> jobs;
|
||||
std::vector<std::pair<std::function<R()>, int>> jobs;
|
||||
std::vector<job_result<R>> results;
|
||||
CompletionFunc complete;
|
||||
std::size_t jobs_outstanding = 0;
|
||||
bool complete_in_proxy = false;
|
||||
int complete_in_thread = 0;
|
||||
bool started = false;
|
||||
bool tagged_thread_jobs = false;
|
||||
|
||||
void check_not_started() {
|
||||
if (started)
|
||||
throw std::logic_error("Cannot add jobs or completion function after starting a lokimq::Batch!");
|
||||
throw std::logic_error("Cannot add jobs or completion function after starting a oxenmq::Batch!");
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -175,39 +183,61 @@ public:
|
|||
/// available. The called function may throw exceptions (which will be propagated to the
|
||||
/// completion function through the job_result values). There is no guarantee on the order of
|
||||
/// invocation of the jobs.
|
||||
void add_job(std::function<R()> job) {
|
||||
///
|
||||
/// \param job the callback
|
||||
/// \param thread an optional TaggedThreadID indicating a thread in which this job must run
|
||||
void add_job(std::function<R()> job, std::optional<TaggedThreadID> thread = std::nullopt) {
|
||||
check_not_started();
|
||||
jobs.emplace_back(std::move(job));
|
||||
results.emplace_back();
|
||||
jobs_outstanding++;
|
||||
if (thread && thread->_id == -1)
|
||||
// There are some special case internal jobs where we allow this, but they use the
|
||||
// private method below that doesn't have this check.
|
||||
throw std::logic_error{"Cannot add a proxy thread batch job -- this makes no sense"};
|
||||
add_job(std::move(job), thread ? thread->_id : 0);
|
||||
}
|
||||
|
||||
/// Sets the completion function to invoke after all jobs have finished. If this is not set
|
||||
/// then jobs simply run and results are discarded.
|
||||
void completion(CompletionFunc comp) {
|
||||
///
|
||||
/// \param comp - function to call when all jobs have finished
|
||||
/// \param thread - optional tagged thread in which to schedule the completion job. If not
|
||||
/// provided then the completion job is scheduled in the pool of batch job threads.
|
||||
///
|
||||
/// `thread` can be provided the value &OxenMQ::run_in_proxy to invoke the completion function
|
||||
/// *IN THE PROXY THREAD* itself after all jobs have finished. Be very, very careful: this
|
||||
/// should be a nearly trivial job that does not require any substantial CPU time and does not
|
||||
/// block for any reason. This is only intended for the case where the completion job is so
|
||||
/// trivial that it will take less time than simply queuing the job to be executed by another
|
||||
/// thread.
|
||||
void completion(CompletionFunc comp, std::optional<TaggedThreadID> thread = std::nullopt) {
|
||||
check_not_started();
|
||||
if (complete)
|
||||
throw std::logic_error("Completion function can only be set once");
|
||||
complete = std::move(comp);
|
||||
}
|
||||
|
||||
/// Sets a completion function to invoke *IN THE PROXY THREAD* after all jobs have finished. Be
|
||||
/// very, very careful: this should not be a job that takes any significant amount of CPU time
|
||||
/// or can block for any reason (NO MUTEXES).
|
||||
void completion_proxy(CompletionFunc comp) {
|
||||
check_not_started();
|
||||
if (complete)
|
||||
throw std::logic_error("Completion function can only be set once");
|
||||
complete = std::move(comp);
|
||||
complete_in_proxy = true;
|
||||
complete_in_thread = thread ? thread->_id : 0;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
std::size_t size() const override {
|
||||
return jobs.size();
|
||||
void add_job(std::function<R()> job, int thread_id) {
|
||||
jobs.emplace_back(std::move(job), thread_id);
|
||||
results.emplace_back();
|
||||
jobs_outstanding++;
|
||||
if (thread_id != 0)
|
||||
tagged_thread_jobs = true;
|
||||
}
|
||||
|
||||
std::pair<std::size_t, bool> size() const override {
|
||||
return {jobs.size(), tagged_thread_jobs};
|
||||
}
|
||||
|
||||
std::vector<int> threads() const override {
|
||||
std::vector<int> t;
|
||||
t.reserve(jobs.size());
|
||||
for (auto& j : jobs)
|
||||
t.push_back(j.second);
|
||||
return t;
|
||||
};
|
||||
|
||||
template <typename S = R>
|
||||
void set_value(job_result<S>& r, std::function<S()>& f) { r.set_value(f()); }
|
||||
void set_value(job_result<void>&, std::function<void()>& f) { f(); }
|
||||
|
@ -216,7 +246,7 @@ private:
|
|||
// called by worker thread
|
||||
auto& r = results[i];
|
||||
try {
|
||||
set_value(r, jobs[i]);
|
||||
set_value(r, jobs[i].first);
|
||||
} catch (...) {
|
||||
r.set_exception(std::current_exception());
|
||||
}
|
||||
|
@ -225,12 +255,10 @@ private:
|
|||
detail::BatchStatus job_finished() override {
|
||||
--jobs_outstanding;
|
||||
if (jobs_outstanding)
|
||||
return detail::BatchStatus::running;
|
||||
return {detail::BatchState::running, 0};
|
||||
if (complete)
|
||||
return complete_in_proxy
|
||||
? detail::BatchStatus::complete_proxy
|
||||
: detail::BatchStatus::complete;
|
||||
return detail::BatchStatus::done;
|
||||
return {detail::BatchState::complete, complete_in_thread};
|
||||
return {detail::BatchState::done, 0};
|
||||
}
|
||||
|
||||
void job_completion() override {
|
||||
|
@ -238,14 +266,60 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
// Similar to Batch<void>, but doesn't support a completion function and only handles a single task.
|
||||
class Job final : private detail::Batch {
|
||||
friend class OxenMQ;
|
||||
public:
|
||||
/// Constructs the Job to run a single task. Takes any callable invokable with no arguments and
|
||||
/// having no return value. The task will be scheduled and run when the next worker thread is
|
||||
/// available. Any exceptions thrown by the job will be caught and squelched (the exception
|
||||
/// terminates/completes the job).
|
||||
|
||||
explicit Job(std::function<void()> f, std::optional<TaggedThreadID> thread = std::nullopt)
|
||||
: Job{std::move(f), thread ? thread->_id : 0}
|
||||
{
|
||||
if (thread && thread->_id == -1)
|
||||
// There are some special case internal jobs where we allow this, but they use the
|
||||
// private ctor below that doesn't have this check.
|
||||
throw std::logic_error{"Cannot add a proxy thread job -- this makes no sense"};
|
||||
}
|
||||
|
||||
// movable
|
||||
Job(Job&&) = default;
|
||||
Job &operator=(Job&&) = default;
|
||||
|
||||
// non-copyable
|
||||
Job(const Job&) = delete;
|
||||
Job &operator=(const Job&) = delete;
|
||||
|
||||
private:
|
||||
explicit Job(std::function<void()> f, int thread_id)
|
||||
: job{std::move(f), thread_id} {}
|
||||
|
||||
std::pair<std::function<void()>, int> job;
|
||||
bool done = false;
|
||||
|
||||
std::pair<size_t, bool> size() const override { return {1, job.second != 0}; }
|
||||
std::vector<int> threads() const override { return {job.second}; }
|
||||
|
||||
void run_job(const int /*i*/) override {
|
||||
try { job.first(); }
|
||||
catch (...) {}
|
||||
}
|
||||
|
||||
detail::BatchStatus job_finished() override { return {detail::BatchState::done, 0}; }
|
||||
|
||||
void job_completion() override {} // Never called because we return ::done (not ::complete) above.
|
||||
|
||||
};
|
||||
|
||||
template <typename R>
|
||||
void LokiMQ::batch(Batch<R>&& batch) {
|
||||
if (batch.size() == 0)
|
||||
void OxenMQ::batch(Batch<R>&& batch) {
|
||||
if (batch.size().first == 0)
|
||||
throw std::logic_error("Cannot batch a a job batch with 0 jobs");
|
||||
// Need to send this over to the proxy thread via the base class pointer. It assumes ownership.
|
||||
auto* baseptr = static_cast<detail::Batch*>(new Batch<R>(std::move(batch)));
|
||||
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
detail::send_control(get_control_socket(), "BATCH", oxenc::bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
}
|
||||
|
||||
}
|
|
@ -1,16 +1,49 @@
|
|||
#include "lokimq.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include "hex.h"
|
||||
#include "oxenmq.h"
|
||||
#include "oxenmq-internal.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include <optional>
|
||||
|
||||
namespace lokimq {
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
extern "C" {
|
||||
#include <sys/epoll.h>
|
||||
#include <unistd.h>
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, const ConnectionID& conn) {
|
||||
if (!conn.pk.empty())
|
||||
return o << (conn.sn() ? "SN " : "non-SN authenticated remote ") << to_hex(conn.pk);
|
||||
else
|
||||
return o << "unauthenticated remote [" << conn.id << "]";
|
||||
return o << conn.to_string();
|
||||
}
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
|
||||
void OxenMQ::rebuild_pollitems() {
|
||||
|
||||
if (epoll_fd != -1)
|
||||
close(epoll_fd);
|
||||
epoll_fd = epoll_create1(0);
|
||||
|
||||
struct epoll_event ev;
|
||||
ev.events = EPOLLIN | EPOLLET;
|
||||
ev.data.u64 = EPOLL_COMMAND_ID;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, command.get(zmq::sockopt::fd), &ev);
|
||||
|
||||
ev.data.u64 = EPOLL_WORKER_ID;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, workers_socket.get(zmq::sockopt::fd), &ev);
|
||||
|
||||
ev.data.u64 = EPOLL_ZAP_ID;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, zap_auth.get(zmq::sockopt::fd), &ev);
|
||||
|
||||
for (auto& [id, s] : connections) {
|
||||
ev.data.u64 = id;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, s.get(zmq::sockopt::fd), &ev);
|
||||
}
|
||||
connections_updated = false;
|
||||
}
|
||||
|
||||
#else // !OXENMQ_USE_EPOLL
|
||||
|
||||
namespace {
|
||||
|
||||
void add_pollitem(std::vector<zmq::pollitem_t>& pollitems, zmq::socket_t& sock) {
|
||||
|
@ -24,84 +57,85 @@ void add_pollitem(std::vector<zmq::pollitem_t>& pollitems, zmq::socket_t& sock)
|
|||
} // anonymous namespace
|
||||
|
||||
|
||||
void LokiMQ::rebuild_pollitems() {
|
||||
void OxenMQ::rebuild_pollitems() {
|
||||
pollitems.clear();
|
||||
add_pollitem(pollitems, command);
|
||||
add_pollitem(pollitems, workers_socket);
|
||||
add_pollitem(pollitems, zap_auth);
|
||||
|
||||
for (auto& s : connections)
|
||||
for (auto& [id, s] : connections)
|
||||
add_pollitem(pollitems, s);
|
||||
pollitems_stale = false;
|
||||
connections_updated = false;
|
||||
}
|
||||
|
||||
void LokiMQ::setup_external_socket(zmq::socket_t& socket) {
|
||||
socket.setsockopt(ZMQ_RECONNECT_IVL, (int) RECONNECT_INTERVAL.count());
|
||||
socket.setsockopt(ZMQ_RECONNECT_IVL_MAX, (int) RECONNECT_INTERVAL_MAX.count());
|
||||
socket.setsockopt(ZMQ_HANDSHAKE_IVL, (int) HANDSHAKE_TIME.count());
|
||||
socket.setsockopt<int64_t>(ZMQ_MAXMSGSIZE, MAX_MSG_SIZE);
|
||||
#endif // OXENMQ_USE_EPOLL
|
||||
|
||||
void OxenMQ::setup_external_socket(zmq::socket_t& socket) {
|
||||
socket.set(zmq::sockopt::reconnect_ivl, (int) RECONNECT_INTERVAL.count());
|
||||
socket.set(zmq::sockopt::reconnect_ivl_max, (int) RECONNECT_INTERVAL_MAX.count());
|
||||
socket.set(zmq::sockopt::handshake_ivl, (int) HANDSHAKE_TIME.count());
|
||||
socket.set(zmq::sockopt::maxmsgsize, MAX_MSG_SIZE);
|
||||
if (IPV6)
|
||||
socket.set(zmq::sockopt::ipv6, 1);
|
||||
|
||||
if (CONN_HEARTBEAT > 0s) {
|
||||
socket.setsockopt(ZMQ_HEARTBEAT_IVL, (int) CONN_HEARTBEAT.count());
|
||||
socket.set(zmq::sockopt::heartbeat_ivl, (int) CONN_HEARTBEAT.count());
|
||||
if (CONN_HEARTBEAT_TIMEOUT > 0s)
|
||||
socket.setsockopt(ZMQ_HEARTBEAT_TIMEOUT, (int) CONN_HEARTBEAT_TIMEOUT.count());
|
||||
socket.set(zmq::sockopt::heartbeat_timeout, (int) CONN_HEARTBEAT_TIMEOUT.count());
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey) {
|
||||
void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey, bool use_ephemeral_routing_id) {
|
||||
|
||||
setup_external_socket(socket);
|
||||
|
||||
if (!remote_pubkey.empty()) {
|
||||
socket.setsockopt(ZMQ_CURVE_SERVERKEY, remote_pubkey.data(), remote_pubkey.size());
|
||||
socket.setsockopt(ZMQ_CURVE_PUBLICKEY, pubkey.data(), pubkey.size());
|
||||
socket.setsockopt(ZMQ_CURVE_SECRETKEY, privkey.data(), privkey.size());
|
||||
socket.set(zmq::sockopt::curve_serverkey, remote_pubkey);
|
||||
socket.set(zmq::sockopt::curve_publickey, pubkey);
|
||||
socket.set(zmq::sockopt::curve_secretkey, privkey);
|
||||
}
|
||||
|
||||
if (PUBKEY_BASED_ROUTING_ID) {
|
||||
if (!use_ephemeral_routing_id) {
|
||||
std::string routing_id;
|
||||
routing_id.reserve(33);
|
||||
routing_id += 'L'; // Prefix because routing id's starting with \0 are reserved by zmq (and our pubkey might start with \0)
|
||||
routing_id.append(pubkey.begin(), pubkey.end());
|
||||
socket.setsockopt(ZMQ_ROUTING_ID, routing_id.data(), routing_id.size());
|
||||
socket.set(zmq::sockopt::routing_id, routing_id);
|
||||
}
|
||||
// else let ZMQ pick a random one
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call connect_sn() before calling `start()`");
|
||||
|
||||
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
|
||||
void OxenMQ::setup_incoming_socket(zmq::socket_t& listener, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index) {
|
||||
|
||||
return pubkey;
|
||||
setup_external_socket(listener);
|
||||
|
||||
listener.set(zmq::sockopt::zap_domain, oxenc::bt_serialize(bind_index));
|
||||
if (curve) {
|
||||
listener.set(zmq::sockopt::curve_server, true);
|
||||
listener.set(zmq::sockopt::curve_publickey, pubkey);
|
||||
listener.set(zmq::sockopt::curve_secretkey, privkey);
|
||||
}
|
||||
listener.set(zmq::sockopt::router_handover, true);
|
||||
listener.set(zmq::sockopt::router_mandatory, true);
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call connect_remote() before calling `start()`");
|
||||
|
||||
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
|
||||
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
|
||||
|
||||
auto id = next_conn_id++;
|
||||
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
|
||||
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
|
||||
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
|
||||
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
|
||||
{"conn_id", id},
|
||||
{"connect", detail::serialize_object(std::move(on_connect))},
|
||||
{"failure", detail::serialize_object(std::move(on_failure))},
|
||||
{"pubkey", pubkey},
|
||||
{"remote", remote},
|
||||
{"timeout", timeout.count()},
|
||||
}));
|
||||
|
||||
return id;
|
||||
// Deprecated versions:
|
||||
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
|
||||
ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
return connect_remote(address{remote}, std::move(on_connect), std::move(on_failure),
|
||||
auth_level, connect_option::timeout{timeout});
|
||||
}
|
||||
|
||||
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
||||
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
|
||||
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
|
||||
ConnectFailure on_failure, std::string_view pubkey, AuthLevel auth_level,
|
||||
std::chrono::milliseconds timeout) {
|
||||
return connect_remote(address{remote}.set_pubkey(pubkey), std::move(on_connect),
|
||||
std::move(on_failure), auth_level, connect_option::timeout{timeout});
|
||||
}
|
||||
|
||||
void OxenMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
||||
detail::send_control(get_control_socket(), "DISCONNECT", oxenc::bt_serialize<oxenc::bt_dict>({
|
||||
{"conn_id", id.id},
|
||||
{"linger_ms", linger.count()},
|
||||
{"pubkey", id.pk},
|
||||
|
@ -109,7 +143,7 @@ void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
|||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string>
|
||||
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
|
||||
OxenMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, bool use_ephemeral_routing_id, std::chrono::milliseconds keep_alive) {
|
||||
ConnectionID remote_cid{remote};
|
||||
auto its = peers.equal_range(remote_cid);
|
||||
peer_info* peer = nullptr;
|
||||
|
@ -123,23 +157,23 @@ LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool opti
|
|||
}
|
||||
|
||||
if (peer) {
|
||||
LMQ_TRACE("proxy asked to connect to ", to_hex(remote), "; reusing existing connection");
|
||||
OMQ_TRACE("proxy asked to connect to ", oxenc::to_hex(remote), "; reusing existing connection");
|
||||
if (peer->route.empty() /* == outgoing*/) {
|
||||
if (peer->idle_expiry < keep_alive) {
|
||||
LMQ_LOG(debug, "updating existing outgoing peer connection idle expiry time from ",
|
||||
OMQ_LOG(debug, "updating existing outgoing peer connection idle expiry time from ",
|
||||
peer->idle_expiry.count(), "ms to ", keep_alive.count(), "ms");
|
||||
peer->idle_expiry = keep_alive;
|
||||
}
|
||||
peer->activity();
|
||||
}
|
||||
return {&connections[peer->conn_index], peer->route};
|
||||
return {&connections[peer->conn_id], peer->route};
|
||||
} else if (optional || incoming_only) {
|
||||
LMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt");
|
||||
OMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt");
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
|
||||
// No connection so establish a new one
|
||||
LMQ_LOG(debug, "proxy establishing new outbound connection to ", to_hex(remote));
|
||||
OMQ_LOG(debug, "proxy establishing new outbound connection to ", oxenc::to_hex(remote));
|
||||
std::string addr;
|
||||
bool to_self = false && remote == pubkey; // FIXME; need to use a separate listening socket for this, otherwise we can't easily
|
||||
// tell it wasn't from a remote.
|
||||
|
@ -151,45 +185,48 @@ LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool opti
|
|||
if (addr.empty())
|
||||
addr = sn_lookup(remote);
|
||||
else
|
||||
LMQ_LOG(debug, "using connection hint ", connect_hint);
|
||||
OMQ_LOG(debug, "using connection hint ", connect_hint);
|
||||
|
||||
if (addr.empty()) {
|
||||
LMQ_LOG(error, "peer lookup failed for ", to_hex(remote));
|
||||
OMQ_LOG(error, "peer lookup failed for ", oxenc::to_hex(remote));
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_LOG(debug, to_hex(pubkey), " (me) connecting to ", addr, " to reach ", to_hex(remote));
|
||||
zmq::socket_t socket{context, zmq::socket_type::dealer};
|
||||
setup_outgoing_socket(socket, remote);
|
||||
OMQ_LOG(debug, oxenc::to_hex(pubkey), " (me) connecting to ", addr, " to reach ", oxenc::to_hex(remote));
|
||||
std::optional<zmq::socket_t> socket;
|
||||
try {
|
||||
socket.connect(addr);
|
||||
socket.emplace(context, zmq::socket_type::dealer);
|
||||
setup_outgoing_socket(*socket, remote, use_ephemeral_routing_id);
|
||||
socket->connect(addr);
|
||||
} catch (const zmq::error_t& e) {
|
||||
// Note that this failure cases indicates something serious went wrong that means zmq isn't
|
||||
// even going to try connecting (for example an unparseable remote address).
|
||||
LMQ_LOG(error, "Outgoing connection to ", addr, " failed: ", e.what());
|
||||
OMQ_LOG(error, "Outgoing connection to ", addr, " failed: ", e.what());
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
peer_info p{};
|
||||
|
||||
auto& p = peers.emplace(std::move(remote_cid), peer_info{})->second;
|
||||
p.service_node = true;
|
||||
p.pubkey = std::string{remote};
|
||||
p.conn_index = connections.size();
|
||||
p.conn_id = next_conn_id++;
|
||||
p.idle_expiry = keep_alive;
|
||||
p.activity();
|
||||
conn_index_to_id.push_back(remote_cid);
|
||||
peers.emplace(std::move(remote_cid), std::move(p));
|
||||
connections.push_back(std::move(socket));
|
||||
pollitems_stale = true;
|
||||
connections_updated = true;
|
||||
outgoing_sn_conns.emplace_hint(outgoing_sn_conns.end(), p.conn_id, ConnectionID{remote});
|
||||
auto it = connections.emplace_hint(connections.end(), p.conn_id, *std::move(socket));
|
||||
|
||||
return {&connections.back(), ""s};
|
||||
return {&it->second, ""s};
|
||||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string> LokiMQ::proxy_connect_sn(bt_dict_consumer data) {
|
||||
string_view hint, remote_pk;
|
||||
std::pair<zmq::socket_t *, std::string> OxenMQ::proxy_connect_sn(oxenc::bt_dict_consumer data) {
|
||||
std::string_view hint, remote_pk;
|
||||
std::chrono::milliseconds keep_alive;
|
||||
bool optional = false, incoming_only = false, outgoing_only = false;
|
||||
bool optional = false, incoming_only = false, outgoing_only = false, ephemeral_rid = EPHEMERAL_ROUTING_ID;
|
||||
|
||||
// Alphabetical order
|
||||
if (data.skip_until("ephemeral_rid"))
|
||||
ephemeral_rid = data.consume_integer<bool>();
|
||||
if (data.skip_until("hint"))
|
||||
hint = data.consume_string_view();
|
||||
if (data.skip_until("incoming"))
|
||||
|
@ -204,57 +241,39 @@ std::pair<zmq::socket_t *, std::string> LokiMQ::proxy_connect_sn(bt_dict_consume
|
|||
throw std::runtime_error("Internal error: Invalid proxy_connect_sn command; pubkey missing");
|
||||
remote_pk = data.consume_string_view();
|
||||
|
||||
return proxy_connect_sn(remote_pk, hint, optional, incoming_only, outgoing_only, keep_alive);
|
||||
}
|
||||
|
||||
template <typename Container, typename AccessIndex>
|
||||
void update_connection_indices(Container& c, size_t index, AccessIndex get_index) {
|
||||
for (auto it = c.begin(); it != c.end(); ) {
|
||||
size_t& i = get_index(*it);
|
||||
if (index == i) {
|
||||
it = c.erase(it);
|
||||
continue;
|
||||
}
|
||||
if (i > index)
|
||||
--i;
|
||||
++it;
|
||||
}
|
||||
return proxy_connect_sn(remote_pk, hint, optional, incoming_only, outgoing_only, ephemeral_rid, keep_alive);
|
||||
}
|
||||
|
||||
/// Closes outgoing connections and removes all references. Note that this will call `erase()`
|
||||
/// which can invalidate iterators on the various connection containers - if you don't want that,
|
||||
/// delete it first so that the container won't contain the element being deleted.
|
||||
void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) {
|
||||
connections[index].setsockopt<int>(ZMQ_LINGER, linger > 0ms ? linger.count() : 0);
|
||||
pollitems_stale = true;
|
||||
connections.erase(connections.begin() + index);
|
||||
void OxenMQ::proxy_close_connection(int64_t id, std::chrono::milliseconds linger) {
|
||||
auto it = connections.find(id);
|
||||
if (it == connections.end()) {
|
||||
OMQ_LOG(warn, "internal error: connection to close (", id, ") doesn't exist!");
|
||||
return;
|
||||
}
|
||||
OMQ_LOG(debug, "Closing conn ", id);
|
||||
it->second.set(zmq::sockopt::linger, linger > 0ms ? (int) linger.count() : 0);
|
||||
connections.erase(it);
|
||||
connections_updated = true;
|
||||
|
||||
LMQ_LOG(debug, "Closing conn index ", index);
|
||||
update_connection_indices(peers, index,
|
||||
[](auto& p) -> size_t& { return p.second.conn_index; });
|
||||
update_connection_indices(pending_connects, index,
|
||||
[](auto& pc) -> size_t& { return std::get<size_t>(pc); });
|
||||
update_connection_indices(bind, index,
|
||||
[](auto& b) -> size_t& { return b.second.index; });
|
||||
update_connection_indices(incoming_conn_index, index,
|
||||
[](auto& oci) -> size_t& { return oci.second; });
|
||||
assert(index < conn_index_to_id.size());
|
||||
conn_index_to_id.erase(conn_index_to_id.begin() + index);
|
||||
outgoing_sn_conns.erase(id);
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_expire_idle_peers() {
|
||||
void OxenMQ::proxy_expire_idle_peers() {
|
||||
for (auto it = peers.begin(); it != peers.end(); ) {
|
||||
auto &info = it->second;
|
||||
if (info.outgoing()) {
|
||||
auto idle = std::chrono::steady_clock::now() - info.last_activity;
|
||||
if (idle > info.idle_expiry) {
|
||||
LMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle time (",
|
||||
OMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle time (",
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(), "ms) reached connection timeout (",
|
||||
info.idle_expiry.count(), "ms)");
|
||||
++it; // The below is going to delete our current element
|
||||
proxy_close_connection(info.conn_index, CLOSE_LINGER);
|
||||
proxy_close_connection(info.conn_id, CLOSE_LINGER);
|
||||
it = peers.erase(it);
|
||||
} else {
|
||||
LMQ_LOG(trace, "Not closing ", it->first, ": ", std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(),
|
||||
OMQ_LOG(trace, "Not closing ", it->first, ": ", std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(),
|
||||
"ms <= ", info.idle_expiry.count(), "ms");
|
||||
++it;
|
||||
continue;
|
||||
|
@ -265,36 +284,37 @@ void LokiMQ::proxy_expire_idle_peers() {
|
|||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_conn_cleanup() {
|
||||
LMQ_TRACE("starting proxy connections cleanup");
|
||||
void OxenMQ::proxy_conn_cleanup() {
|
||||
OMQ_TRACE("starting proxy connections cleanup");
|
||||
|
||||
// Drop idle connections (if we haven't done it in a while)
|
||||
LMQ_TRACE("closing idle connections");
|
||||
OMQ_TRACE("closing idle connections");
|
||||
proxy_expire_idle_peers();
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
// FIXME - check other outgoing connections to see if they died and if so purge them
|
||||
|
||||
LMQ_TRACE("Timing out pending outgoing connections");
|
||||
OMQ_TRACE("Timing out pending outgoing connections");
|
||||
// Check any pending outgoing connections for timeout
|
||||
for (auto it = pending_connects.begin(); it != pending_connects.end(); ) {
|
||||
auto& pc = *it;
|
||||
if (std::get<std::chrono::steady_clock::time_point>(pc) < now) {
|
||||
job([cid = ConnectionID{std::get<long long>(pc)}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); });
|
||||
auto id = std::get<int64_t>(pc);
|
||||
job([cid = ConnectionID{id}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); });
|
||||
it = pending_connects.erase(it); // Don't let the below erase it (because it invalidates iterators)
|
||||
proxy_close_connection(std::get<size_t>(pc), CLOSE_LINGER);
|
||||
proxy_close_connection(id, CLOSE_LINGER);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("Timing out pending requests");
|
||||
OMQ_TRACE("Timing out pending requests");
|
||||
// Remove any expired pending requests and schedule their callback with a failure
|
||||
for (auto it = pending_requests.begin(); it != pending_requests.end(); ) {
|
||||
auto& callback = it->second;
|
||||
if (callback.first < now) {
|
||||
LMQ_LOG(debug, "pending request ", to_hex(it->first), " expired, invoking callback with failure status and removing");
|
||||
OMQ_LOG(debug, "pending request ", oxenc::to_hex(it->first), " expired, invoking callback with failure status and removing");
|
||||
job([callback = std::move(callback.second)] { callback(false, {{"TIMEOUT"s}}); });
|
||||
it = pending_requests.erase(it);
|
||||
} else {
|
||||
|
@ -302,10 +322,10 @@ void LokiMQ::proxy_conn_cleanup() {
|
|||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("done proxy connections cleanup");
|
||||
OMQ_TRACE("done proxy connections cleanup");
|
||||
};
|
||||
|
||||
void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
|
||||
void OxenMQ::proxy_connect_remote(oxenc::bt_dict_consumer data) {
|
||||
AuthLevel auth_level = AuthLevel::none;
|
||||
long long conn_id = -1;
|
||||
ConnectSuccess on_connect;
|
||||
|
@ -313,17 +333,18 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
|
|||
std::string remote;
|
||||
std::string remote_pubkey;
|
||||
std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT;
|
||||
bool ephemeral_rid = EPHEMERAL_ROUTING_ID;
|
||||
|
||||
if (data.skip_until("auth_level"))
|
||||
auth_level = static_cast<AuthLevel>(data.consume_integer<std::underlying_type_t<AuthLevel>>());
|
||||
if (data.skip_until("conn_id"))
|
||||
conn_id = data.consume_integer<long long>();
|
||||
if (data.skip_until("connect")) {
|
||||
if (data.skip_until("connect"))
|
||||
on_connect = detail::deserialize_object<ConnectSuccess>(data.consume_integer<uintptr_t>());
|
||||
}
|
||||
if (data.skip_until("failure")) {
|
||||
if (data.skip_until("ephemeral_rid"))
|
||||
ephemeral_rid = data.consume_integer<bool>();
|
||||
if (data.skip_until("failure"))
|
||||
on_failure = detail::deserialize_object<ConnectFailure>(data.consume_integer<uintptr_t>());
|
||||
}
|
||||
if (data.skip_until("pubkey")) {
|
||||
remote_pubkey = data.consume_string();
|
||||
assert(remote_pubkey.size() == 32 || remote_pubkey.empty());
|
||||
|
@ -336,14 +357,14 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
|
|||
if (conn_id == -1 || remote.empty())
|
||||
throw std::runtime_error("Internal error: CONNECT_REMOTE proxy command missing required 'conn_id' and/or 'remote' value");
|
||||
|
||||
LMQ_LOG(debug, "Establishing remote connection to ", remote, remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + to_hex(remote_pubkey));
|
||||
OMQ_LOG(debug, "Establishing remote connection to ", remote,
|
||||
remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + oxenc::to_hex(remote_pubkey));
|
||||
|
||||
assert(conn_index_to_id.size() == connections.size());
|
||||
|
||||
zmq::socket_t sock{context, zmq::socket_type::dealer};
|
||||
std::optional<zmq::socket_t> sock;
|
||||
try {
|
||||
setup_outgoing_socket(sock, remote_pubkey);
|
||||
sock.connect(remote);
|
||||
sock.emplace(context, zmq::socket_type::dealer);
|
||||
setup_outgoing_socket(*sock, remote_pubkey, ephemeral_rid);
|
||||
sock->connect(remote);
|
||||
} catch (const zmq::error_t &e) {
|
||||
proxy_schedule_reply_job([conn_id, on_failure=std::move(on_failure), what="connect() failed: "s+e.what()] {
|
||||
on_failure(conn_id, std::move(what));
|
||||
|
@ -351,26 +372,22 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
|
|||
return;
|
||||
}
|
||||
|
||||
connections.push_back(std::move(sock));
|
||||
pollitems_stale = true;
|
||||
LMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI");
|
||||
send_direct_message(connections.back(), "HI");
|
||||
pending_connects.emplace_back(connections.size()-1, conn_id, std::chrono::steady_clock::now() + timeout,
|
||||
auto &s = connections.emplace_hint(connections.end(), conn_id, std::move(*sock))->second;
|
||||
connections_updated = true;
|
||||
OMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI");
|
||||
send_direct_message(s, "HI");
|
||||
pending_connects.emplace_back(conn_id, std::chrono::steady_clock::now() + timeout,
|
||||
std::move(on_connect), std::move(on_failure));
|
||||
peer_info peer;
|
||||
auto& peer = peers.emplace(ConnectionID{conn_id, remote_pubkey}, peer_info{})->second;
|
||||
peer.pubkey = std::move(remote_pubkey);
|
||||
peer.service_node = false;
|
||||
peer.auth_level = auth_level;
|
||||
peer.conn_index = connections.size() - 1;
|
||||
ConnectionID conn{conn_id, peer.pubkey};
|
||||
conn_index_to_id.push_back(conn);
|
||||
assert(connections.size() == conn_index_to_id.size());
|
||||
peer.conn_id = conn_id;
|
||||
peer.idle_expiry = 24h * 10 * 365; // "forever"
|
||||
peer.activity();
|
||||
peers.emplace(std::move(conn), std::move(peer));
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_disconnect(bt_dict_consumer data) {
|
||||
void OxenMQ::proxy_disconnect(oxenc::bt_dict_consumer data) {
|
||||
ConnectionID connid{-1};
|
||||
std::chrono::milliseconds linger = 1s;
|
||||
|
||||
|
@ -386,21 +403,27 @@ void LokiMQ::proxy_disconnect(bt_dict_consumer data) {
|
|||
|
||||
proxy_disconnect(std::move(connid), linger);
|
||||
}
|
||||
void LokiMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) {
|
||||
LMQ_TRACE("Disconnecting outgoing connection to ", conn);
|
||||
void OxenMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) {
|
||||
OMQ_TRACE("Disconnecting outgoing connection to ", conn);
|
||||
auto pr = peers.equal_range(conn);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
auto& peer = it->second;
|
||||
if (peer.outgoing()) {
|
||||
LMQ_LOG(debug, "Closing outgoing connection to ", conn);
|
||||
proxy_close_connection(peer.conn_index, linger);
|
||||
OMQ_LOG(debug, "Closing outgoing connection to ", conn);
|
||||
proxy_close_connection(peer.conn_id, linger);
|
||||
peers.erase(it);
|
||||
return;
|
||||
}
|
||||
}
|
||||
LMQ_LOG(warn, "Failed to disconnect ", conn, ": no such outgoing connection");
|
||||
OMQ_LOG(warn, "Failed to disconnect ", conn, ": no such outgoing connection");
|
||||
}
|
||||
|
||||
std::string ConnectionID::to_string() const {
|
||||
if (!pk.empty())
|
||||
return (sn() ? std::string("SN ") : std::string("non-SN authenticated remote ")) + oxenc::to_hex(pk);
|
||||
else
|
||||
return std::string("unauthenticated remote [") + std::to_string(id) + "]";
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,15 +1,20 @@
|
|||
#pragma once
|
||||
#include "auth.h"
|
||||
#include "string_view.h"
|
||||
#include <oxenc/bt_value.h>
|
||||
#include <string_view>
|
||||
#include <iosfwd>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
class bt_dict;
|
||||
struct ConnectionID;
|
||||
|
||||
namespace detail {
|
||||
template <typename... T>
|
||||
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts);
|
||||
oxenc::bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts);
|
||||
}
|
||||
|
||||
/// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For
|
||||
|
@ -17,11 +22,17 @@ bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts);
|
|||
/// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy
|
||||
/// of the ConnectionID returned by connect_remote().
|
||||
struct ConnectionID {
|
||||
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
|
||||
// an actual connection.
|
||||
ConnectionID() : ConnectionID(0) {}
|
||||
// Construction from a service node pubkey
|
||||
ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} {
|
||||
if (pk.size() != 32)
|
||||
throw std::runtime_error{"Invalid pubkey: expected 32 bytes"};
|
||||
}
|
||||
ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
||||
// Construction from a service node pubkey
|
||||
ConnectionID(std::string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
||||
|
||||
ConnectionID(const ConnectionID&) = default;
|
||||
ConnectionID(ConnectionID&&) = default;
|
||||
ConnectionID& operator=(const ConnectionID&) = default;
|
||||
|
@ -33,52 +44,55 @@ struct ConnectionID {
|
|||
}
|
||||
|
||||
// Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both
|
||||
// not SNs and have matching internal IDs. (Pubkeys do not have to match for non-SNs, and
|
||||
// routes are not considered for equality at all).
|
||||
// not SNs and have matching internal IDs and routes. (Pubkeys do not have to match for
|
||||
// non-SNs).
|
||||
bool operator==(const ConnectionID &o) const {
|
||||
if (id == SN_ID && o.id == SN_ID)
|
||||
if (sn() && o.sn())
|
||||
return pk == o.pk;
|
||||
return id == o.id;
|
||||
return id == o.id && route == o.route;
|
||||
}
|
||||
bool operator!=(const ConnectionID &o) const { return !(*this == o); }
|
||||
bool operator<(const ConnectionID &o) const {
|
||||
if (id == SN_ID && o.id == SN_ID)
|
||||
if (sn() && o.sn())
|
||||
return pk < o.pk;
|
||||
return id < o.id;
|
||||
return id < o.id || (id == o.id && route < o.route);
|
||||
}
|
||||
|
||||
// Returns true if this ConnectionID represents a SN connection
|
||||
bool sn() const { return id == SN_ID; }
|
||||
|
||||
// Returns this connection's pubkey, if any. (Note that it is possible to have a pubkey and not
|
||||
// be a SN when connecting to secure remotes: having a non-empty pubkey does not imply that
|
||||
// `sn()` is true).
|
||||
// Returns this connection's pubkey, if any. (Note that all curve connections have pubkeys, not
|
||||
// only SNs).
|
||||
const std::string& pubkey() const { return pk; }
|
||||
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
|
||||
// an actual connection.
|
||||
ConnectionID() : ConnectionID(0) {}
|
||||
|
||||
// Returns a copy of the ConnectionID with the route set to empty.
|
||||
ConnectionID unrouted() { return ConnectionID{id, pk, ""}; }
|
||||
|
||||
std::string to_string() const;
|
||||
|
||||
|
||||
private:
|
||||
ConnectionID(long long id) : id{id} {}
|
||||
ConnectionID(long long id, std::string pubkey, std::string route = "")
|
||||
ConnectionID(int64_t id) : id{id} {}
|
||||
ConnectionID(int64_t id, std::string pubkey, std::string route = "")
|
||||
: id{id}, pk{std::move(pubkey)}, route{std::move(route)} {}
|
||||
|
||||
constexpr static long long SN_ID = -1;
|
||||
long long id = 0;
|
||||
constexpr static int64_t SN_ID = -1;
|
||||
int64_t id = 0;
|
||||
std::string pk;
|
||||
std::string route;
|
||||
friend class LokiMQ;
|
||||
friend class OxenMQ;
|
||||
friend struct std::hash<ConnectionID>;
|
||||
template <typename... T>
|
||||
friend bt_dict detail::build_send(ConnectionID to, string_view cmd, T&&... opts);
|
||||
friend oxenc::bt_dict detail::build_send(ConnectionID to, std::string_view cmd, T&&... opts);
|
||||
friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn);
|
||||
};
|
||||
|
||||
} // namespace lokimq
|
||||
} // namespace oxenmq
|
||||
namespace std {
|
||||
template <> struct hash<lokimq::ConnectionID> {
|
||||
size_t operator()(const lokimq::ConnectionID &c) const {
|
||||
return c.sn() ? lokimq::already_hashed{}(c.pk) :
|
||||
std::hash<long long>{}(c.id);
|
||||
template <> struct hash<oxenmq::ConnectionID> {
|
||||
size_t operator()(const oxenmq::ConnectionID &c) const {
|
||||
return c.sn() ? oxenmq::already_hashed{}(c.pk) :
|
||||
std::hash<int64_t>{}(c.id) + std::hash<std::string>{}(c.route);
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
#pragma once
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include "connections.h"
|
||||
#include "auth.h"
|
||||
#include "address.h"
|
||||
|
||||
template <>
|
||||
struct fmt::formatter<oxenmq::AuthLevel> : fmt::formatter<std::string> {
|
||||
auto format(oxenmq::AuthLevel v, format_context& ctx) {
|
||||
return formatter<std::string>::format(
|
||||
fmt::format("{}", to_string(v)), ctx);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct fmt::formatter<oxenmq::ConnectionID> : fmt::formatter<std::string> {
|
||||
auto format(oxenmq::ConnectionID conn, format_context& ctx) {
|
||||
return formatter<std::string>::format(
|
||||
fmt::format("{}", conn.to_string()), ctx);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct fmt::formatter<oxenmq::address> : fmt::formatter<std::string> {
|
||||
auto format(oxenmq::address addr, format_context& ctx) {
|
||||
return formatter<std::string>::format(
|
||||
fmt::format("{}", addr.full_address()), ctx);
|
||||
}
|
||||
};
|
|
@ -0,0 +1,182 @@
|
|||
#include "oxenmq.h"
|
||||
#include "batch.h"
|
||||
#include "oxenmq-internal.h"
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
void OxenMQ::proxy_batch(detail::Batch* batch) {
|
||||
const auto [jobs, tagged_threads] = batch->size();
|
||||
OMQ_TRACE("proxy queuing batch job with ", jobs, " jobs", tagged_threads ? " (job uses tagged thread(s))" : "");
|
||||
if (!tagged_threads) {
|
||||
for (size_t i = 0; i < jobs; i++)
|
||||
batch_jobs.emplace_back(batch, i);
|
||||
} else {
|
||||
// Some (or all) jobs have a specific thread target so queue any such jobs in the tagged
|
||||
// worker queue.
|
||||
auto threads = batch->threads();
|
||||
for (size_t i = 0; i < jobs; i++) {
|
||||
auto& jobs = threads[i] > 0
|
||||
? std::get<batch_queue>(tagged_workers[threads[i] - 1])
|
||||
: batch_jobs;
|
||||
jobs.emplace_back(batch, i);
|
||||
}
|
||||
}
|
||||
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void OxenMQ::job(std::function<void()> f, std::optional<TaggedThreadID> thread) {
|
||||
if (thread && thread->_id == -1)
|
||||
throw std::logic_error{"job() cannot be used to queue an in-proxy job"};
|
||||
auto* j = new Job(std::move(f), thread);
|
||||
auto* baseptr = static_cast<detail::Batch*>(j);
|
||||
detail::send_control(get_control_socket(), "BATCH", oxenc::bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_schedule_reply_job(std::function<void()> f) {
|
||||
auto* j = new Job(std::move(f));
|
||||
reply_jobs.emplace_back(static_cast<detail::Batch*>(j), 0);
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_run_batch_jobs(batch_queue& jobs, const int reserved, int& active, bool reply) {
|
||||
while (!jobs.empty() && active_workers() < max_workers &&
|
||||
(active < reserved || active_workers() < general_workers)) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(jobs.front()), reply));
|
||||
jobs.pop_front();
|
||||
active++;
|
||||
}
|
||||
}
|
||||
|
||||
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
|
||||
// the timer. If the timer object hasn't been set up yet it gets set up here.
|
||||
void OxenMQ::proxy_timer(int id, std::function<void()> job, std::chrono::milliseconds interval, bool squelch, int thread) {
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
int zmq_timer_id = zmq_timers_add(timers.get(),
|
||||
interval.count(),
|
||||
[](int timer_id, void* self) { static_cast<OxenMQ*>(self)->_queue_timer_job(timer_id); },
|
||||
this);
|
||||
if (zmq_timer_id == -1)
|
||||
throw zmq::error_t{};
|
||||
timer_jobs[zmq_timer_id] = { std::move(job), squelch, false, thread };
|
||||
timer_zmq_id[id] = zmq_timer_id;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_timer(oxenc::bt_list_consumer timer_data) {
|
||||
auto timer_id = timer_data.consume_integer<int>();
|
||||
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
|
||||
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
|
||||
auto squelch = timer_data.consume_integer<bool>();
|
||||
auto thread = timer_data.consume_integer<int>();
|
||||
if (!timer_data.is_finished())
|
||||
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
|
||||
proxy_timer(timer_id, std::move(*func), interval, squelch, thread);
|
||||
}
|
||||
|
||||
void OxenMQ::_queue_timer_job(int timer_id) {
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it == timer_jobs.end()) {
|
||||
OMQ_LOG(warn, "Could not find timer job ", timer_id);
|
||||
return;
|
||||
}
|
||||
auto& [func, squelch, running, thread] = it->second;
|
||||
if (squelch && running) {
|
||||
OMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
|
||||
return;
|
||||
}
|
||||
|
||||
if (thread == -1) { // Run directly in proxy thread
|
||||
try { func(); }
|
||||
catch (const std::exception &e) { OMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { OMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
return;
|
||||
}
|
||||
|
||||
detail::Batch* b;
|
||||
if (squelch) {
|
||||
auto* bv = new Batch<void>;
|
||||
bv->add_job(func, thread);
|
||||
running = true;
|
||||
bv->completion([this,timer_id](auto results) {
|
||||
try { results[0].get(); }
|
||||
catch (const std::exception &e) { OMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { OMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it != timer_jobs.end())
|
||||
it->second.running = false;
|
||||
}, OxenMQ::run_in_proxy);
|
||||
b = bv;
|
||||
} else {
|
||||
b = new Job(func, thread);
|
||||
}
|
||||
OMQ_TRACE("b: ", b->size().first, ", ", b->size().second, "; thread = ", thread);
|
||||
assert(b->size() == std::make_pair(size_t{1}, thread > 0));
|
||||
auto& queue = thread > 0
|
||||
? std::get<batch_queue>(tagged_workers[thread - 1])
|
||||
: batch_jobs;
|
||||
queue.emplace_back(static_cast<detail::Batch*>(b), 0);
|
||||
}
|
||||
|
||||
void OxenMQ::add_timer(TimerID& timer, std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
|
||||
int th_id = thread ? thread->_id : 0;
|
||||
timer._id = next_timer_id++;
|
||||
if (proxy_thread.joinable()) {
|
||||
detail::send_control(get_control_socket(), "TIMER", oxenc::bt_serialize(oxenc::bt_list{{
|
||||
timer._id,
|
||||
detail::serialize_object(std::move(job)),
|
||||
interval.count(),
|
||||
squelch,
|
||||
th_id}}));
|
||||
} else {
|
||||
proxy_timer(timer._id, std::move(job), interval, squelch, th_id);
|
||||
}
|
||||
}
|
||||
|
||||
TimerID OxenMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
|
||||
TimerID tid;
|
||||
add_timer(tid, std::move(job), interval, squelch, std::move(thread));
|
||||
return tid;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_timer_del(int id) {
|
||||
if (!timers)
|
||||
return;
|
||||
auto it = timer_zmq_id.find(id);
|
||||
if (it == timer_zmq_id.end())
|
||||
return;
|
||||
zmq_timers_cancel(timers.get(), it->second);
|
||||
timer_zmq_id.erase(it);
|
||||
}
|
||||
|
||||
void OxenMQ::cancel_timer(TimerID timer_id) {
|
||||
if (proxy_thread.joinable()) {
|
||||
detail::send_control(get_control_socket(), "TIMER_DEL", oxenc::bt_serialize(timer_id._id));
|
||||
} else {
|
||||
proxy_timer_del(timer_id._id);
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
|
||||
|
||||
TaggedThreadID OxenMQ::add_tagged_thread(std::string name, std::function<void()> start) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error{"Cannot add tagged threads after calling `start()`"};
|
||||
|
||||
if (name == "_proxy"sv || name.empty() || name.find('\0') != std::string::npos)
|
||||
throw std::logic_error{"Invalid tagged thread name `" + name + "'"};
|
||||
|
||||
auto& [run, busy, queue] = tagged_workers.emplace_back();
|
||||
busy = false;
|
||||
run.worker_id = tagged_workers.size(); // We want index + 1 (b/c 0 is used for non-tagged jobs)
|
||||
run.worker_routing_name = "t" + std::to_string(run.worker_id);
|
||||
run.worker_routing_id = "t" + std::string{reinterpret_cast<const char*>(&run.worker_id), sizeof(run.worker_id)};
|
||||
OMQ_TRACE("Created new tagged thread ", name, " with routing id ", run.worker_routing_name);
|
||||
|
||||
run.worker_thread = std::thread{&OxenMQ::worker_thread, this, run.worker_id, name, std::move(start)};
|
||||
|
||||
return TaggedThreadID{static_cast<int>(run.worker_id)};
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
#pragma once
|
||||
#include <vector>
|
||||
#include "connections.h"
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
class OxenMQ;
|
||||
|
||||
/// Encapsulates an incoming message from a remote connection with message details plus extra
|
||||
/// info need to send a reply back through the proxy thread via the `reply()` method. Note that
|
||||
/// this object gets reused: callbacks should use but not store any reference beyond the callback.
|
||||
class Message {
|
||||
public:
|
||||
OxenMQ& oxenmq; ///< The owning OxenMQ object
|
||||
std::vector<std::string_view> data; ///< The provided command data parts, if any.
|
||||
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
|
||||
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
|
||||
Access access; ///< The access level of the invoker. This can be higher than the access level of the command, for example for an admin invoking a basic command.
|
||||
std::string remote; ///< Some sort of remote address from which the request came. Often "IP" for TCP connections and "localhost:UID:GID:PID" for unix socket connections.
|
||||
|
||||
/// Constructor
|
||||
Message(OxenMQ& omq, ConnectionID cid, Access access, std::string remote)
|
||||
: oxenmq{omq}, conn{std::move(cid)}, access{std::move(access)}, remote{std::move(remote)} {}
|
||||
|
||||
// Non-copyable
|
||||
Message(const Message&) = delete;
|
||||
Message& operator=(const Message&) = delete;
|
||||
|
||||
/// Sends a command back to whomever sent this message. Arguments are forwarded to send() but
|
||||
/// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e.
|
||||
/// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to
|
||||
/// establish a new connection to the SN if no longer connected. For non-SN messages the reply
|
||||
/// will be attempted using the available routing information, but if the connection has already
|
||||
/// been closed the reply will be dropped.
|
||||
///
|
||||
/// If you want to send a non-strong reply even when the remote is a service node then add
|
||||
/// an explicit `send_option::optional()` argument.
|
||||
template <typename... Args>
|
||||
void send_back(std::string_view command, Args&&... args);
|
||||
|
||||
/// Sends a reply to a request. This takes no command: the command is always the built-in
|
||||
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
|
||||
/// arguments are as in `send_back()`. You should only send one reply for a command expecting
|
||||
/// replies, though this is not enforced: attempting to send multiple replies will simply be
|
||||
/// dropped when received by the remote. (Note, however, that it is possible to send multiple
|
||||
/// messages -- e.g. you could send a reply and then also call send_back() and/or send_request()
|
||||
/// to send more requests back to the sender).
|
||||
template <typename... Args>
|
||||
void send_reply(Args&&... args);
|
||||
|
||||
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
|
||||
/// omq.request() that takes care of setting up the recipient arguments.
|
||||
template <typename ReplyCallback, typename... Args>
|
||||
void send_request(std::string_view command, ReplyCallback&& callback, Args&&... args);
|
||||
|
||||
/** Class returned by `send_later()` that can be used to call `back()`, `reply()`, or
|
||||
* `request()` beyond the lifetime of the Message instance as if calling `msg.send_back()`,
|
||||
* `msg.send_reply()`, or `msg.send_request()`. For example:
|
||||
*
|
||||
* auto send = msg.send_later();
|
||||
* // ... later, perhaps in a lambda or scheduled job:
|
||||
* send.reply("content");
|
||||
*
|
||||
* is equivalent to
|
||||
*
|
||||
* msg.send_reply("content");
|
||||
*
|
||||
* except that it is valid even if `msg` is no longer valid.
|
||||
*/
|
||||
class DeferredSend {
|
||||
public:
|
||||
OxenMQ& oxenmq; ///< The owning OxenMQ object
|
||||
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status
|
||||
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `reply()`.
|
||||
|
||||
explicit DeferredSend(Message& m) : oxenmq{m.oxenmq}, conn{m.conn}, reply_tag{m.reply_tag} {}
|
||||
|
||||
template <typename... Args>
|
||||
void operator()(Args &&...args) const {
|
||||
if (reply_tag.empty())
|
||||
back(std::forward<Args>(args)...);
|
||||
else
|
||||
reply(std::forward<Args>(args)...);
|
||||
};
|
||||
|
||||
|
||||
/// Equivalent to msg.send_back(...), but can be invoked later.
|
||||
template <typename... Args>
|
||||
void back(std::string_view command, Args&&... args) const;
|
||||
|
||||
/// Equivalent to msg.send_reply(...), but can be invoked later.
|
||||
template <typename... Args>
|
||||
void reply(Args&&... args) const;
|
||||
|
||||
/// Equivalent to msg.send_request(...), but can be invoked later.
|
||||
template <typename ReplyCallback, typename... Args>
|
||||
void request(std::string_view command, ReplyCallback&& callback, Args&&... args) const;
|
||||
};
|
||||
|
||||
/// Returns a DeferredSend object that can be used to send replies to this message even if the
|
||||
/// message expires. Typically this is used when sending a reply requires waiting on another
|
||||
/// task to complete without needing to block the handler thread.
|
||||
DeferredSend send_later() { return DeferredSend{*this}; }
|
||||
};
|
||||
|
||||
}
|
|
@ -1,26 +1,34 @@
|
|||
#pragma once
|
||||
#include "lokimq.h"
|
||||
#include <limits>
|
||||
#include "oxenmq.h"
|
||||
|
||||
// Inside some method:
|
||||
// LMQ_LOG(warn, "bad ", 42, " stuff");
|
||||
// OMQ_LOG(warn, "bad ", 42, " stuff");
|
||||
//
|
||||
// (The "this->" is here to work around gcc 5 bugginess when called in a `this`-capturing lambda.)
|
||||
#define LMQ_LOG(level, ...) this->log_(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define OMQ_LOG(level, ...) log(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Same as LMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
|
||||
# define LMQ_TRACE(...) this->log_(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
|
||||
// Same as OMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
|
||||
# define OMQ_TRACE(...) log(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
# define LMQ_TRACE(...)
|
||||
# define OMQ_TRACE(...)
|
||||
#endif
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
constexpr char SN_ADDR_COMMAND[] = "inproc://sn-command";
|
||||
constexpr char SN_ADDR_WORKERS[] = "inproc://sn-workers";
|
||||
constexpr char SN_ADDR_SELF[] = "inproc://sn-self";
|
||||
constexpr char ZMQ_ADDR_ZAP[] = "inproc://zeromq.zap.01";
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
|
||||
constexpr auto EPOLL_COMMAND_ID = std::numeric_limits<uint64_t>::max();
|
||||
constexpr auto EPOLL_WORKER_ID = std::numeric_limits<uint64_t>::max() - 1;
|
||||
constexpr auto EPOLL_ZAP_ID = std::numeric_limits<uint64_t>::max() - 2;
|
||||
|
||||
#endif
|
||||
|
||||
/// Destructor for create_message(std::string&&) that zmq calls when it's done with the message.
|
||||
extern "C" inline void message_buffer_destroy(void*, void* hint) {
|
||||
delete reinterpret_cast<std::string*>(hint);
|
||||
|
@ -33,7 +41,7 @@ inline zmq::message_t create_message(std::string&& data) {
|
|||
};
|
||||
|
||||
/// Create a message copying from a string_view
|
||||
inline zmq::message_t create_message(string_view data) {
|
||||
inline zmq::message_t create_message(std::string_view data) {
|
||||
return zmq::message_t{data.begin(), data.end()};
|
||||
}
|
||||
|
||||
|
@ -87,6 +95,21 @@ inline bool recv_message_parts(zmq::socket_t& sock, std::vector<zmq::message_t>&
|
|||
return true;
|
||||
}
|
||||
|
||||
// Same as above, but using a fixed sized array; this is only used for internal jobs (e.g. control
|
||||
// messages) where we know the message parts should never exceed a given size (this function does
|
||||
// not bounds check except in debug builds). Returns the number of message parts received, or 0 on
|
||||
// read error.
|
||||
template <size_t N>
|
||||
inline size_t recv_message_parts(zmq::socket_t& sock, std::array<zmq::message_t, N>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
|
||||
for (size_t count = 0; ; count++) {
|
||||
assert(count < N);
|
||||
if (!sock.recv(parts[count], flags))
|
||||
return 0;
|
||||
if (!parts[count].more())
|
||||
return count + 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline const char* peer_address(zmq::message_t& msg) {
|
||||
try { return msg.gets("Peer-Address"); } catch (...) {}
|
||||
return "(unknown)";
|
||||
|
@ -94,29 +117,12 @@ inline const char* peer_address(zmq::message_t& msg) {
|
|||
|
||||
// Returns a string view of the given message data. It's the caller's responsibility to keep the
|
||||
// referenced message alive. If you want a std::string instead just call `m.to_string()`
|
||||
inline string_view view(const zmq::message_t& m) {
|
||||
inline std::string_view view(const zmq::message_t& m) {
|
||||
return {m.data<char>(), m.size()};
|
||||
}
|
||||
|
||||
inline std::string to_string(AuthLevel a) {
|
||||
switch (a) {
|
||||
case AuthLevel::denied: return "denied";
|
||||
case AuthLevel::none: return "none";
|
||||
case AuthLevel::basic: return "basic";
|
||||
case AuthLevel::admin: return "admin";
|
||||
default: return "(unknown)";
|
||||
}
|
||||
}
|
||||
|
||||
inline AuthLevel auth_from_string(string_view a) {
|
||||
if (a == "none") return AuthLevel::none;
|
||||
if (a == "basic") return AuthLevel::basic;
|
||||
if (a == "admin") return AuthLevel::admin;
|
||||
return AuthLevel::denied;
|
||||
}
|
||||
|
||||
// Extracts and builds the "send" part of a message for proxy_send/proxy_reply
|
||||
inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_view route) {
|
||||
inline std::list<zmq::message_t> build_send_parts(oxenc::bt_list_consumer send, std::string_view route) {
|
||||
std::list<zmq::message_t> parts;
|
||||
if (!route.empty())
|
||||
parts.push_back(create_message(route));
|
||||
|
@ -128,7 +134,7 @@ inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_
|
|||
/// Sends a control message to a specific destination by prefixing the worker name (or identity)
|
||||
/// then appending the command and optional data (if non-empty). (This is needed when sending the control message
|
||||
/// to a router socket, i.e. inside the proxy thread).
|
||||
inline void route_control(zmq::socket_t& sock, string_view identity, string_view cmd, const std::string& data = {}) {
|
||||
inline void route_control(zmq::socket_t& sock, std::string_view identity, std::string_view cmd, const std::string& data = {}) {
|
||||
sock.send(create_message(identity), zmq::send_flags::sndmore);
|
||||
detail::send_control(sock, cmd, data);
|
||||
}
|
|
@ -1,21 +1,29 @@
|
|||
#include "lokimq.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include "oxenmq.h"
|
||||
#include "oxenmq-internal.h"
|
||||
#include "zmq.hpp"
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <random>
|
||||
#include <ostream>
|
||||
#include <thread>
|
||||
#include <future>
|
||||
|
||||
extern "C" {
|
||||
#include <sodium.h>
|
||||
#include <sodium/core.h>
|
||||
#include <sodium/crypto_box.h>
|
||||
#include <sodium/crypto_scalarmult.h>
|
||||
}
|
||||
#include "hex.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include <oxenc/variant.h>
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
/// Creates a message by bt-serializing the given value (string, number, list, or dict)
|
||||
template <typename T>
|
||||
zmq::message_t create_bt_message(T&& data) { return create_message(bt_serialize(std::forward<T>(data))); }
|
||||
zmq::message_t create_bt_message(T&& data) { return create_message(oxenc::bt_serialize(std::forward<T>(data))); }
|
||||
|
||||
template <typename MessageContainer>
|
||||
std::vector<std::string> as_strings(const MessageContainer& msgs) {
|
||||
|
@ -38,7 +46,7 @@ namespace detail {
|
|||
|
||||
// Sends a control messages between proxy and threads or between proxy and workers consisting of a
|
||||
// single command codes with an optional data part (the data frame is omitted if empty).
|
||||
void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
|
||||
void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data) {
|
||||
auto c = create_message(std::move(cmd));
|
||||
if (data.empty()) {
|
||||
sock.send(c, zmq::send_flags::none);
|
||||
|
@ -53,12 +61,12 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
|
|||
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
||||
auto result = std::make_pair(""s, AuthLevel::none);
|
||||
try {
|
||||
string_view pubkey_hex{msg.gets("User-Id")};
|
||||
std::string_view pubkey_hex{msg.gets("User-Id")};
|
||||
if (pubkey_hex.size() != 64)
|
||||
throw std::logic_error("bad user-id");
|
||||
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
|
||||
assert(oxenc::is_hex(pubkey_hex.begin(), pubkey_hex.end()));
|
||||
result.first.resize(32, 0);
|
||||
from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin());
|
||||
oxenc::from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin());
|
||||
} catch (...) {}
|
||||
|
||||
try {
|
||||
|
@ -71,20 +79,20 @@ std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
|||
|
||||
} // namespace detail
|
||||
|
||||
int LokiMQ::set_zmq_context_option(int option, int value) {
|
||||
return context.setctxopt(option, value);
|
||||
void OxenMQ::set_zmq_context_option(zmq::ctxopt option, int value) {
|
||||
context.set(option, value);
|
||||
}
|
||||
|
||||
void LokiMQ::log_level(LogLevel level) {
|
||||
void OxenMQ::log_level(LogLevel level) {
|
||||
log_lvl.store(level, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
LogLevel LokiMQ::log_level() const {
|
||||
LogLevel OxenMQ::log_level() const {
|
||||
return log_lvl.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
|
||||
CatHelper LokiMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) {
|
||||
CatHelper OxenMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) {
|
||||
check_not_started(proxy_thread, "add a category");
|
||||
|
||||
if (name.size() > MAX_CATEGORY_LENGTH)
|
||||
|
@ -102,7 +110,7 @@ CatHelper LokiMQ::add_category(std::string name, Access access_level, unsigned i
|
|||
return ret;
|
||||
}
|
||||
|
||||
void LokiMQ::add_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
void OxenMQ::add_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
check_not_started(proxy_thread, "add a command");
|
||||
|
||||
if (name.size() > MAX_COMMAND_LENGTH)
|
||||
|
@ -121,12 +129,12 @@ void LokiMQ::add_command(const std::string& category, std::string name, CommandC
|
|||
throw std::runtime_error("Cannot add command `" + fullname + "': that command already exists");
|
||||
}
|
||||
|
||||
void LokiMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
void OxenMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
add_command(category, name, std::move(callback));
|
||||
categories.at(category).commands.at(name).second = true;
|
||||
}
|
||||
|
||||
void LokiMQ::add_command_alias(std::string from, std::string to) {
|
||||
void OxenMQ::add_command_alias(std::string from, std::string to) {
|
||||
check_not_started(proxy_thread, "add a command alias");
|
||||
|
||||
if (from.empty())
|
||||
|
@ -155,40 +163,35 @@ std::atomic<int> next_id{1};
|
|||
/// Accesses a thread-local command socket connected to the proxy's command socket used to issue
|
||||
/// commands in a thread-safe manner. A mutex is only required here the first time a thread
|
||||
/// accesses the control socket.
|
||||
zmq::socket_t& LokiMQ::get_control_socket() {
|
||||
zmq::socket_t& OxenMQ::get_control_socket() {
|
||||
assert(proxy_thread.joinable());
|
||||
|
||||
// Maps the LokiMQ unique ID to a local thread command socket.
|
||||
static thread_local std::map<int, std::shared_ptr<zmq::socket_t>> control_sockets;
|
||||
static thread_local std::pair<int, std::shared_ptr<zmq::socket_t>> last{-1, nullptr};
|
||||
|
||||
// Optimize by caching the last value; LokiMQ is often a singleton and in that case we're
|
||||
// Optimize by caching the last value; OxenMQ is often a singleton and in that case we're
|
||||
// going to *always* hit this optimization. Even if it isn't, we're probably likely to need the
|
||||
// same control socket from the same thread multiple times sequentially so this may still help.
|
||||
if (object_id == last.first)
|
||||
return *last.second;
|
||||
static thread_local int last_id = -1;
|
||||
static thread_local zmq::socket_t* last_socket = nullptr;
|
||||
if (object_id == last_id)
|
||||
return *last_socket;
|
||||
|
||||
auto it = control_sockets.find(object_id);
|
||||
if (it != control_sockets.end()) {
|
||||
last = *it;
|
||||
return *last.second;
|
||||
}
|
||||
std::lock_guard lock{control_sockets_mutex};
|
||||
|
||||
std::lock_guard<std::mutex> lock{control_sockets_mutex};
|
||||
if (proxy_shutting_down)
|
||||
throw std::runtime_error("Unable to obtain LokiMQ control socket: proxy thread is shutting down");
|
||||
auto control = std::make_shared<zmq::socket_t>(context, zmq::socket_type::dealer);
|
||||
control->setsockopt<int>(ZMQ_LINGER, 0);
|
||||
control->connect(SN_ADDR_COMMAND);
|
||||
thread_control_sockets.push_back(control);
|
||||
control_sockets.emplace(object_id, control);
|
||||
last.first = object_id;
|
||||
last.second = std::move(control);
|
||||
return *last.second;
|
||||
throw std::runtime_error("Unable to obtain OxenMQ control socket: proxy thread is shutting down");
|
||||
|
||||
auto& socket = control_sockets[std::this_thread::get_id()];
|
||||
if (!socket) {
|
||||
socket = std::make_unique<zmq::socket_t>(context, zmq::socket_type::dealer);
|
||||
socket->set(zmq::sockopt::linger, 0);
|
||||
socket->connect(SN_ADDR_COMMAND);
|
||||
}
|
||||
last_id = object_id;
|
||||
last_socket = socket.get();
|
||||
return *last_socket;
|
||||
}
|
||||
|
||||
|
||||
LokiMQ::LokiMQ(
|
||||
OxenMQ::OxenMQ(
|
||||
std::string pubkey_,
|
||||
std::string privkey_,
|
||||
bool service_node,
|
||||
|
@ -199,14 +202,17 @@ LokiMQ::LokiMQ(
|
|||
sn_lookup{std::move(lookup)}, log_lvl{level}, logger{std::move(logger)}
|
||||
{
|
||||
|
||||
LMQ_TRACE("Constructing listening LokiMQ, id=", object_id, ", this=", this);
|
||||
OMQ_TRACE("Constructing OxenMQ, id=", object_id, ", this=", this);
|
||||
|
||||
if (sodium_init() == -1)
|
||||
throw std::runtime_error{"libsodium initialization failed"};
|
||||
|
||||
if (pubkey.empty() != privkey.empty()) {
|
||||
throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
|
||||
throw std::invalid_argument("OxenMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
|
||||
} else if (pubkey.empty()) {
|
||||
if (service_node)
|
||||
throw std::invalid_argument("Cannot construct a service node mode LokiMQ without a keypair");
|
||||
LMQ_LOG(debug, "generating x25519 keypair for remote-only LokiMQ instance");
|
||||
throw std::invalid_argument("Cannot construct a service node mode OxenMQ without a keypair");
|
||||
OMQ_LOG(debug, "generating x25519 keypair for remote-only OxenMQ instance");
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
|
@ -221,69 +227,101 @@ LokiMQ::LokiMQ(
|
|||
std::string verify_pubkey(crypto_box_PUBLICKEYBYTES, 0);
|
||||
crypto_scalarmult_base(reinterpret_cast<unsigned char*>(&verify_pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
if (verify_pubkey != pubkey)
|
||||
throw std::invalid_argument("Invalid pubkey/privkey values given to LokiMQ construction: pubkey verification failed");
|
||||
throw std::invalid_argument("Invalid pubkey/privkey values given to OxenMQ construction: pubkey verification failed");
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::start() {
|
||||
void OxenMQ::start() {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call start() multiple times!");
|
||||
|
||||
// If we're not binding to anything then we don't listen, i.e. we can only establish outbound
|
||||
// connections. Don't allow this if we are in service_node mode because, if we aren't
|
||||
// listening, we are useless as a service node.
|
||||
if (bind.empty() && local_service_node)
|
||||
throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"};
|
||||
OMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", oxenc::to_hex(pubkey));
|
||||
|
||||
LMQ_LOG(info, "Initializing LokiMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey));
|
||||
assert(general_workers > 0);
|
||||
if (batch_jobs_reserved < 0)
|
||||
batch_jobs_reserved = (general_workers + 1) / 2;
|
||||
if (reply_jobs_reserved < 0)
|
||||
reply_jobs_reserved = (general_workers + 7) / 8;
|
||||
|
||||
int zmq_socket_limit = context.getctxopt(ZMQ_SOCKET_LIMIT);
|
||||
if (MAX_SOCKETS > 1 && MAX_SOCKETS <= zmq_socket_limit)
|
||||
context.setctxopt(ZMQ_MAX_SOCKETS, MAX_SOCKETS);
|
||||
else
|
||||
LMQ_LOG(error, "Not applying LokiMQ::MAX_SOCKETS setting: ", MAX_SOCKETS, " must be in [1, ", zmq_socket_limit, "]");
|
||||
max_workers = general_workers + batch_jobs_reserved + reply_jobs_reserved;
|
||||
for (const auto& cat : categories) {
|
||||
max_workers += cat.second.reserved_threads;
|
||||
}
|
||||
|
||||
if (log_level() >= LogLevel::debug) {
|
||||
OMQ_LOG(debug, "Reserving space for ", max_workers, " max workers = ", general_workers, " general plus reservations for:");
|
||||
for (const auto& cat : categories)
|
||||
OMQ_LOG(debug, " - ", cat.first, ": ", cat.second.reserved_threads);
|
||||
OMQ_LOG(debug, " - (batch jobs): ", batch_jobs_reserved);
|
||||
OMQ_LOG(debug, " - (reply jobs): ", reply_jobs_reserved);
|
||||
OMQ_LOG(debug, "Plus ", tagged_workers.size(), " tagged worker threads");
|
||||
}
|
||||
|
||||
if (MAX_SOCKETS != 0) {
|
||||
// The max sockets setting we apply to the context here is used during zmq context
|
||||
// initialization, which happens when the first socket is constructed using this context:
|
||||
// hence we set this *before* constructing any socket_t on the context.
|
||||
int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit);
|
||||
int want_sockets = MAX_SOCKETS < 0 ? zmq_socket_limit :
|
||||
std::min<int>(zmq_socket_limit,
|
||||
MAX_SOCKETS + max_workers + tagged_workers.size()
|
||||
+ 4 /* zap_auth, workers_socket, command, inproc_listener */);
|
||||
context.set(zmq::ctxopt::max_sockets, want_sockets);
|
||||
}
|
||||
|
||||
// We bind `command` here so that the `get_control_socket()` below is always connecting to a
|
||||
// bound socket, but we do nothing else here: the proxy thread is responsible for everything
|
||||
// except binding it.
|
||||
command = zmq::socket_t{context, zmq::socket_type::router};
|
||||
command.bind(SN_ADDR_COMMAND);
|
||||
proxy_thread = std::thread{&LokiMQ::proxy_loop, this};
|
||||
std::promise<void> startup_prom;
|
||||
auto proxy_startup = startup_prom.get_future();
|
||||
proxy_thread = std::thread{&OxenMQ::proxy_loop, this, std::move(startup_prom)};
|
||||
|
||||
LMQ_LOG(debug, "Waiting for proxy thread to get ready...");
|
||||
OMQ_LOG(debug, "Waiting for proxy thread to initialize...");
|
||||
proxy_startup.get(); // Rethrows exceptions from the proxy startup (e.g. failure to bind)
|
||||
|
||||
OMQ_LOG(debug, "Waiting for proxy thread to get ready...");
|
||||
auto &control = get_control_socket();
|
||||
detail::send_control(control, "START");
|
||||
LMQ_TRACE("Sent START command");
|
||||
OMQ_TRACE("Sent START command");
|
||||
|
||||
zmq::message_t ready_msg;
|
||||
std::vector<zmq::message_t> parts;
|
||||
try { recv_message_parts(control, parts); }
|
||||
catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from LokiMQ::Proxy thread: "s + e.what()); }
|
||||
catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from OxenMQ::Proxy thread: "s + e.what()); }
|
||||
|
||||
if (!(parts.size() == 1 && view(parts.front()) == "READY"))
|
||||
throw std::runtime_error("Invalid startup message from proxy thread (didn't get expected READY message)");
|
||||
LMQ_LOG(debug, "Proxy thread is ready");
|
||||
OMQ_LOG(debug, "Proxy thread is ready");
|
||||
}
|
||||
|
||||
void LokiMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) {
|
||||
// TODO: there's no particular reason we can't start listening after starting up; just needs to
|
||||
// be implemented. (But if we can start we'll probably also want to be able to stop, so it's
|
||||
// more than just binding that needs implementing).
|
||||
check_not_started(proxy_thread, "start listening");
|
||||
|
||||
bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)});
|
||||
void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
|
||||
if (std::string_view{bind_addr}.substr(0, 9) == "inproc://")
|
||||
throw std::logic_error{"inproc:// cannot be used with listen_curve"};
|
||||
if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; };
|
||||
bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)};
|
||||
if (proxy_thread.joinable())
|
||||
detail::send_control(get_control_socket(), "BIND", oxenc::bt_serialize(detail::serialize_object(std::move(d))));
|
||||
else
|
||||
bind.push_back(std::move(d));
|
||||
}
|
||||
|
||||
void LokiMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) {
|
||||
// TODO: As above.
|
||||
check_not_started(proxy_thread, "start listening");
|
||||
|
||||
bind.emplace_back(std::move(bind_addr), bind_data{false, std::move(allow_connection)});
|
||||
void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
|
||||
if (std::string_view{bind_addr}.substr(0, 9) == "inproc://")
|
||||
throw std::logic_error{"inproc:// cannot be used with listen_plain"};
|
||||
if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; };
|
||||
bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)};
|
||||
if (proxy_thread.joinable())
|
||||
detail::send_control(get_control_socket(), "BIND", oxenc::bt_serialize(detail::serialize_object(std::move(d))));
|
||||
else
|
||||
bind.push_back(std::move(d));
|
||||
}
|
||||
|
||||
|
||||
std::pair<LokiMQ::category*, const std::pair<LokiMQ::CommandCallback, bool>*> LokiMQ::get_command(std::string& command) {
|
||||
std::pair<OxenMQ::category*, const std::pair<OxenMQ::CommandCallback, bool>*> OxenMQ::get_command(std::string& command) {
|
||||
if (command.size() > MAX_CATEGORY_LENGTH + 1 + MAX_COMMAND_LENGTH) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "': command too long");
|
||||
OMQ_LOG(warn, "Invalid command '", command, "': command too long");
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -295,7 +333,7 @@ std::pair<LokiMQ::category*, const std::pair<LokiMQ::CommandCallback, bool>*> Lo
|
|||
|
||||
auto dot = command.find('.');
|
||||
if (dot == 0 || dot == std::string::npos) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "': expected <category>.<command>");
|
||||
OMQ_LOG(warn, "Invalid command '", command, "': expected <category>.<command>");
|
||||
return {};
|
||||
}
|
||||
std::string catname = command.substr(0, dot);
|
||||
|
@ -303,21 +341,21 @@ std::pair<LokiMQ::category*, const std::pair<LokiMQ::CommandCallback, bool>*> Lo
|
|||
|
||||
auto catit = categories.find(catname);
|
||||
if (catit == categories.end()) {
|
||||
LMQ_LOG(warn, "Invalid command category '", catname, "'");
|
||||
OMQ_LOG(warn, "Invalid command category '", catname, "'");
|
||||
return {};
|
||||
}
|
||||
|
||||
const auto& category = catit->second;
|
||||
auto callback_it = category.commands.find(cmd);
|
||||
if (callback_it == category.commands.end()) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "'");
|
||||
OMQ_LOG(warn, "Invalid command '", command, "'");
|
||||
return {};
|
||||
}
|
||||
|
||||
return {&catit->second, &callback_it->second};
|
||||
}
|
||||
|
||||
void LokiMQ::set_batch_threads(int threads) {
|
||||
void OxenMQ::set_batch_threads(int threads) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot change reserved batch threads after calling `start()`");
|
||||
if (threads < -1) // -1 is the default which is based on general threads
|
||||
|
@ -325,7 +363,7 @@ void LokiMQ::set_batch_threads(int threads) {
|
|||
batch_jobs_reserved = threads;
|
||||
}
|
||||
|
||||
void LokiMQ::set_reply_threads(int threads) {
|
||||
void OxenMQ::set_reply_threads(int threads) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot change reserved reply threads after calling `start()`");
|
||||
if (threads < -1) // -1 is the default which is based on general threads
|
||||
|
@ -333,7 +371,7 @@ void LokiMQ::set_reply_threads(int threads) {
|
|||
reply_jobs_reserved = threads;
|
||||
}
|
||||
|
||||
void LokiMQ::set_general_threads(int threads) {
|
||||
void OxenMQ::set_general_threads(int threads) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot change general thread count after calling `start()`");
|
||||
if (threads < 1)
|
||||
|
@ -341,40 +379,72 @@ void LokiMQ::set_general_threads(int threads) {
|
|||
general_workers = threads;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_,
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_,
|
||||
std::vector<zmq::message_t> data_parts_, const std::pair<CommandCallback, bool>* callback_) {
|
||||
is_batch_job = false;
|
||||
is_reply_job = false;
|
||||
reset();
|
||||
cat = cat_;
|
||||
command = std::move(command_);
|
||||
conn = std::move(conn_);
|
||||
access = std::move(access_);
|
||||
remote = std::move(remote_);
|
||||
data_parts = std::move(data_parts_);
|
||||
callback = callback_;
|
||||
to_run = callback_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) {
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.conn),
|
||||
std::move(pending.data_parts), pending.callback);
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function<void()> callback) {
|
||||
reset();
|
||||
is_injected = true;
|
||||
cat = cat_;
|
||||
command = std::move(command_);
|
||||
conn = {};
|
||||
access = {};
|
||||
remote = std::move(remote_);
|
||||
to_run = std::move(callback);
|
||||
return *this;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job) {
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(pending_command&& pending) {
|
||||
if (auto *f = std::get_if<std::function<void()>>(&pending.callback))
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.remote), std::move(*f));
|
||||
|
||||
assert(pending.callback.index() == 0);
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.conn), std::move(pending.access),
|
||||
std::move(pending.remote), std::move(pending.data_parts), var::get<0>(pending.callback));
|
||||
}
|
||||
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) {
|
||||
reset();
|
||||
is_batch_job = true;
|
||||
is_reply_job = reply_job;
|
||||
is_tagged_thread_job = tagged_thread > 0;
|
||||
batch_jobno = bj.second;
|
||||
batch = bj.first;
|
||||
to_run = bj.first;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
LokiMQ::~LokiMQ() {
|
||||
if (!proxy_thread.joinable())
|
||||
return;
|
||||
OxenMQ::~OxenMQ() {
|
||||
if (!proxy_thread.joinable()) {
|
||||
if (!tagged_workers.empty()) {
|
||||
// We have tagged workers that are waiting on a signal for startup, but we didn't start
|
||||
// up, so signal them so that they can end themselves.
|
||||
{
|
||||
std::lock_guard lock{tagged_startup_mutex};
|
||||
tagged_go = tagged_go_mode::SHUTDOWN;
|
||||
}
|
||||
tagged_cv.notify_all();
|
||||
for (auto& [run, busy, queue] : tagged_workers)
|
||||
run.worker_thread.join();
|
||||
}
|
||||
|
||||
LMQ_LOG(info, "LokiMQ shutting down proxy thread");
|
||||
return;
|
||||
}
|
||||
|
||||
OMQ_LOG(info, "OxenMQ shutting down proxy thread");
|
||||
detail::send_control(get_control_socket(), "QUIT");
|
||||
proxy_thread.join();
|
||||
LMQ_LOG(info, "LokiMQ proxy thread has stopped");
|
||||
OMQ_LOG(info, "OxenMQ proxy thread has stopped");
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
|
||||
|
@ -390,13 +460,14 @@ std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
|
|||
|
||||
std::string make_random_string(size_t size) {
|
||||
static thread_local std::mt19937_64 rng{std::random_device{}()};
|
||||
static thread_local std::uniform_int_distribution<char> dist{std::numeric_limits<char>::min(), std::numeric_limits<char>::max()};
|
||||
std::string rando;
|
||||
rando.reserve(size);
|
||||
for (size_t i = 0; i < size; i++)
|
||||
rando += dist(rng);
|
||||
while (rando.size() < size) {
|
||||
uint64_t x = rng();
|
||||
rando.append(reinterpret_cast<const char*>(&x), std::min<size_t>(size - rando.size(), 8));
|
||||
}
|
||||
return rando;
|
||||
}
|
||||
|
||||
} // namespace lokimq
|
||||
} // namespace oxenmq
|
||||
// vim:sw=4:et
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,843 @@
|
|||
#include "oxenmq.h"
|
||||
#include "oxenmq-internal.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include <exception>
|
||||
#include <future>
|
||||
|
||||
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
extern "C" {
|
||||
#include <pthread.h>
|
||||
#include <pthread_np.h>
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
#include <sys/epoll.h>
|
||||
#endif
|
||||
|
||||
#ifndef _WIN32
|
||||
extern "C" {
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
void OxenMQ::proxy_quit() {
|
||||
OMQ_LOG(debug, "Received quit command, shutting down proxy thread");
|
||||
|
||||
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
|
||||
assert(std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto& worker) { return std::get<0>(worker).worker_thread.joinable(); }));
|
||||
|
||||
command.set(zmq::sockopt::linger, 0);
|
||||
command.close();
|
||||
{
|
||||
std::lock_guard lock{control_sockets_mutex};
|
||||
proxy_shutting_down = true; // To prevent threads from opening new control sockets
|
||||
}
|
||||
workers_socket.close();
|
||||
int linger = std::chrono::milliseconds{CLOSE_LINGER}.count();
|
||||
for (auto& [id, s] : connections)
|
||||
s.set(zmq::sockopt::linger, linger);
|
||||
connections.clear();
|
||||
peers.clear();
|
||||
|
||||
OMQ_LOG(debug, "Proxy thread teardown complete");
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_send(oxenc::bt_dict_consumer data) {
|
||||
// NB: bt_dict_consumer goes in alphabetical order
|
||||
std::string_view hint;
|
||||
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
|
||||
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
|
||||
bool optional = false;
|
||||
bool outgoing = false;
|
||||
bool incoming = false;
|
||||
bool request = false;
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id;
|
||||
|
||||
std::string request_tag;
|
||||
ReplyCallback request_callback;
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_pubkey or conn_id missing");
|
||||
if (data.skip_until("conn_route"))
|
||||
conn_id.route = data.consume_string();
|
||||
if (data.skip_until("hint"))
|
||||
hint = data.consume_string_view();
|
||||
if (data.skip_until("incoming"))
|
||||
incoming = data.consume_integer<bool>();
|
||||
if (data.skip_until("keep_alive"))
|
||||
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
if (data.skip_until("optional"))
|
||||
optional = data.consume_integer<bool>();
|
||||
if (data.skip_until("outgoing"))
|
||||
outgoing = data.consume_integer<bool>();
|
||||
|
||||
if (data.skip_until("request"))
|
||||
request = data.consume_integer<bool>();
|
||||
if (request) {
|
||||
if (!data.skip_until("request_callback"))
|
||||
throw std::runtime_error("Internal error: received request without request_callback");
|
||||
|
||||
request_callback = detail::deserialize_object<ReplyCallback>(data.consume_integer<uintptr_t>());
|
||||
|
||||
if (!data.skip_until("request_tag"))
|
||||
throw std::runtime_error("Internal error: received request without request_name");
|
||||
request_tag = data.consume_string();
|
||||
if (data.skip_until("request_timeout"))
|
||||
request_timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
}
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; send parts missing");
|
||||
oxenc::bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
send_option::queue_failure::callback_t callback_nosend;
|
||||
if (data.skip_until("send_fail"))
|
||||
callback_nosend = detail::deserialize_object<decltype(callback_nosend)>(data.consume_integer<uintptr_t>());
|
||||
|
||||
send_option::queue_full::callback_t callback_noqueue;
|
||||
if (data.skip_until("send_full_q"))
|
||||
callback_noqueue = detail::deserialize_object<decltype(callback_noqueue)>(data.consume_integer<uintptr_t>());
|
||||
|
||||
// Now figure out which socket to send to and do the actual sending. We can repeat this loop
|
||||
// multiple times, if we're sending to a SN, because it's possible that we have multiple
|
||||
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
|
||||
// connection and try the next one.
|
||||
bool retry = true, sent = false, nowarn = false;
|
||||
while (retry) {
|
||||
retry = false;
|
||||
zmq::socket_t *send_to;
|
||||
if (conn_id.sn()) {
|
||||
auto sock_route = proxy_connect_sn(conn_id.pk, hint, optional, incoming, outgoing, EPHEMERAL_ROUTING_ID, keep_alive);
|
||||
if (!sock_route.first) {
|
||||
nowarn = true;
|
||||
if (optional)
|
||||
OMQ_LOG(debug, "Not sending: send is optional and no connection to ",
|
||||
oxenc::to_hex(conn_id.pk), " is currently established");
|
||||
else
|
||||
OMQ_LOG(error, "Unable to send to ", oxenc::to_hex(conn_id.pk), ": no valid connection address found");
|
||||
break;
|
||||
}
|
||||
send_to = sock_route.first;
|
||||
conn_id.route = std::move(sock_route.second);
|
||||
} else if (!conn_id.route.empty()) { // incoming non-SN connection
|
||||
auto it = connections.find(conn_id.id);
|
||||
if (it == connections.end()) {
|
||||
OMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
|
||||
break;
|
||||
}
|
||||
send_to = &it->second;
|
||||
} else {
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == peers.end()) {
|
||||
OMQ_LOG(warn, "Unable to send: connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
|
||||
break;
|
||||
}
|
||||
auto& peer = pr.first->second;
|
||||
auto it = connections.find(peer.conn_id);
|
||||
if (it == connections.end()) {
|
||||
OMQ_LOG(warn, "Unable to send: peer connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
|
||||
break;
|
||||
}
|
||||
send_to = &it->second;
|
||||
}
|
||||
|
||||
try {
|
||||
sent = send_message_parts(*send_to, build_send_parts(send, conn_id.route));
|
||||
} catch (const zmq::error_t &e) {
|
||||
if (e.num() == EHOSTUNREACH && !conn_id.route.empty() /*= incoming conn*/) {
|
||||
|
||||
OMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first != peers.end()) {
|
||||
if (!conn_id.sn()) {
|
||||
peers.erase(pr.first);
|
||||
} else {
|
||||
bool removed;
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
auto& peer = it->second;
|
||||
if (peer.route == conn_id.route) {
|
||||
peers.erase(it);
|
||||
removed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The incoming connection to the SN is no longer good, but we can retry because
|
||||
// we may have another active connection with the SN (or may want to open one).
|
||||
if (removed) {
|
||||
OMQ_LOG(debug, "Retrying sending to SN ", oxenc::to_hex(conn_id.pk), " using other sockets");
|
||||
retry = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!retry) {
|
||||
if (!conn_id.sn() && !conn_id.route.empty()) { // incoming non-SN connection
|
||||
OMQ_LOG(debug, "Unable to send message to incoming connection ", conn_id, ": ", e.what(),
|
||||
"; remote has probably disconnected");
|
||||
} else {
|
||||
OMQ_LOG(warn, "Unable to send message to ", conn_id, ": ", e.what());
|
||||
}
|
||||
nowarn = true;
|
||||
if (callback_nosend) {
|
||||
job([callback = std::move(callback_nosend), error = e] { callback(&error); });
|
||||
callback_nosend = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (request) {
|
||||
if (sent) {
|
||||
OMQ_LOG(debug, "Added new pending request ", oxenc::to_hex(request_tag));
|
||||
pending_requests.insert({ request_tag, {
|
||||
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
|
||||
} else {
|
||||
OMQ_LOG(debug, "Could not send request, scheduling request callback failure");
|
||||
job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); });
|
||||
}
|
||||
}
|
||||
if (!sent) {
|
||||
if (callback_nosend)
|
||||
job([callback = std::move(callback_nosend)] { callback(nullptr); });
|
||||
else if (callback_noqueue)
|
||||
job(std::move(callback_noqueue));
|
||||
else if (!nowarn)
|
||||
OMQ_LOG(warn, "Unable to send message to ", conn_id, ": sending would block");
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_reply(oxenc::bt_dict_consumer data) {
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id{0};
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_pubkey or conn_id missing");
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; send parts missing");
|
||||
|
||||
oxenc::bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == pr.second) {
|
||||
OMQ_LOG(warn, "Unable to send tagged reply: the connection is no longer valid");
|
||||
return;
|
||||
}
|
||||
|
||||
// We try any connections until one works (for ordinary remotes there will be just one, but for
|
||||
// SNs there might be one incoming and one outgoing).
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
try {
|
||||
send_message_parts(connections[it->second.conn_id], build_send_parts(send, it->second.route));
|
||||
break;
|
||||
} catch (const zmq::error_t &err) {
|
||||
if (err.num() == EHOSTUNREACH) {
|
||||
if (it->second.outgoing()) {
|
||||
OMQ_LOG(debug, "Unable to send reply to non-SN request on outgoing socket: "
|
||||
"remote is no longer connected; closing connection");
|
||||
proxy_close_connection(it->second.conn_id, CLOSE_LINGER);
|
||||
it = peers.erase(it);
|
||||
++it;
|
||||
} else {
|
||||
OMQ_LOG(debug, "Unable to send reply to non-SN request on incoming socket: "
|
||||
"remote is no longer connected; removing peer details");
|
||||
it = peers.erase(it);
|
||||
}
|
||||
} else {
|
||||
OMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what());
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_control_message(OxenMQ::control_message_array& parts, size_t len) {
|
||||
// We throw an uncaught exception here because we only generate control messages internally in
|
||||
// oxenmq code: if one of these condition fail it's a oxenmq bug.
|
||||
if (len < 2)
|
||||
throw std::logic_error("OxenMQ bug: Expected 2-3 message parts for a proxy control message");
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
OMQ_TRACE("control message: ", cmd);
|
||||
if (len == 3) {
|
||||
OMQ_TRACE("...: ", parts[2]);
|
||||
auto data = view(parts[2]);
|
||||
if (cmd == "SEND") {
|
||||
OMQ_TRACE("proxying message");
|
||||
return proxy_send(data);
|
||||
} else if (cmd == "REPLY") {
|
||||
OMQ_TRACE("proxying reply to non-SN incoming message");
|
||||
return proxy_reply(data);
|
||||
} else if (cmd == "BATCH") {
|
||||
OMQ_TRACE("proxy batch jobs");
|
||||
auto ptrval = oxenc::bt_deserialize<uintptr_t>(data);
|
||||
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
|
||||
} else if (cmd == "INJECT") {
|
||||
OMQ_TRACE("proxy inject");
|
||||
return proxy_inject_task(detail::deserialize_object<injected_task>(oxenc::bt_deserialize<uintptr_t>(data)));
|
||||
} else if (cmd == "SET_SNS") {
|
||||
return proxy_set_active_sns(data);
|
||||
} else if (cmd == "UPDATE_SNS") {
|
||||
return proxy_update_active_sns(data);
|
||||
} else if (cmd == "CONNECT_SN") {
|
||||
proxy_connect_sn(data);
|
||||
return;
|
||||
} else if (cmd == "CONNECT_REMOTE") {
|
||||
return proxy_connect_remote(data);
|
||||
} else if (cmd == "DISCONNECT") {
|
||||
return proxy_disconnect(data);
|
||||
} else if (cmd == "TIMER") {
|
||||
return proxy_timer(data);
|
||||
} else if (cmd == "TIMER_DEL") {
|
||||
return proxy_timer_del(oxenc::bt_deserialize<int>(data));
|
||||
} else if (cmd == "BIND") {
|
||||
auto b = detail::deserialize_object<bind_data>(oxenc::bt_deserialize<uintptr_t>(data));
|
||||
if (proxy_bind(b, bind.size()))
|
||||
bind.push_back(std::move(b));
|
||||
return;
|
||||
}
|
||||
} else if (len == 2) {
|
||||
if (cmd == "START") {
|
||||
// Command send by the owning thread during startup; we send back a simple READY reply to
|
||||
// let it know we are running.
|
||||
return route_control(command, route, "READY");
|
||||
} else if (cmd == "QUIT") {
|
||||
// Asked to quit: set max_workers to zero and tell any idle ones to quit. We will
|
||||
// close workers as they come back to READY status, and then close external
|
||||
// connections once all workers are done.
|
||||
max_workers = 0;
|
||||
for (size_t i = 0; i < idle_worker_count; i++)
|
||||
route_control(workers_socket, workers[idle_workers[i]].worker_routing_id, "QUIT");
|
||||
idle_worker_count = 0;
|
||||
for (auto& [run, busy, queue] : tagged_workers)
|
||||
if (!busy)
|
||||
route_control(workers_socket, run.worker_routing_id, "QUIT");
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("OxenMQ bug: Proxy received invalid control command: " +
|
||||
std::string{cmd} + " (" + std::to_string(len) + ")");
|
||||
}
|
||||
|
||||
bool OxenMQ::proxy_bind(bind_data& b, size_t bind_index) {
|
||||
zmq::socket_t listener{context, zmq::socket_type::router};
|
||||
setup_incoming_socket(listener, b.curve, pubkey, privkey, bind_index);
|
||||
|
||||
bool good = true;
|
||||
try {
|
||||
listener.bind(b.address);
|
||||
} catch (const zmq::error_t&) {
|
||||
good = false;
|
||||
}
|
||||
if (b.on_bind) {
|
||||
b.on_bind(good);
|
||||
b.on_bind = nullptr;
|
||||
}
|
||||
if (!good) {
|
||||
OMQ_LOG(warn, "OxenMQ failed to listen on ", b.address);
|
||||
return false;
|
||||
}
|
||||
|
||||
OMQ_LOG(info, "OxenMQ listening on ", b.address);
|
||||
|
||||
b.conn_id = next_conn_id++;
|
||||
connections.emplace_hint(connections.end(), b.conn_id, std::move(listener));
|
||||
|
||||
connections_updated = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_loop_init() {
|
||||
|
||||
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
|
||||
pthread_setname_np(pthread_self(), "omq-proxy");
|
||||
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
pthread_set_name_np(pthread_self(), "omq-proxy");
|
||||
#elif defined(__MACH__)
|
||||
pthread_setname_np("omq-proxy");
|
||||
#endif
|
||||
|
||||
zap_auth = zmq::socket_t{context, zmq::socket_type::rep};
|
||||
zap_auth.set(zmq::sockopt::linger, 0);
|
||||
zap_auth.bind(ZMQ_ADDR_ZAP);
|
||||
|
||||
workers_socket = zmq::socket_t{context, zmq::socket_type::router};
|
||||
workers_socket.set(zmq::sockopt::router_mandatory, true);
|
||||
workers_socket.bind(SN_ADDR_WORKERS);
|
||||
|
||||
workers.reserve(max_workers);
|
||||
idle_workers.resize(max_workers);
|
||||
if (!workers.empty() || !worker_sockets.empty())
|
||||
throw std::logic_error("Internal error: proxy thread started with active worker threads");
|
||||
worker_sockets.reserve(max_workers);
|
||||
// Pre-initialize these worker sockets rather than creating during thread initialization so that
|
||||
// we can't hit the zmq socket limit during worker thread startup.
|
||||
for (int i = 0; i < max_workers; i++)
|
||||
worker_sockets.emplace_back(context, zmq::socket_type::dealer);
|
||||
|
||||
#ifndef _WIN32
|
||||
int saved_umask = -1;
|
||||
if (STARTUP_UMASK >= 0)
|
||||
saved_umask = umask(STARTUP_UMASK);
|
||||
#endif
|
||||
|
||||
{
|
||||
zmq::socket_t inproc_listener{context, zmq::socket_type::router};
|
||||
inproc_listener.bind(SN_ADDR_SELF);
|
||||
inproc_listener_connid = next_conn_id++;
|
||||
connections.emplace_hint(connections.end(), inproc_listener_connid, std::move(inproc_listener));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < bind.size(); i++) {
|
||||
if (!proxy_bind(bind[i], i)) {
|
||||
OMQ_LOG(fatal, "OxenMQ failed to listen on ", bind[i].address);
|
||||
throw zmq::error_t{};
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef _WIN32
|
||||
if (saved_umask != -1)
|
||||
umask(saved_umask);
|
||||
|
||||
// set socket gid / uid if it is provided
|
||||
if (SOCKET_GID != -1 or SOCKET_UID != -1) {
|
||||
for (auto& b : bind) {
|
||||
const address addr(b.address);
|
||||
if (addr.ipc())
|
||||
if (chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1)
|
||||
throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
connections_updated = true;
|
||||
|
||||
// Also add an internal connection to self so that calling code can avoid needing to
|
||||
// special-case rare situations where we are supposed to talk to a quorum member that happens to
|
||||
// be ourselves (which can happen, for example, with cross-quoum Blink communication)
|
||||
// FIXME: not working
|
||||
//listener.bind(SN_ADDR_SELF);
|
||||
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
if (-1 == zmq_timers_add(timers.get(),
|
||||
std::chrono::milliseconds{CONN_CHECK_INTERVAL}.count(),
|
||||
[](int /*timer_id*/, void* self) { static_cast<OxenMQ*>(self)->proxy_conn_cleanup(); },
|
||||
this)) {
|
||||
throw zmq::error_t{};
|
||||
}
|
||||
|
||||
// Wait for tagged worker threads to get ready and connect to us (we get a "STARTING" message)
|
||||
// and send them back a "START" to let them know to go ahead with startup. We need this
|
||||
// synchronization dance to guarantee that the workers are routable before we can proceed.
|
||||
if (!tagged_workers.empty()) {
|
||||
OMQ_LOG(debug, "Waiting for tagged workers");
|
||||
{
|
||||
std::unique_lock lock{tagged_startup_mutex};
|
||||
tagged_go = tagged_go_mode::GO;
|
||||
}
|
||||
tagged_cv.notify_all();
|
||||
std::unordered_set<std::string_view> waiting_on;
|
||||
for (auto& w : tagged_workers)
|
||||
waiting_on.emplace(std::get<run_info>(w).worker_routing_id);
|
||||
for (std::vector<zmq::message_t> parts; !waiting_on.empty(); parts.clear()) {
|
||||
recv_message_parts(workers_socket, parts);
|
||||
if (parts.size() != 2 || view(parts[1]) != "STARTING"sv) {
|
||||
OMQ_LOG(error, "Received invalid message on worker socket while waiting for tagged thread startup");
|
||||
continue;
|
||||
}
|
||||
OMQ_LOG(debug, "Received STARTING message from ", view(parts[0]));
|
||||
if (auto it = waiting_on.find(view(parts[0])); it != waiting_on.end())
|
||||
waiting_on.erase(it);
|
||||
else
|
||||
OMQ_LOG(error, "Received STARTING message from unknown worker ", view(parts[0]));
|
||||
}
|
||||
|
||||
for (auto&w : tagged_workers) {
|
||||
OMQ_LOG(debug, "Telling tagged thread worker ", std::get<run_info>(w).worker_routing_name, " to finish startup");
|
||||
route_control(workers_socket, std::get<run_info>(w).worker_routing_id, "START");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_loop(std::promise<void> startup) {
|
||||
try {
|
||||
proxy_loop_init();
|
||||
} catch (...) {
|
||||
startup.set_exception(std::current_exception());
|
||||
return;
|
||||
}
|
||||
startup.set_value();
|
||||
|
||||
// Fixed array used for worker and control messages: these are never longer than 3 parts:
|
||||
std::array<zmq::message_t, 3> control_parts;
|
||||
|
||||
// General vector for handling incoming messages:
|
||||
std::vector<zmq::message_t> parts;
|
||||
|
||||
std::vector<std::pair<const int64_t, zmq::socket_t>*> queue; // Used as a circular buffer
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
std::vector<struct epoll_event> evs;
|
||||
#endif
|
||||
|
||||
while (true) {
|
||||
std::chrono::milliseconds poll_timeout;
|
||||
if (max_workers == 0) { // Will be 0 only if we are quitting
|
||||
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); }) &&
|
||||
std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto &w) { return std::get<0>(w).worker_thread.joinable(); })) {
|
||||
// All the workers have finished, so we can finish shutting down
|
||||
return proxy_quit();
|
||||
}
|
||||
poll_timeout = 1s; // We don't keep running timers when we're quitting, so don't have a timer to check
|
||||
} else {
|
||||
poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())};
|
||||
}
|
||||
|
||||
if (connections_updated) {
|
||||
rebuild_pollitems();
|
||||
// If we just rebuilt the queue then do a full check of everything, because we might
|
||||
// have sockets that already edge-triggered that we need to fully drain before we start
|
||||
// polling.
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
// We round-robin connections when pulling off pending messages one-by-one rather than
|
||||
// pulling off all messages from one connection before moving to the next; thus in cases of
|
||||
// contention we end up fairly distributing.
|
||||
queue.reserve(connections.size() + 1);
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
bool process_command = false, process_worker = false, process_zap = false, process_all = false;
|
||||
|
||||
if (proxy_skip_one_poll) {
|
||||
proxy_skip_one_poll = false;
|
||||
|
||||
process_command = command.get(zmq::sockopt::events) & ZMQ_POLLIN;
|
||||
process_worker = workers_socket.get(zmq::sockopt::events) & ZMQ_POLLIN;
|
||||
process_zap = zap_auth.get(zmq::sockopt::events) & ZMQ_POLLIN;
|
||||
process_all = true;
|
||||
}
|
||||
else {
|
||||
OMQ_TRACE("polling for new messages via epoll");
|
||||
|
||||
evs.resize(3 + connections.size());
|
||||
const int max = epoll_wait(epoll_fd, evs.data(), evs.size(), poll_timeout.count());
|
||||
|
||||
queue.clear();
|
||||
for (int i = 0; i < max; i++) {
|
||||
const auto conn_id = evs[i].data.u64;
|
||||
if (conn_id == EPOLL_COMMAND_ID)
|
||||
process_command = true;
|
||||
else if (conn_id == EPOLL_WORKER_ID)
|
||||
process_worker = true;
|
||||
else if (conn_id == EPOLL_ZAP_ID)
|
||||
process_zap = true;
|
||||
else if (auto it = connections.find(conn_id); it != connections.end())
|
||||
queue.push_back(&*it);
|
||||
}
|
||||
queue.push_back(nullptr);
|
||||
}
|
||||
|
||||
#else
|
||||
if (proxy_skip_one_poll)
|
||||
proxy_skip_one_poll = false;
|
||||
else {
|
||||
OMQ_TRACE("polling for new messages");
|
||||
|
||||
// We poll the control socket and worker socket for any incoming messages. If we have
|
||||
// available worker room then also poll incoming connections and outgoing connections
|
||||
// for messages to forward to a worker. Otherwise, we just look for a control message
|
||||
// or a worker coming back with a ready message.
|
||||
zmq::poll(pollitems.data(), pollitems.size(), poll_timeout);
|
||||
}
|
||||
|
||||
constexpr bool process_command = true, process_worker = true, process_zap = true, process_all = true;
|
||||
#endif
|
||||
|
||||
if (process_command) {
|
||||
OMQ_TRACE("processing control messages");
|
||||
while (size_t len = recv_message_parts(command, control_parts, zmq::recv_flags::dontwait))
|
||||
proxy_control_message(control_parts, len);
|
||||
}
|
||||
|
||||
if (process_worker) {
|
||||
OMQ_TRACE("processing worker messages");
|
||||
while (size_t len = recv_message_parts(workers_socket, control_parts, zmq::recv_flags::dontwait))
|
||||
proxy_worker_message(control_parts, len);
|
||||
}
|
||||
|
||||
OMQ_TRACE("processing timers");
|
||||
zmq_timers_execute(timers.get());
|
||||
|
||||
if (process_zap) {
|
||||
// Handle any zap authentication
|
||||
OMQ_TRACE("processing zap requests");
|
||||
process_zap_requests();
|
||||
}
|
||||
|
||||
// See if we can drain anything from the current queue before we potentially add to it
|
||||
// below.
|
||||
OMQ_TRACE("processing queued jobs and messages");
|
||||
proxy_process_queue();
|
||||
|
||||
OMQ_TRACE("processing new incoming messages");
|
||||
if (process_all) {
|
||||
queue.clear();
|
||||
for (auto& id_sock : connections)
|
||||
if (id_sock.second.get(zmq::sockopt::events) & ZMQ_POLLIN)
|
||||
queue.push_back(&id_sock);
|
||||
queue.push_back(nullptr);
|
||||
}
|
||||
|
||||
size_t end = queue.size() - 1;
|
||||
|
||||
for (size_t pos = 0; pos != end; ++pos %= queue.size()) {
|
||||
parts.clear();
|
||||
auto& [id, sock] = *queue[pos];
|
||||
|
||||
if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait))
|
||||
continue;
|
||||
|
||||
// We only pull this one message now but then requeue the socket so that after we check
|
||||
// all other sockets we come back to this one to check again.
|
||||
queue[end] = queue[pos];
|
||||
++end %= queue.size();
|
||||
|
||||
if (parts.empty()) {
|
||||
OMQ_LOG(warn, "Ignoring empty (0-part) incoming message");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!proxy_handle_builtin(id, sock, parts))
|
||||
proxy_to_worker(id, sock, parts);
|
||||
|
||||
if (connections_updated) {
|
||||
// If connections got updated then our points are stale, so restart the proxy loop;
|
||||
// we'll immediately end up right back here at least once before we resume polling.
|
||||
OMQ_TRACE("connections became stale; short-circuiting incoming message loop");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
// If any socket still has ZMQ_POLLIN (which is possible if something we did above changed
|
||||
// state on another socket, perhaps by writing to it) then we need to repeat the loop
|
||||
// *without* going back to epoll again, until we get through everything without any
|
||||
// ZMQ_POLLIN sockets. If we didn't, we could miss it and might end up deadlocked because
|
||||
// of ZMQ's edge-triggered notifications on zmq fd's.
|
||||
//
|
||||
// More info on the complexities here at https://github.com/zeromq/libzmq/issues/3641 and
|
||||
// https://funcptr.net/2012/09/10/zeromq---edge-triggered-notification/
|
||||
if (!connections_updated && !proxy_skip_one_poll) {
|
||||
for (auto* s : {&command, &workers_socket, &zap_auth}) {
|
||||
if (s->get(zmq::sockopt::events) & ZMQ_POLLIN) {
|
||||
proxy_skip_one_poll = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!proxy_skip_one_poll) {
|
||||
for (auto& [id, sock] : connections) {
|
||||
if (sock.get(zmq::sockopt::events) & ZMQ_POLLIN) {
|
||||
proxy_skip_one_poll = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
OMQ_TRACE("done proxy loop");
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_error_response(std::string_view cmd) {
|
||||
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
|
||||
}
|
||||
|
||||
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
|
||||
// reason)
|
||||
bool OxenMQ::proxy_handle_builtin(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
|
||||
// Doubling as a bool and an offset:
|
||||
size_t incoming = sock.get(zmq::sockopt::type) == ZMQ_ROUTER;
|
||||
|
||||
std::string_view route, cmd;
|
||||
if (parts.size() < 1 + incoming) {
|
||||
OMQ_LOG(warn, "Received empty message; ignoring");
|
||||
return true;
|
||||
}
|
||||
if (incoming) {
|
||||
route = view(parts[0]);
|
||||
cmd = view(parts[1]);
|
||||
} else {
|
||||
cmd = view(parts[0]);
|
||||
}
|
||||
OMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back()));
|
||||
|
||||
if (cmd == "REPLY") {
|
||||
size_t tag_pos = 1 + incoming;
|
||||
if (parts.size() <= tag_pos) {
|
||||
OMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
|
||||
return true;
|
||||
}
|
||||
std::string reply_tag{view(parts[tag_pos])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
OMQ_LOG(debug, "Received REPLY for pending command ", oxenc::to_hex(reply_tag), "; scheduling callback");
|
||||
std::vector<std::string> data;
|
||||
data.reserve(parts.size() - (tag_pos + 1));
|
||||
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
|
||||
data.emplace_back(view(*it));
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), data=std::move(data)] {
|
||||
callback(true, std::move(data));
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
OMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", oxenc::to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
return true;
|
||||
} else if (cmd == "HI") {
|
||||
if (!incoming) {
|
||||
OMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
OMQ_LOG(debug, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO");
|
||||
try {
|
||||
send_routed_message(sock, std::string{route}, "HELLO");
|
||||
} catch (const std::exception &e) { OMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
|
||||
return true;
|
||||
} else if (cmd == "HELLO") {
|
||||
if (incoming) {
|
||||
OMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto it = std::find_if(pending_connects.begin(), pending_connects.end(),
|
||||
[&](auto& pc) { return std::get<int64_t>(pc) == conn_id; });
|
||||
if (it == pending_connects.end()) {
|
||||
OMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto& pc = *it;
|
||||
auto pit = peers.find(std::get<int64_t>(pc));
|
||||
if (pit == peers.end()) {
|
||||
OMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring");
|
||||
return true;
|
||||
}
|
||||
|
||||
OMQ_LOG(debug, "Got initial HELLO server response from ", peer_address(parts.back()));
|
||||
proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)),
|
||||
conn=pit->first] {
|
||||
on_success(conn);
|
||||
});
|
||||
pending_connects.erase(it);
|
||||
return true;
|
||||
} else if (cmd == "BYE") {
|
||||
if (!incoming) {
|
||||
OMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
|
||||
proxy_close_connection(conn_id, 0s);
|
||||
} else {
|
||||
OMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
else if (is_error_response(cmd)) {
|
||||
// These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to
|
||||
// invoke something that doesn't exist or we don't have permission to access. These have
|
||||
// two forms (the latter is only sent by remotes running 1.1.0+).
|
||||
// - ["XXX", "whatever.command"]
|
||||
// - ["XXX", "REPLY", replytag]
|
||||
// (ignoring the routing prefix on incoming commands).
|
||||
// For the former, we log; for the latter we trigger the reply callback with a failure
|
||||
|
||||
if (parts.size() == (1 + incoming) && cmd == "UNKNOWNCOMMAND") {
|
||||
// pre-1.1.0 sent just a plain UNKNOWNCOMMAND (without the actual command); this was not
|
||||
// useful, but also this response is *expected* for things 1.0.5 didn't understand, like
|
||||
// FORBIDDEN_SN: so log it only at debug level and move on.
|
||||
OMQ_LOG(debug, "Received plain UNKNOWNCOMMAND; remote is probably an older oxenmq. Ignoring.");
|
||||
return true;
|
||||
}
|
||||
|
||||
if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") {
|
||||
std::string reply_tag{view(parts[2 + incoming])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
OMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", oxenc::to_hex(reply_tag), "; scheduling failure callback");
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] {
|
||||
callback(false, {{std::move(cmd)}});
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
OMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", oxenc::to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
} else {
|
||||
OMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"sv),
|
||||
" from ", peer_address(parts.back()));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_process_queue() {
|
||||
if (max_workers == 0) // shutting down
|
||||
return;
|
||||
|
||||
// First: send any tagged thread tasks to the tagged threads, if idle
|
||||
for (auto& [run, busy, queue] : tagged_workers) {
|
||||
if (!busy && !queue.empty()) {
|
||||
busy = true;
|
||||
proxy_run_worker(run.load(std::move(queue.front()), false, run.worker_id));
|
||||
queue.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// Second: process any batch jobs; since these are internal they are given higher priority.
|
||||
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
|
||||
|
||||
// Next any reply batch jobs (which are a bit different from the above, since they are
|
||||
// externally triggered but for things we initiated locally).
|
||||
proxy_run_batch_jobs(reply_jobs, reply_jobs_reserved, reply_jobs_active, true);
|
||||
|
||||
// Finally general incoming commands
|
||||
for (auto it = pending_commands.begin(); it != pending_commands.end() && active_workers() < max_workers; ) {
|
||||
auto& pending = *it;
|
||||
if (pending.cat.active_threads < pending.cat.reserved_threads
|
||||
|| active_workers() < general_workers) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(pending)));
|
||||
pending.cat.queued--;
|
||||
pending.cat.active_threads++;
|
||||
assert(pending.cat.queued >= 0);
|
||||
it = pending_commands.erase(it);
|
||||
} else {
|
||||
++it; // no available general or reserved worker spots for this job right now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
#pragma once
|
||||
|
||||
#include "connections.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
namespace detail {
|
||||
struct no_data_t final {};
|
||||
inline constexpr no_data_t no_data{};
|
||||
|
||||
template <typename UserData>
|
||||
struct SubData {
|
||||
std::chrono::steady_clock::time_point expiry;
|
||||
UserData user_data;
|
||||
explicit SubData(std::chrono::steady_clock::time_point _exp)
|
||||
: expiry{_exp}, user_data{} {}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SubData<void> {
|
||||
std::chrono::steady_clock::time_point expiry;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* OMQ Subscription class. Handles pub/sub connections such that the user only needs to call
|
||||
* methods to subscribe and publish.
|
||||
*
|
||||
* FIXME: do we want an unsubscribe, or is expiry / conn management sufficient?
|
||||
*
|
||||
* Type UserData can contain whatever information the user may need at publish time, for example if
|
||||
* the subscription is for logs the subscriber can specify log levels or categories, and the
|
||||
* publisher can choose to send or not based on those. The UserData type, if provided and non-void,
|
||||
* must be default constructible, must be comparable with ==, and must be movable.
|
||||
*/
|
||||
template <typename UserData = void>
|
||||
class Subscription {
|
||||
static constexpr bool have_user_data = !std::is_void_v<UserData>;
|
||||
using UserData_if_present = std::conditional_t<have_user_data, UserData, detail::no_data_t>;
|
||||
using subdata_t = detail::SubData<UserData>;
|
||||
|
||||
std::unordered_map<ConnectionID, subdata_t> subs;
|
||||
std::shared_mutex _mutex;
|
||||
const std::string description; // description of the sub for logging
|
||||
const std::chrono::milliseconds sub_duration; // extended by re-subscribe
|
||||
|
||||
public:
|
||||
|
||||
Subscription() = delete;
|
||||
Subscription(std::string description, std::chrono::milliseconds sub_duration = 30min)
|
||||
: description{std::move(description)}, sub_duration{sub_duration} {}
|
||||
|
||||
// returns true if new sub, false if refresh sub. throws on error. `data` will be checked
|
||||
// against the existing data: if there is existing data and it compares `==` to the given value,
|
||||
// false is returned (and the existing data is not replaced). Otherwise the given data gets
|
||||
// stored for this connection (replacing existing data, if present), and true is returned.
|
||||
bool subscribe(const ConnectionID& conn, UserData_if_present data) {
|
||||
std::unique_lock lock{_mutex};
|
||||
auto expiry = std::chrono::steady_clock::now() + sub_duration;
|
||||
auto [value, added] = subs.emplace(conn, subdata_t{expiry});
|
||||
if (added) {
|
||||
if constexpr (have_user_data)
|
||||
value->second.user_data = std::move(data);
|
||||
return true;
|
||||
}
|
||||
|
||||
value->second.expiry = expiry;
|
||||
|
||||
if constexpr (have_user_data) {
|
||||
// if user_data changed, consider it a new sub rather than refresh, and update
|
||||
// user_data in the mapped value.
|
||||
if (!(value->second.user_data == data)) {
|
||||
value->second.user_data = std::move(data);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// no-user-data version, only available for Subscription<void> (== Subscription without a
|
||||
// UserData type).
|
||||
template <bool enable = !have_user_data, std::enable_if_t<enable, int> = 0>
|
||||
bool subscribe(const ConnectionID& conn) {
|
||||
return subscribe(conn, detail::no_data);
|
||||
}
|
||||
|
||||
// unsubscribe a connection ID. return the user data, if a sub was present.
|
||||
template <bool enable = have_user_data, std::enable_if_t<enable, int> = 0>
|
||||
std::optional<UserData> unsubscribe(const ConnectionID& conn) {
|
||||
std::unique_lock lock{_mutex};
|
||||
|
||||
auto node = subs.extract(conn);
|
||||
if (!node.empty())
|
||||
return node.mapped().user_data;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// no-user-data version, only available for Subscription<void> (== Subscription without a
|
||||
// UserData type).
|
||||
template <bool enable = !have_user_data, std::enable_if_t<enable, int> = 0>
|
||||
bool unsubscribe(const ConnectionID& conn) {
|
||||
std::unique_lock lock{_mutex};
|
||||
auto node = subs.extract(conn);
|
||||
return !node.empty(); // true if removed, false if wasn't present
|
||||
}
|
||||
|
||||
// force removal of expired subscriptions. removal will otherwise only happen on publish
|
||||
void remove_expired() {
|
||||
std::unique_lock lock{_mutex};
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
for (auto itr = subs.begin(); itr != subs.end();) {
|
||||
if (itr->second.expiry < now)
|
||||
itr = subs.erase(itr);
|
||||
else
|
||||
itr++;
|
||||
}
|
||||
}
|
||||
|
||||
// Func is any callable which takes:
|
||||
// - (const ConnectionID&, const UserData&) for Subscription<UserData> with non-void UserData
|
||||
// - (const ConnectionID&) for Subscription<void>.
|
||||
template <typename Func>
|
||||
void publish(Func&& func) {
|
||||
std::vector<ConnectionID> to_remove;
|
||||
{
|
||||
std::shared_lock lock(_mutex);
|
||||
if (subs.empty())
|
||||
return;
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
for (const auto& [conn, sub] : subs) {
|
||||
if (sub.expiry < now)
|
||||
to_remove.push_back(conn);
|
||||
else if constexpr (have_user_data)
|
||||
func(conn, sub.user_data);
|
||||
else
|
||||
func(conn);
|
||||
}
|
||||
}
|
||||
|
||||
if (to_remove.empty())
|
||||
return;
|
||||
|
||||
std::unique_lock lock{_mutex};
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
for (auto& conn : to_remove) {
|
||||
auto it = subs.find(conn);
|
||||
if (it != subs.end() && it->second.expiry < now /* recheck: client might have resubscribed in between locks */) {
|
||||
subs.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // namespace oxenmq
|
||||
|
||||
// vim:sw=4:et
|
|
@ -0,0 +1,5 @@
|
|||
namespace oxenmq {
|
||||
constexpr int VERSION_MAJOR = @PROJECT_VERSION_MAJOR@;
|
||||
constexpr int VERSION_MINOR = @PROJECT_VERSION_MINOR@;
|
||||
constexpr int VERSION_PATCH = @PROJECT_VERSION_PATCH@;
|
||||
}
|
|
@ -0,0 +1,431 @@
|
|||
#include "oxenmq.h"
|
||||
#include "batch.h"
|
||||
#include "oxenmq-internal.h"
|
||||
|
||||
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
extern "C" {
|
||||
#include <pthread.h>
|
||||
#include <pthread_np.h>
|
||||
}
|
||||
#endif
|
||||
#include <oxenc/variant.h>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
namespace {
|
||||
|
||||
// Waits for a specific command or "QUIT" on the given socket. Returns true if the command was
|
||||
// received. If "QUIT" was received, replies with "QUITTING" on the socket and closes it, then
|
||||
// returns false.
|
||||
[[gnu::always_inline]] inline
|
||||
bool worker_wait_for(OxenMQ& omq, zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const std::string_view worker_id, const std::string_view expect) {
|
||||
while (true) {
|
||||
omq.log(LogLevel::trace, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect);
|
||||
parts.clear();
|
||||
recv_message_parts(sock, parts);
|
||||
if (parts.size() != 1) {
|
||||
omq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part control msg");
|
||||
continue;
|
||||
}
|
||||
auto command = view(parts[0]);
|
||||
if (command == expect) {
|
||||
#ifndef NDEBUG
|
||||
omq.log(LogLevel::trace, __FILE__, __LINE__, "Worker ", worker_id, " received waited-for ", expect, " command");
|
||||
#endif
|
||||
return true;
|
||||
} else if (command == "QUIT"sv) {
|
||||
omq.log(LogLevel::debug, __FILE__, __LINE__, "Worker ", worker_id, " received QUIT command, shutting down");
|
||||
detail::send_control(sock, "QUITTING");
|
||||
sock.set(zmq::sockopt::linger, 1000);
|
||||
sock.close();
|
||||
return false;
|
||||
} else {
|
||||
omq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void OxenMQ::worker_thread(unsigned int index, std::optional<std::string> tagged, std::function<void()> start) {
|
||||
std::string routing_id = (tagged ? "t" : "w") +
|
||||
std::string(reinterpret_cast<const char*>(&index), sizeof(index)); // for routing
|
||||
std::string worker_id{tagged ? *tagged : "w" + std::to_string(index)}; // for debug
|
||||
|
||||
[[maybe_unused]] std::string thread_name = tagged.value_or("omq-" + worker_id);
|
||||
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
|
||||
if (thread_name.size() > 15) thread_name.resize(15);
|
||||
pthread_setname_np(pthread_self(), thread_name.c_str());
|
||||
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
pthread_set_name_np(pthread_self(), thread_name.c_str());
|
||||
#elif defined(__MACH__)
|
||||
pthread_setname_np(thread_name.c_str());
|
||||
#endif
|
||||
|
||||
std::optional<zmq::socket_t> tagged_socket;
|
||||
if (tagged) {
|
||||
// If we're a tagged worker then we got started up before OxenMQ started, so we need to wait
|
||||
// for an all-clear signal from OxenMQ first, then we fire our `start` callback, then we can
|
||||
// start waiting for commands in the main loop further down. (We also can't get the
|
||||
// reference to our `tagged_workers` element or create a socket until the main proxy thread
|
||||
// is running).
|
||||
{
|
||||
std::unique_lock lock{tagged_startup_mutex};
|
||||
tagged_cv.wait(lock, [this] { return tagged_go != tagged_go_mode::WAIT; });
|
||||
}
|
||||
if (tagged_go == tagged_go_mode::SHUTDOWN) // OxenMQ destroyed without starting
|
||||
return;
|
||||
tagged_socket.emplace(context, zmq::socket_type::dealer);
|
||||
}
|
||||
auto& sock = tagged ? *tagged_socket : worker_sockets[index];
|
||||
sock.set(zmq::sockopt::routing_id, routing_id);
|
||||
OMQ_LOG(debug, "New worker thread ", worker_id, " (", routing_id, ") started");
|
||||
sock.connect(SN_ADDR_WORKERS);
|
||||
if (tagged)
|
||||
detail::send_control(sock, "STARTING");
|
||||
|
||||
Message message{*this, 0, AuthLevel::none, ""s};
|
||||
std::vector<zmq::message_t> parts;
|
||||
|
||||
bool waiting_for_command;
|
||||
if (tagged) {
|
||||
waiting_for_command = true;
|
||||
|
||||
if (!worker_wait_for(*this, sock, parts, worker_id, "START"sv))
|
||||
return;
|
||||
if (start) start();
|
||||
} else {
|
||||
// Otherwise for a regular worker we can only be started by an active main proxy thread
|
||||
// which will have preloaded our first job so we can start off right away.
|
||||
waiting_for_command = false;
|
||||
}
|
||||
|
||||
// This will always contains the current job, and is guaranteed to never be invalidated.
|
||||
run_info& run = tagged ? std::get<run_info>(tagged_workers[index - 1]) : workers[index];
|
||||
|
||||
while (true) {
|
||||
if (waiting_for_command) {
|
||||
if (!worker_wait_for(*this, sock, parts, worker_id, "RUN"sv))
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
if (run.is_batch_job) {
|
||||
auto* batch = var::get<detail::Batch*>(run.to_run);
|
||||
if (run.batch_jobno >= 0) {
|
||||
OMQ_TRACE("worker thread ", worker_id, " running batch ", batch, "#", run.batch_jobno);
|
||||
batch->run_job(run.batch_jobno);
|
||||
} else if (run.batch_jobno == -1) {
|
||||
OMQ_TRACE("worker thread ", worker_id, " running batch ", batch, " completion");
|
||||
batch->job_completion();
|
||||
}
|
||||
} else if (run.is_injected) {
|
||||
auto& func = var::get<std::function<void()>>(run.to_run);
|
||||
OMQ_TRACE("worker thread ", worker_id, " invoking injected command ", run.command);
|
||||
func();
|
||||
func = nullptr;
|
||||
} else {
|
||||
message.conn = run.conn;
|
||||
message.access = run.access;
|
||||
message.remote = std::move(run.remote);
|
||||
message.data.clear();
|
||||
|
||||
OMQ_TRACE("Got incoming command from ", message.remote, "/", message.conn, message.conn.route.empty() ? " (outgoing)" : " (incoming)");
|
||||
|
||||
auto& [callback, is_request] = *var::get<const std::pair<CommandCallback, bool>*>(run.to_run);
|
||||
if (is_request) {
|
||||
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
|
||||
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
|
||||
message.data.emplace_back(it->data<char>(), it->size());
|
||||
} else {
|
||||
for (auto& m : run.data_parts)
|
||||
message.data.emplace_back(m.data<char>(), m.size());
|
||||
}
|
||||
|
||||
OMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
|
||||
callback(message);
|
||||
}
|
||||
}
|
||||
catch (const oxenc::bt_deserialize_invalid& e) {
|
||||
OMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
|
||||
}
|
||||
#ifndef BROKEN_APPLE_VARIANT
|
||||
catch (const std::bad_variant_access& e) {
|
||||
OMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
|
||||
}
|
||||
#endif
|
||||
catch (const std::out_of_range& e) {
|
||||
OMQ_LOG(warn, worker_id, " deserialization failed: invalid data - required field missing (", e.what(), "); ignoring request");
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
OMQ_LOG(warn, worker_id, " caught exception when processing command: ", e.what());
|
||||
}
|
||||
catch (...) {
|
||||
OMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
|
||||
}
|
||||
|
||||
// Tell the proxy thread that we are ready for another job
|
||||
detail::send_control(sock, "RAN");
|
||||
waiting_for_command = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
OxenMQ::run_info& OxenMQ::get_idle_worker() {
|
||||
if (idle_worker_count == 0) {
|
||||
uint32_t id = workers.size();
|
||||
workers.emplace_back();
|
||||
auto& r = workers.back();
|
||||
r.worker_id = id;
|
||||
r.worker_routing_id = "w" + std::string(reinterpret_cast<const char*>(&id), sizeof(id));
|
||||
r.worker_routing_name = "w" + std::to_string(id);
|
||||
return r;
|
||||
}
|
||||
size_t id = idle_workers[--idle_worker_count];
|
||||
return workers[id];
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_worker_message(OxenMQ::control_message_array& parts, size_t len) {
|
||||
// Process messages sent by workers
|
||||
if (len != 2) {
|
||||
OMQ_LOG(error, "Received send invalid ", len, "-part message");
|
||||
return;
|
||||
}
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
if (route.size() != 5 || (route[0] != 'w' && route[0] != 't')) {
|
||||
OMQ_LOG(error, "Received malformed worker id in worker message; unable to process worker command");
|
||||
return;
|
||||
}
|
||||
bool tagged_worker = route[0] == 't';
|
||||
uint32_t worker_id;
|
||||
std::memcpy(&worker_id, route.data() + 1, 4);
|
||||
if (tagged_worker
|
||||
? 0 == worker_id || worker_id > tagged_workers.size() // tagged worker ids are indexed from 1 to N (0 means untagged)
|
||||
: worker_id >= workers.size()) { // regular worker ids are indexed from 0 to N-1
|
||||
OMQ_LOG(error, "Received invalid worker id w" + std::to_string(worker_id) + " in worker message; unable to process worker command");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = tagged_worker ? std::get<run_info>(tagged_workers[worker_id - 1]) : workers[worker_id];
|
||||
|
||||
OMQ_TRACE("received ", cmd, " command from ", route);
|
||||
if (cmd == "RAN"sv) {
|
||||
OMQ_TRACE("Worker ", route, " finished ", run.is_batch_job ? "batch job" : run.command);
|
||||
if (run.is_batch_job) {
|
||||
if (tagged_worker) {
|
||||
std::get<bool>(tagged_workers[worker_id - 1]) = false;
|
||||
} else {
|
||||
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
|
||||
assert(active > 0);
|
||||
active--;
|
||||
}
|
||||
bool clear_job = false;
|
||||
auto* batch = var::get<detail::Batch*>(run.to_run);
|
||||
if (run.batch_jobno == -1) {
|
||||
// Returned from the completion function
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto [state, thread] = batch->job_finished();
|
||||
if (state == detail::BatchState::complete) {
|
||||
if (thread == -1) { // run directly in proxy
|
||||
OMQ_TRACE("Completion job running directly in proxy");
|
||||
try {
|
||||
batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
|
||||
} catch (const std::exception &e) {
|
||||
// Raise these to error levels: the caller really shouldn't be doing
|
||||
// anything non-trivial in an in-proxy completion function!
|
||||
OMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
|
||||
} catch (...) {
|
||||
OMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
|
||||
}
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto& jobs =
|
||||
thread > 0
|
||||
? std::get<batch_queue>(tagged_workers[thread - 1]) // run in tagged thread
|
||||
: run.is_reply_job
|
||||
? reply_jobs
|
||||
: batch_jobs;
|
||||
jobs.emplace_back(batch, -1);
|
||||
}
|
||||
} else if (state == detail::BatchState::done) {
|
||||
// No completion job
|
||||
clear_job = true;
|
||||
}
|
||||
// else the job is still running
|
||||
}
|
||||
|
||||
if (clear_job) {
|
||||
delete batch;
|
||||
}
|
||||
} else {
|
||||
assert(run.cat->active_threads > 0);
|
||||
run.cat->active_threads--;
|
||||
}
|
||||
if (max_workers == 0) { // Shutting down
|
||||
OMQ_TRACE("Telling worker ", route, " to quit");
|
||||
route_control(workers_socket, route, "QUIT");
|
||||
} else if (!tagged_worker) {
|
||||
idle_workers[idle_worker_count++] = worker_id;
|
||||
}
|
||||
} else if (cmd == "QUITTING"sv) {
|
||||
run.worker_thread.join();
|
||||
OMQ_LOG(debug, "Worker ", route, " exited normally");
|
||||
} else {
|
||||
OMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_run_worker(run_info& run) {
|
||||
if (!run.worker_thread.joinable())
|
||||
run.worker_thread = std::thread{&OxenMQ::worker_thread, this, run.worker_id, std::nullopt, nullptr};
|
||||
else
|
||||
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_to_worker(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
|
||||
bool outgoing = sock.get(zmq::sockopt::type) == ZMQ_DEALER;
|
||||
|
||||
peer_info tmp_peer;
|
||||
tmp_peer.conn_id = conn_id;
|
||||
if (!outgoing) tmp_peer.route = parts[0].to_string();
|
||||
peer_info* peer = nullptr;
|
||||
if (outgoing) {
|
||||
auto snit = outgoing_sn_conns.find(conn_id);
|
||||
auto it = snit != outgoing_sn_conns.end()
|
||||
? peers.find(snit->second)
|
||||
: peers.find(conn_id);
|
||||
|
||||
if (it == peers.end()) {
|
||||
OMQ_LOG(warn, "Internal error: connection id ", conn_id, " not found");
|
||||
return;
|
||||
}
|
||||
peer = &it->second;
|
||||
} else if (conn_id == inproc_listener_connid) {
|
||||
tmp_peer.auth_level = AuthLevel::admin;
|
||||
tmp_peer.pubkey = pubkey;
|
||||
tmp_peer.service_node = active_service_nodes.count(pubkey);
|
||||
peer = &tmp_peer;
|
||||
} else {
|
||||
std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
|
||||
tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey);
|
||||
|
||||
if (tmp_peer.service_node) {
|
||||
// It's a service node so we should have a peer_info entry; see if we can find one with
|
||||
// the same route, and if not, add one.
|
||||
auto pr = peers.equal_range(tmp_peer.pubkey);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
if (it->second.conn_id == tmp_peer.conn_id && it->second.route == tmp_peer.route) {
|
||||
peer = &it->second;
|
||||
// Update the stored auth level just in case the peer reconnected
|
||||
peer->auth_level = tmp_peer.auth_level;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!peer) {
|
||||
// We don't have a record: this is either a new SN connection or a new message on a
|
||||
// connection that recently gained SN status.
|
||||
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
|
||||
}
|
||||
} else {
|
||||
// Incoming, non-SN connection: we don't store a peer_info for this, so just use the
|
||||
// temporary one
|
||||
peer = &tmp_peer;
|
||||
}
|
||||
}
|
||||
|
||||
size_t command_part_index = outgoing ? 0 : 1;
|
||||
std::string command = parts[command_part_index].to_string();
|
||||
|
||||
// Steal any data message parts
|
||||
size_t data_part_index = command_part_index + 1;
|
||||
std::vector<zmq::message_t> data_parts;
|
||||
data_parts.reserve(parts.size() - data_part_index);
|
||||
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
|
||||
data_parts.push_back(std::move(*it));
|
||||
|
||||
auto cat_call = get_command(command);
|
||||
|
||||
// Check that command is valid, that we have permission, etc.
|
||||
if (!proxy_check_auth(conn_id, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
|
||||
return;
|
||||
|
||||
auto& category = *cat_call.first;
|
||||
Access access{peer->auth_level, peer->service_node, local_service_node};
|
||||
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free reserved or general spots, try to queue it for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
OMQ_LOG(warn, "No space to queue incoming command ", command, "; already have ", category.queued,
|
||||
"commands queued in that category (max ", category.max_queue, "); dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
OMQ_LOG(debug, "No available free workers, queuing ", command, " for later");
|
||||
ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey, std::move(tmp_peer.route)};
|
||||
pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second,
|
||||
std::move(conn), std::move(access), peer_address(parts[command_part_index]));
|
||||
category.queued++;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cat_call.second->second /*is_request*/ && data_parts.empty()) {
|
||||
OMQ_LOG(warn, "Received an invalid request command with no reply tag; dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = get_idle_worker();
|
||||
{
|
||||
ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey};
|
||||
c.route = std::move(tmp_peer.route);
|
||||
if (outgoing || peer->service_node)
|
||||
tmp_peer.route.clear();
|
||||
run.load(&category, std::move(command), std::move(c), std::move(access), peer_address(parts[command_part_index]),
|
||||
std::move(data_parts), cat_call.second);
|
||||
}
|
||||
|
||||
if (outgoing)
|
||||
peer->activity(); // outgoing connection activity, pump the activity timer
|
||||
|
||||
OMQ_TRACE("Forwarding incoming ", run.command, " from ", run.conn, " @ ", peer_address(parts[command_part_index]),
|
||||
" to worker ", run.worker_routing_name);
|
||||
|
||||
proxy_run_worker(run);
|
||||
category.active_threads++;
|
||||
}
|
||||
|
||||
void OxenMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function<void()> callback) {
|
||||
if (!callback) return;
|
||||
auto it = categories.find(category);
|
||||
if (it == categories.end())
|
||||
throw std::out_of_range{"Invalid category `" + category + "': category does not exist"};
|
||||
detail::send_control(get_control_socket(), "INJECT", oxenc::bt_serialize(detail::serialize_object(
|
||||
injected_task{it->second, std::move(command), std::move(remote), std::move(callback)})));
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_inject_task(injected_task task) {
|
||||
auto& category = task.cat;
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free worker slot, queue for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
OMQ_LOG(warn, "No space to queue injected task ", task.command, "; already have ", category.queued,
|
||||
"commands queued in that category (max ", category.max_queue, "); dropping task");
|
||||
return;
|
||||
}
|
||||
OMQ_LOG(debug, "No available free workers for injected task ", task.command, "; queuing for later");
|
||||
pending_commands.emplace_back(category, std::move(task.command), std::move(task.callback), std::move(task.remote));
|
||||
category.queued++;
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = get_idle_worker();
|
||||
OMQ_TRACE("Forwarding incoming injected task ", task.command, " from ", task.remote, " to worker ", run.worker_routing_name);
|
||||
run.load(&category, std::move(task.command), std::move(task.remote), std::move(task.callback));
|
||||
|
||||
proxy_run_worker(run);
|
||||
category.active_threads++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,26 +1,27 @@
|
|||
|
||||
add_subdirectory(Catch2)
|
||||
|
||||
set(LMQ_TEST_SRC
|
||||
add_executable(tests
|
||||
main.cpp
|
||||
test_address.cpp
|
||||
test_batch.cpp
|
||||
test_connect.cpp
|
||||
test_commands.cpp
|
||||
test_failures.cpp
|
||||
test_inject.cpp
|
||||
test_pubsub.cpp
|
||||
test_requests.cpp
|
||||
test_string_view.cpp
|
||||
)
|
||||
|
||||
add_executable(tests ${LMQ_TEST_SRC})
|
||||
test_socket_limit.cpp
|
||||
test_tagged_threads.cpp
|
||||
test_timer.cpp
|
||||
)
|
||||
|
||||
find_package(Threads)
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(SODIUM REQUIRED libsodium)
|
||||
|
||||
target_link_libraries(tests Catch2::Catch2 lokimq ${SODIUM_LIBRARIES} Threads::Threads)
|
||||
target_link_libraries(tests Catch2::Catch2 oxenmq Threads::Threads)
|
||||
|
||||
set_target_properties(tests PROPERTIES
|
||||
CXX_STANDARD 14
|
||||
CXX_STANDARD 17
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CXX_EXTENSIONS OFF
|
||||
)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit b3b07215d1ca2224aea6ff3e21d87ad0f7750df2
|
||||
Subproject commit dba29b60d639bf8d206a9a12c223e6ed4284fb13
|
|
@ -1,19 +1,49 @@
|
|||
#pragma once
|
||||
#include "lokimq/lokimq.h"
|
||||
#include "oxenmq/oxenmq.h"
|
||||
#include <catch2/catch.hpp>
|
||||
#include <chrono>
|
||||
|
||||
using namespace lokimq;
|
||||
using namespace oxenmq;
|
||||
|
||||
// Apple's mutexes, thread scheduling, and IO handling are garbage and it shows up with lots of
|
||||
// spurious failures in this test suite (because it expects a system to not suck that badly), so we
|
||||
// multiply the time-sensitive bits by this factor as a hack to make the test suite work.
|
||||
constexpr int TIME_DILATION =
|
||||
#ifdef __APPLE__
|
||||
5;
|
||||
#else
|
||||
1;
|
||||
#endif
|
||||
|
||||
static auto startup = std::chrono::steady_clock::now();
|
||||
|
||||
/// Waits up to 100ms for something to happen.
|
||||
/// Returns a localhost connection string to listen on. It can be considered random, though in
|
||||
/// practice in the current implementation is sequential starting at 25432.
|
||||
inline std::string random_localhost() {
|
||||
static std::atomic<uint16_t> last = 25432;
|
||||
last++;
|
||||
assert(last); // We should never call this enough to overflow
|
||||
return "tcp://127.0.0.1:" + std::to_string(last);
|
||||
}
|
||||
|
||||
|
||||
// Catch2 macros aren't thread safe, so guard with a mutex
|
||||
inline std::unique_lock<std::mutex> catch_lock() {
|
||||
static std::mutex mutex;
|
||||
return std::unique_lock<std::mutex>{mutex};
|
||||
}
|
||||
|
||||
/// Waits up to 200ms for something to happen.
|
||||
template <typename Func>
|
||||
inline void wait_for(Func f) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
for (int i = 0; i < 20; i++) {
|
||||
if (f())
|
||||
break;
|
||||
std::this_thread::sleep_for(10ms);
|
||||
std::this_thread::sleep_for(10ms * TIME_DILATION);
|
||||
}
|
||||
auto lock = catch_lock();
|
||||
UNSCOPED_INFO("done waiting after " << (std::chrono::steady_clock::now() - start).count() << "ns");
|
||||
}
|
||||
|
||||
/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough
|
||||
|
@ -23,15 +53,9 @@ inline void wait_for_conn(std::atomic<bool> &c) {
|
|||
}
|
||||
|
||||
/// Waits enough time for us to receive a reply from a localhost remote.
|
||||
inline void reply_sleep() { std::this_thread::sleep_for(10ms); }
|
||||
inline void reply_sleep() { std::this_thread::sleep_for(10ms * TIME_DILATION); }
|
||||
|
||||
// Catch2 macros aren't thread safe, so guard with a mutex
|
||||
inline std::unique_lock<std::mutex> catch_lock() {
|
||||
static std::mutex mutex;
|
||||
return std::unique_lock<std::mutex>{mutex};
|
||||
}
|
||||
|
||||
inline LokiMQ::Logger get_logger(std::string prefix = "") {
|
||||
inline OxenMQ::Logger get_logger(std::string prefix = "") {
|
||||
std::string me = "tests/common.h";
|
||||
std::string strip = __FILE__;
|
||||
if (strip.substr(strip.size() - me.size()) == me)
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
#include "oxenmq/address.h"
|
||||
#include "common.h"
|
||||
|
||||
const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88";
|
||||
const std::string pk_hex = "f16ba5591039f089b42a834175093094074d0d937a79e53e5ce730f946e14b88";
|
||||
const std::string pk_HEX = "F16BA5591039F089B42A834175093094074D0D937A79E53E5CE730F946E14B88";
|
||||
const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqry";
|
||||
const std::string pk_B32Z = "6FI4KSEO88AEUPBKOPYZKNJO1ODW4DCUXJH6KX1HHHAX1TZBJQRY";
|
||||
const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g"; // NB: padding '=' omitted
|
||||
|
||||
TEST_CASE("tcp addresses", "[address][tcp]") {
|
||||
address a{"tcp://1.2.3.4:5678"};
|
||||
REQUIRE( a.host == "1.2.3.4" );
|
||||
REQUIRE( a.port == 5678 );
|
||||
REQUIRE_FALSE( a.curve() );
|
||||
REQUIRE( a.tcp() );
|
||||
REQUIRE( a.zmq_address() == "tcp://1.2.3.4:5678" );
|
||||
REQUIRE( a.full_address() == "tcp://1.2.3.4:5678" );
|
||||
REQUIRE( a.qr_address() == "TCP://1.2.3.4:5678" );
|
||||
|
||||
REQUIRE_THROWS_AS( address{"tcp://1:1:1"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://localhost:123"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcp://abc"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://localhost:0"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://[::1:1080"}, std::invalid_argument );
|
||||
|
||||
address b = address::tcp("example.com", 80);
|
||||
REQUIRE( b.host == "example.com" );
|
||||
REQUIRE( b.port == 80 );
|
||||
REQUIRE_FALSE( b.curve() );
|
||||
REQUIRE( b.tcp() );
|
||||
REQUIRE( b.zmq_address() == "tcp://example.com:80" );
|
||||
REQUIRE( b.full_address() == "tcp://example.com:80" );
|
||||
REQUIRE( b.qr_address() == "TCP://EXAMPLE.COM:80" );
|
||||
|
||||
address c{"tcp://[::1]:1111"};
|
||||
REQUIRE( c.host == "[::1]" );
|
||||
REQUIRE( c.port == 1111 );
|
||||
}
|
||||
|
||||
TEST_CASE("unix sockets", "[address][ipc]") {
|
||||
address a{"ipc:///path/to/foo"};
|
||||
REQUIRE( a.socket == "/path/to/foo" );
|
||||
REQUIRE_FALSE( a.curve() );
|
||||
REQUIRE_FALSE( a.tcp() );
|
||||
REQUIRE( a.zmq_address() == "ipc:///path/to/foo" );
|
||||
REQUIRE( a.full_address() == "ipc:///path/to/foo" );
|
||||
|
||||
address b = address::ipc("../foo");
|
||||
REQUIRE( b.socket == "../foo" );
|
||||
REQUIRE_FALSE( b.curve() );
|
||||
REQUIRE_FALSE( b.tcp() );
|
||||
REQUIRE( b.zmq_address() == "ipc://../foo" );
|
||||
REQUIRE( b.full_address() == "ipc://../foo" );
|
||||
}
|
||||
|
||||
TEST_CASE("pubkey formats", "[address][curve][pubkey]") {
|
||||
address a{"tcp+curve://a:1/" + pk_hex};
|
||||
address b{"curve://a:1/" + pk_b32z};
|
||||
address c{"curve://a:1/" + pk_b64};
|
||||
address d{"CURVE://A:1/" + pk_B32Z};
|
||||
REQUIRE( a.curve() );
|
||||
REQUIRE( a.host == "a" );
|
||||
REQUIRE( a.port == 1 );
|
||||
REQUIRE((b.curve() && c.curve() && d.curve()));
|
||||
REQUIRE( a.pubkey == pk );
|
||||
REQUIRE( b.pubkey == pk );
|
||||
REQUIRE( c.pubkey == pk );
|
||||
REQUIRE( d.pubkey == pk );
|
||||
|
||||
address e{"ipc+curve://my.sock/" + pk_hex};
|
||||
address f{"ipc+curve://../my.sock/" + pk_b32z};
|
||||
address g{"ipc+curve:///my.sock/" + pk_B32Z};
|
||||
address h{"ipc+curve://./my.sock/" + pk_b64};
|
||||
REQUIRE( e.curve() );
|
||||
REQUIRE( e.ipc() );
|
||||
REQUIRE_FALSE( e.tcp() );
|
||||
REQUIRE((f.curve() && g.curve() && h.curve()));
|
||||
REQUIRE( e.socket == "my.sock" );
|
||||
REQUIRE( f.socket == "../my.sock" );
|
||||
REQUIRE( g.socket == "/my.sock" );
|
||||
REQUIRE( h.socket == "./my.sock" );
|
||||
REQUIRE( e.pubkey == pk );
|
||||
REQUIRE( f.pubkey == pk );
|
||||
REQUIRE( g.pubkey == pk );
|
||||
REQUIRE( h.pubkey == pk );
|
||||
|
||||
REQUIRE( d.full_address(address::encoding::hex) == "curve://a:1/" + pk_hex );
|
||||
REQUIRE( c.full_address(address::encoding::base32z) == "curve://a:1/" + pk_b32z );
|
||||
REQUIRE( b.full_address(address::encoding::BASE32Z) == "curve://a:1/" + pk_B32Z );
|
||||
REQUIRE( a.full_address(address::encoding::base64) == "curve://a:1/" + pk_b64 );
|
||||
|
||||
REQUIRE( h.full_address(address::encoding::hex) == "ipc+curve://./my.sock/" + pk_hex );
|
||||
REQUIRE( g.full_address(address::encoding::base32z) == "ipc+curve:///my.sock/" + pk_b32z );
|
||||
REQUIRE( f.full_address(address::encoding::BASE32Z) == "ipc+curve://../my.sock/" + pk_B32Z );
|
||||
REQUIRE( e.full_address(address::encoding::base64) == "ipc+curve://my.sock/" + pk_b64 );
|
||||
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_hex.substr(0, 63)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b32z.substr(0, 51)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_B32Z.substr(0, 51)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b64.substr(0, 42)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock"}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/"}, std::invalid_argument);
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("tcp QR-code friendly addresses", "[address][tcp][qr]") {
|
||||
address a{"tcp://public.loki.foundation:12345"};
|
||||
address a_qr{"TCP://PUBLIC.LOKI.FOUNDATION:12345"};
|
||||
address b{"tcp://PUBLIC.LOKI.FOUNDATION:12345"};
|
||||
REQUIRE( a == a_qr );
|
||||
REQUIRE( a != b );
|
||||
REQUIRE( a.host == "public.loki.foundation" );
|
||||
REQUIRE( a.qr_address() == "TCP://PUBLIC.LOKI.FOUNDATION:12345" );
|
||||
|
||||
address c = address::tcp_curve("public.loki.foundation", 12345, pk);
|
||||
REQUIRE( c.qr_address() == "CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z );
|
||||
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z} == c );
|
||||
// We don't produce with upper-case hex, but we accept it:
|
||||
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_HEX} == c );
|
||||
|
||||
// lower case not permitted: ▾
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATiON:12345/" + pk_B32Z}, std::invalid_argument);
|
||||
// also only accept upper-base base32z and hex:
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_b32z}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_hex}, std::invalid_argument);
|
||||
// don't accept base64 even if it's upper-case (because case-converting it changes the value)
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}, std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("address hashing", "[address][hash]") {
|
||||
address a{"tcp://public.loki.foundation:12345"};
|
||||
address b{"tcp+curve://public.loki.foundation:12345/" + pk_hex};
|
||||
address c{"ipc:///tmp/some.sock"};
|
||||
address d{"ipc:///tmp/some.other.sock"};
|
||||
|
||||
std::hash<oxenmq::address> hasher{};
|
||||
REQUIRE( hasher(a) != hasher(b) );
|
||||
REQUIRE( hasher(a) != hasher(c) );
|
||||
REQUIRE( hasher(a) != hasher(d) );
|
||||
REQUIRE( hasher(b) != hasher(c) );
|
||||
REQUIRE( hasher(b) != hasher(d) );
|
||||
REQUIRE( hasher(c) != hasher(d) );
|
||||
|
||||
std::unordered_set<oxenmq::address> set;
|
||||
set.insert(a);
|
||||
set.insert(b);
|
||||
set.insert(c);
|
||||
set.insert(d);
|
||||
|
||||
CHECK( set.size() == 4 );
|
||||
std::unordered_map<oxenmq::address, int> count;
|
||||
for (const auto& addr : set)
|
||||
count[addr]++;
|
||||
|
||||
REQUIRE( count.size() == 4 );
|
||||
CHECK( count[a] == 1 );
|
||||
CHECK( count[b] == 1 );
|
||||
CHECK( count[c] == 1 );
|
||||
CHECK( count[d] == 1 );
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
#include "lokimq/batch.h"
|
||||
#include "oxenmq/batch.h"
|
||||
#include "common.h"
|
||||
#include <future>
|
||||
|
||||
|
@ -12,7 +12,7 @@ double do_my_task(int input) {
|
|||
|
||||
std::promise<std::pair<double, int>> done;
|
||||
|
||||
void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
||||
void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
|
||||
double sum = 0;
|
||||
int exc_count = 0;
|
||||
for (auto& r : results) {
|
||||
|
@ -25,10 +25,10 @@ void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
|||
done.set_value({sum, exc_count});
|
||||
}
|
||||
|
||||
void start_big_task(lokimq::LokiMQ& lmq) {
|
||||
void start_big_task(oxenmq::OxenMQ& omq) {
|
||||
size_t num_jobs = 32;
|
||||
|
||||
lokimq::Batch<double /*return type*/> batch;
|
||||
oxenmq::Batch<double /*return type*/> batch;
|
||||
batch.reserve(num_jobs);
|
||||
|
||||
for (size_t i = 0; i < num_jobs; i++)
|
||||
|
@ -36,21 +36,21 @@ void start_big_task(lokimq::LokiMQ& lmq) {
|
|||
|
||||
batch.completion(&continue_big_task);
|
||||
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("batching many small jobs", "[batch-many]") {
|
||||
lokimq::LokiMQ lmq{
|
||||
oxenmq::OxenMQ omq{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
lmq.set_general_threads(4);
|
||||
lmq.set_batch_threads(4);
|
||||
lmq.start();
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
omq.start();
|
||||
|
||||
start_big_task(lmq);
|
||||
start_big_task(omq);
|
||||
auto sum = done.get_future().get();
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( sum.first == 1337.0 );
|
||||
|
@ -58,14 +58,14 @@ TEST_CASE("batching many small jobs", "[batch-many]") {
|
|||
}
|
||||
|
||||
TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
||||
lokimq::LokiMQ lmq{
|
||||
oxenmq::OxenMQ omq{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
lmq.set_general_threads(4);
|
||||
lmq.set_batch_threads(4);
|
||||
lmq.start();
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
omq.start();
|
||||
|
||||
std::promise<void> done_promise;
|
||||
std::future<void> done_future = done_promise.get_future();
|
||||
|
@ -73,7 +73,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
using Catch::Matchers::Message;
|
||||
|
||||
SECTION( "value return" ) {
|
||||
lokimq::Batch<int> batch;
|
||||
oxenmq::Batch<int> batch;
|
||||
for (int i : {1, 2})
|
||||
batch.add_job([i]() { if (i == 1) return 42; throw std::domain_error("bad value " + std::to_string(i)); });
|
||||
batch.completion([&done_promise](auto results) {
|
||||
|
@ -83,12 +83,12 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
REQUIRE_THROWS_MATCHES( results[1].get() == 0, std::domain_error, Message("bad value 2") );
|
||||
done_promise.set_value();
|
||||
});
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
done_future.get();
|
||||
}
|
||||
|
||||
SECTION( "lvalue return" ) {
|
||||
lokimq::Batch<int&> batch;
|
||||
oxenmq::Batch<int&> batch;
|
||||
int forty_two = 42;
|
||||
for (int i : {1, 2})
|
||||
batch.add_job([i,&forty_two]() -> int& {
|
||||
|
@ -105,12 +105,12 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
REQUIRE_THROWS_MATCHES( results[1].get(), std::domain_error, Message("bad value 2") );
|
||||
done_promise.set_value();
|
||||
});
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
done_future.get();
|
||||
}
|
||||
|
||||
SECTION( "void return" ) {
|
||||
lokimq::Batch<void> batch;
|
||||
oxenmq::Batch<void> batch;
|
||||
for (int i : {1, 2})
|
||||
batch.add_job([i]() { if (i != 1) throw std::domain_error("bad value " + std::to_string(i)); });
|
||||
batch.completion([&done_promise](auto results) {
|
||||
|
@ -120,7 +120,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
REQUIRE_THROWS_MATCHES( results[1].get(), std::domain_error, Message("bad value 2") );
|
||||
done_promise.set_value();
|
||||
});
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
done_future.get();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
#include "common.h"
|
||||
#include <lokimq/hex.h>
|
||||
#include <oxenc/hex.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
using namespace lokimq;
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("basic commands", "[commands]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -31,7 +31,7 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.add_category("public", Access{AuthLevel::none});
|
||||
client.add_command("public", "hi", [&](auto&) { his++; });
|
||||
|
@ -41,10 +41,9 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
bool success = false, failed = false;
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
|
||||
[&](auto conn, string_view) { failed = true; got = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto conn, std::string_view) { failed = true; got = true; });
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
|
@ -52,7 +51,7 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
client.send(c, "public.hello");
|
||||
|
@ -62,7 +61,7 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 1 );
|
||||
REQUIRE( his == 1 );
|
||||
REQUIRE( to_hex(client_pubkey) == to_hex(client.get_pubkey()) );
|
||||
REQUIRE( oxenc::to_hex(client_pubkey) == oxenc::to_hex(client.get_pubkey()) );
|
||||
}
|
||||
|
||||
for (int i = 0; i < 50; i++)
|
||||
|
@ -77,8 +76,8 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
}
|
||||
|
||||
TEST_CASE("outgoing auth level", "[commands][auth]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -94,7 +93,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
std::atomic<int> public_hi{0}, basic_hi{0}, admin_hi{0};
|
||||
client.add_category("public", Access{AuthLevel::none});
|
||||
|
@ -105,11 +104,12 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
client.add_command("admin", "hi", [&](auto&) { admin_hi++; });
|
||||
client.start();
|
||||
|
||||
client.PUBKEY_BASED_ROUTING_ID = false; // establishing multiple connections below, so we need unique routing ids
|
||||
client.EPHEMERAL_ROUTING_ID = true; // establishing multiple connections below, so we need unique routing ids
|
||||
|
||||
auto public_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey());
|
||||
auto basic_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::basic);
|
||||
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
|
||||
address server_addr{listen, server.get_pubkey()};
|
||||
auto public_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {});
|
||||
auto basic_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::basic);
|
||||
auto admin_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::admin);
|
||||
|
||||
client.send(public_c, "public.reflect", "public.hi");
|
||||
wait_for([&] { return public_hi == 1; });
|
||||
|
@ -158,8 +158,8 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
// Tests that the ConnectionID from a Message can be stored and reused later to contact the
|
||||
// original node.
|
||||
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -173,18 +173,26 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
|
||||
server.add_category("hey google", Access{AuthLevel::none});
|
||||
server.add_request_command("hey google", "remember", [&](Message& m) {
|
||||
auto l = catch_lock();
|
||||
subscribers.emplace_back(m.conn, std::string{m.data[0]});
|
||||
bool bd;
|
||||
{
|
||||
auto l = catch_lock();
|
||||
subscribers.emplace_back(m.conn, std::string{m.data[0]});
|
||||
bd = (bool) backdoor;
|
||||
}
|
||||
m.send_reply("Okay, I'll remember that.");
|
||||
|
||||
if (backdoor)
|
||||
m.lokimq.send(backdoor, "backdoor.data", m.data[0]);
|
||||
if (bd)
|
||||
m.oxenmq.send(backdoor, "backdoor.data", m.data[0]);
|
||||
});
|
||||
server.add_command("hey google", "recall", [&](Message& m) {
|
||||
auto l = catch_lock();
|
||||
for (auto& s : subscribers) {
|
||||
server.send(s.first, "personal.detail", s.second);
|
||||
decltype(subscribers) subs;
|
||||
{
|
||||
auto l = catch_lock();
|
||||
subs = subscribers;
|
||||
}
|
||||
|
||||
for (auto& s : subs)
|
||||
server.send(s.first, "personal.detail", s.second);
|
||||
});
|
||||
server.add_command("hey google", "install backdoor", [&](Message& m) {
|
||||
auto l = catch_lock();
|
||||
|
@ -193,19 +201,20 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
|
||||
server.start();
|
||||
|
||||
auto connect_success = [&](...) { auto l = catch_lock(); REQUIRE(true); };
|
||||
auto connect_failure = [&](...) { auto l = catch_lock(); REQUIRE(false); };
|
||||
auto connect_success = [&](auto&&...) { auto l = catch_lock(); REQUIRE(true); };
|
||||
auto connect_failure = [&](auto&&...) { auto l = catch_lock(); REQUIRE(false); };
|
||||
|
||||
|
||||
std::set<std::string> backdoor_details;
|
||||
|
||||
LokiMQ nsa{get_logger("NSA» ")};
|
||||
OxenMQ nsa{get_logger("NSA» ")};
|
||||
nsa.add_category("backdoor", Access{AuthLevel::admin});
|
||||
nsa.add_command("backdoor", "data", [&](Message& m) {
|
||||
backdoor_details.emplace(m.data[0]);
|
||||
auto l = catch_lock();
|
||||
backdoor_details.emplace(m.data[0]);
|
||||
});
|
||||
nsa.start();
|
||||
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
|
||||
auto nsa_c = nsa.connect_remote(address{listen, server.get_pubkey()}, connect_success, connect_failure, AuthLevel::admin);
|
||||
nsa.send(nsa_c, "hey google.install backdoor");
|
||||
|
||||
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
|
||||
|
@ -214,7 +223,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
REQUIRE( backdoor );
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<LokiMQ>> clients;
|
||||
std::vector<std::unique_ptr<OxenMQ>> clients;
|
||||
std::vector<ConnectionID> conns;
|
||||
std::map<int, std::set<std::string>> personal_details{
|
||||
{0, {"Loretta"s, "photos"s}},
|
||||
|
@ -226,10 +235,11 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
std::set<std::string> all_the_things;
|
||||
for (auto& pd : personal_details) all_the_things.insert(pd.second.begin(), pd.second.end());
|
||||
|
||||
address server_addr{listen, server.get_pubkey()};
|
||||
std::map<int, std::set<std::string>> google_knows;
|
||||
int things_remembered{0};
|
||||
for (int i = 0; i < 5; i++) {
|
||||
clients.push_back(std::make_unique<LokiMQ>(
|
||||
clients.push_back(std::make_unique<OxenMQ>(
|
||||
get_logger("C" + std::to_string(i) + "» "), LogLevel::trace
|
||||
));
|
||||
auto& c = clients.back();
|
||||
|
@ -240,7 +250,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
});
|
||||
c->start();
|
||||
conns.push_back(
|
||||
c->connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::basic));
|
||||
c->connect_remote(server_addr, connect_success, connect_failure, AuthLevel::basic));
|
||||
for (auto& personal_detail : personal_details[i])
|
||||
c->request(conns.back(), "hey google.remember",
|
||||
[&](bool success, std::vector<std::string> data) {
|
||||
|
@ -252,7 +262,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
},
|
||||
personal_detail);
|
||||
}
|
||||
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size(); });
|
||||
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size() && things_remembered == backdoor_details.size(); });
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( things_remembered == all_the_things.size() );
|
||||
|
@ -268,8 +278,8 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
}
|
||||
|
||||
TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -296,17 +306,18 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
|||
server.start();
|
||||
|
||||
// Use a raw socket here because I want to stall it by not reading from it at all, and that is
|
||||
// hard with LokiMQ.
|
||||
// hard with OxenMQ.
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
string_view hello_sv{hello.data<char>(), hello.size()};
|
||||
auto recvd = client.recv(hello);
|
||||
std::string_view hello_sv{hello.data<char>(), hello.size()};
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello_sv == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -359,3 +370,156 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
|||
REQUIRE( send_failures.load() > 0 );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("data parts", "[commands][send][data_parts]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::mutex mut;
|
||||
std::vector<std::string> r;
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_command("public", "hello", [&](Message& m) {
|
||||
std::lock_guard l{mut};
|
||||
for (const auto& s : m.data)
|
||||
r.emplace_back(s);
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false, failed = false;
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
|
||||
[&](auto conn, std::string_view) { failed = true; got = true; });
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::vector some_data{{"abc"s, "def"s, "omg123\0zzz"s}};
|
||||
client.send(c, "public.hello", oxenmq::send_option::data_parts(some_data.begin(), some_data.end()));
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
std::lock_guard l{mut};
|
||||
REQUIRE( r == some_data );
|
||||
r.clear();
|
||||
}
|
||||
|
||||
std::optional<std::string_view> opt1, opt2;
|
||||
std::optional<std::string> opt3, opt4;
|
||||
opt1 = "o1"sv;
|
||||
opt4 = "o4"s;
|
||||
std::vector some_data2{{"a"sv, "b"sv, "\0"sv}};
|
||||
client.send(c, "public.hello",
|
||||
"hi",
|
||||
oxenmq::send_option::data_parts(some_data2.begin(), some_data2.end()),
|
||||
"another",
|
||||
"string"sv,
|
||||
oxenmq::send_option::data_parts(some_data.begin(), some_data.end()),
|
||||
opt1, opt2, opt3, opt4
|
||||
);
|
||||
|
||||
std::vector<std::string> expected;
|
||||
expected.push_back("hi");
|
||||
expected.insert(expected.end(), some_data2.begin(), some_data2.end());
|
||||
expected.push_back("another");
|
||||
expected.push_back("string");
|
||||
expected.insert(expected.end(), some_data.begin(), some_data.end());
|
||||
expected.push_back("o1");
|
||||
expected.push_back("o4");
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
std::lock_guard l{mut};
|
||||
REQUIRE( r == expected );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("deferred replies", "[commands][send][deferred]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos{0}, his{0};
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "echo", [&](Message& m) {
|
||||
std::string msg = m.data.empty() ? ""s : std::string{m.data.front()};
|
||||
std::thread t{[send=m.send_later(), msg=std::move(msg)] {
|
||||
{ auto lock = catch_lock(); UNSCOPED_INFO("sleeping"); }
|
||||
std::this_thread::sleep_for(50ms * TIME_DILATION);
|
||||
{ auto lock = catch_lock(); UNSCOPED_INFO("sending"); }
|
||||
send(msg);
|
||||
}};
|
||||
t.detach();
|
||||
});
|
||||
server.set_general_threads(1);
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
get_logger("C» "),
|
||||
LogLevel::trace);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> replies;
|
||||
std::mutex reply_mut;
|
||||
std::vector<std::string> data;
|
||||
for (auto str : {"hello", "world", "omg"})
|
||||
client.request(c, "public.echo", [&](bool ok, std::vector<std::string> data_) {
|
||||
std::lock_guard lock{reply_mut};
|
||||
replies.insert(std::move(data_[0]));
|
||||
}, str);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
std::lock_guard lq{reply_mut};
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( replies.size() == 0 ); // The server waits 50ms before sending, so we shouldn't have any reply yet
|
||||
}
|
||||
std::this_thread::sleep_for(60ms * TIME_DILATION); // We're at least 70ms in now so the 50ms-delayed server responses should have arrived
|
||||
{
|
||||
std::lock_guard lq{reply_mut};
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( replies == std::unordered_set<std::string>{{"hello", "world", "omg"}} );
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
#include "common.h"
|
||||
#include <lokimq/hex.h>
|
||||
extern "C" {
|
||||
#include <sodium.h>
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("connections with curve authentication", "[curve][connect]") {
|
||||
std::string listen = "tcp://127.0.0.1:4455";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -20,17 +19,16 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") {
|
|||
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.start();
|
||||
|
||||
auto pubkey = server.get_pubkey();
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
auto server_conn = client.connect_remote(listen,
|
||||
auto server_conn = client.connect_remote(address{listen, pubkey},
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; },
|
||||
pubkey);
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; });
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
|
@ -53,16 +51,18 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
|||
std::string pubkey, privkey;
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
REQUIRE(sodium_init() != -1);
|
||||
auto listen_addr = random_localhost();
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
LokiMQ sn{
|
||||
OxenMQ sn{
|
||||
pubkey, privkey,
|
||||
true,
|
||||
[&](auto pk) { if (pk == pubkey) return "tcp://127.0.0.1:5544"; else return ""; },
|
||||
[&](auto pk) { if (pk == pubkey) return listen_addr; else return ""s; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
|
||||
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk, auto sn) {
|
||||
sn.listen_curve(listen_addr, [&](auto ip, auto pk, auto sn) {
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(ip == "127.0.0.1");
|
||||
REQUIRE(sn == (pk == pubkey));
|
||||
|
@ -90,8 +90,8 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
|||
}
|
||||
|
||||
TEST_CASE("plain-text connections", "[plaintext][connect]") {
|
||||
std::string listen = "tcp://127.0.0.1:4455";
|
||||
LokiMQ server{get_logger("S» "), LogLevel::trace};
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{get_logger("S» "), LogLevel::trace};
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
|
@ -100,15 +100,15 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") {
|
|||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen},
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
|
||||
);
|
||||
|
||||
wait_for_conn(got);
|
||||
|
@ -128,27 +128,142 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_CASE("post-start listening", "[connect][listen]") {
|
||||
OxenMQ server{get_logger("S» "), LogLevel::trace};
|
||||
server.add_category("x", AuthLevel::none)
|
||||
.add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); });
|
||||
server.start();
|
||||
std::atomic<int> listens = 0;
|
||||
auto listen_curve = random_localhost();
|
||||
server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; });
|
||||
auto listen_plain = random_localhost();
|
||||
server.listen_plain(listen_plain, nullptr, [&](bool success) { if (success) listens += 10; });
|
||||
|
||||
wait_for([&] { return listens.load() >= 11; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( listens == 11 );
|
||||
}
|
||||
|
||||
// This should fail since we're already listening on it:
|
||||
server.listen_curve(listen_plain, nullptr, [&](bool success) { if (!success) listens++; });
|
||||
|
||||
wait_for([&] { return listens.load() >= 12; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( listens == 12 );
|
||||
}
|
||||
|
||||
|
||||
OxenMQ client{get_logger("C1» "), LogLevel::trace};
|
||||
client.start();
|
||||
std::atomic<int> conns = 0;
|
||||
auto c1 = client.connect_remote(address{listen_curve, server.get_pubkey()},
|
||||
[&](auto) { conns++; },
|
||||
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
|
||||
auto c2 = client.connect_remote(address{listen_plain},
|
||||
[&](auto) { conns += 10; },
|
||||
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
|
||||
|
||||
|
||||
wait_for([&] { return conns.load() >= 11; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( conns == 11 );
|
||||
}
|
||||
|
||||
std::atomic<int> replies = 0;
|
||||
std::string reply1, reply2;
|
||||
client.request(c1, "x.y", [&](auto success, auto parts) { replies++; for (auto& p : parts) reply1 += p; }, " world");
|
||||
client.request(c2, "x.y", [&](auto success, auto parts) { replies += 10; for (auto& p : parts) reply2 += p; }, " cat");
|
||||
|
||||
wait_for([&] { return replies.load() >= 11; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( replies == 11 );
|
||||
REQUIRE( reply1 == "hi world" );
|
||||
REQUIRE( reply2 == "hi cat" );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("unique connection IDs", "[connect][id]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{get_logger("S» "), LogLevel::trace};
|
||||
|
||||
ConnectionID first, second;
|
||||
server.add_category("x", Access{AuthLevel::none})
|
||||
.add_request_command("x", [&](Message& m) { first = m.conn; m.send_reply("hi"); })
|
||||
.add_request_command("y", [&](Message& m) { second = m.conn; m.send_reply("hi"); })
|
||||
;
|
||||
|
||||
server.listen_plain(listen);
|
||||
|
||||
server.start();
|
||||
|
||||
OxenMQ client1{get_logger("C1» "), LogLevel::trace};
|
||||
OxenMQ client2{get_logger("C2» "), LogLevel::trace};
|
||||
client1.start();
|
||||
client2.start();
|
||||
|
||||
std::atomic<bool> good1{false}, good2{false};
|
||||
auto r1 = client1.connect_remote(address{listen},
|
||||
[&](auto conn) { good1 = true; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
);
|
||||
auto r2 = client2.connect_remote(address{listen},
|
||||
[&](auto conn) { good2 = true; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
);
|
||||
|
||||
wait_for_conn(good1);
|
||||
wait_for_conn(good2);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( good1 );
|
||||
REQUIRE( good2 );
|
||||
REQUIRE( first == second );
|
||||
REQUIRE_FALSE( first );
|
||||
REQUIRE_FALSE( second );
|
||||
}
|
||||
|
||||
good1 = false;
|
||||
good2 = false;
|
||||
client1.request(r1, "x.x", [&](auto success_, auto parts_) { good1 = true; });
|
||||
client2.request(r2, "x.y", [&](auto success_, auto parts_) { good2 = true; });
|
||||
reply_sleep();
|
||||
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( good1 );
|
||||
REQUIRE( good2 );
|
||||
REQUIRE_FALSE( first == second );
|
||||
REQUIRE_FALSE( std::hash<ConnectionID>{}(first) == std::hash<ConnectionID>{}(second) );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
||||
std::vector<std::unique_ptr<LokiMQ>> lmq;
|
||||
std::vector<std::unique_ptr<OxenMQ>> omq;
|
||||
std::vector<std::string> pubkey, privkey;
|
||||
std::unordered_map<std::string, std::string> conn;
|
||||
REQUIRE(sodium_init() != -1);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
pubkey.emplace_back();
|
||||
privkey.emplace_back();
|
||||
pubkey[i].resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey[i].resize(crypto_box_SECRETKEYBYTES);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[i][0]), reinterpret_cast<unsigned char*>(&privkey[i][0]));
|
||||
conn.emplace(pubkey[i], "tcp://127.0.0.1:" + std::to_string(4450 + i));
|
||||
conn.emplace(pubkey[i], random_localhost());
|
||||
}
|
||||
std::atomic<int> his{0};
|
||||
for (int i = 0; i < pubkey.size(); i++) {
|
||||
lmq.push_back(std::make_unique<LokiMQ>(
|
||||
omq.push_back(std::make_unique<OxenMQ>(
|
||||
pubkey[i], privkey[i], true,
|
||||
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
|
||||
get_logger("S" + std::to_string(i) + "» "),
|
||||
LogLevel::trace
|
||||
));
|
||||
auto& server = *lmq.back();
|
||||
auto& server = *omq.back();
|
||||
|
||||
server.listen_curve(conn[pubkey[i]]);
|
||||
server.add_category("sn", Access{AuthLevel::none, true})
|
||||
|
@ -157,13 +272,13 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
|||
server.start();
|
||||
}
|
||||
|
||||
lmq[0]->send(pubkey[1], "sn.hi");
|
||||
lmq[0]->send(pubkey[2], "sn.hi");
|
||||
lmq[2]->send(pubkey[0], "sn.hi");
|
||||
lmq[2]->send(pubkey[1], "sn.hi");
|
||||
lmq[1]->send(pubkey[0], "BYE");
|
||||
lmq[0]->send(pubkey[2], "sn.hi");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
omq[0]->send(pubkey[1], "sn.hi");
|
||||
omq[0]->send(pubkey[2], "sn.hi");
|
||||
omq[2]->send(pubkey[0], "sn.hi");
|
||||
omq[2]->send(pubkey[1], "sn.hi");
|
||||
omq[1]->send(pubkey[0], "BYE");
|
||||
omq[0]->send(pubkey[2], "sn.hi");
|
||||
std::this_thread::sleep_for(50ms * TIME_DILATION);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(his == 5);
|
||||
|
@ -174,12 +289,13 @@ TEST_CASE("SN auth checks", "[sandwich][auth]") {
|
|||
// isn't recognized as a SN but tries to invoke a SN command it'll be told to disconnect; if it
|
||||
// tries to send again it should reconnect and reauthenticate. This test is meant to test this
|
||||
// pattern where the reconnection/reauthentication now authenticates it as a SN.
|
||||
std::string listen = "tcp://127.0.0.1:4455";
|
||||
std::string listen = random_localhost();
|
||||
std::string pubkey, privkey;
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
REQUIRE(sodium_init() != -1);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
LokiMQ server{
|
||||
OxenMQ server{
|
||||
pubkey, privkey,
|
||||
true, // service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -206,7 +322,7 @@ TEST_CASE("SN auth checks", "[sandwich][auth]") {
|
|||
.add_request_command("make", [&](Message& m) { m.send_reply("okay"); });
|
||||
server.start();
|
||||
|
||||
LokiMQ client{
|
||||
OxenMQ client{
|
||||
"", "", false,
|
||||
[&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; },
|
||||
get_logger("B» "), LogLevel::trace};
|
||||
|
@ -288,3 +404,206 @@ TEST_CASE("SN auth checks", "[sandwich][auth]") {
|
|||
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("SN single worker test", "[connect][worker]") {
|
||||
// Tests a failure case that could trigger when all workers are allocated (here we make that
|
||||
// simpler by just having one worker).
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "",
|
||||
false, // service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.set_general_threads(1);
|
||||
server.set_batch_threads(0);
|
||||
server.set_reply_threads(0);
|
||||
server.listen_plain(listen);
|
||||
server.add_category("c", Access{AuthLevel::none})
|
||||
.add_request_command("x", [&](Message& m) { m.send_reply(); })
|
||||
;
|
||||
server.start();
|
||||
|
||||
OxenMQ client{get_logger("B» "), LogLevel::trace};
|
||||
client.start();
|
||||
auto conn = client.connect_remote(address{listen}, [](auto) {}, [](auto, auto) {});
|
||||
|
||||
std::atomic<int> got{0};
|
||||
std::atomic<int> success{0};
|
||||
client.request(conn, "c.x", [&](auto success_, auto) { if (success_) ++success; ++got; });
|
||||
wait_for([&] { return got.load() >= 1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success == 1 );
|
||||
}
|
||||
client.request(conn, "c.x", [&](auto success_, auto) { if (success_) ++success; ++got; });
|
||||
wait_for([&] { return got.load() >= 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success == 2 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("SN backchatter", "[connect][sn]") {
|
||||
// When we have a SN connection A -> B and then B sends a message to A on that existing
|
||||
// connection, A should see it as coming from B.
|
||||
std::vector<std::unique_ptr<OxenMQ>> omq;
|
||||
std::vector<std::string> pubkey, privkey;
|
||||
std::unordered_map<std::string, std::string> conn;
|
||||
REQUIRE(sodium_init() != -1);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
pubkey.emplace_back();
|
||||
privkey.emplace_back();
|
||||
pubkey[i].resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey[i].resize(crypto_box_SECRETKEYBYTES);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[i][0]), reinterpret_cast<unsigned char*>(&privkey[i][0]));
|
||||
conn.emplace(pubkey[i], random_localhost());
|
||||
}
|
||||
|
||||
for (int i = 0; i < pubkey.size(); i++) {
|
||||
omq.push_back(std::make_unique<OxenMQ>(
|
||||
pubkey[i], privkey[i], true,
|
||||
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
|
||||
get_logger("S" + std::to_string(i) + "» "),
|
||||
LogLevel::trace
|
||||
));
|
||||
auto& server = *omq.back();
|
||||
|
||||
server.listen_curve(conn[pubkey[i]]);
|
||||
server.set_active_sns({pubkey.begin(), pubkey.end()});
|
||||
}
|
||||
std::string f;
|
||||
omq[0]->add_category("a", Access{AuthLevel::none, true})
|
||||
.add_command("a", [&](Message& m) {
|
||||
m.oxenmq.send(m.conn, "b.b", "abc");
|
||||
//m.send_back("b.b", "abc");
|
||||
})
|
||||
.add_command("z", [&](Message& m) {
|
||||
auto lock = catch_lock();
|
||||
f = m.data[0];
|
||||
});
|
||||
omq[1]->add_category("b", Access{AuthLevel::none, true})
|
||||
.add_command("b", [&](Message& m) {
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
UNSCOPED_INFO("b.b from conn " << m.conn);
|
||||
}
|
||||
m.send_back("a.z", m.data[0]);
|
||||
});
|
||||
|
||||
for (auto& server : omq)
|
||||
server->start();
|
||||
|
||||
auto c = omq[1]->connect_sn(pubkey[0]);
|
||||
omq[1]->send(c, "a.a");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(f == "abc");
|
||||
}
|
||||
|
||||
TEST_CASE("inproc connections", "[connect][inproc]") {
|
||||
std::string inproc_name = "foo";
|
||||
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
|
||||
|
||||
omq.add_category("public", Access{AuthLevel::none});
|
||||
omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
|
||||
omq.start();
|
||||
|
||||
std::atomic<int> got{0};
|
||||
bool success = false;
|
||||
auto c_inproc = omq.connect_inproc(
|
||||
[&](auto conn) { success = true; got++; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; }
|
||||
);
|
||||
|
||||
wait_for([&got] { return got.load() > 0; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
REQUIRE( got == 1 );
|
||||
}
|
||||
|
||||
got = 0;
|
||||
success = false;
|
||||
omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) {
|
||||
success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got == 1 );
|
||||
REQUIRE( success );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("no explicit inproc listening", "[connect][inproc]") {
|
||||
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
|
||||
REQUIRE_THROWS_AS(omq.listen_plain("inproc://foo"), std::logic_error);
|
||||
REQUIRE_THROWS_AS(omq.listen_curve("inproc://foo"), std::logic_error);
|
||||
}
|
||||
|
||||
TEST_CASE("inproc connection permissions", "[connect][inproc]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
|
||||
|
||||
omq.add_category("public", Access{AuthLevel::none});
|
||||
omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
omq.add_category("private", Access{AuthLevel::admin});
|
||||
omq.add_request_command("private", "handshake", [&](Message& m) { m.send_reply("yo dude"); });
|
||||
|
||||
omq.listen_plain(listen);
|
||||
|
||||
omq.start();
|
||||
|
||||
std::atomic<int> got{0};
|
||||
bool success = false;
|
||||
auto c_inproc = omq.connect_inproc(
|
||||
[&](auto conn) { success = true; got++; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; }
|
||||
);
|
||||
|
||||
bool pub_success = false;
|
||||
auto c_pub = omq.connect_remote(address{listen},
|
||||
[&](auto conn) { pub_success = true; got++; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("tcp connection failed: " << reason); got++; }
|
||||
);
|
||||
|
||||
wait_for([&got] { return got.load() == 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got == 2 );
|
||||
REQUIRE( success );
|
||||
REQUIRE( pub_success );
|
||||
}
|
||||
|
||||
got = 0;
|
||||
success = false;
|
||||
pub_success = false;
|
||||
bool success_private = false;
|
||||
bool pub_success_private = false;
|
||||
omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) {
|
||||
success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
|
||||
});
|
||||
omq.request(c_pub, "public.hello", [&](auto success_, auto parts_) {
|
||||
pub_success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
|
||||
});
|
||||
omq.request(c_inproc, "private.handshake", [&](auto success_, auto parts_) {
|
||||
success_private = success_ && parts_.size() == 1 && parts_.front() == "yo dude"; got++;
|
||||
});
|
||||
omq.request(c_pub, "private.handshake", [&](auto success_, auto parts_) {
|
||||
pub_success_private = success_; got++;
|
||||
});
|
||||
wait_for([&got] { return got.load() == 4; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got == 4 );
|
||||
REQUIRE( success );
|
||||
REQUIRE( pub_success );
|
||||
REQUIRE( success_private );
|
||||
REQUIRE_FALSE( pub_success_private );
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
#include "common.h"
|
||||
#include <lokimq/hex.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
using namespace lokimq;
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -25,28 +24,30 @@ TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "a.a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -66,45 +67,48 @@ TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.r" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
|
||||
client.recv(resp);
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "foo" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -132,9 +136,10 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -144,18 +149,20 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
|||
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
clients[0].recv(resp);
|
||||
auto recvd = clients[0].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
clients[0].recv(resp);
|
||||
REQUIRE( clients[0].recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
for (int i : {1, 2}) {
|
||||
clients[i].recv(resp);
|
||||
recvd = clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -164,25 +171,27 @@ TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
|||
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
|
||||
|
||||
for (int i : {0, 1}) {
|
||||
clients[i].recv(resp);
|
||||
recvd = clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
clients[i].recv(resp);
|
||||
REQUIRE( clients[i].recv(resp) );
|
||||
REQUIRE( resp.to_string() == "y.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
clients[2].recv(resp);
|
||||
recvd = clients[2].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "b" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -207,9 +216,10 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -217,12 +227,13 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
|
|||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -230,23 +241,24 @@ TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NOD
|
|||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
client.recv(resp);
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -271,9 +283,10 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
|||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
@ -281,12 +294,13 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
|||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
auto recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
@ -294,15 +308,16 @@ TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
|||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
client.recv(resp);
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
#include "common.h"
|
||||
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("injected external commands", "[injected]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.set_general_threads(1);
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos = 0;
|
||||
std::atomic<bool> done = false;
|
||||
server.add_category("public", AuthLevel::none, 3);
|
||||
server.add_command("public", "hello", [&](Message& m) {
|
||||
hellos++;
|
||||
while (!done) std::this_thread::sleep_for(10ms);
|
||||
});
|
||||
|
||||
server.start();
|
||||
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
|
||||
// Deliberately using a deprecated command here, disable -Wdeprecated-declarations
|
||||
#ifdef __GNUG__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, std::string_view) { got = true; },
|
||||
server.get_pubkey());
|
||||
|
||||
#ifdef __GNUG__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
// First make sure that basic message respects the 3 thread limit
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 3; });
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 3 );
|
||||
}
|
||||
done = true;
|
||||
wait_for([&] { return hellos >= 4; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 4 );
|
||||
}
|
||||
|
||||
// Now try injecting external commands
|
||||
done = false;
|
||||
hellos = 0;
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 1; });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
wait_for([&] { return hellos >= 11; });
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 12; });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
wait_for([&] { return hellos >= 12; });
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 12 );
|
||||
}
|
||||
done = true;
|
||||
wait_for([&] { return hellos >= 42; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 42 );
|
||||
}
|
||||
}
|
|
@ -0,0 +1,611 @@
|
|||
#include "common.h"
|
||||
#include "oxenmq/pubsub.h"
|
||||
|
||||
#include <oxenc/hex.h>
|
||||
|
||||
using namespace oxenmq;
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
TEST_CASE("sub OK", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> is_new{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
is_new = greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::atomic<int> reply_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
bool success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 2 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("user data", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<std::string> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> is_new{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
is_new = greetings.subscribe(m.conn, std::string{m.data[0]});
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::string response{"foo"};
|
||||
std::atomic<int> reply_count{0};
|
||||
std::atomic<int> foo_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == response)
|
||||
reply_count++;
|
||||
if (data[0] == "foo")
|
||||
foo_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
std::atomic<bool> success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
}, response);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( is_new );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
}, response);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( is_new );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn, std::string user) {
|
||||
server.send(conn, "notify.greetings", user);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
REQUIRE( foo_count == 1 );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
response = "bar";
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
}, response);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( is_new );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn, std::string user) {
|
||||
server.send(conn, "notify.greetings", user);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 2 );
|
||||
REQUIRE( foo_count == 1 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("unsubscribe", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> was_subbed{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.add_request_command("public", "goodbye", [&](Message& m) {
|
||||
was_subbed = greetings.unsubscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::atomic<int> reply_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
std::atomic<bool> success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
REQUIRE( was_subbed );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
REQUIRE( was_subbed == false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("expire", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings", 250ms};
|
||||
|
||||
std::atomic<bool> was_subbed{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.add_request_command("public", "goodbye", [&](Message& m) {
|
||||
was_subbed = greetings.unsubscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::atomic<int> reply_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
bool success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
// should be expired by now
|
||||
std::this_thread::sleep_for(500ms);
|
||||
|
||||
greetings.remove_expired();
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
REQUIRE( was_subbed == false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("multiple subs", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> is_new{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
is_new = greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
/* client 1 */
|
||||
std::atomic<int> reply_count_c1{0};
|
||||
std::atomic<bool> connected_c1{false}, failed_c1{false};
|
||||
std::atomic<bool> got_reply_c1{false};
|
||||
bool success_c1;
|
||||
std::vector<std::string> data_c1;
|
||||
std::string pubkey_c1;
|
||||
OxenMQ client1(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
client1.add_category("notify", Access{AuthLevel::none});
|
||||
client1.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count_c1++;
|
||||
});
|
||||
|
||||
client1.start();
|
||||
|
||||
auto c1 = client1.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey_c1 = conn.pubkey(); connected_c1 = true; },
|
||||
[&](auto, auto) { failed_c1 = true; });
|
||||
|
||||
wait_for([&] { return connected_c1 || failed_c1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected_c1 );
|
||||
REQUIRE_FALSE( failed_c1 );
|
||||
REQUIRE( oxenc::to_hex(pubkey_c1) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
client1.request(c1, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply_c1 = true;
|
||||
success_c1 = ok;
|
||||
data_c1 = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply_c1.load() );
|
||||
REQUIRE( success_c1 );
|
||||
REQUIRE( data_c1 == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
/* end client 1 */
|
||||
|
||||
/* client 2 */
|
||||
std::atomic<int> reply_count_c2{0};
|
||||
std::atomic<bool> connected_c2{false}, failed_c2{false};
|
||||
std::atomic<bool> got_reply_c2{false};
|
||||
bool success_c2;
|
||||
std::vector<std::string> data_c2;
|
||||
std::string pubkey_c2;
|
||||
OxenMQ client2(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
client2.add_category("notify", Access{AuthLevel::none});
|
||||
client2.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count_c2++;
|
||||
});
|
||||
|
||||
client2.start();
|
||||
|
||||
auto c2 = client2.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey_c2 = conn.pubkey(); connected_c2 = true; },
|
||||
[&](auto, auto) { failed_c2 = true; });
|
||||
|
||||
wait_for([&] { return connected_c2 || failed_c2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected_c2 );
|
||||
REQUIRE_FALSE( failed_c2 );
|
||||
REQUIRE( oxenc::to_hex(pubkey_c2) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
client2.request(c2, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply_c2 = true;
|
||||
success_c2 = ok;
|
||||
data_c2 = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply_c2.load() );
|
||||
REQUIRE( success_c2 );
|
||||
REQUIRE( data_c2 == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
/* end client2 */
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count_c1 == 1 );
|
||||
REQUIRE( reply_count_c2 == 1 );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count_c1 == 2 );
|
||||
REQUIRE( reply_count_c2 == 2 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// vim:sw=4:et
|
|
@ -1,11 +1,11 @@
|
|||
#include "common.h"
|
||||
#include <lokimq/hex.h>
|
||||
#include <oxenc/hex.h>
|
||||
|
||||
using namespace lokimq;
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("basic requests", "[requests]") {
|
||||
std::string listen = "tcp://127.0.0.1:5678";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -20,7 +20,7 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
});
|
||||
server.start();
|
||||
|
||||
LokiMQ client(
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
@ -30,17 +30,16 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
|
@ -62,8 +61,8 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
}
|
||||
|
||||
TEST_CASE("request from server to client", "[requests]") {
|
||||
std::string listen = "tcp://127.0.0.1:5678";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -78,7 +77,7 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
});
|
||||
server.start();
|
||||
|
||||
LokiMQ client(
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
@ -88,10 +87,9 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
|
@ -104,7 +102,7 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
REQUIRE( connected.load() );
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
|
@ -126,8 +124,8 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
}
|
||||
|
||||
TEST_CASE("request timeouts", "[requests][timeout]") {
|
||||
std::string listen = "tcp://127.0.0.1:5678";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -140,7 +138,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
server.add_request_command("public", "blackhole", [&](Message& m) { /* doesn't reply */ });
|
||||
server.start();
|
||||
|
||||
LokiMQ client(
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
@ -151,16 +149,15 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
|
||||
std::atomic<bool> got_triggered{false};
|
||||
bool success;
|
||||
|
@ -170,7 +167,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
success = ok;
|
||||
data = std::move(data_);
|
||||
},
|
||||
lokimq::send_option::request_timeout{20ms}
|
||||
oxenmq::send_option::request_timeout{10ms}
|
||||
);
|
||||
|
||||
std::atomic<bool> got_triggered2{false};
|
||||
|
@ -179,10 +176,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
success = ok;
|
||||
data = std::move(data_);
|
||||
},
|
||||
lokimq::send_option::request_timeout{100ms}
|
||||
oxenmq::send_option::request_timeout{200ms}
|
||||
);
|
||||
|
||||
std::this_thread::sleep_for(40ms);
|
||||
std::this_thread::sleep_for(100ms);
|
||||
REQUIRE( got_triggered );
|
||||
REQUIRE_FALSE( got_triggered2 );
|
||||
REQUIRE_FALSE( success );
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
#include "common.h"
|
||||
#include <oxenc/hex.h>
|
||||
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("zmq socket limit", "[zmq][socket-limit]") {
|
||||
// Make sure setting .MAX_SOCKETS works as expected. (This test was added when a bug was fixed
|
||||
// that was causing it not to be applied).
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_plain(listen);
|
||||
server.start();
|
||||
|
||||
std::atomic<int> failed = 0, good = 0, failed_toomany = 0;
|
||||
OxenMQ client;
|
||||
client.MAX_SOCKETS = 15;
|
||||
client.start();
|
||||
|
||||
std::vector<ConnectionID> conns;
|
||||
address server_addr{listen};
|
||||
for (int i = 0; i < 16; i++)
|
||||
client.connect_remote(server_addr,
|
||||
[&](auto) { good++; },
|
||||
[&](auto cid, auto msg) {
|
||||
if (msg == "connect() failed: Too many open files")
|
||||
failed_toomany++;
|
||||
else
|
||||
failed++;
|
||||
});
|
||||
|
||||
|
||||
wait_for([&] { return good > 0 && failed_toomany > 0; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( good > 0 );
|
||||
REQUIRE( failed == 0 );
|
||||
REQUIRE( failed_toomany > 0 );
|
||||
}
|
||||
}
|
|
@ -1,187 +0,0 @@
|
|||
#include <catch2/catch.hpp>
|
||||
#include "lokimq/string_view.h"
|
||||
#include <future>
|
||||
|
||||
using namespace lokimq;
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
TEST_CASE("string view", "[string_view]") {
|
||||
std::string foo = "abc 123 xyz";
|
||||
string_view f1{foo};
|
||||
string_view f2{"def 789 uvw"};
|
||||
string_view f3{"nu\0ll", 5};
|
||||
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
REQUIRE( f2 == "def 789 uvw" );
|
||||
REQUIRE( f3.size() == 5 );
|
||||
REQUIRE( f3 == std::string{"nu\0ll", 5} );
|
||||
REQUIRE( f3 != "nu" );
|
||||
REQUIRE( f3.data() == "nu"s );
|
||||
REQUIRE( string_view(f3) == f3 );
|
||||
|
||||
auto f4 = f3;
|
||||
REQUIRE( f4 == f3 );
|
||||
f4 = f2;
|
||||
REQUIRE( f4 == "def 789 uvw" );
|
||||
|
||||
REQUIRE( f1.size() == 11 );
|
||||
REQUIRE( f3.length() == 5 );
|
||||
|
||||
string_view f5{""};
|
||||
REQUIRE( !f3.empty() );
|
||||
REQUIRE( f5.empty() );
|
||||
|
||||
REQUIRE( f1[5] == '2' );
|
||||
size_t i = 0;
|
||||
for (auto c : f3)
|
||||
REQUIRE(c == f3[i++]);
|
||||
|
||||
std::string backwards;
|
||||
for (auto it = std::rbegin(f2); it != f2.crend(); ++it)
|
||||
backwards += *it;
|
||||
|
||||
REQUIRE( backwards == "wvu 987 fed" );
|
||||
|
||||
REQUIRE( f1.at(10) == 'z' );
|
||||
REQUIRE_THROWS_AS( f1.at(15), std::out_of_range );
|
||||
REQUIRE_THROWS_AS( f1.at(11), std::out_of_range );
|
||||
|
||||
f4 = f1;
|
||||
f4.remove_prefix(2);
|
||||
REQUIRE( f4 == "c 123 xyz" );
|
||||
f4.remove_prefix(2);
|
||||
f4.remove_suffix(4);
|
||||
REQUIRE( f4 == "123" );
|
||||
f4.remove_prefix(1);
|
||||
REQUIRE( f4 == "23" );
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
f4.swap(f1);
|
||||
REQUIRE( f1 == "23" );
|
||||
REQUIRE( f4 == "abc 123 xyz" );
|
||||
f1.remove_suffix(2);
|
||||
REQUIRE( f1.empty() );
|
||||
REQUIRE( f4 == "abc 123 xyz" );
|
||||
f1.swap(f4);
|
||||
REQUIRE( f4.empty() );
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
|
||||
REQUIRE( f1.front() == 'a' );
|
||||
REQUIRE( f1.back() == 'z' );
|
||||
REQUIRE( f1.compare("abc") > 0 );
|
||||
REQUIRE( f1.compare("abd") < 0 );
|
||||
REQUIRE( f1.compare("abc 123 xyz") == 0 );
|
||||
REQUIRE( f1.compare("abc 123 xyza") < 0 );
|
||||
REQUIRE( f1.compare("abc 123 xy") > 0 );
|
||||
|
||||
std::string buf;
|
||||
buf.resize(5);
|
||||
f1.copy(&buf[0], 5, 2);
|
||||
REQUIRE( buf == "c 123" );
|
||||
buf.resize(100, 'X');
|
||||
REQUIRE( f1.copy(&buf[0], 100) == 11 );
|
||||
REQUIRE( buf.substr(0, 11) == f1 );
|
||||
REQUIRE( buf.substr(11) == std::string(89, 'X') );
|
||||
REQUIRE( f1.substr(4) == "123 xyz" );
|
||||
REQUIRE( f1.substr(4, 3) == "123" );
|
||||
REQUIRE_THROWS_AS( f1.substr(500, 3), std::out_of_range );
|
||||
REQUIRE( f1.substr(11, 2) == "" );
|
||||
REQUIRE( f1.substr(8, 500) == "xyz" );
|
||||
REQUIRE( f1.find("123") == 4 );
|
||||
REQUIRE( f1.find("abc") == 0 );
|
||||
REQUIRE( f1.find("xyz") == 8 );
|
||||
REQUIRE( f1.find("abc 123 xyz 7") == string_view::npos );
|
||||
REQUIRE( f1.find("23") == 5 );
|
||||
REQUIRE( f1.find("234") == string_view::npos );
|
||||
|
||||
string_view f6{"zz abc abcd abcde abcdef"};
|
||||
REQUIRE( f6.find("abc") == 3 );
|
||||
REQUIRE( f6.find("abc", 3) == 3 );
|
||||
REQUIRE( f6.find("abc", 4) == 7 );
|
||||
REQUIRE( f6.find("abc", 7) == 7 );
|
||||
REQUIRE( f6.find("abc", 8) == 12 );
|
||||
REQUIRE( f6.find("abc", 18) == 18 );
|
||||
REQUIRE( f6.find("abc", 19) == string_view::npos );
|
||||
REQUIRE( f6.find("abcd") == 7 );
|
||||
REQUIRE( f6.rfind("abc") == 18 );
|
||||
REQUIRE( f6.rfind("abcd") == 18 );
|
||||
REQUIRE( f6.rfind("bcd") == 19 );
|
||||
REQUIRE( f6.rfind("abc", 19) == 18 );
|
||||
REQUIRE( f6.rfind("abc", 18) == 18 );
|
||||
REQUIRE( f6.rfind("abc", 17) == 12 );
|
||||
REQUIRE( f6.rfind("abc", 17) == 12 );
|
||||
REQUIRE( f6.rfind("abc", 8) == 7 );
|
||||
REQUIRE( f6.rfind("abc", 7) == 7 );
|
||||
REQUIRE( f6.rfind("abc", 6) == 3 );
|
||||
REQUIRE( f6.rfind("abc", 3) == 3 );
|
||||
REQUIRE( f6.rfind("abc", 2) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.find('a') == 3 );
|
||||
REQUIRE( f6.find('a', 17) == 18 );
|
||||
REQUIRE( f6.find('a', 20) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.rfind('a') == 18 );
|
||||
REQUIRE( f6.rfind('a', 17) == 12 );
|
||||
REQUIRE( f6.rfind('a', 2) == string_view::npos );
|
||||
|
||||
string_view f7{"abc\0def", 7};
|
||||
REQUIRE( f7.find("c\0d", 0, 3) == 2 );
|
||||
REQUIRE( f7.find("c\0e", 0, 3) == string_view::npos );
|
||||
REQUIRE( f7.rfind("c\0d", string_view::npos, 3) == 2 );
|
||||
REQUIRE( f7.rfind("c\0e", 0, 3) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.find_first_of("c789b") == 4 );
|
||||
REQUIRE( f6.find_first_of("c789") == 5 );
|
||||
REQUIRE( f2.find_first_of("c789b") == 4 );
|
||||
REQUIRE( f6.find_first_of("c789b", 6) == 8 );
|
||||
|
||||
REQUIRE( f6.find_last_of("c789b") == 20 );
|
||||
REQUIRE( f6.find_last_of("789b") == 19 );
|
||||
REQUIRE( f2.find_last_of("c789b") == 6 );
|
||||
REQUIRE( f6.find_last_of("c789b", 6) == 5 );
|
||||
REQUIRE( f6.find_last_of("c789b", 5) == 5 );
|
||||
REQUIRE( f6.find_last_of("c789b", 4) == 4 );
|
||||
REQUIRE( f6.find_last_of("c789b", 3) == string_view::npos );
|
||||
|
||||
REQUIRE( f2.find_first_of(f7) == 0 );
|
||||
REQUIRE( f3.find_first_of(f7) == 2 );
|
||||
REQUIRE( f3.find_first_of('\0') == 2 );
|
||||
REQUIRE( f3.find_first_of("jk\0", 0, 3) == 2 );
|
||||
|
||||
REQUIRE( f1.find_first_not_of("abc") == 3 );
|
||||
REQUIRE( f1.find_first_not_of("abc ", 3) == 4 );
|
||||
REQUIRE( f1.find_first_not_of(" 123", 3) == 8 );
|
||||
REQUIRE( f1.find_last_not_of("abc") == 10 );
|
||||
REQUIRE( f1.find_last_not_of("xyz") == 7 );
|
||||
REQUIRE( f1.find_last_not_of("xyz 321") == 2 );
|
||||
REQUIRE( f1.find_last_not_of("xay z1b2c3") == string_view::npos );
|
||||
REQUIRE( f6.find_last_not_of("def") == 20 );
|
||||
REQUIRE( f6.find_last_not_of("abcdef") == 17 );
|
||||
REQUIRE( f6.find_last_not_of("abcdef ") == 1 );
|
||||
REQUIRE( f6.find_first_not_of('z') == 2 );
|
||||
REQUIRE( f6.find_first_not_of("z ") == 3 );
|
||||
REQUIRE( f6.find_first_not_of("a ", 2) == 4 );
|
||||
REQUIRE( f6.find_last_not_of("abc ", 9) == 1 );
|
||||
|
||||
REQUIRE( string_view{"abc"} == string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} == string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} == string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} != string_view{"abd"} );
|
||||
REQUIRE( string_view{"abc"} != string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} < string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} < string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abd"} < string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abcd"} < string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} < string_view{"abc"} );
|
||||
REQUIRE( string_view{"abd"} > string_view{"abc"} );
|
||||
REQUIRE( string_view{"abcd"} > string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abcd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abd"} );
|
||||
REQUIRE( string_view{"abd"} >= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} >= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abcd"} >= string_view{"abc"} );
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
#include "oxenmq/batch.h"
|
||||
#include "common.h"
|
||||
#include <future>
|
||||
|
||||
TEST_CASE("tagged thread start functions", "[tagged][start]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(2);
|
||||
omq.set_batch_threads(2);
|
||||
auto t_abc = omq.add_tagged_thread("abc");
|
||||
std::atomic<bool> start_called = false;
|
||||
auto t_def = omq.add_tagged_thread("def", [&] { start_called = true; });
|
||||
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE_FALSE( start_called );
|
||||
}
|
||||
|
||||
omq.start();
|
||||
wait_for([&] { return start_called.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( start_called );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("tagged threads quit-before-start", "[tagged][quit]") {
|
||||
auto omq = std::make_unique<oxenmq::OxenMQ>(get_logger(""), LogLevel::trace);
|
||||
auto t_abc = omq->add_tagged_thread("abc");
|
||||
REQUIRE_NOTHROW(omq.reset());
|
||||
}
|
||||
|
||||
TEST_CASE("batch jobs to tagged threads", "[tagged][batch]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(2);
|
||||
omq.set_batch_threads(2);
|
||||
std::thread::id id_abc, id_def;
|
||||
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
auto t_def = omq.add_tagged_thread("def", [&] { id_def = std::this_thread::get_id(); });
|
||||
omq.start();
|
||||
|
||||
std::atomic<bool> done = false;
|
||||
std::thread::id id;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; });
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id != id_abc );
|
||||
REQUIRE( id != id_def );
|
||||
}
|
||||
|
||||
done = false;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id == id_abc );
|
||||
}
|
||||
|
||||
done = false;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_def);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id == id_def );
|
||||
}
|
||||
|
||||
std::atomic<bool> sleep = true;
|
||||
auto sleeper = [&] { for (int i = 0; sleep && i < 10; i++) { std::this_thread::sleep_for(25ms); } };
|
||||
omq.job(sleeper);
|
||||
omq.job(sleeper);
|
||||
// This one should stall:
|
||||
std::atomic<bool> bad = false;
|
||||
omq.job([&] { bad = true; });
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
done = false;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE_FALSE( bad.load() );
|
||||
}
|
||||
|
||||
done = false;
|
||||
// We can queue up a bunch of jobs which should all happen in order, and all on the abc thread.
|
||||
std::vector<int> v;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
omq.job([&] { if (std::this_thread::get_id() == id_abc) v.push_back(v.size()); }, t_abc);
|
||||
}
|
||||
omq.job([&] { done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE_FALSE( bad.load() );
|
||||
REQUIRE( v.size() == 100 );
|
||||
for (int i = 0; i < 100; i++)
|
||||
REQUIRE( v[i] == i );
|
||||
}
|
||||
sleep = false;
|
||||
wait_for([&] { return bad.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( bad.load() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
std::thread::id id_abc;
|
||||
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
omq.start();
|
||||
|
||||
oxenmq::Batch<int> batch;
|
||||
for (int i = 1; i < 10; i++)
|
||||
batch.add_job([i, &id_abc]() { if (std::this_thread::get_id() == id_abc) return 0; return i; });
|
||||
|
||||
std::atomic<int> result_sum = -1;
|
||||
batch.completion([&](auto result) {
|
||||
int sum = 0;
|
||||
for (auto& r : result)
|
||||
sum += r.get();
|
||||
result_sum = std::this_thread::get_id() == id_abc ? sum : -sum;
|
||||
}, t_abc);
|
||||
omq.batch(std::move(batch));
|
||||
wait_for([&] { return result_sum.load() != -1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( result_sum == 45 );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("timer job completion on tagged threads", "[tagged][timer]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
|
||||
std::thread::id id_abc;
|
||||
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
omq.start();
|
||||
|
||||
std::atomic<int> ticks = 0;
|
||||
std::atomic<int> abc_ticks = 0;
|
||||
omq.add_timer([&] { ticks++; }, 10ms);
|
||||
omq.add_timer([&] { if (std::this_thread::get_id() == id_abc) abc_ticks++; }, 10ms, true, t_abc);
|
||||
|
||||
wait_for([&] { return ticks.load() > 2 && abc_ticks > 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks.load() > 2 );
|
||||
REQUIRE( abc_ticks.load() > 2 );
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
#include "oxenmq/oxenmq.h"
|
||||
#include "common.h"
|
||||
#include <chrono>
|
||||
#include <future>
|
||||
|
||||
TEST_CASE("timer test", "[timer][basic]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(1);
|
||||
omq.set_batch_threads(1);
|
||||
|
||||
std::atomic<int> ticks = 0;
|
||||
auto timer = omq.add_timer([&] { ticks++; }, 5ms);
|
||||
omq.start();
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
wait_for([&] { return ticks.load() > 3; });
|
||||
{
|
||||
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start).count();
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks.load() > 3 );
|
||||
REQUIRE( elapsed_ms < 50 * TIME_DILATION );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("timer squelch", "[timer][squelch]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(3);
|
||||
omq.set_batch_threads(3);
|
||||
|
||||
std::atomic<bool> first = true;
|
||||
std::atomic<bool> done = false;
|
||||
std::atomic<int> ticks = 0;
|
||||
|
||||
// Set up a timer with squelch on; the job shouldn't get rescheduled until the first call
|
||||
// finishes, by which point we set `done` and so should get exactly 1 tick.
|
||||
auto timer = omq.add_timer([&] {
|
||||
if (first.exchange(false)) {
|
||||
std::this_thread::sleep_for(30ms * TIME_DILATION);
|
||||
ticks++;
|
||||
done = true;
|
||||
} else if (!done) {
|
||||
ticks++;
|
||||
}
|
||||
}, 5ms * TIME_DILATION, true /* squelch */);
|
||||
omq.start();
|
||||
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE( ticks.load() == 1 );
|
||||
}
|
||||
|
||||
// Start another timer with squelch *off*; the subsequent jobs should get scheduled even while
|
||||
// the first one blocks
|
||||
std::atomic<bool> first2 = true;
|
||||
std::atomic<bool> done2 = false;
|
||||
std::atomic<int> ticks2 = 0;
|
||||
auto timer2 = omq.add_timer([&] {
|
||||
if (first2.exchange(false)) {
|
||||
std::this_thread::sleep_for(40ms * TIME_DILATION);
|
||||
done2 = true;
|
||||
} else if (!done2) {
|
||||
ticks2++;
|
||||
}
|
||||
}, 5ms, false /* squelch */);
|
||||
|
||||
wait_for([&] { return done2.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks2.load() > 2 );
|
||||
REQUIRE( done2.load() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("timer cancel", "[timer][cancel]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(1);
|
||||
omq.set_batch_threads(1);
|
||||
|
||||
std::atomic<int> ticks = 0;
|
||||
|
||||
// We set up *and cancel* this timer before omq starts, so it should never fire
|
||||
auto notimer = omq.add_timer([&] { ticks += 1000; }, 5ms * TIME_DILATION);
|
||||
omq.cancel_timer(notimer);
|
||||
|
||||
TimerID timer = omq.add_timer([&] {
|
||||
if (++ticks == 3)
|
||||
omq.cancel_timer(timer);
|
||||
}, 5ms * TIME_DILATION);
|
||||
|
||||
omq.start();
|
||||
|
||||
wait_for([&] { return ticks.load() >= 3; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks.load() == 3 );
|
||||
}
|
||||
|
||||
// Test the alternative taking an lvalue reference instead of returning by value (see oxenmq.h
|
||||
// for why this is sometimes needed).
|
||||
std::atomic<int> ticks3 = 0;
|
||||
std::weak_ptr<TimerID> w_timer3;
|
||||
{
|
||||
auto timer3 = std::make_shared<TimerID>();
|
||||
auto& t3ref = *timer3; // Get this reference *before* we move the shared pointer into the lambda
|
||||
omq.add_timer(t3ref, [&ticks3, &omq, timer3=std::move(timer3)] {
|
||||
if (ticks3 == 0)
|
||||
ticks3++;
|
||||
else if (ticks3 > 1) {
|
||||
omq.cancel_timer(*timer3);
|
||||
ticks3++;
|
||||
}
|
||||
}, 1ms);
|
||||
}
|
||||
|
||||
wait_for([&] { return ticks3.load() >= 1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks3.load() == 1 );
|
||||
}
|
||||
ticks3++;
|
||||
wait_for([&] { return ticks3.load() >= 3 && w_timer3.expired(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks3.load() == 3 );
|
||||
REQUIRE( w_timer3.expired() );
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue