mirror of
https://github.com/oxen-io/oxen-mq.git
synced 2023-12-13 21:00:31 +01:00
Compare commits
242 commits
Author | SHA1 | Date | |
---|---|---|---|
a27961d787 | |||
5878473f67 | |||
68b3420bad | |||
dc7fb35493 | |||
caadd35052 | |||
fd58ab9cac | |||
8f97add30f | |||
e1b66ced48 | |||
4f3ee28784 | |||
bd3e2cdfb0 | |||
b8bb10eac5 | |||
ff0e515c51 | |||
2e308d4f43 | |||
445f214840 | |||
358005df06 | |||
85437d167b | |||
b26fe8cb04 | |||
df19d1dd94 | |||
25f714371b | |||
0858dd278b | |||
057685b7c0 | |||
3a3ffa7d23 | |||
edcde9246a | |||
c854046684 | |||
c91e56cf2d | |||
61b7505304 | |||
b0c3bd4ee9 | |||
fd95919704 | |||
4671af3ca0 | |||
c4b7aa9b23 | |||
115c5550ca | |||
ace6ea9d8e | |||
62a803f371 | |||
d86ecb3a70 | |||
45791d3a19 | |||
b8e4eb148f | |||
fa6de369b2 | |||
371606cde0 | |||
3a51713396 | |||
5c7f6504d2 | |||
5a3c12e721 | |||
f0c2222d6e | |||
320a85ac0c | |||
7fca36b3a9 | |||
bbdf4af98f | |||
77c4840273 | |||
d7f5efebc1 | |||
a0a54ed461 | |||
045df9cb9b | |||
3d178ce3ea | |||
fe8a1f4306 | |||
3b634329ac | |||
f88691b7e9 | |||
9c022b29de | |||
4d68868482 | |||
430951bf3c | |||
03749c87f0 | |||
85d35fa505 | |||
e180187746 | |||
e382373f2e | |||
375cfab4ce | |||
f04bd72a4c | |||
31f64821f8 | |||
a53e1f1786 | |||
39b6d89037 | |||
f0bb2c3d3f | |||
09f3de2232 | |||
519a107542 | |||
23c2d537a3 | |||
6a386b7d4a | |||
5e9b8c0948 | |||
560d38d069 | |||
504d0d10ea | |||
7695e770a7 | |||
0d0ed8efa9 | |||
02a542b9c6 | |||
9a8adb5bfd | |||
ee1d69f333 | |||
24dd7a3854 | |||
cd56ad8e08 | |||
6100802f82 | |||
7cb7c2fd6d | |||
5a41e84378 | |||
377932607c | |||
cdd21a9e81 | |||
977bced84e | |||
9e3469d968 | |||
f12a48a195 | |||
e1b1a84c4b | |||
2ac4379fa6 | |||
ae884d2f13 | |||
45f358ab5f | |||
c6ae1faefa | |||
719d33f1cc | |||
f553085558 | |||
bae71ec6a8 | |||
29cd543af9 | |||
917c7d64c5 | |||
4a24ac9baa | |||
e1d21d3faf | |||
1d2246cda8 | |||
3bb32a81ff | |||
9e0d2e24f6 | |||
4a6bb3f702 | |||
ad04c53c0e | |||
7ba81a7d50 | |||
45db87f712 | |||
a0642a894e | |||
5dd7c12219 | |||
dccbd1e8cd | |||
780246858f | |||
0287f7834e | |||
cdc6a9709c | |||
3991f50547 | |||
26745299ed | |||
4ef1060e3f | |||
5ccacafdb1 | |||
6d20a3614a | |||
39dce56e14 | |||
ac58e5b574 | |||
99a3f1d840 | |||
dc40ebd428 | |||
e3e79e1fb7 | |||
f9ef827075 | |||
506bd65b05 | |||
86247bc5c7 | |||
396f591fae | |||
b49a94fb83 | |||
0738695eb9 | |||
2ae6b96016 | |||
bd9313bf19 | |||
1959f8747d | |||
90701e5d62 | |||
178bd4f674 | |||
b1543513bb | |||
253f1ee66e | |||
d889f308ae | |||
768a639dea | |||
ec0d44e143 | |||
ea484729c7 | |||
7049d3cb5a | |||
8ed529200b | |||
318781a6d4 | |||
f37e619d7b | |||
0ac1d48bc8 | |||
0938e1fc53 | |||
0c9eeeea43 | |||
9467c4682c | |||
8c28c52d41 | |||
faeeaa86d4 | |||
8d3ed4606f | |||
30faadf01a | |||
d8d1d8677c | |||
e5cf174b83 | |||
af189a8d72 | |||
d2f852c217 | |||
ee080e0550 | |||
7cd58e4677 | |||
9c54264321 | |||
932bbb33d7 | |||
07b31bd8a1 | |||
8a56b18cc6 | |||
1d56c3d44c | |||
66176d44d7 | |||
4e89dce5b6 | |||
0493f615b9 | |||
d0a73e5e68 | |||
278909db77 | |||
3edcab9344 | |||
ae8dd27cdd | |||
8caab97355 | |||
44b91534c2 | |||
29380922bf | |||
6356421488 | |||
d28e39ffeb | |||
9a103f1bf6 | |||
211d5211b0 | |||
9a283a148c | |||
65aa5940be | |||
ec9c58ea34 | |||
af59d58797 | |||
e072e68d84 | |||
e5a8d09127 | |||
a24e87d4d0 | |||
9ac47ec419 | |||
d0a07f7c08 | |||
86f5b463e9 | |||
68c1899cda | |||
1479a030d7 | |||
f296b82ba5 | |||
1f60abf50e | |||
de395af872 | |||
e970f14e55 | |||
1e38f3b1d1 | |||
c9cf833861 | |||
7b42537801 | |||
8984dfc4ea | |||
be4cbc6641 | |||
46d007e1ac | |||
59a41943d4 | |||
719a9b0b58 | |||
22559548fc | |||
7b552007df | |||
b905a8a4ff | |||
08a11bb9ba | |||
3a0508fdce | |||
f4f1506df0 | |||
a812abd422 | |||
730633bbae | |||
99bbf8dea9 | |||
1a65d7f5e5 | |||
e7cd2dedc2 | |||
6ddf033674 | |||
0ebfef2164 | |||
fc1ea66599 | |||
238dfa7f78 | |||
911c66140f | |||
2966427cc0 | |||
34bbaaf612 | |||
b2518b8eb3 | |||
712662f144 | |||
131bc95f65 | |||
3aa63c059d | |||
7de36da483 | |||
b081cf9331 | |||
84bd5544cc | |||
3b86eb1341 | |||
fb3bf9bd1f | |||
95540ec7d5 | |||
af42875e97 | |||
bc49b5e9a0 | |||
e3a86aaf71 | |||
b9e9f10f29 | |||
d4ffebebbd | |||
6ba70923b9 | |||
4c470f3e33 | |||
bd196d08b8 | |||
b66f653708 | |||
716d73d196 | |||
8e1b2dffa5 | |||
2493e2abd4 | |||
bcca8dd34e |
117
.drone.jsonnet
Normal file
117
.drone.jsonnet
Normal file
|
@ -0,0 +1,117 @@
|
|||
local docker_base = 'registry.oxen.rocks/lokinet-ci-';
|
||||
|
||||
local default_deps_nocxx = ['libsodium-dev', 'libzmq3-dev', 'liboxenc-dev'];
|
||||
|
||||
local submodule_commands = ['git fetch --tags', 'git submodule update --init --recursive --depth=1'];
|
||||
|
||||
local submodules = {
|
||||
name: 'submodules',
|
||||
image: 'drone/git',
|
||||
commands: submodule_commands,
|
||||
};
|
||||
|
||||
local apt_get_quiet = 'apt-get -o=Dpkg::Use-Pty=0 -q ';
|
||||
|
||||
local debian_pipeline(name,
|
||||
image,
|
||||
arch='amd64',
|
||||
deps=['g++'] + default_deps_nocxx,
|
||||
cmake_extra='',
|
||||
build_type='Release',
|
||||
extra_cmds=[],
|
||||
distro='$$(lsb_release -sc)',
|
||||
allow_fail=false) = {
|
||||
kind: 'pipeline',
|
||||
type: 'docker',
|
||||
name: name,
|
||||
platform: { arch: arch },
|
||||
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
|
||||
steps: [
|
||||
submodules,
|
||||
{
|
||||
name: 'build',
|
||||
image: image,
|
||||
pull: 'always',
|
||||
[if allow_fail then 'failure']: 'ignore',
|
||||
commands: [
|
||||
'echo "Building on ${DRONE_STAGE_MACHINE}"',
|
||||
'echo "man-db man-db/auto-update boolean false" | debconf-set-selections',
|
||||
apt_get_quiet + 'update',
|
||||
apt_get_quiet + 'install -y eatmydata',
|
||||
'eatmydata ' + apt_get_quiet + ' install --no-install-recommends -y lsb-release',
|
||||
'cp contrib/deb.oxen.io.gpg /etc/apt/trusted.gpg.d',
|
||||
'echo deb http://deb.oxen.io ' + distro + ' main >/etc/apt/sources.list.d/oxen.list',
|
||||
'eatmydata ' + apt_get_quiet + ' update',
|
||||
'eatmydata ' + apt_get_quiet + 'dist-upgrade -y',
|
||||
'eatmydata ' + apt_get_quiet + 'install -y cmake git ninja-build pkg-config ccache ' + std.join(' ', deps),
|
||||
'mkdir build',
|
||||
'cd build',
|
||||
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fdiagnostics-color=always -DCMAKE_BUILD_TYPE=' + build_type + ' -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ' + cmake_extra,
|
||||
'ninja -v',
|
||||
'./tests/tests --use-colour yes',
|
||||
] + extra_cmds,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
local clang(version) = debian_pipeline(
|
||||
'Debian sid/clang-' + version + ' (amd64)',
|
||||
docker_base + 'debian-sid-clang',
|
||||
distro='sid',
|
||||
deps=['clang-' + version] + default_deps_nocxx,
|
||||
cmake_extra='-DCMAKE_C_COMPILER=clang-' + version + ' -DCMAKE_CXX_COMPILER=clang++-' + version + ' '
|
||||
);
|
||||
|
||||
local full_llvm(version) = debian_pipeline(
|
||||
'Debian sid/llvm-' + version + ' (amd64)',
|
||||
docker_base + 'debian-sid-clang',
|
||||
distro='sid',
|
||||
deps=['clang-' + version, 'lld-' + version, 'libc++-' + version + '-dev', 'libc++abi-' + version + '-dev']
|
||||
+ default_deps_nocxx,
|
||||
cmake_extra='-DCMAKE_C_COMPILER=clang-' + version +
|
||||
' -DCMAKE_CXX_COMPILER=clang++-' + version +
|
||||
' -DCMAKE_CXX_FLAGS=-stdlib=libc++ ' +
|
||||
std.join(' ', [
|
||||
'-DCMAKE_' + type + '_LINKER_FLAGS=-fuse-ld=lld-' + version
|
||||
for type in ['EXE', 'MODULE', 'SHARED', 'STATIC']
|
||||
])
|
||||
);
|
||||
|
||||
|
||||
[
|
||||
debian_pipeline('Debian sid (amd64)', docker_base + 'debian-sid', distro='sid'),
|
||||
debian_pipeline('Debian sid/Debug (amd64)', docker_base + 'debian-sid', build_type='Debug', distro='sid'),
|
||||
clang(16),
|
||||
full_llvm(16),
|
||||
debian_pipeline('Debian buster (amd64)', docker_base + 'debian-buster'),
|
||||
debian_pipeline('Debian stable (i386)', docker_base + 'debian-stable/i386'),
|
||||
debian_pipeline('Debian sid (ARM64)', docker_base + 'debian-sid', arch='arm64', distro='sid'),
|
||||
debian_pipeline('Debian stable (armhf)', docker_base + 'debian-stable/arm32v7', arch='arm64'),
|
||||
debian_pipeline('Debian buster (armhf)', docker_base + 'debian-buster/arm32v7', arch='arm64'),
|
||||
debian_pipeline('Ubuntu focal (amd64)', docker_base + 'ubuntu-focal'),
|
||||
debian_pipeline('Ubuntu bionic (amd64)',
|
||||
docker_base + 'ubuntu-bionic',
|
||||
deps=default_deps_nocxx,
|
||||
cmake_extra='-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8'),
|
||||
{
|
||||
kind: 'pipeline',
|
||||
type: 'exec',
|
||||
name: 'macOS (w/macports)',
|
||||
platform: { os: 'darwin', arch: 'amd64' },
|
||||
environment: { CLICOLOR_FORCE: '1' }, // Lets color through ninja (1.9+)
|
||||
steps: [
|
||||
{ name: 'submodules', commands: submodule_commands },
|
||||
{
|
||||
name: 'build',
|
||||
commands: [
|
||||
'mkdir build',
|
||||
'cd build',
|
||||
'ulimit -n 1024', // Because macOS has a stupid tiny default ulimit
|
||||
'cmake .. -G Ninja -DCMAKE_CXX_FLAGS=-fcolor-diagnostics -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
|
||||
'ninja -v',
|
||||
'./tests/tests --use-colour yes',
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
6
.gitmodules
vendored
6
.gitmodules
vendored
|
@ -1,9 +1,9 @@
|
|||
[submodule "mapbox-variant"]
|
||||
path = mapbox-variant
|
||||
url = https://github.com/mapbox/variant.git
|
||||
[submodule "cppzmq"]
|
||||
path = cppzmq
|
||||
url = https://github.com/zeromq/cppzmq.git
|
||||
[submodule "Catch2"]
|
||||
path = tests/Catch2
|
||||
url = https://github.com/catchorg/Catch2.git
|
||||
[submodule "oxen-encoding"]
|
||||
path = oxen-encoding
|
||||
url = https://github.com/oxen-io/oxen-encoding.git
|
||||
|
|
252
CMakeLists.txt
252
CMakeLists.txt
|
@ -1,88 +1,163 @@
|
|||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
foreach(lang C CXX)
|
||||
if(NOT DEFINED CMAKE_${lang}_COMPILER_LAUNCHER AND NOT CMAKE_${lang}_COMPILER MATCHES ".*/ccache")
|
||||
message(STATUS "Enabling ccache for ${lang}")
|
||||
set(CMAKE_${lang}_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE STRING "")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
cmake_minimum_required(VERSION 3.7)
|
||||
|
||||
project(liblokimq CXX)
|
||||
# Has to be set before `project()`, and ignored on non-macos:
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET 10.12 CACHE STRING "macOS deployment target (Apple clang only)")
|
||||
|
||||
project(liboxenmq
|
||||
VERSION 1.2.16
|
||||
LANGUAGES CXX C)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
set(LOKIMQ_VERSION_MAJOR 1)
|
||||
set(LOKIMQ_VERSION_MINOR 0)
|
||||
set(LOKIMQ_VERSION_PATCH 3)
|
||||
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
|
||||
message(STATUS "lokimq v${LOKIMQ_VERSION}")
|
||||
message(STATUS "oxenmq v${PROJECT_VERSION}")
|
||||
|
||||
set(LOKIMQ_LIBVERSION 0)
|
||||
set(OXENMQ_LIBVERSION 0)
|
||||
|
||||
|
||||
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
|
||||
set(oxenmq_IS_TOPLEVEL_PROJECT TRUE)
|
||||
else()
|
||||
set(oxenmq_IS_TOPLEVEL_PROJECT FALSE)
|
||||
endif()
|
||||
|
||||
|
||||
option(BUILD_SHARED_LIBS "Build shared libraries instead of static ones" ON)
|
||||
set(oxenmq_INSTALL_DEFAULT OFF)
|
||||
if(BUILD_SHARED_LIBS OR oxenmq_IS_TOPLEVEL_PROJECT)
|
||||
set(oxenmq_INSTALL_DEFAULT ON)
|
||||
endif()
|
||||
set(oxenmq_EPOLL_DEFAULT OFF)
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND NOT CMAKE_CROSSCOMPILING)
|
||||
set(oxenmq_EPOLL_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
option(OXENMQ_BUILD_TESTS "Building and perform oxenmq tests" ${oxenmq_IS_TOPLEVEL_PROJECT})
|
||||
option(OXENMQ_INSTALL "Add oxenmq libraries and headers to cmake install target; defaults to ON if BUILD_SHARED_LIBS is enabled or we are the top-level project; OFF for a static subdirectory build" ${oxenmq_INSTALL_DEFAULT})
|
||||
option(OXENMQ_INSTALL_CPPZMQ "Install cppzmq header with oxenmq/ headers (requires OXENMQ_INSTALL)" ON)
|
||||
option(OXENMQ_USE_EPOLL "Use epoll for socket polling (requires Linux)" ${oxenmq_EPOLL_DEFAULT})
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
configure_file(lokimq/version.h.in lokimq/version.h @ONLY)
|
||||
configure_file(liblokimq.pc.in liblokimq.pc @ONLY)
|
||||
configure_file(oxenmq/version.h.in oxenmq/version.h @ONLY)
|
||||
configure_file(liboxenmq.pc.in liboxenmq.pc @ONLY)
|
||||
|
||||
add_library(lokimq
|
||||
lokimq/auth.cpp
|
||||
lokimq/batch.cpp
|
||||
lokimq/bt_serialize.cpp
|
||||
lokimq/connections.cpp
|
||||
lokimq/jobs.cpp
|
||||
lokimq/lokimq.cpp
|
||||
lokimq/proxy.cpp
|
||||
lokimq/worker.cpp
|
||||
|
||||
add_library(oxenmq
|
||||
oxenmq/address.cpp
|
||||
oxenmq/auth.cpp
|
||||
oxenmq/connections.cpp
|
||||
oxenmq/jobs.cpp
|
||||
oxenmq/oxenmq.cpp
|
||||
oxenmq/proxy.cpp
|
||||
oxenmq/worker.cpp
|
||||
)
|
||||
set_target_properties(lokimq PROPERTIES SOVERSION ${LOKIMQ_LIBVERSION})
|
||||
set_target_properties(oxenmq PROPERTIES SOVERSION ${OXENMQ_LIBVERSION})
|
||||
if(OXENMQ_USE_EPOLL)
|
||||
target_compile_definitions(oxenmq PRIVATE OXENMQ_USE_EPOLL)
|
||||
endif()
|
||||
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
target_link_libraries(lokimq PRIVATE Threads::Threads)
|
||||
target_link_libraries(oxenmq PRIVATE Threads::Threads)
|
||||
|
||||
|
||||
if(TARGET oxenc::oxenc)
|
||||
add_library(_oxenmq_external_oxenc INTERFACE IMPORTED)
|
||||
target_link_libraries(_oxenmq_external_oxenc INTERFACE oxenc::oxenc)
|
||||
target_link_libraries(oxenmq PUBLIC _oxenmq_external_oxenc)
|
||||
message(STATUS "using pre-existing oxenc::oxenc target")
|
||||
elseif(BUILD_SHARED_LIBS)
|
||||
include(FindPkgConfig)
|
||||
pkg_check_modules(oxenc liboxenc IMPORTED_TARGET)
|
||||
|
||||
if(oxenc_FOUND)
|
||||
# Work around cmake bug 22180 (PkgConfig::tgt not set if no flags needed)
|
||||
if(TARGET PkgConfig::oxenc OR CMAKE_VERSION VERSION_GREATER_EQUAL "3.21")
|
||||
target_link_libraries(oxenmq PUBLIC PkgConfig::oxenc)
|
||||
endif()
|
||||
else()
|
||||
add_subdirectory(oxen-encoding)
|
||||
target_link_libraries(oxenmq PUBLIC oxenc::oxenc)
|
||||
endif()
|
||||
else()
|
||||
add_subdirectory(oxen-encoding)
|
||||
target_link_libraries(oxenmq PUBLIC oxenc::oxenc)
|
||||
endif()
|
||||
|
||||
# libzmq is nearly impossible to link statically from a system-installed static library: it depends
|
||||
# on a ton of other libraries, some of which are not all statically available. If the caller wants
|
||||
# to mess with this, so be it: they can set up a libzmq target and we'll use it. Otherwise if they
|
||||
# asked us to do things statically, don't even try to find a system lib and just build it.
|
||||
set(lokimq_build_static_libzmq OFF)
|
||||
set(oxenmq_build_static_libzmq OFF)
|
||||
if(TARGET libzmq)
|
||||
target_link_libraries(lokimq PUBLIC libzmq)
|
||||
target_link_libraries(oxenmq PUBLIC libzmq)
|
||||
elseif(BUILD_SHARED_LIBS)
|
||||
include(FindPkgConfig)
|
||||
pkg_check_modules(libzmq libzmq>=4.3 IMPORTED_TARGET)
|
||||
|
||||
if(libzmq_FOUND)
|
||||
target_link_libraries(lokimq PUBLIC PkgConfig::libzmq)
|
||||
# Debian sid includes a -isystem in the mit-krb package that, starting with pkg-config 0.29.2,
|
||||
# breaks cmake's pkgconfig module because it stupidly thinks "-isystem" is a path, so if we find
|
||||
# -isystem in the include dirs then hack it out.
|
||||
get_property(zmq_inc TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
|
||||
list(FIND zmq_inc "-isystem" broken_isystem)
|
||||
if(NOT broken_isystem EQUAL -1)
|
||||
list(REMOVE_AT zmq_inc ${broken_isystem})
|
||||
set_property(TARGET PkgConfig::libzmq PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${zmq_inc})
|
||||
endif()
|
||||
|
||||
target_link_libraries(oxenmq PUBLIC PkgConfig::libzmq)
|
||||
else()
|
||||
set(lokimq_build_static_libzmq ON)
|
||||
set(oxenmq_build_static_libzmq ON)
|
||||
endif()
|
||||
else()
|
||||
set(lokimq_build_static_libzmq ON)
|
||||
set(oxenmq_build_static_libzmq ON)
|
||||
endif()
|
||||
|
||||
if(lokimq_build_static_libzmq)
|
||||
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled 4.3.2")
|
||||
if(oxenmq_build_static_libzmq)
|
||||
message(STATUS "libzmq >= 4.3 not found or static build requested, building bundled version")
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/local-libzmq")
|
||||
include(LocalLibzmq)
|
||||
target_link_libraries(lokimq PUBLIC libzmq_vendor)
|
||||
target_link_libraries(oxenmq PUBLIC libzmq_vendor)
|
||||
endif()
|
||||
|
||||
target_include_directories(lokimq
|
||||
target_include_directories(oxenmq
|
||||
PUBLIC
|
||||
$<INSTALL_INTERFACE:>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/cppzmq>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/mapbox-variant/include>
|
||||
)
|
||||
|
||||
target_compile_options(lokimq PRIVATE -Wall -Wextra -Werror)
|
||||
set_target_properties(lokimq PROPERTIES
|
||||
CXX_STANDARD 14
|
||||
target_compile_options(oxenmq PRIVATE -Wall -Wextra)
|
||||
|
||||
option(WARNINGS_AS_ERRORS "treat all warnings as errors" ON)
|
||||
if(WARNINGS_AS_ERRORS)
|
||||
target_compile_options(oxenmq PRIVATE -Werror)
|
||||
endif()
|
||||
|
||||
set_target_properties(oxenmq PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD 17
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CXX_EXTENSIONS OFF
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
)
|
||||
|
||||
function(link_dep_libs target linktype libdirs)
|
||||
foreach(lib ${ARGN})
|
||||
find_library(link_lib-${lib} NAMES ${lib} PATHS ${libdirs})
|
||||
message(STATUS "FIND ${lib} FOUND ${link_lib-${lib}}")
|
||||
if(link_lib-${lib})
|
||||
target_link_libraries(${target} ${linktype} ${link_lib-${lib}})
|
||||
endif()
|
||||
|
@ -92,87 +167,72 @@ endfunction()
|
|||
# If the caller has already set up a sodium target then we will just link to it, otherwise we go
|
||||
# looking for it.
|
||||
if(TARGET sodium)
|
||||
target_link_libraries(lokimq PRIVATE sodium)
|
||||
if(lokimq_build_static_libzmq)
|
||||
target_link_libraries(oxenmq PUBLIC sodium)
|
||||
if(oxenmq_build_static_libzmq)
|
||||
target_link_libraries(libzmq_vendor INTERFACE sodium)
|
||||
endif()
|
||||
else()
|
||||
include(FindPkgConfig)
|
||||
pkg_check_modules(sodium REQUIRED libsodium IMPORTED_TARGET)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_libraries(lokimq PRIVATE PkgConfig::sodium)
|
||||
if(lokimq_build_static_libzmq)
|
||||
target_link_libraries(oxenmq PUBLIC PkgConfig::sodium)
|
||||
if(oxenmq_build_static_libzmq)
|
||||
target_link_libraries(libzmq_vendor INTERFACE PkgConfig::sodium)
|
||||
endif()
|
||||
else()
|
||||
link_dep_libs(lokimq PRIVATE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_include_directories(lokimq PRIVATE ${sodium_STATIC_INCLUDE_DIRS})
|
||||
if(lokimq_build_static_libzmq)
|
||||
link_dep_libs(oxenmq PUBLIC "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_include_directories(oxenmq PUBLIC ${sodium_STATIC_INCLUDE_DIRS})
|
||||
if(oxenmq_build_static_libzmq)
|
||||
link_dep_libs(libzmq_vendor INTERFACE "${sodium_STATIC_LIBRARY_DIRS}" ${sodium_STATIC_LIBRARIES})
|
||||
target_link_libraries(libzmq_vendor INTERFACE ${sodium_STATIC_INCLUDE_DIRS})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(lokimq::lokimq ALIAS lokimq)
|
||||
add_library(oxenmq::oxenmq ALIAS oxenmq)
|
||||
|
||||
export(
|
||||
TARGETS lokimq
|
||||
NAMESPACE lokimq::
|
||||
FILE lokimqTargets.cmake
|
||||
)
|
||||
install(
|
||||
TARGETS lokimq
|
||||
EXPORT lokimqConfig
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
TARGETS oxenmq
|
||||
NAMESPACE oxenmq::
|
||||
FILE oxenmqTargets.cmake
|
||||
)
|
||||
|
||||
install(
|
||||
FILES lokimq/auth.h
|
||||
lokimq/batch.h
|
||||
lokimq/bt_serialize.h
|
||||
lokimq/connections.h
|
||||
lokimq/hex.h
|
||||
lokimq/lokimq.h
|
||||
lokimq/message.h
|
||||
lokimq/string_view.h
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lokimq/version.h
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
|
||||
)
|
||||
option(LOKIMQ_INSTALL_MAPBOX_VARIANT "Install mapbox-variant headers with lokimq/ headers" ON)
|
||||
if(LOKIMQ_INSTALL_MAPBOX_VARIANT)
|
||||
install(
|
||||
FILES mapbox-variant/include/mapbox/variant.hpp
|
||||
mapbox-variant/include/mapbox/variant_cast.hpp
|
||||
mapbox-variant/include/mapbox/variant_io.hpp
|
||||
mapbox-variant/include/mapbox/variant_visitor.hpp
|
||||
mapbox-variant/include/mapbox/recursive_wrapper.hpp
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq/mapbox
|
||||
)
|
||||
if(OXENMQ_INSTALL)
|
||||
install(
|
||||
TARGETS oxenmq
|
||||
EXPORT oxenmqConfig
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES oxenmq/address.h
|
||||
oxenmq/auth.h
|
||||
oxenmq/batch.h
|
||||
oxenmq/connections.h
|
||||
oxenmq/fmt.h
|
||||
oxenmq/message.h
|
||||
oxenmq/oxenmq.h
|
||||
oxenmq/pubsub.h
|
||||
${CMAKE_CURRENT_BINARY_DIR}/oxenmq/version.h
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq
|
||||
)
|
||||
|
||||
if(OXENMQ_INSTALL_CPPZMQ)
|
||||
install(
|
||||
FILES cppzmq/zmq.hpp
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/oxenmq
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/liboxenmq.pc
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
option(LOKIMQ_INSTALL_CPPZMQ "Install cppzmq header with lokimq/ headers" ON)
|
||||
if(LOKIMQ_INSTALL_CPPZMQ)
|
||||
install(
|
||||
FILES cppzmq/zmq.hpp
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/lokimq
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/liblokimq.pc
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig
|
||||
)
|
||||
|
||||
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
|
||||
set(lokimq_IS_TOPLEVEL_PROJECT TRUE)
|
||||
else()
|
||||
set(lokimq_IS_TOPLEVEL_PROJECT FALSE)
|
||||
endif()
|
||||
|
||||
option(LOKIMQ_BUILD_TESTS "Building and perform lokimq tests" ${lokimq_IS_TOPLEVEL_PROJECT})
|
||||
if(LOKIMQ_BUILD_TESTS)
|
||||
if(OXENMQ_BUILD_TESTS)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
|
|
83
README.md
83
README.md
|
@ -1,14 +1,15 @@
|
|||
# LokiMQ - zeromq-based message passing for Loki projects
|
||||
# OxenMQ - high-level zeromq-based message passing for network-based projects
|
||||
|
||||
This C++14 library contains an abstraction layer around ZeroMQ to support integration with Loki
|
||||
authentication, RPC, and message passing. It is designed to be usable as the underlying
|
||||
communication mechanism of SN-to-SN communication ("quorumnet"), the RPC interface used by wallets
|
||||
and local daemon commands, communication channels between lokid and auxiliary services (storage
|
||||
server, lokinet), and also provides a local multithreaded job scheduling within a process.
|
||||
This C++17 library contains an abstraction layer around ZeroMQ to provide a high-level interface to
|
||||
authentication, RPC, and message passing. It is used extensively within Oxen projects (hence the
|
||||
name) as the underlying communication mechanism of SN-to-SN communication ("quorumnet"), the RPC
|
||||
interface used by wallets and local daemon commands, communication channels between oxend and
|
||||
auxiliary services (storage server, lokinet), and also provides local multithreaded job scheduling
|
||||
within a process.
|
||||
|
||||
Messages channels can be encrypted (using x25519) or not -- however opening an encrypted channel
|
||||
requires knowing the server pubkey. All SN-to-SN traffic is encrypted, and other traffic can be
|
||||
encrypted as needed.
|
||||
requires knowing the server pubkey. Within Oxen, all SN-to-SN traffic is encrypted, and other
|
||||
traffic can be encrypted as needed.
|
||||
|
||||
This library makes minimal use of mutexes, and none in the hot paths of the code, instead mostly
|
||||
relying on ZMQ sockets for synchronization; for more information on this (and why this is generally
|
||||
|
@ -16,20 +17,20 @@ much better performing and more scalable) see the ZMQ guide documentation on the
|
|||
|
||||
## Basic message structure
|
||||
|
||||
LokiMQ messages come in two fundamental forms: "commands", consisting of a command named and
|
||||
OxenMQ messages come in two fundamental forms: "commands", consisting of a command named and
|
||||
optional arguments, and "requests", consisting of a request name, a request tag, and optional
|
||||
arguments.
|
||||
|
||||
All channels are capable of bidirectional communication, and multiple messages can be in transit in
|
||||
either direction at any time. LokiMQ sets up a "listener" and "client" connections, but these only
|
||||
either direction at any time. OxenMQ sets up a "listener" and "client" connections, but these only
|
||||
determine how connections are established: once established, commands can be issued by either party.
|
||||
|
||||
The command/request string is one of two types:
|
||||
|
||||
`category.command` - for commands/requests registered by the LokiMQ caller (e.g. lokid). Here
|
||||
`category.command` - for commands/requests registered by the OxenMQ caller (e.g. oxend). Here
|
||||
`category` must be at least one character not containing a `.` and `command` may be anything. These
|
||||
categories and commands are registered according to general function and authentication level (more
|
||||
on this below). For example, for lokid categories are:
|
||||
on this below). For example, for oxend categories are:
|
||||
|
||||
- `system` - is for RPC commands related to the system administration such as mining, getting
|
||||
sensitive statistics, accessing SN private keys, remote shutdown, etc.
|
||||
|
@ -42,14 +43,14 @@ on this below). For example, for lokid categories are:
|
|||
The difference between a request and a command is that a request includes an additional opaque tag
|
||||
value which is used to identify a reply. For example you could register a `general.backwards`
|
||||
request that takes a string that receives a reply containing that string reversed. When invoking
|
||||
the request via LokiMQ you provide a callback to be invoked when the reply arrives. On the wire
|
||||
the request via OxenMQ you provide a callback to be invoked when the reply arrives. On the wire
|
||||
this looks like:
|
||||
|
||||
<<< [general.backwards] [v71.&a] [hello world]
|
||||
>>> [REPLY] [v71.&a] [dlrow olleh]
|
||||
|
||||
where each [] denotes a message part and `v71.&a` is a unique randomly generated identifier handled
|
||||
by LokiMQ (both the invoker and the recipient code only see the `hello world`/`dlrow olleh` message
|
||||
by OxenMQ (both the invoker and the recipient code only see the `hello world`/`dlrow olleh` message
|
||||
parts).
|
||||
|
||||
In contrast, regular registered commands have no identifier or expected reply callback. For example
|
||||
|
@ -92,7 +93,7 @@ handled for you transparently.
|
|||
|
||||
## Command arguments
|
||||
|
||||
Optional command/request arguments are always strings on the wire. The LokiMQ-using developer is
|
||||
Optional command/request arguments are always strings on the wire. The OxenMQ-using developer is
|
||||
free to create whatever encoding she wants, and these can vary across commands. For example
|
||||
`wallet.tx` might be a request that returns a transaction in binary, while `wallet.tx_info` might
|
||||
return tx metadata in JSON, and `p2p.send_tx` might encode tx data and metadata in a bt-encoded
|
||||
|
@ -101,47 +102,49 @@ data string.
|
|||
No structure at all is imposed on message data to allow maximum flexibility; it is entirely up to
|
||||
the calling code to handle all encoding/decoding duties.
|
||||
|
||||
Internal commands passed between LokiMQ-managed threads use either plain strings or bt-encoded
|
||||
dictionaries. See `lokimq/bt_serialize.h` if you want a bt serializer/deserializer.
|
||||
Internal commands passed between OxenMQ-managed threads use either plain strings or bt-encoded
|
||||
dictionaries. See `oxenmq/bt_serialize.h` if you want a bt serializer/deserializer.
|
||||
|
||||
## Sending commands
|
||||
|
||||
Sending a command to a peer is done by using a connection ID, and generally falls into either a
|
||||
`send()` method or a `request()` method.
|
||||
|
||||
lmq.send(conn, "category.command", "some data");
|
||||
lmq.request(conn, "category.command", [](bool success, std::vector<std::string> data) {
|
||||
omq.send(conn, "category.command", "some data");
|
||||
omq.request(conn, "category.command", [](bool success, std::vector<std::string> data) {
|
||||
if (success) { std::cout << "Remote replied: " << data.at(0) << "\n"; } });
|
||||
|
||||
The connection ID generally has two possible values:
|
||||
|
||||
- a string containing a service node pubkey. In this mode LokiMQ will look for the given SN in
|
||||
- a string containing a service node pubkey. In this mode OxenMQ will look for the given SN in
|
||||
already-established connections, reusing a connection if one exists. If no connection already
|
||||
exists, a new connection to the given SN is attempted (this requires constructing the LokiMQ
|
||||
exists, a new connection to the given SN is attempted (this requires constructing the OxenMQ
|
||||
object with a callback to determine SN remote addresses).
|
||||
- a ConnectionID object, typically returned by the `connect_remote` method (although there are other
|
||||
places to get one, such as from the `Message` object passed to a command: see the following
|
||||
section).
|
||||
|
||||
```C++
|
||||
// Send to a service node, establishing a connection if necessary:
|
||||
std::string my_sn = ...; // 32-byte pubkey of a known SN
|
||||
lmq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
|
||||
omq.send(my_sn, "sn.explode", "{ \"seconds\": 30 }");
|
||||
|
||||
// Connect to a remote by address then send it something
|
||||
auto conn = lmq.connect_remote("tcp://127.0.0.1:4567",
|
||||
auto conn = omq.connect_remote("tcp://127.0.0.1:4567",
|
||||
[](ConnectionID c) { std::cout << "Connected!\n"; },
|
||||
[](ConnectionID c, string_view f) { std::cout << "Connect failed: " << f << "\n" });
|
||||
lmq.request(conn, "rpc.get_height", [](bool s, std::vector<std::string> d) {
|
||||
omq.request(conn, "rpc.get_height", [](bool s, std::vector<std::string> d) {
|
||||
if (s && d.size() == 1)
|
||||
std::cout << "Current height: " << d[0] << "\n";
|
||||
else
|
||||
std::cout << "Timeout fetching height!";
|
||||
});
|
||||
```
|
||||
|
||||
## Command invocation
|
||||
|
||||
The application registers categories and registers commands within these categories with callbacks.
|
||||
The callbacks are passed a LokiMQ::Message object from which the message (plus various connection
|
||||
The callbacks are passed a OxenMQ::Message object from which the message (plus various connection
|
||||
information) can be obtained. There is no structure imposed at all on the data passed in subsequent
|
||||
message parts: it is up to the command itself to deserialize however it wishes (e.g. JSON,
|
||||
bt-encoded, or any other encoding).
|
||||
|
@ -149,13 +152,13 @@ bt-encoded, or any other encoding).
|
|||
The Message object also provides methods for replying to the caller. Simple replies queue a reply
|
||||
if the client is still connected. Replies to service nodes can also be "strong" replies: when
|
||||
replying to a SN that has closed connection with a strong reply we will attempt to reestablish a
|
||||
connection to deliver the message. In order for this to work the LokiMQ caller must provide a
|
||||
connection to deliver the message. In order for this to work the OxenMQ caller must provide a
|
||||
lookup function to retrieve the remote address given a SN x25519 pubkey.
|
||||
|
||||
### Callbacks
|
||||
|
||||
Invoked command functions are always invoked with exactly one arguments: a non-const LokiMQ::Message
|
||||
reference from which the connection info, LokiMQ object, and message data can be obtained.
|
||||
Invoked command functions are always invoked with exactly one arguments: a non-const OxenMQ::Message
|
||||
reference from which the connection info, OxenMQ object, and message data can be obtained.
|
||||
|
||||
The Message object also contains a `ConnectionID` object as the public `conn` member; it is safe to
|
||||
take a copy of this and then use it later to send commands to this peer. (For example, a wallet
|
||||
|
@ -185,7 +188,7 @@ logins.
|
|||
Configuration defaults allows controlling the default access for an incoming connection based on its
|
||||
remote address. Typically this is used to allow connections from localhost (or a unix domain
|
||||
socket) to automatically be an Admin connection without requiring explicit authentication. This
|
||||
also allows configuration of how public connections should be treated: for example, a lokid running
|
||||
also allows configuration of how public connections should be treated: for example, an oxend running
|
||||
as a public RPC server would do so by granting Basic access to all incoming connections.
|
||||
|
||||
Explicit logins allow the daemon to specify username/passwords with mapping to Basic or Admin
|
||||
|
@ -194,7 +197,7 @@ authentication levels.
|
|||
Thus, for example, a daemon could be configured to be allow Basic remote access with authentication
|
||||
(i.e. requiring a username/password login given out to people who should be able to access).
|
||||
|
||||
For example, in lokid the categories described above have authentication levels of:
|
||||
For example, in oxend the categories described above have authentication levels of:
|
||||
|
||||
- `system` - Admin
|
||||
- `sn` - ServiceNode
|
||||
|
@ -203,7 +206,7 @@ For example, in lokid the categories described above have authentication levels
|
|||
|
||||
### Service Node authentication
|
||||
|
||||
In order to handle ServiceNode authentication, LokiMQ uses an Allow callback invoked during
|
||||
In order to handle ServiceNode authentication, OxenMQ uses an Allow callback invoked during
|
||||
connection to determine both whether to allow the connection, and to determine whether the incoming
|
||||
connection is an active service node.
|
||||
|
||||
|
@ -224,7 +227,7 @@ such aliases be used only temporarily for version transitions.
|
|||
|
||||
## Threads
|
||||
|
||||
LokiMQ operates a pool of worker threads to handle jobs. The simplest use just allocates new jobs
|
||||
OxenMQ operates a pool of worker threads to handle jobs. The simplest use just allocates new jobs
|
||||
to a free worker thread, and we have a "general threads" value to configure how many such threads
|
||||
are available.
|
||||
|
||||
|
@ -239,7 +242,7 @@ Note that these actual reserved threads are not exclusive: reserving M of N tota
|
|||
category simply ensures that no more than (N-M) threads are being used for other categories at any
|
||||
given time, but the actual jobs may run on any worker thread.
|
||||
|
||||
As mentioned above, LokiMQ tries to avoid exceeding the configured general threads value (G)
|
||||
As mentioned above, OxenMQ tries to avoid exceeding the configured general threads value (G)
|
||||
whenever possible: the only time we will dispatch a job to a worker thread when we have >= G threads
|
||||
already running is when a new command arrives, the category reserves M threads, and the thread pool
|
||||
is currently processing fewer than M jobs for that category.
|
||||
|
@ -275,7 +278,7 @@ when a command with reserve threads arrived.
|
|||
A common pattern is one where a single thread suddenly has some work that can be be parallelized.
|
||||
You could employ some blocking, locking, mutex + condition variable monstrosity, but you shouldn't.
|
||||
|
||||
Instead LokiMQ provides a mechanism for this by allowing you to submit a batch of jobs with a
|
||||
Instead OxenMQ provides a mechanism for this by allowing you to submit a batch of jobs with a
|
||||
completion callback. All jobs will be queued and, when the last one finishes, the finalization
|
||||
callback will be queued to continue with the task.
|
||||
|
||||
|
@ -300,7 +303,7 @@ double do_my_task(int input) {
|
|||
return 3.0 * input;
|
||||
}
|
||||
|
||||
void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
||||
void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
|
||||
double sum = 0;
|
||||
for (auto& r : results) {
|
||||
try {
|
||||
|
@ -321,7 +324,7 @@ void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
|||
void start_big_task() {
|
||||
size_t num_jobs = 32;
|
||||
|
||||
lokimq::Batch<double /*return type*/> batch;
|
||||
oxenmq::Batch<double /*return type*/> batch;
|
||||
batch.reserve(num_jobs);
|
||||
|
||||
for (size_t i = 0; i < num_jobs; i++)
|
||||
|
@ -329,7 +332,7 @@ void start_big_task() {
|
|||
|
||||
batch.completion(&continue_big_task);
|
||||
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
// ... to be continued in `continue_big_task` after all the jobs finish
|
||||
|
||||
// Can do other things here, but note that continue_big_task could run
|
||||
|
@ -339,12 +342,12 @@ void start_big_task() {
|
|||
|
||||
This code deliberately does not support blocking to wait for the tasks to finish: if you want such a
|
||||
poor design (which is a recipe for deadlocks: imagine jobs that queuing other jobs that can end up
|
||||
exhausting the worker threads with waiting jobs) then you can implement it yourself; LokiMQ isn't
|
||||
exhausting the worker threads with waiting jobs) then you can implement it yourself; OxenMQ isn't
|
||||
going to help you hurt yourself like that.
|
||||
|
||||
### Single-job queuing
|
||||
|
||||
As a shortcut there is a `lmq.job(...)` method that schedules a single task (with no return value)
|
||||
As a shortcut there is a `omq.job(...)` method that schedules a single task (with no return value)
|
||||
in the batch job queue. This is useful when some event requires triggering some other event, but
|
||||
you don't need to wait for or collect its result. (Internally this is just a convenience method
|
||||
around creating a single-job, no-completion Batch job).
|
||||
|
@ -356,7 +359,7 @@ either using your own thread or a periodic timer (see below) to shepherd those o
|
|||
|
||||
## Timers
|
||||
|
||||
LokiMQ supports scheduling periodic tasks via the `add_timer()` function. These timers have an
|
||||
OxenMQ supports scheduling periodic tasks via the `add_timer()` function. These timers have an
|
||||
interval and are scheduled as (single-job) batches when the timer fires. They also support
|
||||
"squelching" (enabled by default) that supresses the job being scheduled if a previously scheduled
|
||||
job is already scheduled or running.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
set(LIBZMQ_PREFIX ${CMAKE_BINARY_DIR}/libzmq)
|
||||
set(ZeroMQ_VERSION 4.3.2)
|
||||
set(ZeroMQ_VERSION 4.3.5)
|
||||
set(LIBZMQ_URL https://github.com/zeromq/libzmq/releases/download/v${ZeroMQ_VERSION}/zeromq-${ZeroMQ_VERSION}.tar.gz)
|
||||
set(LIBZMQ_HASH SHA512=b6251641e884181db9e6b0b705cced7ea4038d404bdae812ff47bdd0eed12510b6af6846b85cb96898e253ccbac71eca7fe588673300ddb9c3109c973250c8e4)
|
||||
set(LIBZMQ_HASH SHA512=a71d48aa977ad8941c1609947d8db2679fc7a951e4cd0c3a1127ae026d883c11bd4203cf315de87f95f5031aec459a731aec34e5ce5b667b8d0559b157952541)
|
||||
|
||||
message(${LIBZMQ_URL})
|
||||
|
||||
|
@ -13,19 +13,30 @@ endif()
|
|||
|
||||
file(MAKE_DIRECTORY ${LIBZMQ_PREFIX}/include)
|
||||
|
||||
set(libzmq_compiler_args)
|
||||
foreach(lang C CXX)
|
||||
foreach(thing COMPILER FLAGS COMPILER_LAUNCHER)
|
||||
if(DEFINED CMAKE_${lang}_${thing})
|
||||
list(APPEND libzmq_compiler_args "-DCMAKE_${lang}_${thing}=${CMAKE_${lang}_${thing}}")
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
include(ExternalProject)
|
||||
include(ProcessorCount)
|
||||
ExternalProject_Add(libzmq_external
|
||||
PREFIX ${LIBZMQ_PREFIX}
|
||||
URL ${LIBZMQ_URL}
|
||||
URL_HASH ${LIBZMQ_HASH}
|
||||
CMAKE_ARGS -DWITH_LIBSODIUM=ON -DZMQ_BUILD_TESTS=OFF -DWITH_PERF_TOOL=OFF -DENABLE_DRAFTS=OFF
|
||||
CMAKE_ARGS ${libzmq_compiler_args}
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DWITH_LIBSODIUM=ON -DZMQ_BUILD_TESTS=OFF -DWITH_PERF_TOOL=OFF -DENABLE_DRAFTS=OFF
|
||||
-DBUILD_SHARED=OFF -DBUILD_STATIC=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX=${LIBZMQ_PREFIX}
|
||||
BUILD_BYPRODUCTS ${LIBZMQ_PREFIX}/lib/libzmq.a
|
||||
BUILD_BYPRODUCTS ${LIBZMQ_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libzmq.a
|
||||
)
|
||||
|
||||
add_library(libzmq_vendor STATIC IMPORTED GLOBAL)
|
||||
add_dependencies(libzmq_vendor libzmq_external)
|
||||
set_target_properties(libzmq_vendor PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES ${LIBZMQ_PREFIX}/include
|
||||
IMPORTED_LOCATION ${LIBZMQ_PREFIX}/lib/libzmq.a)
|
||||
IMPORTED_LOCATION ${LIBZMQ_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libzmq.a)
|
||||
|
|
BIN
contrib/deb.oxen.io.gpg
Normal file
BIN
contrib/deb.oxen.io.gpg
Normal file
Binary file not shown.
2
cppzmq
2
cppzmq
|
@ -1 +1 @@
|
|||
Subproject commit 8d5c9a88988dcbebb72939ca0939d432230ffde1
|
||||
Subproject commit 76bf169fd67b8e99c1b0e6490029d9cd5ef97666
|
|
@ -3,11 +3,12 @@ exec_prefix=${prefix}
|
|||
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
|
||||
includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
|
||||
|
||||
Name: liblokimq
|
||||
Description: ZeroMQ-based communication library for Loki
|
||||
Version: @LOKIMQ_VERSION@
|
||||
Name: liboxenmq
|
||||
Description: ZeroMQ-based communication library
|
||||
Version: @PROJECT_VERSION@
|
||||
|
||||
Libs: -L${libdir} -llokimq
|
||||
Libs: -L${libdir} -loxenmq
|
||||
Libs.private: @PRIVATE_LIBS@
|
||||
Requires: liboxenc
|
||||
Requires.private: libzmq libsodium
|
||||
Cflags: -I${includedir}
|
187
lokimq/auth.cpp
187
lokimq/auth.cpp
|
@ -1,187 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "hex.h"
|
||||
#include "lokimq-internal.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, AuthLevel a) {
|
||||
return o << to_string(a);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Builds a ZMTP metadata key-value pair. These will be available on every message from that peer.
|
||||
// Keys must start with X- and be <= 255 characters.
|
||||
std::string zmtp_metadata(string_view key, string_view value) {
|
||||
assert(key.size() > 2 && key.size() <= 255 && key[0] == 'X' && key[1] == '-');
|
||||
|
||||
std::string result;
|
||||
result.reserve(1 + key.size() + 4 + value.size());
|
||||
result += static_cast<char>(key.size()); // Size octet of key
|
||||
result.append(&key[0], key.size()); // key data
|
||||
for (int i = 24; i >= 0; i -= 8) // 4-byte size of value in network order
|
||||
result += static_cast<char>((value.size() >> i) & 0xff);
|
||||
result.append(&value[0], value.size()); // value data
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
|
||||
const std::string& command, const category& cat, zmq::message_t& msg) {
|
||||
std::string reply;
|
||||
if (peer.auth_level < cat.access.auth) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
|
||||
": peer auth level ", peer.auth_level, " < ", cat.access.auth);
|
||||
reply = "FORBIDDEN";
|
||||
}
|
||||
else if (cat.access.local_sn && !local_service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
|
||||
": that command is only available when this LokiMQ is running in service node mode");
|
||||
reply = "NOT_A_SERVICE_NODE";
|
||||
}
|
||||
else if (cat.access.remote_sn && !peer.service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
|
||||
": remote is not recognized as a service node");
|
||||
// Disconnect: we don't think the remote is a SN, but it issued a command only SNs should be
|
||||
// issuing. Drop the connection; if the remote has something important to relay it will
|
||||
// reconnect, at which point we will reassess the SN status on the new incoming connection.
|
||||
if (outgoing)
|
||||
proxy_disconnect(peer.service_node ? ConnectionID{peer.pubkey} : conn_index_to_id[conn_index], 1s);
|
||||
else
|
||||
send_routed_message(connections[conn_index], peer.route, "BYE");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (reply.empty())
|
||||
return true;
|
||||
|
||||
if (outgoing)
|
||||
send_direct_message(connections[conn_index], std::move(reply), command);
|
||||
else
|
||||
send_routed_message(connections[conn_index], peer.route, std::move(reply), command);
|
||||
return false;
|
||||
}
|
||||
|
||||
void LokiMQ::process_zap_requests() {
|
||||
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
|
||||
#ifndef NDEBUG
|
||||
if (log_level() >= LogLevel::trace) {
|
||||
std::ostringstream o;
|
||||
o << "Processing ZAP authentication request:";
|
||||
for (size_t i = 0; i < frames.size(); i++) {
|
||||
o << "\n[" << i << "]: ";
|
||||
auto v = view(frames[i]);
|
||||
if (i == 1 || i == 6)
|
||||
o << to_hex(v);
|
||||
else
|
||||
o << v;
|
||||
}
|
||||
log_(LogLevel::trace, __FILE__, __LINE__, o.str());
|
||||
} else
|
||||
#endif
|
||||
LMQ_LOG(debug, "Processing ZAP authentication request");
|
||||
|
||||
// https://rfc.zeromq.org/spec:27/ZAP/
|
||||
//
|
||||
// The request message SHALL consist of the following message frames:
|
||||
//
|
||||
// The version frame, which SHALL contain the three octets "1.0".
|
||||
// The request id, which MAY contain an opaque binary blob.
|
||||
// The domain, which SHALL contain a (non-empty) string.
|
||||
// The address, the origin network IP address.
|
||||
// The identity, the connection Identity, if any.
|
||||
// The mechanism, which SHALL contain a string.
|
||||
// The credentials, which SHALL be zero or more opaque frames.
|
||||
//
|
||||
// The reply message SHALL consist of the following message frames:
|
||||
//
|
||||
// The version frame, which SHALL contain the three octets "1.0".
|
||||
// The request id, which MAY contain an opaque binary blob.
|
||||
// The status code, which SHALL contain a string.
|
||||
// The status text, which MAY contain a string.
|
||||
// The user id, which SHALL contain a string.
|
||||
// The metadata, which MAY contain a blob.
|
||||
//
|
||||
// (NB: there are also null address delimiters at the beginning of each mentioned in the
|
||||
// RFC, but those have already been removed through the use of a REP socket)
|
||||
|
||||
std::vector<std::string> response_vals(6);
|
||||
response_vals[0] = "1.0"; // version
|
||||
if (frames.size() >= 2)
|
||||
response_vals[1] = std::string{view(frames[1])}; // unique identifier
|
||||
std::string &status_code = response_vals[2], &status_text = response_vals[3];
|
||||
|
||||
if (frames.size() < 6 || view(frames[0]) != "1.0") {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: version != 1.0 or invalid ZAP message parts");
|
||||
status_code = "500";
|
||||
status_text = "Internal error: invalid auth request";
|
||||
} else {
|
||||
auto auth_domain = view(frames[2]);
|
||||
size_t bind_id = (size_t) -1;
|
||||
try {
|
||||
bind_id = bt_deserialize<size_t>(view(frames[2]));
|
||||
} catch (...) {}
|
||||
|
||||
if (bind_id >= bind.size()) {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
|
||||
status_code = "400";
|
||||
status_text = "Unknown authentication domain: " + std::string{auth_domain};
|
||||
} else if (bind[bind_id].second.curve
|
||||
? !(frames.size() == 7 && view(frames[5]) == "CURVE")
|
||||
: !(frames.size() == 6 && view(frames[5]) == "NULL")) {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: invalid ",
|
||||
bind[bind_id].second.curve ? "CURVE" : "NULL", " authentication request");
|
||||
status_code = "500";
|
||||
status_text = "Invalid authentication request mechanism";
|
||||
} else if (bind[bind_id].second.curve && frames[6].size() != 32) {
|
||||
LMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
|
||||
status_code = "500";
|
||||
status_text = "Invalid public key size for CURVE authentication";
|
||||
} else {
|
||||
auto ip = view(frames[3]);
|
||||
string_view pubkey;
|
||||
if (bind[bind_id].second.curve)
|
||||
pubkey = view(frames[6]);
|
||||
auto result = bind[bind_id].second.allow(ip, pubkey);
|
||||
bool sn = result.remote_sn;
|
||||
auto& user_id = response_vals[4];
|
||||
if (bind[bind_id].second.curve) {
|
||||
user_id.reserve(64);
|
||||
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
|
||||
}
|
||||
|
||||
if (result.auth <= AuthLevel::denied || result.auth > AuthLevel::admin) {
|
||||
LMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
|
||||
" with initial auth level ", result.auth);
|
||||
status_code = "400";
|
||||
status_text = "Access denied";
|
||||
user_id.clear();
|
||||
} else {
|
||||
LMQ_LOG(info, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection with authentication level ", result.auth,
|
||||
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
|
||||
|
||||
auto& metadata = response_vals[5];
|
||||
metadata += zmtp_metadata("X-SN", result.remote_sn ? "1" : "0");
|
||||
metadata += zmtp_metadata("X-AuthLevel", to_string(result.auth));
|
||||
|
||||
status_code = "200";
|
||||
status_text = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("ZAP request result: ", status_code, " ", status_text);
|
||||
|
||||
std::vector<zmq::message_t> response;
|
||||
response.reserve(response_vals.size());
|
||||
for (auto &r : response_vals) response.push_back(create_message(std::move(r)));
|
||||
send_message_parts(zap_auth, response.begin(), response.end());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
#pragma once
|
||||
#include <iostream>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
/// Authentication levels for command categories and connections
|
||||
enum class AuthLevel {
|
||||
denied, ///< Not actually an auth level, but can be returned by the AllowFunc to deny an incoming connection.
|
||||
none, ///< No authentication at all; any random incoming ZMQ connection can invoke this command.
|
||||
basic, ///< Basic authentication commands require a login, or a node that is specifically configured to be a public node (e.g. for public RPC).
|
||||
admin, ///< Advanced authentication commands require an admin user, either via explicit login or by implicit login from localhost. This typically protects administrative commands like shutting down, starting mining, or access sensitive data.
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, AuthLevel a);
|
||||
|
||||
/// The access level for a command category
|
||||
struct Access {
|
||||
/// Minimum access level required
|
||||
AuthLevel auth = AuthLevel::none;
|
||||
/// If true only remote SNs may call the category commands
|
||||
bool remote_sn = false;
|
||||
/// If true the category requires that the local node is a SN
|
||||
bool local_sn = false;
|
||||
};
|
||||
|
||||
/// Return type of the AllowFunc: this determines whether we allow the connection at all, and if so,
|
||||
/// sets the initial authentication level and tells LokiMQ whether the other end is an active SN.
|
||||
struct Allow {
|
||||
AuthLevel auth = AuthLevel::none;
|
||||
bool remote_sn = false;
|
||||
};
|
||||
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
#include "batch.h"
|
||||
#include "lokimq-internal.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
|
||||
}
|
|
@ -1,233 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include "bt_serialize.h"
|
||||
|
||||
namespace lokimq {
|
||||
namespace detail {
|
||||
|
||||
/// Reads digits into an unsigned 64-bit int.
|
||||
uint64_t extract_unsigned(string_view& s) {
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid{"Expected 0-9 but found end of string"};
|
||||
if (s[0] < '0' || s[0] > '9')
|
||||
throw bt_deserialize_invalid("Expected 0-9 but found '"s + s[0]);
|
||||
uint64_t uval = 0;
|
||||
while (!s.empty() && (s[0] >= '0' && s[0] <= '9')) {
|
||||
uint64_t bigger = uval * 10 + (s[0] - '0');
|
||||
s.remove_prefix(1);
|
||||
if (bigger < uval) // overflow
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: value is too large for a 64-bit int");
|
||||
uval = bigger;
|
||||
}
|
||||
return uval;
|
||||
}
|
||||
|
||||
void bt_deserialize<string_view>::operator()(string_view& s, string_view& val) {
|
||||
if (s.size() < 2) throw bt_deserialize_invalid{"Deserialize failed: given data is not an bt-encoded string"};
|
||||
if (s[0] < '0' || s[0] > '9')
|
||||
throw bt_deserialize_invalid_type{"Expected 0-9 but found '"s + s[0] + "'"};
|
||||
auto len = static_cast<size_t>(extract_unsigned(s));
|
||||
if (s.empty() || s[0] != ':')
|
||||
throw bt_deserialize_invalid{"Did not find expected ':' during string deserialization"};
|
||||
s.remove_prefix(1);
|
||||
|
||||
if (len > s.size())
|
||||
throw bt_deserialize_invalid{"String deserialization failed: encoded string length is longer than the serialized data"};
|
||||
|
||||
val = {s.data(), len};
|
||||
s.remove_prefix(len);
|
||||
}
|
||||
|
||||
// Check that we are on a 2's complement architecture. It's highly unlikely that this code ever
|
||||
// runs on a non-2s-complement architecture (especially since C++20 requires a two's complement
|
||||
// signed value behaviour), but check at compile time anyway because we rely on these relations
|
||||
// below.
|
||||
static_assert(std::numeric_limits<int64_t>::min() + std::numeric_limits<int64_t>::max() == -1 &&
|
||||
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()) + uint64_t{1} == (uint64_t{1} << 63),
|
||||
"Non 2s-complement architecture not supported!");
|
||||
|
||||
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s) {
|
||||
// Smallest possible encoded integer is 3 chars: "i0e"
|
||||
if (s.size() < 3) throw bt_deserialize_invalid("Deserialization failed: end of string found where integer expected");
|
||||
if (s[0] != 'i') throw bt_deserialize_invalid_type("Deserialization failed: expected 'i', found '"s + s[0] + '\'');
|
||||
s.remove_prefix(1);
|
||||
std::pair<maybe_signed_int64_t, bool> result;
|
||||
if (s[0] == '-') {
|
||||
result.second = true;
|
||||
s.remove_prefix(1);
|
||||
}
|
||||
|
||||
uint64_t uval = extract_unsigned(s);
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: encountered end of string before integer was finished");
|
||||
if (s[0] != 'e')
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: expected digit or 'e', found '"s + s[0] + '\'');
|
||||
s.remove_prefix(1);
|
||||
if (result.second) { // negative
|
||||
if (uval > (uint64_t{1} << 63))
|
||||
throw bt_deserialize_invalid("Deserialization of integer failed: negative integer value is too large for a 64-bit signed int");
|
||||
result.first.i64 = -uval;
|
||||
} else {
|
||||
result.first.u64 = uval;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template struct bt_deserialize<int64_t>;
|
||||
template struct bt_deserialize<uint64_t>;
|
||||
|
||||
void bt_deserialize<bt_value, void>::operator()(string_view& s, bt_value& val) {
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where bt-encoded value expected");
|
||||
|
||||
switch (s[0]) {
|
||||
case 'd': {
|
||||
bt_dict dict;
|
||||
bt_deserialize<bt_dict>{}(s, dict);
|
||||
val = std::move(dict);
|
||||
break;
|
||||
}
|
||||
case 'l': {
|
||||
bt_list list;
|
||||
bt_deserialize<bt_list>{}(s, list);
|
||||
val = std::move(list);
|
||||
break;
|
||||
}
|
||||
case 'i': {
|
||||
auto read = bt_deserialize_integer(s);
|
||||
val = read.first.i64; // We only store an i64, but can get a u64 out of it via get<uint64_t>(val)
|
||||
break;
|
||||
}
|
||||
case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': {
|
||||
std::string str;
|
||||
bt_deserialize<std::string>{}(s, str);
|
||||
val = std::move(str);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw bt_deserialize_invalid("Deserialize failed: encountered invalid value '"s + s[0] + "'; expected one of [0-9idl]");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
||||
bt_list_consumer::bt_list_consumer(string_view data_) : data{std::move(data_)} {
|
||||
if (data.empty()) throw std::runtime_error{"Cannot create a bt_list_consumer with an empty string_view"};
|
||||
if (data[0] != 'l') throw std::runtime_error{"Cannot create a bt_list_consumer with non-list data"};
|
||||
data.remove_prefix(1);
|
||||
}
|
||||
|
||||
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
|
||||
/// value is not a string.
|
||||
string_view bt_list_consumer::consume_string_view() {
|
||||
if (data.empty())
|
||||
throw bt_deserialize_invalid{"expected a string, but reached end of data"};
|
||||
else if (!is_string())
|
||||
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
|
||||
string_view next{data}, result;
|
||||
detail::bt_deserialize<string_view>{}(next, result);
|
||||
data = next;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string bt_list_consumer::consume_string() {
|
||||
return std::string{consume_string_view()};
|
||||
}
|
||||
|
||||
/// Consumes a value without returning it.
|
||||
void bt_list_consumer::skip_value() {
|
||||
if (is_string())
|
||||
consume_string_view();
|
||||
else if (is_integer())
|
||||
detail::bt_deserialize_integer(data);
|
||||
else if (is_list())
|
||||
consume_list_data();
|
||||
else if (is_dict())
|
||||
consume_dict_data();
|
||||
else
|
||||
throw bt_deserialize_invalid_type{"next bt value has unknown type"};
|
||||
}
|
||||
|
||||
string_view bt_list_consumer::consume_list_data() {
|
||||
auto start = data.begin();
|
||||
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
data.remove_prefix(1); // Descend into the sublist, consume the "l"
|
||||
while (!is_finished()) {
|
||||
skip_value();
|
||||
if (data.empty())
|
||||
throw bt_deserialize_invalid{"bt list consumption failed: hit the end of string before the list was done"};
|
||||
}
|
||||
data.remove_prefix(1); // Back out from the sublist, consume the "e"
|
||||
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
|
||||
}
|
||||
|
||||
string_view bt_list_consumer::consume_dict_data() {
|
||||
auto start = data.begin();
|
||||
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
data.remove_prefix(1); // Descent into the dict, consumer the "d"
|
||||
while (!is_finished()) {
|
||||
consume_string_view(); // Key is always a string
|
||||
if (!data.empty())
|
||||
skip_value();
|
||||
if (data.empty())
|
||||
throw bt_deserialize_invalid{"bt dict consumption failed: hit the end of string before the dict was done"};
|
||||
}
|
||||
data.remove_prefix(1); // Back out of the dict, consume the "e"
|
||||
return {start, static_cast<size_t>(std::distance(start, data.begin()))};
|
||||
}
|
||||
|
||||
bt_dict_consumer::bt_dict_consumer(string_view data_) {
|
||||
data = std::move(data_);
|
||||
if (data.empty()) throw std::runtime_error{"Cannot create a bt_dict_consumer with an empty string_view"};
|
||||
if (data.size() < 2 || data[0] != 'd') throw std::runtime_error{"Cannot create a bt_dict_consumer with non-dict data"};
|
||||
data.remove_prefix(1);
|
||||
}
|
||||
|
||||
bool bt_dict_consumer::consume_key() {
|
||||
if (key_.data())
|
||||
return true;
|
||||
if (data.empty()) throw bt_deserialize_invalid_type{"expected a key or dict end, found end of string"};
|
||||
if (data[0] == 'e') return false;
|
||||
key_ = bt_list_consumer::consume_string_view();
|
||||
if (data.empty() || data[0] == 'e')
|
||||
throw bt_deserialize_invalid{"dict key isn't followed by a value"};
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<string_view, string_view> bt_dict_consumer::next_string() {
|
||||
if (!is_string())
|
||||
throw bt_deserialize_invalid_type{"expected a string, but found "s + data.front()};
|
||||
std::pair<string_view, string_view> ret;
|
||||
ret.second = bt_list_consumer::consume_string_view();
|
||||
ret.first = flush_key();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
} // namespace lokimq
|
|
@ -1,829 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <cstring>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include "string_view.h"
|
||||
#include "mapbox/variant.hpp"
|
||||
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
/** \file
|
||||
* LokiMQ serialization for internal commands is very simple: we support two primitive types,
|
||||
* strings and integers, and two container types, lists and dicts with string keys. On the wire
|
||||
* these go in BitTorrent byte encoding as described in BEP-0003
|
||||
* (https://www.bittorrent.org/beps/bep_0003.html#bencoding).
|
||||
*
|
||||
* On the C++ side, on input we allow strings, integral types, STL-like containers of these types,
|
||||
* and STL-like containers of pairs with a string first value and any of these types as second
|
||||
* value. We also accept std::variants (if compiled with std::variant support, i.e. in C++17 mode)
|
||||
* that contain any of these, and mapbox::util::variants (the internal type used for its recursive
|
||||
* support).
|
||||
*
|
||||
* One minor deviation from BEP-0003 is that we don't support serializing values that don't fit in a
|
||||
* 64-bit integer (BEP-0003 specifies arbitrary precision integers).
|
||||
*
|
||||
* On deserialization we can either deserialize into a mapbox::util::variant that supports everything, or
|
||||
* we can fill a container of your given type (though this fails if the container isn't compatible
|
||||
* with the deserialized data).
|
||||
*/
|
||||
|
||||
/// Exception throw if deserialization fails
|
||||
class bt_deserialize_invalid : public std::invalid_argument {
|
||||
using std::invalid_argument::invalid_argument;
|
||||
};
|
||||
|
||||
/// A more specific subclass that is thown if the serialization type is an initial mismatch: for
|
||||
/// example, trying deserializing an int but the next thing in input is a list. This is not,
|
||||
/// however, thrown if the type initially looks fine but, say, a nested serialization fails. This
|
||||
/// error will only be thrown when the input stream has not been advanced (and so can be tried for a
|
||||
/// different type).
|
||||
class bt_deserialize_invalid_type : public bt_deserialize_invalid {
|
||||
using bt_deserialize_invalid::bt_deserialize_invalid;
|
||||
};
|
||||
|
||||
class bt_list;
|
||||
class bt_dict;
|
||||
|
||||
/// Recursive generic type that can fully represent everything valid for a BT serialization.
|
||||
using bt_value = mapbox::util::variant<
|
||||
std::string,
|
||||
string_view,
|
||||
int64_t,
|
||||
mapbox::util::recursive_wrapper<bt_list>,
|
||||
mapbox::util::recursive_wrapper<bt_dict>
|
||||
>;
|
||||
|
||||
/// Very thin wrapper around a std::list<bt_value> that holds a list of generic values (though *any*
|
||||
/// compatible data type can be used).
|
||||
class bt_list : public std::list<bt_value> {
|
||||
using std::list<bt_value>::list;
|
||||
};
|
||||
/// Very thin wrapper around a std::unordered_map<bt_value> that holds a list of string -> generic
|
||||
/// value pairs (though *any* compatible data type can be used).
|
||||
class bt_dict : public std::unordered_map<std::string, bt_value> {
|
||||
using std::unordered_map<std::string, bt_value>::unordered_map;
|
||||
};
|
||||
|
||||
#ifdef __cpp_lib_void_t
|
||||
using std::void_t;
|
||||
#else
|
||||
/// C++17 void_t backport
|
||||
template <typename... Ts> struct void_t_impl { using type = void; };
|
||||
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Reads digits into an unsigned 64-bit int.
|
||||
uint64_t extract_unsigned(string_view& s);
|
||||
inline uint64_t extract_unsigned(string_view&& s) { return extract_unsigned(s); }
|
||||
|
||||
// Fallback base case; we only get here if none of the partial specializations below work
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct bt_serialize { static_assert(!std::is_same<T, T>::value, "Cannot serialize T: unsupported type for bt serialization"); };
|
||||
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct bt_deserialize { static_assert(!std::is_same<T, T>::value, "Cannot deserialize T: unsupported type for bt deserialization"); };
|
||||
|
||||
/// Checks that we aren't at the end of a string view and throws if we are.
|
||||
inline void bt_need_more(const string_view &s) {
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid{"Unexpected end of string while deserializing"};
|
||||
}
|
||||
|
||||
union maybe_signed_int64_t { int64_t i64; uint64_t u64; };
|
||||
|
||||
/// Deserializes a signed or unsigned 64-bit integer from a string. Sets the second bool to true
|
||||
/// iff the value is int64_t because a negative value was read. Throws an exception if the read
|
||||
/// value doesn't fit in a int64_t (if negative) or a uint64_t (if positive). Removes consumed
|
||||
/// characters from the string_view.
|
||||
std::pair<maybe_signed_int64_t, bool> bt_deserialize_integer(string_view& s);
|
||||
|
||||
/// Integer specializations
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<std::is_integral<T>::value>> {
|
||||
static_assert(sizeof(T) <= sizeof(uint64_t), "Serialization of integers larger than uint64_t is not supported");
|
||||
void operator()(std::ostream &os, const T &val) {
|
||||
// Cast 1-byte types to a larger type to avoid iostream interpreting them as single characters
|
||||
using output_type = std::conditional_t<(sizeof(T) > 1), T, std::conditional_t<std::is_signed<T>::value, int, unsigned>>;
|
||||
os << 'i' << static_cast<output_type>(val) << 'e';
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<std::is_integral<T>::value>> {
|
||||
void operator()(string_view& s, T &val) {
|
||||
constexpr uint64_t umax = static_cast<uint64_t>(std::numeric_limits<T>::max());
|
||||
constexpr int64_t smin = static_cast<int64_t>(std::numeric_limits<T>::min()),
|
||||
smax = static_cast<int64_t>(std::numeric_limits<T>::max());
|
||||
|
||||
auto read = bt_deserialize_integer(s);
|
||||
if (std::is_signed<T>::value) {
|
||||
if (!read.second) { // read a positive value
|
||||
if (read.first.u64 > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(read.first.u64);
|
||||
} else {
|
||||
bool oob = read.first.i64 < smin || read.first.i64 > smax;
|
||||
if (sizeof(T) < sizeof(int64_t) && oob)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found out-of-range value " + std::to_string(read.first.i64) + " not in [" + std::to_string(smin) + "," + std::to_string(smax) + "]");
|
||||
val = static_cast<T>(read.first.i64);
|
||||
}
|
||||
} else {
|
||||
if (read.second)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found negative value " + std::to_string(read.first.i64) + " but type is unsigned");
|
||||
if (sizeof(T) < sizeof(uint64_t) && read.first.u64 > umax)
|
||||
throw bt_deserialize_invalid("Integer deserialization failed: found too-large value " + std::to_string(read.first.u64) + " > " + std::to_string(umax));
|
||||
val = static_cast<T>(read.first.u64);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
extern template struct bt_deserialize<int64_t>;
|
||||
extern template struct bt_deserialize<uint64_t>;
|
||||
|
||||
template <>
|
||||
struct bt_serialize<string_view> {
|
||||
void operator()(std::ostream &os, const string_view &val) { os << val.size(); os.put(':'); os.write(val.data(), val.size()); }
|
||||
};
|
||||
template <>
|
||||
struct bt_deserialize<string_view> {
|
||||
void operator()(string_view& s, string_view& val);
|
||||
};
|
||||
|
||||
/// String specialization
|
||||
template <>
|
||||
struct bt_serialize<std::string> {
|
||||
void operator()(std::ostream &os, const std::string &val) { bt_serialize<string_view>{}(os, val); }
|
||||
};
|
||||
template <>
|
||||
struct bt_deserialize<std::string> {
|
||||
void operator()(string_view& s, std::string& val) { string_view view; bt_deserialize<string_view>{}(s, view); val = {view.data(), view.size()}; }
|
||||
};
|
||||
|
||||
/// char * and string literals -- we allow serialization for convenience, but not deserialization
|
||||
template <>
|
||||
struct bt_serialize<char *> {
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, std::strlen(str)}); }
|
||||
};
|
||||
template <size_t N>
|
||||
struct bt_serialize<char[N]> {
|
||||
void operator()(std::ostream &os, const char *str) { bt_serialize<string_view>{}(os, {str, N-1}); }
|
||||
};
|
||||
|
||||
/// Partial dict validity; we don't check the second type for serializability, that will be handled
|
||||
/// via the base case static_assert if invalid.
|
||||
template <typename T, typename = void> struct is_bt_input_dict_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_input_dict_container<T, std::enable_if_t<
|
||||
std::is_same<std::string, std::remove_cv_t<typename T::value_type::first_type>>::value,
|
||||
void_t<typename T::const_iterator /* is const iterable */,
|
||||
typename T::value_type::second_type /* has a second type */>>>
|
||||
: std::true_type {};
|
||||
|
||||
/// Determines whether the type looks like something we can insert into (using `v.insert(v.end(), x)`)
|
||||
template <typename T, typename = void> struct is_bt_insertable : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_insertable<T,
|
||||
void_t<decltype(std::declval<T>().insert(std::declval<T>().end(), std::declval<typename T::value_type>()))>>
|
||||
: std::true_type {};
|
||||
|
||||
/// Determines whether the given type looks like a compatible map (i.e. has std::string keys) that
|
||||
/// we can insert into.
|
||||
template <typename T, typename = void> struct is_bt_output_dict_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_output_dict_container<T, std::enable_if_t<
|
||||
std::is_same<std::string, std::remove_cv_t<typename T::key_type>>::value &&
|
||||
is_bt_insertable<T>::value,
|
||||
void_t<typename T::value_type::second_type /* has a second type */>>>
|
||||
: std::true_type {};
|
||||
|
||||
|
||||
/// Specialization for a dict-like container (such as an unordered_map). We accept anything for a
|
||||
/// dict that is const iterable over something that looks like a pair with std::string for first
|
||||
/// value type. The value (i.e. second element of the pair) also must be serializable.
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_dict_container<T>::value>> {
|
||||
using second_type = typename T::value_type::second_type;
|
||||
using ref_pair = std::reference_wrapper<const typename T::value_type>;
|
||||
void operator()(std::ostream &os, const T &dict) {
|
||||
os << 'd';
|
||||
std::vector<ref_pair> pairs;
|
||||
pairs.reserve(dict.size());
|
||||
for (const auto &pair : dict)
|
||||
pairs.emplace(pairs.end(), pair);
|
||||
std::sort(pairs.begin(), pairs.end(), [](ref_pair a, ref_pair b) { return a.get().first < b.get().first; });
|
||||
for (auto &ref : pairs) {
|
||||
bt_serialize<std::string>{}(os, ref.get().first);
|
||||
bt_serialize<second_type>{}(os, ref.get().second);
|
||||
}
|
||||
os << 'e';
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_dict_container<T>::value>> {
|
||||
using second_type = typename T::value_type::second_type;
|
||||
void operator()(string_view& s, T& dict) {
|
||||
// Smallest dict is 2 bytes "de", for an empty dict.
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where dict expected");
|
||||
if (s[0] != 'd') throw bt_deserialize_invalid_type("Deserialization failed: expected 'd', found '"s + s[0] + "'"s);
|
||||
s.remove_prefix(1);
|
||||
dict.clear();
|
||||
bt_deserialize<std::string> key_deserializer;
|
||||
bt_deserialize<second_type> val_deserializer;
|
||||
|
||||
while (!s.empty() && s[0] != 'e') {
|
||||
std::string key;
|
||||
second_type val;
|
||||
key_deserializer(s, key);
|
||||
val_deserializer(s, val);
|
||||
dict.insert(dict.end(), typename T::value_type{std::move(key), std::move(val)});
|
||||
}
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before dict was finished");
|
||||
s.remove_prefix(1); // Consume the 'e'
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Accept anything that looks iterable; value serialization validity isn't checked here (it fails
|
||||
/// via the base case static assert).
|
||||
template <typename T, typename = void> struct is_bt_input_list_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_input_list_container<T, std::enable_if_t<
|
||||
!std::is_same<T, std::string>::value &&
|
||||
!is_bt_input_dict_container<T>::value,
|
||||
void_t<typename T::const_iterator, typename T::value_type>>>
|
||||
: std::true_type {};
|
||||
|
||||
template <typename T, typename = void> struct is_bt_output_list_container : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_bt_output_list_container<T, std::enable_if_t<
|
||||
!std::is_same<T, std::string>::value &&
|
||||
!is_bt_output_dict_container<T>::value &&
|
||||
is_bt_insertable<T>::value>>
|
||||
: std::true_type {};
|
||||
|
||||
|
||||
/// List specialization
|
||||
template <typename T>
|
||||
struct bt_serialize<T, std::enable_if_t<is_bt_input_list_container<T>::value>> {
|
||||
void operator()(std::ostream& os, const T& list) {
|
||||
os << 'l';
|
||||
for (const auto &v : list)
|
||||
bt_serialize<std::remove_cv_t<typename T::value_type>>{}(os, v);
|
||||
os << 'e';
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
struct bt_deserialize<T, std::enable_if_t<is_bt_output_list_container<T>::value>> {
|
||||
using value_type = typename T::value_type;
|
||||
void operator()(string_view& s, T& list) {
|
||||
// Smallest list is 2 bytes "le", for an empty list.
|
||||
if (s.size() < 2) throw bt_deserialize_invalid("Deserialization failed: end of string found where list expected");
|
||||
if (s[0] != 'l') throw bt_deserialize_invalid_type("Deserialization failed: expected 'l', found '"s + s[0] + "'"s);
|
||||
s.remove_prefix(1);
|
||||
list.clear();
|
||||
bt_deserialize<value_type> deserializer;
|
||||
while (!s.empty() && s[0] != 'e') {
|
||||
value_type v;
|
||||
deserializer(s, v);
|
||||
list.insert(list.end(), std::move(v));
|
||||
}
|
||||
if (s.empty())
|
||||
throw bt_deserialize_invalid("Deserialization failed: encountered end of string before list was finished");
|
||||
s.remove_prefix(1); // Consume the 'e'
|
||||
}
|
||||
};
|
||||
|
||||
/// variant visitor; serializes whatever is contained
|
||||
class bt_serialize_visitor {
|
||||
std::ostream &os;
|
||||
public:
|
||||
using result_type = void;
|
||||
bt_serialize_visitor(std::ostream &os) : os{os} {}
|
||||
template <typename T> void operator()(const T &val) const {
|
||||
bt_serialize<T>{}(os, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using is_bt_deserializable = std::integral_constant<bool,
|
||||
std::is_same<T, std::string>::value || std::is_integral<T>::value ||
|
||||
is_bt_output_dict_container<T>::value || is_bt_output_list_container<T>::value>;
|
||||
|
||||
// General template and base case; this base will only actually be invoked when Ts... is empty,
|
||||
// which means we reached the end without finding any variant type capable of holding the value.
|
||||
template <typename SFINAE, typename Variant, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl {
|
||||
void operator()(string_view&, Variant&) {
|
||||
throw bt_deserialize_invalid("Deserialization failed: could not deserialize value into any variant type");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts, typename Variant>
|
||||
void bt_deserialize_try_variant(string_view& s, Variant& variant) {
|
||||
bt_deserialize_try_variant_impl<void, Variant, Ts...>{}(s, variant);
|
||||
}
|
||||
|
||||
|
||||
template <typename Variant, typename T, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<is_bt_deserializable<T>::value>, Variant, T, Ts...> {
|
||||
void operator()(string_view& s, Variant& variant) {
|
||||
if ( is_bt_output_list_container<T>::value ? s[0] == 'l' :
|
||||
is_bt_output_dict_container<T>::value ? s[0] == 'd' :
|
||||
std::is_integral<T>::value ? s[0] == 'i' :
|
||||
std::is_same<T, std::string>::value ? s[0] >= '0' && s[0] <= '9' :
|
||||
false) {
|
||||
T val;
|
||||
bt_deserialize<T>{}(s, val);
|
||||
variant = std::move(val);
|
||||
} else {
|
||||
bt_deserialize_try_variant<Ts...>(s, variant);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Variant, typename T, typename... Ts>
|
||||
struct bt_deserialize_try_variant_impl<std::enable_if_t<!is_bt_deserializable<T>::value>, Variant, T, Ts...> {
|
||||
void operator()(string_view& s, Variant& variant) {
|
||||
// Unsupported deserialization type, skip it
|
||||
bt_deserialize_try_variant<Ts...>(s, variant);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct bt_deserialize<bt_value, void> {
|
||||
void operator()(string_view& s, bt_value& val);
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct bt_serialize<mapbox::util::variant<Ts...>> {
|
||||
void operator()(std::ostream& os, const mapbox::util::variant<Ts...>& val) {
|
||||
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct bt_deserialize<mapbox::util::variant<Ts...>> {
|
||||
void operator()(string_view& s, mapbox::util::variant<Ts...>& val) {
|
||||
bt_deserialize_try_variant<Ts...>(s, val);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#ifdef __cpp_lib_variant
|
||||
/// C++17 std::variant support
|
||||
template <typename... Ts>
|
||||
struct bt_serialize<std::variant<Ts...>> {
|
||||
void operator()(std::ostream &os, const std::variant<Ts...>& val) {
|
||||
mapbox::util::apply_visitor(bt_serialize_visitor{os}, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct bt_deserialize<std::variant<Ts...>> {
|
||||
void operator()(string_view& s, std::variant<Ts...>& val) {
|
||||
bt_deserialize_try_variant<Ts...>(s, val);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct bt_stream_serializer {
|
||||
const T &val;
|
||||
explicit bt_stream_serializer(const T &val) : val{val} {}
|
||||
operator std::string() const {
|
||||
std::ostringstream oss;
|
||||
oss << *this;
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
std::ostream &operator<<(std::ostream &os, const bt_stream_serializer<T> &s) {
|
||||
bt_serialize<T>{}(os, s.val);
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
||||
/// Returns a wrapper around a value reference that can serialize the value directly to an output
|
||||
/// stream. This class is intended to be used inline (i.e. without being stored) as in:
|
||||
///
|
||||
/// std::list<int> my_list{{1,2,3}};
|
||||
/// std::cout << bt_serializer(my_list);
|
||||
///
|
||||
/// While it is possible to store the returned object and use it, such as:
|
||||
///
|
||||
/// auto encoded = bt_serializer(42);
|
||||
/// std::cout << encoded;
|
||||
///
|
||||
/// this approach is not generally recommended: the returned object stores a reference to the
|
||||
/// passed-in type, which may not survive. If doing this note that it is the caller's
|
||||
/// responsibility to ensure the serializer is not used past the end of the lifetime of the value
|
||||
/// being serialized.
|
||||
///
|
||||
/// Also note that serializing directly to an output stream is more efficient as no intermediate
|
||||
/// string containing the entire serialization has to be constructed.
|
||||
///
|
||||
template <typename T>
|
||||
detail::bt_stream_serializer<T> bt_serializer(const T &val) { return detail::bt_stream_serializer<T>{val}; }
|
||||
|
||||
/// Serializes the given value into a std::string.
|
||||
///
|
||||
/// int number = 42;
|
||||
/// std::string encoded = bt_serialize(number);
|
||||
/// // Equivalent:
|
||||
/// //auto encoded = (std::string) bt_serialize(number);
|
||||
///
|
||||
/// This takes any serializable type: integral types, strings, lists of serializable types, and
|
||||
/// string->value maps of serializable types.
|
||||
template <typename T>
|
||||
std::string bt_serialize(const T &val) { return bt_serializer(val); }
|
||||
|
||||
/// Deserializes the given string view directly into `val`. Usage:
|
||||
///
|
||||
/// std::string encoded = "i42e";
|
||||
/// int value;
|
||||
/// bt_deserialize(encoded, value); // Sets value to 42
|
||||
///
|
||||
template <typename T, std::enable_if_t<!std::is_const<T>::value, int> = 0>
|
||||
void bt_deserialize(string_view s, T& val) {
|
||||
return detail::bt_deserialize<T>{}(s, val);
|
||||
}
|
||||
|
||||
|
||||
/// Deserializes the given string_view into a `T`, which is returned.
|
||||
///
|
||||
/// std::string encoded = "li1ei2ei3ee"; // bt-encoded list of ints: [1,2,3]
|
||||
/// auto mylist = bt_deserialize<std::list<int>>(encoded);
|
||||
///
|
||||
template <typename T>
|
||||
T bt_deserialize(string_view s) {
|
||||
T val;
|
||||
bt_deserialize(s, val);
|
||||
return val;
|
||||
}
|
||||
|
||||
/// Deserializes the given value into a generic `bt_value` type (mapbox::util::variant) which is capable
|
||||
/// of holding all possible BT-encoded values (including recursion).
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// std::string encoded = "i42e";
|
||||
/// auto val = bt_get(encoded);
|
||||
/// int v = get_int<int>(val); // fails unless the encoded value was actually an integer that
|
||||
/// // fits into an `int`
|
||||
///
|
||||
inline bt_value bt_get(string_view s) {
|
||||
return bt_deserialize<bt_value>(s);
|
||||
}
|
||||
|
||||
/// Helper functions to extract a value of some integral type from a bt_value which contains an
|
||||
/// integer. Does range checking, throwing std::overflow_error if the stored value is outside the
|
||||
/// range of the target type.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// std::string encoded = "i123456789e";
|
||||
/// auto val = bt_get(encoded);
|
||||
/// auto v = get_int<uint32_t>(val); // throws if the decoded value doesn't fit in a uint32_t
|
||||
template <typename IntType, std::enable_if_t<std::is_integral<IntType>::value, int> = 0>
|
||||
IntType get_int(const bt_value &v) {
|
||||
// It's highly unlikely that this code ever runs on a non-2s-complement architecture, but check
|
||||
// at compile time if converting to a uint64_t (because while int64_t -> uint64_t is
|
||||
// well-defined, uint64_t -> int64_t only does the right thing under 2's complement).
|
||||
static_assert(!std::is_unsigned<IntType>::value || sizeof(IntType) != sizeof(int64_t) || -1 == ~0,
|
||||
"Non 2s-complement architecture not supported!");
|
||||
int64_t value = mapbox::util::get<int64_t>(v);
|
||||
if (sizeof(IntType) < sizeof(int64_t)) {
|
||||
if (value > static_cast<int64_t>(std::numeric_limits<IntType>::max())
|
||||
|| value < static_cast<int64_t>(std::numeric_limits<IntType>::min()))
|
||||
throw std::overflow_error("Unable to extract integer value: stored value is outside the range of the requested type");
|
||||
}
|
||||
return static_cast<IntType>(value);
|
||||
}
|
||||
|
||||
/// Class that allows you to walk through a bt-encoded list in memory without copying or allocating
|
||||
/// memory. It accesses existing memory directly and so the caller must ensure that the referenced
|
||||
/// memory stays valid for the lifetime of the bt_list_consumer object.
|
||||
class bt_list_consumer {
|
||||
protected:
|
||||
string_view data;
|
||||
bt_list_consumer() = default;
|
||||
public:
|
||||
bt_list_consumer(string_view data_);
|
||||
|
||||
/// Copy constructor. Making a copy copies the current position so can be used for multipass
|
||||
/// iteration through a list.
|
||||
bt_list_consumer(const bt_list_consumer&) = default;
|
||||
bt_list_consumer& operator=(const bt_list_consumer&) = default;
|
||||
|
||||
/// Returns true if the next value indicates the end of the list
|
||||
bool is_finished() const { return data.front() == 'e'; }
|
||||
/// Returns true if the next element looks like an encoded string
|
||||
bool is_string() const { return data.front() >= '0' && data.front() <= '9'; }
|
||||
/// Returns true if the next element looks like an encoded integer
|
||||
bool is_integer() const { return data.front() == 'i'; }
|
||||
/// Returns true if the next element looks like an encoded list
|
||||
bool is_list() const { return data.front() == 'l'; }
|
||||
/// Returns true if the next element looks like an encoded dict
|
||||
bool is_dict() const { return data.front() == 'd'; }
|
||||
|
||||
/// Attempt to parse the next value as a string (and advance just past it). Throws if the next
|
||||
/// value is not a string.
|
||||
std::string consume_string();
|
||||
string_view consume_string_view();
|
||||
|
||||
/// Attempts to parse the next value as an integer (and advance just past it). Throws if the
|
||||
/// next value is not an integer.
|
||||
template <typename IntType>
|
||||
IntType consume_integer() {
|
||||
if (!is_integer()) throw bt_deserialize_invalid_type{"next value is not an integer"};
|
||||
string_view next{data};
|
||||
IntType ret;
|
||||
detail::bt_deserialize<IntType>{}(next, ret);
|
||||
data = next;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Consumes a list, return it as a list-like type. This typically requires dynamic allocation,
|
||||
/// but only has to parse the data once. Compare with consume_list_data() which allows
|
||||
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
|
||||
template <typename T = bt_list>
|
||||
T consume_list() {
|
||||
T list;
|
||||
consume_list(list);
|
||||
return list;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing list-like data type.
|
||||
template <typename T>
|
||||
void consume_list(T& list) {
|
||||
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
string_view n{data};
|
||||
detail::bt_deserialize<T>{}(n, list);
|
||||
data = n;
|
||||
}
|
||||
|
||||
/// Consumes a dict, return it as a dict-like type. This typically requires dynamic allocation,
|
||||
/// but only has to parse the data once. Compare with consume_dict_data() which allows
|
||||
/// alloc-free traversal, but requires parsing twice (if the contents are to be used).
|
||||
template <typename T = bt_dict>
|
||||
T consume_dict() {
|
||||
T dict;
|
||||
consume_dict(dict);
|
||||
return dict;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing dict-like data type.
|
||||
template <typename T>
|
||||
void consume_dict(T& dict) {
|
||||
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
string_view n{data};
|
||||
detail::bt_deserialize<T>{}(n, dict);
|
||||
data = n;
|
||||
}
|
||||
|
||||
/// Consumes a value without returning it.
|
||||
void skip_value();
|
||||
|
||||
/// Attempts to parse the next value as a list and returns the string_view that contains the
|
||||
/// entire thing. This is recursive into both lists and dicts and likely to be quite
|
||||
/// inefficient for large, nested structures (unless the values only need to be skipped but
|
||||
/// aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
string_view consume_list_data();
|
||||
|
||||
/// Attempts to parse the next value as a dict and returns the string_view that contains the
|
||||
/// entire thing. This is recursive into both lists and dicts and likely to be quite
|
||||
/// inefficient for large, nested structures (unless the values only need to be skipped but
|
||||
/// aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
string_view consume_dict_data();
|
||||
};
|
||||
|
||||
|
||||
/// Class that allows you to walk through key-value pairs of a bt-encoded dict in memory without
|
||||
/// copying or allocating memory. It accesses existing memory directly and so the caller must
|
||||
/// ensure that the referenced memory stays valid for the lifetime of the bt_dict_consumer object.
|
||||
class bt_dict_consumer : private bt_list_consumer {
|
||||
string_view key_;
|
||||
|
||||
/// Consume the key if not already consumed and there is a key present (rather than 'e').
|
||||
/// Throws exception if what should be a key isn't a string, or if the key consumes the entire
|
||||
/// data (i.e. requires that it be followed by something). Returns true if the key was consumed
|
||||
/// (either now or previously and cached).
|
||||
bool consume_key();
|
||||
|
||||
/// Clears the cached key and returns it. Must have already called consume_key directly or
|
||||
/// indirectly via one of the `is_{...}` methods.
|
||||
string_view flush_key() {
|
||||
string_view k;
|
||||
k.swap(key_);
|
||||
return k;
|
||||
}
|
||||
|
||||
public:
|
||||
bt_dict_consumer(string_view data_);
|
||||
|
||||
/// Copy constructor. Making a copy copies the current position so can be used for multipass
|
||||
/// iteration through a list.
|
||||
bt_dict_consumer(const bt_dict_consumer&) = default;
|
||||
bt_dict_consumer& operator=(const bt_dict_consumer&) = default;
|
||||
|
||||
/// Returns true if the next value indicates the end of the dict
|
||||
bool is_finished() { return !consume_key() && data.front() == 'e'; }
|
||||
/// Operator bool is an alias for `!is_finished()`
|
||||
operator bool() { return !is_finished(); }
|
||||
/// Returns true if the next value looks like an encoded string
|
||||
bool is_string() { return consume_key() && data.front() >= '0' && data.front() <= '9'; }
|
||||
/// Returns true if the next element looks like an encoded integer
|
||||
bool is_integer() { return consume_key() && data.front() == 'i'; }
|
||||
/// Returns true if the next element looks like an encoded list
|
||||
bool is_list() { return consume_key() && data.front() == 'l'; }
|
||||
/// Returns true if the next element looks like an encoded dict
|
||||
bool is_dict() { return consume_key() && data.front() == 'd'; }
|
||||
/// Returns the key of the next pair. This does not have to be called; it is also returned by
|
||||
/// all of the other consume_* methods. The value is cached whether called here or by some
|
||||
/// other method; accessing it multiple times simple accesses the cache until the next value is
|
||||
/// consumed.
|
||||
string_view key() {
|
||||
if (!consume_key())
|
||||
throw bt_deserialize_invalid{"Cannot access next key: at the end of the dict"};
|
||||
return key_;
|
||||
}
|
||||
|
||||
/// Attempt to parse the next value as a string->string pair (and advance just past it). Throws
|
||||
/// if the next value is not a string.
|
||||
std::pair<string_view, string_view> next_string();
|
||||
|
||||
/// Attempts to parse the next value as an string->integer pair (and advance just past it).
|
||||
/// Throws if the next value is not an integer.
|
||||
template <typename IntType>
|
||||
std::pair<string_view, IntType> next_integer() {
|
||||
if (!is_integer()) throw bt_deserialize_invalid_type{"next bt dict value is not an integer"};
|
||||
std::pair<string_view, IntType> ret;
|
||||
ret.second = bt_list_consumer::consume_integer<IntType>();
|
||||
ret.first = flush_key();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Consumes a string->list pair, return it as a list-like type. This typically requires
|
||||
/// dynamic allocation, but only has to parse the data once. Compare with consume_list_data()
|
||||
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
|
||||
/// used).
|
||||
template <typename T = bt_list>
|
||||
std::pair<string_view, T> next_list() {
|
||||
std::pair<string_view, T> pair;
|
||||
pair.first = consume_list(pair.second);
|
||||
return pair;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing list-like data type. Returns the key.
|
||||
template <typename T>
|
||||
string_view next_list(T& list) {
|
||||
if (!is_list()) throw bt_deserialize_invalid_type{"next bt value is not a list"};
|
||||
bt_list_consumer::consume_list(list);
|
||||
return flush_key();
|
||||
}
|
||||
|
||||
/// Consumes a string->dict pair, return it as a dict-like type. This typically requires
|
||||
/// dynamic allocation, but only has to parse the data once. Compare with consume_dict_data()
|
||||
/// which allows alloc-free traversal, but requires parsing twice (if the contents are to be
|
||||
/// used).
|
||||
template <typename T = bt_dict>
|
||||
std::pair<string_view, T> next_dict() {
|
||||
std::pair<string_view, T> pair;
|
||||
pair.first = consume_dict(pair.second);
|
||||
return pair;
|
||||
}
|
||||
|
||||
/// Same as above, but takes a pre-existing dict-like data type. Returns the key.
|
||||
template <typename T>
|
||||
string_view next_dict(T& dict) {
|
||||
if (!is_dict()) throw bt_deserialize_invalid_type{"next bt value is not a dict"};
|
||||
bt_list_consumer::consume_dict(dict);
|
||||
return flush_key();
|
||||
}
|
||||
|
||||
/// Attempts to parse the next value as a string->list pair and returns the string_view that
|
||||
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
|
||||
/// quite inefficient for large, nested structures (unless the values only need to be skipped
|
||||
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
std::pair<string_view, string_view> next_list_data() {
|
||||
if (data.size() < 2 || !is_list()) throw bt_deserialize_invalid_type{"next bt dict value is not a list"};
|
||||
return {flush_key(), bt_list_consumer::consume_list_data()};
|
||||
}
|
||||
|
||||
/// Same as next_list_data(), but wraps the value in a bt_list_consumer for convenience
|
||||
std::pair<string_view, bt_list_consumer> next_list_consumer() { return next_list_data(); }
|
||||
|
||||
/// Attempts to parse the next value as a string->dict pair and returns the string_view that
|
||||
/// contains the entire thing. This is recursive into both lists and dicts and likely to be
|
||||
/// quite inefficient for large, nested structures (unless the values only need to be skipped
|
||||
/// but aren't separately needed). This, however, does not require dynamic memory allocation.
|
||||
std::pair<string_view, string_view> next_dict_data() {
|
||||
if (data.size() < 2 || !is_dict()) throw bt_deserialize_invalid_type{"next bt dict value is not a dict"};
|
||||
return {flush_key(), bt_list_consumer::consume_dict_data()};
|
||||
}
|
||||
|
||||
/// Same as next_dict_data(), but wraps the value in a bt_dict_consumer for convenience
|
||||
std::pair<string_view, bt_dict_consumer> next_dict_consumer() { return next_dict_data(); }
|
||||
|
||||
/// Skips ahead until we find the first key >= the given key or reach the end of the dict.
|
||||
/// Returns true if we found an exact match, false if we reached some greater value or the end.
|
||||
/// If we didn't hit the end, the next `consumer_*()` call will return the key-value pair we
|
||||
/// found (either the exact match or the first key greater than the requested key).
|
||||
///
|
||||
/// Two important notes:
|
||||
///
|
||||
/// - properly encoded bt dicts must have lexicographically sorted keys, and this method assumes
|
||||
/// that the input is correctly sorted (and thus if we find a greater value then your key does
|
||||
/// not exist).
|
||||
/// - this is irreversible; you cannot returned to skipped values without reparsing. (You *can*
|
||||
/// however, make a copy of the bt_dict_consumer before calling and use the copy to return to
|
||||
/// the pre-skipped position).
|
||||
bool skip_until(string_view find) {
|
||||
while (consume_key() && key_ < find) {
|
||||
flush_key();
|
||||
skip_value();
|
||||
}
|
||||
return key_ == find;
|
||||
}
|
||||
|
||||
/// The `consume_*` functions are wrappers around next_whatever that discard the returned key.
|
||||
///
|
||||
/// Intended for use with skip_until such as:
|
||||
///
|
||||
/// std::string value;
|
||||
/// if (d.skip_until("key"))
|
||||
/// value = d.consume_string();
|
||||
///
|
||||
|
||||
auto consume_string_view() { return next_string().second; }
|
||||
auto consume_string() { return std::string{consume_string_view()}; }
|
||||
|
||||
template <typename IntType>
|
||||
auto consume_integer() { return next_integer<IntType>().second; }
|
||||
|
||||
template <typename T = bt_list>
|
||||
auto consume_list() { return next_list<T>().second; }
|
||||
|
||||
template <typename T>
|
||||
void consume_list(T& list) { next_list(list); }
|
||||
|
||||
template <typename T = bt_dict>
|
||||
auto consume_dict() { return next_dict<T>().second; }
|
||||
|
||||
template <typename T>
|
||||
void consume_dict(T& dict) { next_dict(dict); }
|
||||
|
||||
string_view consume_list_data() { return next_list_data().second; }
|
||||
string_view consume_dict_data() { return next_dict_data().second; }
|
||||
|
||||
bt_list_consumer consume_list_consumer() { return consume_list_data(); }
|
||||
bt_dict_consumer consume_dict_consumer() { return consume_dict_data(); }
|
||||
};
|
||||
|
||||
|
||||
} // namespace lokimq
|
|
@ -1,339 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include "hex.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, const ConnectionID& conn) {
|
||||
if (!conn.pk.empty())
|
||||
return o << (conn.sn() ? "SN " : "non-SN authenticated remote ") << to_hex(conn.pk);
|
||||
else
|
||||
return o << "unauthenticated remote [" << conn.id << "]";
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void add_pollitem(std::vector<zmq::pollitem_t>& pollitems, zmq::socket_t& sock) {
|
||||
pollitems.emplace_back();
|
||||
auto &p = pollitems.back();
|
||||
p.socket = static_cast<void *>(sock);
|
||||
p.fd = 0;
|
||||
p.events = ZMQ_POLLIN;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
void LokiMQ::rebuild_pollitems() {
|
||||
pollitems.clear();
|
||||
add_pollitem(pollitems, command);
|
||||
add_pollitem(pollitems, workers_socket);
|
||||
add_pollitem(pollitems, zap_auth);
|
||||
|
||||
for (auto& s : connections)
|
||||
add_pollitem(pollitems, s);
|
||||
}
|
||||
|
||||
void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey) {
|
||||
if (!remote_pubkey.empty()) {
|
||||
socket.setsockopt(ZMQ_CURVE_SERVERKEY, remote_pubkey.data(), remote_pubkey.size());
|
||||
socket.setsockopt(ZMQ_CURVE_PUBLICKEY, pubkey.data(), pubkey.size());
|
||||
socket.setsockopt(ZMQ_CURVE_SECRETKEY, privkey.data(), privkey.size());
|
||||
}
|
||||
socket.setsockopt(ZMQ_HANDSHAKE_IVL, (int) HANDSHAKE_TIME.count());
|
||||
socket.setsockopt<int64_t>(ZMQ_MAXMSGSIZE, MAX_MSG_SIZE);
|
||||
if (PUBKEY_BASED_ROUTING_ID) {
|
||||
std::string routing_id;
|
||||
routing_id.reserve(33);
|
||||
routing_id += 'L'; // Prefix because routing id's starting with \0 are reserved by zmq (and our pubkey might start with \0)
|
||||
routing_id.append(pubkey.begin(), pubkey.end());
|
||||
socket.setsockopt(ZMQ_ROUTING_ID, routing_id.data(), routing_id.size());
|
||||
}
|
||||
// else let ZMQ pick a random one
|
||||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string>
|
||||
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, std::chrono::milliseconds keep_alive) {
|
||||
ConnectionID remote_cid{remote};
|
||||
auto its = peers.equal_range(remote_cid);
|
||||
peer_info* peer = nullptr;
|
||||
for (auto it = its.first; it != its.second; ++it) {
|
||||
if (incoming_only && it->second.route.empty())
|
||||
continue; // outgoing connection but we were asked to only use incoming connections
|
||||
peer = &it->second;
|
||||
break;
|
||||
}
|
||||
|
||||
if (peer) {
|
||||
LMQ_TRACE("proxy asked to connect to ", to_hex(remote), "; reusing existing connection");
|
||||
if (peer->route.empty() /* == outgoing*/) {
|
||||
if (peer->idle_expiry < keep_alive) {
|
||||
LMQ_LOG(debug, "updating existing outgoing peer connection idle expiry time from ",
|
||||
peer->idle_expiry.count(), "ms to ", keep_alive.count(), "ms");
|
||||
peer->idle_expiry = keep_alive;
|
||||
}
|
||||
peer->activity();
|
||||
}
|
||||
return {&connections[peer->conn_index], peer->route};
|
||||
} else if (optional || incoming_only) {
|
||||
LMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt");
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
|
||||
// No connection so establish a new one
|
||||
LMQ_LOG(debug, "proxy establishing new outbound connection to ", to_hex(remote));
|
||||
std::string addr;
|
||||
bool to_self = false && remote == pubkey; // FIXME; need to use a separate listening socket for this, otherwise we can't easily
|
||||
// tell it wasn't from a remote.
|
||||
if (to_self) {
|
||||
// special inproc connection if self that doesn't need any external connection
|
||||
addr = SN_ADDR_SELF;
|
||||
} else {
|
||||
addr = std::string{connect_hint};
|
||||
if (addr.empty())
|
||||
addr = sn_lookup(remote);
|
||||
else
|
||||
LMQ_LOG(debug, "using connection hint ", connect_hint);
|
||||
|
||||
if (addr.empty()) {
|
||||
LMQ_LOG(error, "peer lookup failed for ", to_hex(remote));
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_LOG(debug, to_hex(pubkey), " (me) connecting to ", addr, " to reach ", to_hex(remote));
|
||||
zmq::socket_t socket{context, zmq::socket_type::dealer};
|
||||
setup_outgoing_socket(socket, remote);
|
||||
socket.connect(addr);
|
||||
peer_info p{};
|
||||
p.service_node = true;
|
||||
p.pubkey = std::string{remote};
|
||||
p.conn_index = connections.size();
|
||||
p.idle_expiry = keep_alive;
|
||||
p.activity();
|
||||
conn_index_to_id.push_back(remote_cid);
|
||||
peers.emplace(std::move(remote_cid), std::move(p));
|
||||
connections.push_back(std::move(socket));
|
||||
|
||||
return {&connections.back(), ""s};
|
||||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string> LokiMQ::proxy_connect_sn(bt_dict_consumer data) {
|
||||
string_view hint, remote_pk;
|
||||
std::chrono::milliseconds keep_alive;
|
||||
bool optional = false, incoming_only = false;
|
||||
|
||||
// Alphabetical order
|
||||
if (data.skip_until("hint"))
|
||||
hint = data.consume_string();
|
||||
if (data.skip_until("incoming"))
|
||||
incoming_only = data.consume_integer<bool>();
|
||||
if (data.skip_until("keep_alive"))
|
||||
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
if (data.skip_until("optional"))
|
||||
optional = data.consume_integer<bool>();
|
||||
if (!data.skip_until("pubkey"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy_connect_sn command; pubkey missing");
|
||||
remote_pk = data.consume_string();
|
||||
|
||||
return proxy_connect_sn(remote_pk, hint, optional, incoming_only, keep_alive);
|
||||
}
|
||||
|
||||
template <typename Container, typename AccessIndex>
|
||||
void update_connection_indices(Container& c, size_t index, AccessIndex get_index) {
|
||||
for (auto it = c.begin(); it != c.end(); ) {
|
||||
size_t& i = get_index(*it);
|
||||
if (index == i) {
|
||||
it = c.erase(it);
|
||||
continue;
|
||||
}
|
||||
if (i > index)
|
||||
--i;
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
/// Closes outgoing connections and removes all references. Note that this will invalidate
|
||||
/// iterators on the various connection containers - if you don't want that, delete it first so that
|
||||
/// the container won't contain the element being deleted.
|
||||
void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) {
|
||||
connections[index].setsockopt<int>(ZMQ_LINGER, linger > 0ms ? linger.count() : 0);
|
||||
pollitems_stale = true;
|
||||
connections.erase(connections.begin() + index);
|
||||
|
||||
LMQ_LOG(debug, "Closing conn index ", index);
|
||||
update_connection_indices(peers, index,
|
||||
[](auto& p) -> size_t& { return p.second.conn_index; });
|
||||
update_connection_indices(pending_connects, index,
|
||||
[](auto& pc) -> size_t& { return std::get<size_t>(pc); });
|
||||
update_connection_indices(bind, index,
|
||||
[](auto& b) -> size_t& { return b.second.index; });
|
||||
update_connection_indices(incoming_conn_index, index,
|
||||
[](auto& oci) -> size_t& { return oci.second; });
|
||||
assert(index < conn_index_to_id.size());
|
||||
conn_index_to_id.erase(conn_index_to_id.begin() + index);
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_expire_idle_peers() {
|
||||
for (auto it = peers.begin(); it != peers.end(); ) {
|
||||
auto &info = it->second;
|
||||
if (info.outgoing()) {
|
||||
auto idle = info.last_activity - std::chrono::steady_clock::now();
|
||||
if (idle <= info.idle_expiry) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
LMQ_LOG(info, "Closing outgoing connection to ", it->first, ": idle timeout reached");
|
||||
proxy_close_connection(info.conn_index, CLOSE_LINGER);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_conn_cleanup() {
|
||||
LMQ_TRACE("starting proxy connections cleanup");
|
||||
|
||||
// Drop idle connections (if we haven't done it in a while) but *only* if we have some idle
|
||||
// general workers: if we don't have any idle workers then we may still have incoming messages which
|
||||
// we haven't processed yet and those messages might end up resetting the last activity time.
|
||||
if (static_cast<int>(workers.size()) < general_workers) {
|
||||
LMQ_TRACE("closing idle connections");
|
||||
proxy_expire_idle_peers();
|
||||
}
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
// FIXME - check other outgoing connections to see if they died and if so purge them
|
||||
|
||||
LMQ_TRACE("Timing out pending outgoing connections");
|
||||
// Check any pending outgoing connections for timeout
|
||||
for (auto it = pending_connects.begin(); it != pending_connects.end(); ) {
|
||||
auto& pc = *it;
|
||||
if (std::get<std::chrono::steady_clock::time_point>(pc) < now) {
|
||||
job([cid = ConnectionID{std::get<long long>(pc)}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); });
|
||||
it = pending_connects.erase(it); // Don't let the below erase it (because it invalidates iterators)
|
||||
proxy_close_connection(std::get<size_t>(pc), CLOSE_LINGER);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("Timing out pending requests");
|
||||
// Remove any expired pending requests and schedule their callback with a failure
|
||||
for (auto it = pending_requests.begin(); it != pending_requests.end(); ) {
|
||||
auto& callback = it->second;
|
||||
if (callback.first < now) {
|
||||
LMQ_LOG(debug, "pending request ", to_hex(it->first), " expired, invoking callback with failure status and removing");
|
||||
job([callback = std::move(callback.second)] { callback(false, {}); });
|
||||
it = pending_requests.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("done proxy connections cleanup");
|
||||
};
|
||||
|
||||
void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
|
||||
AuthLevel auth_level = AuthLevel::none;
|
||||
long long conn_id = -1;
|
||||
ConnectSuccess on_connect;
|
||||
ConnectFailure on_failure;
|
||||
std::string remote;
|
||||
std::string remote_pubkey;
|
||||
std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT;
|
||||
|
||||
if (data.skip_until("auth_level"))
|
||||
auth_level = static_cast<AuthLevel>(data.consume_integer<std::underlying_type_t<AuthLevel>>());
|
||||
if (data.skip_until("conn_id"))
|
||||
conn_id = data.consume_integer<long long>();
|
||||
if (data.skip_until("connect")) {
|
||||
auto* ptr = reinterpret_cast<ConnectSuccess*>(data.consume_integer<uintptr_t>());
|
||||
on_connect = std::move(*ptr);
|
||||
delete ptr;
|
||||
}
|
||||
if (data.skip_until("failure")) {
|
||||
auto* ptr = reinterpret_cast<ConnectFailure*>(data.consume_integer<uintptr_t>());
|
||||
on_failure = std::move(*ptr);
|
||||
delete ptr;
|
||||
}
|
||||
if (data.skip_until("pubkey")) {
|
||||
remote_pubkey = data.consume_string();
|
||||
assert(remote_pubkey.size() == 32 || remote_pubkey.empty());
|
||||
}
|
||||
if (data.skip_until("remote"))
|
||||
remote = data.consume_string();
|
||||
if (data.skip_until("timeout"))
|
||||
timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
|
||||
if (conn_id == -1 || remote.empty())
|
||||
throw std::runtime_error("Internal error: CONNECT_REMOTE proxy command missing required 'conn_id' and/or 'remote' value");
|
||||
|
||||
LMQ_LOG(info, "Establishing remote connection to ", remote, remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + to_hex(remote_pubkey));
|
||||
|
||||
assert(conn_index_to_id.size() == connections.size());
|
||||
|
||||
zmq::socket_t sock{context, zmq::socket_type::dealer};
|
||||
try {
|
||||
setup_outgoing_socket(sock, remote_pubkey);
|
||||
sock.connect(remote);
|
||||
} catch (const zmq::error_t &e) {
|
||||
proxy_schedule_reply_job([conn_id, on_failure=std::move(on_failure), what="connect() failed: "s+e.what()] {
|
||||
on_failure(conn_id, std::move(what));
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
connections.push_back(std::move(sock));
|
||||
LMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI");
|
||||
send_direct_message(connections.back(), "HI");
|
||||
pending_connects.emplace_back(connections.size()-1, conn_id, std::chrono::steady_clock::now() + timeout,
|
||||
std::move(on_connect), std::move(on_failure));
|
||||
peer_info peer;
|
||||
peer.pubkey = std::move(remote_pubkey);
|
||||
peer.service_node = false;
|
||||
peer.auth_level = auth_level;
|
||||
peer.conn_index = connections.size() - 1;
|
||||
ConnectionID conn{conn_id, peer.pubkey};
|
||||
conn_index_to_id.push_back(conn);
|
||||
assert(connections.size() == conn_index_to_id.size());
|
||||
peer.idle_expiry = 24h * 10 * 365; // "forever"
|
||||
peer.activity();
|
||||
peers.emplace(std::move(conn), std::move(peer));
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_disconnect(bt_dict_consumer data) {
|
||||
ConnectionID connid{-1};
|
||||
std::chrono::milliseconds linger = 1s;
|
||||
|
||||
if (data.skip_until("conn_id"))
|
||||
connid.id = data.consume_integer<long long>();
|
||||
if (data.skip_until("linger_ms"))
|
||||
linger = std::chrono::milliseconds(data.consume_integer<long long>());
|
||||
if (data.skip_until("pubkey"))
|
||||
connid.pk = data.consume_string();
|
||||
|
||||
if (connid.sn() && connid.pk.size() != 32)
|
||||
throw std::runtime_error("Error: invalid disconnect of SN without a valid pubkey");
|
||||
|
||||
proxy_disconnect(std::move(connid), linger);
|
||||
}
|
||||
void LokiMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) {
|
||||
LMQ_TRACE("Disconnecting outgoing connection to ", conn);
|
||||
auto pr = peers.equal_range(conn);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
auto& peer = it->second;
|
||||
if (peer.outgoing()) {
|
||||
LMQ_LOG(info, "Closing outgoing connection to ", conn);
|
||||
proxy_close_connection(peer.conn_index, linger);
|
||||
return;
|
||||
}
|
||||
}
|
||||
LMQ_LOG(warn, "Failed to disconnect ", conn, ": no such outgoing connection");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
122
lokimq/hex.h
122
lokimq/hex.h
|
@ -1,122 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include "string_view.h"
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
#include <cassert>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Compile-time generated lookup tables hex conversion
|
||||
struct hex_table {
|
||||
char from_hex_lut[256];
|
||||
char to_hex_lut[16];
|
||||
constexpr hex_table() noexcept : from_hex_lut{}, to_hex_lut{} {
|
||||
for (unsigned char c = 0; c < 10; c++) {
|
||||
from_hex_lut[(unsigned char)('0' + c)] = 0 + c;
|
||||
to_hex_lut[ (unsigned char)( 0 + c)] = '0' + c;
|
||||
}
|
||||
for (unsigned char c = 0; c < 6; c++) {
|
||||
from_hex_lut[(unsigned char)('a' + c)] = 10 + c;
|
||||
from_hex_lut[(unsigned char)('A' + c)] = 10 + c;
|
||||
to_hex_lut[ (unsigned char)(10 + c)] = 'a' + c;
|
||||
}
|
||||
}
|
||||
constexpr char from_hex(unsigned char c) const noexcept { return from_hex_lut[c]; }
|
||||
constexpr char to_hex(unsigned char b) const noexcept { return to_hex_lut[b]; }
|
||||
} constexpr hex_lut;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Creates hex digits from a character sequence.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void to_hex(InputIt begin, InputIt end, OutputIt out) {
|
||||
for (; begin != end; ++begin) {
|
||||
auto c = *begin;
|
||||
*out++ = detail::hex_lut.to_hex((c & 0xf0) >> 4);
|
||||
*out++ = detail::hex_lut.to_hex(c & 0x0f);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a hex string from an iterable, std::string-like object
|
||||
inline std::string to_hex(string_view s) {
|
||||
std::string hex;
|
||||
hex.reserve(s.size() * 2);
|
||||
to_hex(s.begin(), s.end(), std::back_inserter(hex));
|
||||
return hex;
|
||||
}
|
||||
|
||||
/// Returns true if all elements in the range are hex characters
|
||||
template <typename It>
|
||||
constexpr bool is_hex(It begin, It end) {
|
||||
for (; begin != end; ++begin) {
|
||||
if (detail::hex_lut.from_hex(*begin) == 0 && *begin != '0')
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if all elements in the string-like value are hex characters
|
||||
constexpr bool is_hex(string_view s) { return is_hex(s.begin(), s.end()); }
|
||||
|
||||
/// Convert a hex digit into its numeric (0-15) value
|
||||
constexpr char from_hex_digit(unsigned char x) noexcept {
|
||||
return detail::hex_lut.from_hex(x);
|
||||
}
|
||||
|
||||
/// Constructs a byte value from a pair of hex digits
|
||||
constexpr char from_hex_pair(unsigned char a, unsigned char b) noexcept { return (from_hex_digit(a) << 4) | from_hex_digit(b); }
|
||||
|
||||
/// Converts a sequence of hex digits to bytes. Undefined behaviour if any characters are not in
|
||||
/// [0-9a-fA-F] or if the input sequence length is not even. It is permitted for the input and
|
||||
/// output ranges to overlap as long as out is no earlier than begin.
|
||||
template <typename InputIt, typename OutputIt>
|
||||
void from_hex(InputIt begin, InputIt end, OutputIt out) {
|
||||
using std::distance;
|
||||
assert(distance(begin, end) % 2 == 0);
|
||||
while (begin != end) {
|
||||
auto a = *begin++;
|
||||
auto b = *begin++;
|
||||
*out++ = from_hex_pair(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts hex digits from a std::string-like object into a std::string of bytes. Undefined
|
||||
/// behaviour if any characters are not in [0-9a-fA-F] or if the input sequence length is not even.
|
||||
inline std::string from_hex(string_view s) {
|
||||
std::string bytes;
|
||||
bytes.reserve(s.size() / 2);
|
||||
from_hex(s.begin(), s.end(), std::back_inserter(bytes));
|
||||
return bytes;
|
||||
}
|
||||
|
||||
}
|
109
lokimq/jobs.cpp
109
lokimq/jobs.cpp
|
@ -1,109 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "batch.h"
|
||||
#include "lokimq-internal.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
void LokiMQ::proxy_batch(detail::Batch* batch) {
|
||||
batches.insert(batch);
|
||||
const int jobs = batch->size();
|
||||
for (int i = 0; i < jobs; i++)
|
||||
batch_jobs.emplace(batch, i);
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void LokiMQ::job(std::function<void()> f) {
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::move(f));
|
||||
auto* baseptr = static_cast<detail::Batch*>(b);
|
||||
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_schedule_reply_job(std::function<void()> f) {
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::move(f));
|
||||
batches.insert(b);
|
||||
reply_jobs.emplace(static_cast<detail::Batch*>(b), 0);
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_run_batch_jobs(std::queue<batch_job>& jobs, const int reserved, int& active, bool reply) {
|
||||
while (!jobs.empty() &&
|
||||
(active < reserved || static_cast<int>(workers.size() - idle_workers.size()) < general_workers)) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(jobs.front()), reply));
|
||||
jobs.pop();
|
||||
active++;
|
||||
}
|
||||
}
|
||||
|
||||
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
|
||||
// the timer. If the timer object hasn't been set up yet it gets set up here.
|
||||
void LokiMQ::proxy_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
int timer_id = zmq_timers_add(timers.get(),
|
||||
interval.count(),
|
||||
[](int timer_id, void* self) { static_cast<LokiMQ*>(self)->_queue_timer_job(timer_id); },
|
||||
this);
|
||||
if (timer_id == -1)
|
||||
throw zmq::error_t{};
|
||||
timer_jobs[timer_id] = std::make_tuple(std::move(job), squelch, false);
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_timer(bt_list_consumer timer_data) {
|
||||
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
|
||||
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
|
||||
auto squelch = timer_data.consume_integer<bool>();
|
||||
if (!timer_data.is_finished())
|
||||
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
|
||||
proxy_timer(std::move(*func), interval, squelch);
|
||||
}
|
||||
|
||||
void LokiMQ::_queue_timer_job(int timer_id) {
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it == timer_jobs.end()) {
|
||||
LMQ_LOG(warn, "Could not find timer job ", timer_id);
|
||||
return;
|
||||
}
|
||||
auto& timer = it->second;
|
||||
auto& squelch = std::get<1>(timer);
|
||||
auto& running = std::get<2>(timer);
|
||||
if (squelch && running) {
|
||||
LMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
|
||||
return;
|
||||
}
|
||||
|
||||
auto* b = new Batch<void>;
|
||||
b->add_job(std::get<0>(timer));
|
||||
if (squelch) {
|
||||
running = true;
|
||||
b->completion_proxy([this,timer_id](auto results) {
|
||||
try { results[0].get(); }
|
||||
catch (const std::exception &e) { LMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { LMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it != timer_jobs.end())
|
||||
std::get<2>(it->second)/*running*/ = false;
|
||||
});
|
||||
}
|
||||
batches.insert(b);
|
||||
batch_jobs.emplace(static_cast<detail::Batch*>(b), 0);
|
||||
assert(b->size() == 1);
|
||||
}
|
||||
|
||||
void LokiMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch) {
|
||||
if (proxy_thread.joinable()) {
|
||||
auto *jobptr = new std::function<void()>{std::move(job)};
|
||||
detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{
|
||||
reinterpret_cast<uintptr_t>(jobptr),
|
||||
interval.count(),
|
||||
squelch}}));
|
||||
} else {
|
||||
proxy_timer(std::move(job), interval, squelch);
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
|
||||
|
||||
}
|
1200
lokimq/lokimq.h
1200
lokimq/lokimq.h
File diff suppressed because it is too large
Load diff
|
@ -1,54 +0,0 @@
|
|||
#pragma once
|
||||
#include <vector>
|
||||
#include "connections.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
class LokiMQ;
|
||||
|
||||
/// Encapsulates an incoming message from a remote connection with message details plus extra
|
||||
/// info need to send a reply back through the proxy thread via the `reply()` method. Note that
|
||||
/// this object gets reused: callbacks should use but not store any reference beyond the callback.
|
||||
class Message {
|
||||
public:
|
||||
LokiMQ& lokimq; ///< The owning LokiMQ object
|
||||
std::vector<string_view> data; ///< The provided command data parts, if any.
|
||||
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
|
||||
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
|
||||
|
||||
/// Constructor
|
||||
Message(LokiMQ& lmq, ConnectionID cid) : lokimq{lmq}, conn{std::move(cid)} {}
|
||||
|
||||
// Non-copyable
|
||||
Message(const Message&) = delete;
|
||||
Message& operator=(const Message&) = delete;
|
||||
|
||||
/// Sends a command back to whomever sent this message. Arguments are forwarded to send() but
|
||||
/// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e.
|
||||
/// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to
|
||||
/// establish a new connection to the SN if no longer connected. For non-SN messages the reply
|
||||
/// will be attempted using the available routing information, but if the connection has already
|
||||
/// been closed the reply will be dropped.
|
||||
///
|
||||
/// If you want to send a non-strong reply even when the remote is a service node then add
|
||||
/// an explicit `send_option::optional()` argument.
|
||||
template <typename... Args>
|
||||
void send_back(string_view, Args&&... args);
|
||||
|
||||
/// Sends a reply to a request. This takes no command: the command is always the built-in
|
||||
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
|
||||
/// arguments are as in `send_back()`. You should only send one reply for a command expecting
|
||||
/// replies, though this is not enforced: attempting to send multiple replies will simply be
|
||||
/// dropped when received by the remote. (Note, however, that it is possible to send multiple
|
||||
/// messages -- e.g. you could send a reply and then also call send_back() and/or send_request()
|
||||
/// to send more requests back to the sender).
|
||||
template <typename... Args>
|
||||
void send_reply(Args&&... args);
|
||||
|
||||
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
|
||||
/// lmq.request() that takes care of setting up the recipient arguments.
|
||||
template <typename ReplyCallback, typename... Args>
|
||||
void send_request(string_view cmd, ReplyCallback&& callback, Args&&... args);
|
||||
};
|
||||
|
||||
}
|
564
lokimq/proxy.cpp
564
lokimq/proxy.cpp
|
@ -1,564 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include "hex.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
void LokiMQ::proxy_quit() {
|
||||
LMQ_LOG(debug, "Received quit command, shutting down proxy thread");
|
||||
|
||||
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
|
||||
|
||||
command.setsockopt<int>(ZMQ_LINGER, 0);
|
||||
command.close();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{control_sockets_mutex};
|
||||
for (auto &control : thread_control_sockets)
|
||||
control->close();
|
||||
proxy_shutting_down = true; // To prevent threads from opening new control sockets
|
||||
}
|
||||
workers_socket.close();
|
||||
int linger = std::chrono::milliseconds{CLOSE_LINGER}.count();
|
||||
for (auto& s : connections)
|
||||
s.setsockopt(ZMQ_LINGER, linger);
|
||||
connections.clear();
|
||||
peers.clear();
|
||||
|
||||
LMQ_LOG(debug, "Proxy thread teardown complete");
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_send(bt_dict_consumer data) {
|
||||
// NB: bt_dict_consumer goes in alphabetical order
|
||||
string_view hint;
|
||||
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
|
||||
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
|
||||
bool optional = false;
|
||||
bool incoming = false;
|
||||
bool request = false;
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id;
|
||||
|
||||
std::string request_tag;
|
||||
ReplyCallback request_callback;
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_pubkey or conn_id missing");
|
||||
if (data.skip_until("conn_route"))
|
||||
conn_id.route = data.consume_string();
|
||||
if (data.skip_until("hint"))
|
||||
hint = data.consume_string_view();
|
||||
if (data.skip_until("incoming"))
|
||||
incoming = data.consume_integer<bool>();
|
||||
if (data.skip_until("keep_alive"))
|
||||
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
if (data.skip_until("optional"))
|
||||
optional = data.consume_integer<bool>();
|
||||
|
||||
if (data.skip_until("request"))
|
||||
request = data.consume_integer<bool>();
|
||||
if (request) {
|
||||
if (!data.skip_until("request_callback"))
|
||||
throw std::runtime_error("Internal error: received request without request_callback");
|
||||
|
||||
// The initiator gives up ownership of the callback to us (serializing it through a
|
||||
// uintptr_t), so we take the pointer, move the value out of it, then destroy the pointer we
|
||||
// were given. Further down, if we are able to send the request successfully, we set up the
|
||||
// pending request.
|
||||
auto* cbptr = reinterpret_cast<ReplyCallback*>(data.consume_integer<uintptr_t>());
|
||||
request_callback = std::move(*cbptr);
|
||||
delete cbptr;
|
||||
|
||||
if (!data.skip_until("request_tag"))
|
||||
throw std::runtime_error("Internal error: received request without request_name");
|
||||
request_tag = data.consume_string();
|
||||
if (data.skip_until("request_timeout"))
|
||||
request_timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
}
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; send parts missing");
|
||||
bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
// Now figure out which socket to send to and do the actual sending. We can repeat this loop
|
||||
// multiple times, if we're sending to a SN, because it's possible that we have multiple
|
||||
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
|
||||
// connection and try the next one.
|
||||
bool retry = true, sent = false;
|
||||
while (retry) {
|
||||
retry = false;
|
||||
zmq::socket_t *send_to;
|
||||
if (conn_id.sn()) {
|
||||
auto sock_route = proxy_connect_sn(conn_id.pk, hint, optional, incoming, keep_alive);
|
||||
if (!sock_route.first) {
|
||||
if (optional)
|
||||
LMQ_LOG(debug, "Not sending: send is optional and no connection to ",
|
||||
to_hex(conn_id.pk), " is currently established");
|
||||
else
|
||||
LMQ_LOG(error, "Unable to send to ", to_hex(conn_id.pk), ": no connection address found");
|
||||
break;
|
||||
}
|
||||
send_to = sock_route.first;
|
||||
conn_id.route = std::move(sock_route.second);
|
||||
} else if (!conn_id.route.empty()) { // incoming non-SN connection
|
||||
auto it = incoming_conn_index.find(conn_id);
|
||||
if (it == incoming_conn_index.end()) {
|
||||
LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
|
||||
break;
|
||||
}
|
||||
send_to = &connections[it->second];
|
||||
} else {
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == peers.end()) {
|
||||
LMQ_LOG(warn, "Unable to send: connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
|
||||
break;
|
||||
}
|
||||
auto& peer = pr.first->second;
|
||||
send_to = &connections[peer.conn_index];
|
||||
}
|
||||
|
||||
try {
|
||||
send_message_parts(*send_to, build_send_parts(send, conn_id.route));
|
||||
sent = true;
|
||||
} catch (const zmq::error_t &e) {
|
||||
if (e.num() == EHOSTUNREACH && !conn_id.route.empty() /*= incoming conn*/) {
|
||||
|
||||
LMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first != peers.end()) {
|
||||
if (!conn_id.sn()) {
|
||||
peers.erase(pr.first);
|
||||
} else {
|
||||
bool removed;
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
auto& peer = it->second;
|
||||
if (peer.route == conn_id.route) {
|
||||
peers.erase(it);
|
||||
removed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The incoming connection to the SN is no longer good, but we can retry because
|
||||
// we may have another active connection with the SN (or may want to open one).
|
||||
if (removed) {
|
||||
LMQ_LOG(debug, "Retrying sending to SN ", to_hex(conn_id.pk), " using other sockets");
|
||||
retry = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!retry) {
|
||||
LMQ_LOG(warn, "Unable to send message to ", conn_id, ": ", e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (request) {
|
||||
if (sent) {
|
||||
LMQ_LOG(debug, "Added new pending request ", to_hex(request_tag));
|
||||
pending_requests.insert({ request_tag, {
|
||||
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
|
||||
} else {
|
||||
LMQ_LOG(debug, "Could not send request, scheduling request callback failure");
|
||||
job([callback = std::move(request_callback)] { callback(false, {}); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_reply(bt_dict_consumer data) {
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id{0};
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_pubkey or conn_id missing");
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; send parts missing");
|
||||
|
||||
bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == pr.second) {
|
||||
LMQ_LOG(warn, "Unable to send tagged reply: the connection is no longer valid");
|
||||
return;
|
||||
}
|
||||
|
||||
// We try any connections until one works (for ordinary remotes there will be just one, but for
|
||||
// SNs there might be one incoming and one outgoing).
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
try {
|
||||
send_message_parts(connections[it->second.conn_index], build_send_parts(send, it->second.route));
|
||||
break;
|
||||
} catch (const zmq::error_t &err) {
|
||||
if (err.num() == EHOSTUNREACH) {
|
||||
LMQ_LOG(info, "Unable to send reply to incoming non-SN request: remote is no longer connected");
|
||||
LMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
|
||||
it = peers.erase(it);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what());
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
|
||||
if (parts.size() < 2)
|
||||
throw std::logic_error("Expected 2-3 message parts for a proxy control message");
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
LMQ_TRACE("control message: ", cmd);
|
||||
if (parts.size() == 3) {
|
||||
LMQ_TRACE("...: ", parts[2]);
|
||||
if (cmd == "SEND") {
|
||||
LMQ_TRACE("proxying message");
|
||||
return proxy_send(view(parts[2]));
|
||||
} else if (cmd == "REPLY") {
|
||||
LMQ_TRACE("proxying reply to non-SN incoming message");
|
||||
return proxy_reply(view(parts[2]));
|
||||
} else if (cmd == "BATCH") {
|
||||
LMQ_TRACE("proxy batch jobs");
|
||||
auto ptrval = bt_deserialize<uintptr_t>(view(parts[2]));
|
||||
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
|
||||
} else if (cmd == "CONNECT_SN") {
|
||||
proxy_connect_sn(view(parts[2]));
|
||||
return;
|
||||
} else if (cmd == "CONNECT_REMOTE") {
|
||||
return proxy_connect_remote(view(parts[2]));
|
||||
} else if (cmd == "DISCONNECT") {
|
||||
return proxy_disconnect(view(parts[2]));
|
||||
} else if (cmd == "TIMER") {
|
||||
return proxy_timer(view(parts[2]));
|
||||
}
|
||||
} else if (parts.size() == 2) {
|
||||
if (cmd == "START") {
|
||||
// Command send by the owning thread during startup; we send back a simple READY reply to
|
||||
// let it know we are running.
|
||||
return route_control(command, route, "READY");
|
||||
} else if (cmd == "QUIT") {
|
||||
// Asked to quit: set max_workers to zero and tell any idle ones to quit. We will
|
||||
// close workers as they come back to READY status, and then close external
|
||||
// connections once all workers are done.
|
||||
max_workers = 0;
|
||||
for (const auto &route : idle_workers)
|
||||
route_control(workers_socket, workers[route].worker_routing_id, "QUIT");
|
||||
idle_workers.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Proxy received invalid control command: " + std::string{cmd} +
|
||||
" (" + std::to_string(parts.size()) + ")");
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_loop() {
|
||||
|
||||
zap_auth.setsockopt<int>(ZMQ_LINGER, 0);
|
||||
zap_auth.bind(ZMQ_ADDR_ZAP);
|
||||
|
||||
workers_socket.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
|
||||
workers_socket.bind(SN_ADDR_WORKERS);
|
||||
|
||||
assert(general_workers > 0);
|
||||
if (batch_jobs_reserved < 0)
|
||||
batch_jobs_reserved = (general_workers + 1) / 2;
|
||||
if (reply_jobs_reserved < 0)
|
||||
reply_jobs_reserved = (general_workers + 7) / 8;
|
||||
|
||||
max_workers = general_workers + batch_jobs_reserved + reply_jobs_reserved;
|
||||
for (const auto& cat : categories) {
|
||||
max_workers += cat.second.reserved_threads;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
if (log_level() >= LogLevel::trace) {
|
||||
LMQ_TRACE("Reserving space for ", max_workers, " max workers = ", general_workers, " general plus reservations for:");
|
||||
for (const auto& cat : categories)
|
||||
LMQ_TRACE(" - ", cat.first, ": ", cat.second.reserved_threads);
|
||||
LMQ_TRACE(" - (batch jobs): ", batch_jobs_reserved);
|
||||
LMQ_TRACE(" - (reply jobs): ", reply_jobs_reserved);
|
||||
}
|
||||
#endif
|
||||
|
||||
workers.reserve(max_workers);
|
||||
if (!workers.empty())
|
||||
throw std::logic_error("Internal error: proxy thread started with active worker threads");
|
||||
|
||||
for (size_t i = 0; i < bind.size(); i++) {
|
||||
auto& b = bind[i].second;
|
||||
zmq::socket_t listener{context, zmq::socket_type::router};
|
||||
|
||||
std::string auth_domain = bt_serialize(i);
|
||||
listener.setsockopt(ZMQ_ZAP_DOMAIN, auth_domain.c_str(), auth_domain.size());
|
||||
if (b.curve) {
|
||||
listener.setsockopt<int>(ZMQ_CURVE_SERVER, 1);
|
||||
listener.setsockopt(ZMQ_CURVE_PUBLICKEY, pubkey.data(), pubkey.size());
|
||||
listener.setsockopt(ZMQ_CURVE_SECRETKEY, privkey.data(), privkey.size());
|
||||
}
|
||||
listener.setsockopt(ZMQ_HANDSHAKE_IVL, (int) HANDSHAKE_TIME.count());
|
||||
listener.setsockopt<int64_t>(ZMQ_MAXMSGSIZE, MAX_MSG_SIZE);
|
||||
listener.setsockopt<int>(ZMQ_ROUTER_HANDOVER, 1);
|
||||
listener.setsockopt<int>(ZMQ_ROUTER_MANDATORY, 1);
|
||||
|
||||
listener.bind(bind[i].first);
|
||||
LMQ_LOG(info, "LokiMQ listening on ", bind[i].first);
|
||||
|
||||
connections.push_back(std::move(listener));
|
||||
auto conn_id = next_conn_id++;
|
||||
conn_index_to_id.push_back(conn_id);
|
||||
incoming_conn_index[conn_id] = connections.size() - 1;
|
||||
b.index = connections.size() - 1;
|
||||
}
|
||||
pollitems_stale = true;
|
||||
|
||||
// Also add an internal connection to self so that calling code can avoid needing to
|
||||
// special-case rare situations where we are supposed to talk to a quorum member that happens to
|
||||
// be ourselves (which can happen, for example, with cross-quoum Blink communication)
|
||||
// FIXME: not working
|
||||
//listener.bind(SN_ADDR_SELF);
|
||||
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
auto do_conn_cleanup = [this] { proxy_conn_cleanup(); };
|
||||
using CleanupLambda = decltype(do_conn_cleanup);
|
||||
if (-1 == zmq_timers_add(timers.get(),
|
||||
std::chrono::milliseconds{CONN_CHECK_INTERVAL}.count(),
|
||||
// Wrap our lambda into a C function pointer where we pass in the lambda pointer as extra arg
|
||||
[](int /*timer_id*/, void* cleanup) { (*static_cast<CleanupLambda*>(cleanup))(); },
|
||||
&do_conn_cleanup)) {
|
||||
throw zmq::error_t{};
|
||||
}
|
||||
|
||||
std::vector<zmq::message_t> parts;
|
||||
|
||||
while (true) {
|
||||
std::chrono::milliseconds poll_timeout;
|
||||
if (max_workers == 0) { // Will be 0 only if we are quitting
|
||||
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); })) {
|
||||
// All the workers have finished, so we can finish shutting down
|
||||
return proxy_quit();
|
||||
}
|
||||
poll_timeout = 1s; // We don't keep running timers when we're quitting, so don't have a timer to check
|
||||
} else {
|
||||
poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())};
|
||||
}
|
||||
|
||||
if (proxy_skip_one_poll)
|
||||
proxy_skip_one_poll = false;
|
||||
else {
|
||||
LMQ_TRACE("polling for new messages");
|
||||
|
||||
if (pollitems_stale)
|
||||
rebuild_pollitems();
|
||||
|
||||
// We poll the control socket and worker socket for any incoming messages. If we have
|
||||
// available worker room then also poll incoming connections and outgoing connections
|
||||
// for messages to forward to a worker. Otherwise, we just look for a control message
|
||||
// or a worker coming back with a ready message.
|
||||
zmq::poll(pollitems.data(), pollitems.size(), poll_timeout);
|
||||
}
|
||||
|
||||
LMQ_TRACE("processing control messages");
|
||||
// Retrieve any waiting incoming control messages
|
||||
for (parts.clear(); recv_message_parts(command, parts, zmq::recv_flags::dontwait); parts.clear()) {
|
||||
proxy_control_message(parts);
|
||||
}
|
||||
|
||||
LMQ_TRACE("processing worker messages");
|
||||
for (parts.clear(); recv_message_parts(workers_socket, parts, zmq::recv_flags::dontwait); parts.clear()) {
|
||||
proxy_worker_message(parts);
|
||||
}
|
||||
|
||||
LMQ_TRACE("processing timers");
|
||||
zmq_timers_execute(timers.get());
|
||||
|
||||
// Handle any zap authentication
|
||||
LMQ_TRACE("processing zap requests");
|
||||
process_zap_requests();
|
||||
|
||||
// See if we can drain anything from the current queue before we potentially add to it
|
||||
// below.
|
||||
LMQ_TRACE("processing queued jobs and messages");
|
||||
proxy_process_queue();
|
||||
|
||||
LMQ_TRACE("processing new incoming messages");
|
||||
|
||||
// We round-robin connections when pulling off pending messages one-by-one rather than
|
||||
// pulling off all messages from one connection before moving to the next; thus in cases of
|
||||
// contention we end up fairly distributing.
|
||||
const int num_sockets = connections.size();
|
||||
std::queue<int> queue_index;
|
||||
for (int i = 0; i < num_sockets; i++)
|
||||
queue_index.push(i);
|
||||
|
||||
for (parts.clear(); !queue_index.empty() && static_cast<int>(workers.size()) < max_workers; parts.clear()) {
|
||||
size_t i = queue_index.front();
|
||||
queue_index.pop();
|
||||
auto& sock = connections[i];
|
||||
|
||||
if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait))
|
||||
continue;
|
||||
|
||||
// We only pull this one message now but then requeue the socket so that after we check
|
||||
// all other sockets we come back to this one to check again.
|
||||
queue_index.push(i);
|
||||
|
||||
if (parts.empty()) {
|
||||
LMQ_LOG(warn, "Ignoring empty (0-part) incoming message");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!proxy_handle_builtin(i, parts))
|
||||
proxy_to_worker(i, parts);
|
||||
|
||||
if (pollitems_stale) {
|
||||
// If our items became stale then we may have just closed a connection and so our
|
||||
// queue index maybe also be stale, so restart the proxy loop (so that we rebuild
|
||||
// pollitems).
|
||||
LMQ_TRACE("pollitems became stale; short-circuiting incoming message loop");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LMQ_TRACE("done proxy loop");
|
||||
}
|
||||
}
|
||||
|
||||
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
|
||||
// reason)
|
||||
bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts) {
|
||||
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
|
||||
|
||||
string_view route, cmd;
|
||||
if (parts.size() < (outgoing ? 1 : 2)) {
|
||||
LMQ_LOG(warn, "Received empty message; ignoring");
|
||||
return true;
|
||||
}
|
||||
if (outgoing) {
|
||||
cmd = view(parts[0]);
|
||||
} else {
|
||||
route = view(parts[0]);
|
||||
cmd = view(parts[1]);
|
||||
}
|
||||
LMQ_TRACE("Checking for builtins: ", cmd, " from ", peer_address(parts.back()));
|
||||
|
||||
if (cmd == "REPLY") {
|
||||
size_t tag_pos = (outgoing ? 1 : 2);
|
||||
if (parts.size() <= tag_pos) {
|
||||
LMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
|
||||
return true;
|
||||
}
|
||||
std::string reply_tag{view(parts[tag_pos])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
LMQ_LOG(debug, "Received REPLY for pending command", to_hex(reply_tag), "; scheduling callback");
|
||||
std::vector<std::string> data;
|
||||
data.reserve(parts.size() - (tag_pos + 1));
|
||||
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
|
||||
data.emplace_back(view(*it));
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), data=std::move(data)] {
|
||||
callback(true, std::move(data));
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
return true;
|
||||
} else if (cmd == "HI") {
|
||||
if (outgoing) {
|
||||
LMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
LMQ_LOG(info, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO");
|
||||
try {
|
||||
send_routed_message(connections[conn_index], std::string{route}, "HELLO");
|
||||
} catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
|
||||
return true;
|
||||
} else if (cmd == "HELLO") {
|
||||
if (!outgoing) {
|
||||
LMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto it = std::find_if(pending_connects.begin(), pending_connects.end(),
|
||||
[&](auto& pc) { return std::get<size_t>(pc) == conn_index; });
|
||||
if (it == pending_connects.end()) {
|
||||
LMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto& pc = *it;
|
||||
auto pit = peers.find(std::get<long long>(pc));
|
||||
if (pit == peers.end()) {
|
||||
LMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring");
|
||||
return true;
|
||||
}
|
||||
|
||||
LMQ_LOG(info, "Got initial HELLO server response from ", peer_address(parts.back()));
|
||||
proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)),
|
||||
conn=conn_index_to_id[conn_index]] {
|
||||
on_success(conn);
|
||||
});
|
||||
pending_connects.erase(it);
|
||||
return true;
|
||||
} else if (cmd == "BYE") {
|
||||
if (outgoing) {
|
||||
std::string pk;
|
||||
bool sn;
|
||||
AuthLevel a;
|
||||
std::tie(pk, sn, a) = detail::extract_metadata(parts.back());
|
||||
ConnectionID conn = sn ? ConnectionID{std::move(pk)} : conn_index_to_id[conn_index];
|
||||
LMQ_LOG(info, "BYE command received; disconnecting from ", conn);
|
||||
proxy_disconnect(conn, 1s);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
else if (cmd == "FORBIDDEN" || cmd == "NOT_A_SERVICE_NODE") {
|
||||
return true; // FIXME - ignore these? Log?
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_process_queue() {
|
||||
// First up: process any batch jobs; since these are internal they are given higher priority.
|
||||
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
|
||||
|
||||
// Next any reply batch jobs (which are a bit different from the above, since they are
|
||||
// externally triggered but for things we initiated locally).
|
||||
proxy_run_batch_jobs(reply_jobs, reply_jobs_reserved, reply_jobs_active, true);
|
||||
|
||||
// Finally general incoming commands
|
||||
for (auto it = pending_commands.begin(); it != pending_commands.end() && active_workers() < max_workers; ) {
|
||||
auto& pending = *it;
|
||||
if (pending.cat.active_threads < pending.cat.reserved_threads
|
||||
|| active_workers() < general_workers) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(pending)));
|
||||
pending.cat.queued--;
|
||||
pending.cat.active_threads++;
|
||||
assert(pending.cat.queued >= 0);
|
||||
it = pending_commands.erase(it);
|
||||
} else {
|
||||
++it; // no available general or reserved worker spots for this job right now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,247 +0,0 @@
|
|||
// Copyright (c) 2019-2020, The Loki Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#ifdef __cpp_lib_string_view
|
||||
|
||||
#include <string_view>
|
||||
namespace lokimq { using string_view = std::string_view; }
|
||||
|
||||
#else
|
||||
|
||||
#include <ostream>
|
||||
#include <limits>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
/// Basic implementation of std::string_view (except for std::hash support).
|
||||
class simple_string_view {
|
||||
const char *data_;
|
||||
size_t size_;
|
||||
public:
|
||||
using traits_type = std::char_traits<char>;
|
||||
using value_type = char;
|
||||
using pointer = char*;
|
||||
using const_pointer = const char*;
|
||||
using reference = char&;
|
||||
using const_reference = const char&;
|
||||
using const_iterator = const_pointer;
|
||||
using iterator = const_iterator;
|
||||
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
|
||||
using reverse_iterator = const_reverse_iterator;
|
||||
using size_type = std::size_t;
|
||||
using different_type = std::ptrdiff_t;
|
||||
|
||||
static constexpr auto& npos = std::string::npos;
|
||||
|
||||
constexpr simple_string_view() noexcept : data_{nullptr}, size_{0} {}
|
||||
constexpr simple_string_view(const simple_string_view&) noexcept = default;
|
||||
simple_string_view(const std::string& str) : data_{str.data()}, size_{str.size()} {}
|
||||
constexpr simple_string_view(const char* data, size_t size) noexcept : data_{data}, size_{size} {}
|
||||
simple_string_view(const char* data) : data_{data}, size_{traits_type::length(data)} {}
|
||||
simple_string_view& operator=(const simple_string_view&) = default;
|
||||
constexpr const char* data() const noexcept { return data_; }
|
||||
constexpr size_t size() const noexcept { return size_; }
|
||||
constexpr size_t length() const noexcept { return size_; }
|
||||
constexpr size_t max_size() const noexcept { return std::numeric_limits<size_t>::max(); }
|
||||
constexpr bool empty() const noexcept { return size_ == 0; }
|
||||
explicit operator std::string() const { return {data_, size_}; }
|
||||
constexpr const char* begin() const noexcept { return data_; }
|
||||
constexpr const char* cbegin() const noexcept { return data_; }
|
||||
constexpr const char* end() const noexcept { return data_ + size_; }
|
||||
constexpr const char* cend() const noexcept { return data_ + size_; }
|
||||
reverse_iterator rbegin() const { return reverse_iterator{end()}; }
|
||||
reverse_iterator crbegin() const { return reverse_iterator{end()}; }
|
||||
reverse_iterator rend() const { return reverse_iterator{begin()}; }
|
||||
reverse_iterator crend() const { return reverse_iterator{begin()}; }
|
||||
constexpr const char& operator[](size_t pos) const { return data_[pos]; }
|
||||
constexpr const char& front() const { return *data_; }
|
||||
constexpr const char& back() const { return data_[size_ - 1]; }
|
||||
int compare(simple_string_view s) const;
|
||||
constexpr void remove_prefix(size_t n) { data_ += n; size_ -= n; }
|
||||
constexpr void remove_suffix(size_t n) { size_ -= n; }
|
||||
void swap(simple_string_view &s) noexcept { std::swap(data_, s.data_); std::swap(size_, s.size_); }
|
||||
|
||||
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
|
||||
constexpr // GCC 5.x is buggy wrt constexpr throwing
|
||||
#endif
|
||||
const char& at(size_t pos) const {
|
||||
if (pos >= size())
|
||||
throw std::out_of_range{"invalid string_view index"};
|
||||
return data_[pos];
|
||||
};
|
||||
|
||||
size_t copy(char* dest, size_t count, size_t pos = 0) const {
|
||||
if (pos > size()) throw std::out_of_range{"invalid copy pos"};
|
||||
size_t rcount = std::min(count, size_ - pos);
|
||||
traits_type::copy(dest, data_ + pos, rcount);
|
||||
return rcount;
|
||||
}
|
||||
|
||||
#if defined(__clang__) || !defined(__GNUG__) || __GNUC__ >= 6
|
||||
constexpr // GCC 5.x is buggy wrt constexpr throwing
|
||||
#endif
|
||||
simple_string_view substr(size_t pos = 0, size_t count = npos) const {
|
||||
if (pos > size()) throw std::out_of_range{"invalid substr range"};
|
||||
simple_string_view result = *this;
|
||||
if (pos > 0) result.remove_prefix(pos);
|
||||
if (count < result.size()) result.remove_suffix(result.size() - count);
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t find(simple_string_view v, size_t pos = 0) const {
|
||||
if (pos > size_ || v.size_ > size_) return npos;
|
||||
for (const size_t max_pos = size_ - v.size_; pos <= max_pos; ++pos) {
|
||||
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
|
||||
return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
size_t find(char c, size_t pos = 0) const { return find({&c, 1}, pos); }
|
||||
size_t find(const char* c, size_t pos, size_t count) const { return find({c, count}, pos); }
|
||||
size_t find(const char* c, size_t pos = 0) const { return find(simple_string_view(c), pos); }
|
||||
|
||||
size_t rfind(simple_string_view v, size_t pos = npos) const {
|
||||
if (v.size_ > size_) return npos;
|
||||
const size_t max_pos = size_ - v.size_;
|
||||
for (pos = std::min(pos, max_pos); pos <= max_pos; --pos) {
|
||||
if (0 == traits_type::compare(v.data_, data_ + pos, v.size_))
|
||||
return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
size_t rfind(char c, size_t pos = npos) const { return rfind({&c, 1}, pos); }
|
||||
size_t rfind(const char* c, size_t pos, size_t count) const { return rfind({c, count}, pos); }
|
||||
size_t rfind(const char* c, size_t pos = npos) const { return rfind(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_first_of(simple_string_view v, size_t pos = 0) const noexcept {
|
||||
for (; pos < size_; ++pos)
|
||||
for (char c : v)
|
||||
if (data_[pos] == c)
|
||||
return pos;
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_first_of(char c, size_t pos = 0) const noexcept { return find_first_of({&c, 1}, pos); }
|
||||
constexpr size_t find_first_of(const char* c, size_t pos, size_t count) const { return find_first_of({c, count}, pos); }
|
||||
size_t find_first_of(const char* c, size_t pos = 0) const { return find_first_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_last_of(simple_string_view v, const size_t pos = npos) const noexcept {
|
||||
if (size_ == 0) return npos;
|
||||
const size_t last_pos = std::min(pos, size_-1);
|
||||
for (size_t i = last_pos; i <= last_pos; --i)
|
||||
for (char c : v)
|
||||
if (data_[i] == c)
|
||||
return i;
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_last_of(char c, size_t pos = npos) const noexcept { return find_last_of({&c, 1}, pos); }
|
||||
constexpr size_t find_last_of(const char* c, size_t pos, size_t count) const { return find_last_of({c, count}, pos); }
|
||||
size_t find_last_of(const char* c, size_t pos = npos) const { return find_last_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_first_not_of(simple_string_view v, size_t pos = 0) const noexcept {
|
||||
for (; pos < size_; ++pos) {
|
||||
bool none = true;
|
||||
for (char c : v) {
|
||||
if (data_[pos] == c) {
|
||||
none = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (none) return pos;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_first_not_of(char c, size_t pos = 0) const noexcept { return find_first_not_of({&c, 1}, pos); }
|
||||
constexpr size_t find_first_not_of(const char* c, size_t pos, size_t count) const { return find_first_not_of({c, count}, pos); }
|
||||
size_t find_first_not_of(const char* c, size_t pos = 0) const { return find_first_not_of(simple_string_view(c), pos); }
|
||||
|
||||
constexpr size_t find_last_not_of(simple_string_view v, const size_t pos = npos) const noexcept {
|
||||
if (size_ == 0) return npos;
|
||||
const size_t last_pos = std::min(pos, size_-1);
|
||||
for (size_t i = last_pos; i <= last_pos; --i) {
|
||||
bool none = true;
|
||||
for (char c : v) {
|
||||
if (data_[i] == c) {
|
||||
none = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (none) return i;
|
||||
}
|
||||
return npos;
|
||||
}
|
||||
constexpr size_t find_last_not_of(char c, size_t pos = npos) const noexcept { return find_last_not_of({&c, 1}, pos); }
|
||||
constexpr size_t find_last_not_of(const char* c, size_t pos, size_t count) const { return find_last_not_of({c, count}, pos); }
|
||||
size_t find_last_not_of(const char* c, size_t pos = npos) const { return find_last_not_of(simple_string_view(c), pos); }
|
||||
};
|
||||
inline bool operator==(simple_string_view lhs, simple_string_view rhs) {
|
||||
return lhs.size() == rhs.size() && 0 == std::char_traits<char>::compare(lhs.data(), rhs.data(), lhs.size());
|
||||
};
|
||||
inline bool operator!=(simple_string_view lhs, simple_string_view rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
inline int simple_string_view::compare(simple_string_view s) const {
|
||||
int cmp = std::char_traits<char>::compare(data_, s.data(), std::min(size_, s.size()));
|
||||
if (cmp) return cmp;
|
||||
if (size_ < s.size()) return -1;
|
||||
else if (size_ > s.size()) return 1;
|
||||
return 0;
|
||||
}
|
||||
inline bool operator<(simple_string_view lhs, simple_string_view rhs) {
|
||||
return lhs.compare(rhs) < 0;
|
||||
};
|
||||
inline bool operator<=(simple_string_view lhs, simple_string_view rhs) {
|
||||
return lhs.compare(rhs) <= 0;
|
||||
};
|
||||
inline bool operator>(simple_string_view lhs, simple_string_view rhs) {
|
||||
return lhs.compare(rhs) > 0;
|
||||
};
|
||||
inline bool operator>=(simple_string_view lhs, simple_string_view rhs) {
|
||||
return lhs.compare(rhs) >= 0;
|
||||
};
|
||||
inline std::ostream& operator<<(std::ostream& os, const simple_string_view& s) {
|
||||
os.write(s.data(), s.size());
|
||||
return os;
|
||||
}
|
||||
|
||||
using string_view = simple_string_view;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Add a "foo"_sv literal that works exactly like the C++17 "foo"sv literal, but works with out
|
||||
// implementation in pre-C++17.
|
||||
namespace lokimq {
|
||||
inline namespace literals {
|
||||
inline string_view operator""_sv(const char* str, size_t len) { return {str, len}; }
|
||||
}
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
namespace lokimq {
|
||||
constexpr int VERSION_MAJOR = @LOKIMQ_VERSION_MAJOR@;
|
||||
constexpr int VERSION_MINOR = @LOKIMQ_VERSION_MINOR@;
|
||||
constexpr int VERSION_PATCH = @LOKIMQ_VERSION_PATCH@;
|
||||
}
|
|
@ -1,292 +0,0 @@
|
|||
#include "lokimq.h"
|
||||
#include "batch.h"
|
||||
#include "lokimq-internal.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
void LokiMQ::worker_thread(unsigned int index) {
|
||||
std::string worker_id = "w" + std::to_string(index);
|
||||
zmq::socket_t sock{context, zmq::socket_type::dealer};
|
||||
sock.setsockopt(ZMQ_ROUTING_ID, worker_id.data(), worker_id.size());
|
||||
LMQ_LOG(debug, "New worker thread ", worker_id, " started");
|
||||
sock.connect(SN_ADDR_WORKERS);
|
||||
|
||||
Message message{*this, 0};
|
||||
std::vector<zmq::message_t> parts;
|
||||
run_info& run = workers[index]; // This contains our first job, and will be updated later with subsequent jobs
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
if (run.is_batch_job) {
|
||||
if (run.batch_jobno >= 0) {
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, "#", run.batch_jobno);
|
||||
run.batch->run_job(run.batch_jobno);
|
||||
} else if (run.batch_jobno == -1) {
|
||||
LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, " completion");
|
||||
run.batch->job_completion();
|
||||
}
|
||||
} else {
|
||||
message.conn = run.conn;
|
||||
message.data.clear();
|
||||
|
||||
LMQ_TRACE("Got incoming command from ", message.conn, message.conn.route.empty() ? "(outgoing)" : "(incoming)");
|
||||
|
||||
if (run.callback->second /*is_request*/) {
|
||||
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
|
||||
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
|
||||
message.data.emplace_back(it->data<char>(), it->size());
|
||||
} else {
|
||||
for (auto& m : run.data_parts)
|
||||
message.data.emplace_back(m.data<char>(), m.size());
|
||||
}
|
||||
|
||||
LMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
|
||||
run.callback->first(message);
|
||||
}
|
||||
}
|
||||
catch (const bt_deserialize_invalid& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
|
||||
}
|
||||
catch (const mapbox::util::bad_variant_access& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
|
||||
}
|
||||
catch (const std::out_of_range& e) {
|
||||
LMQ_LOG(warn, worker_id, " deserialization failed: invalid data - required field missing (", e.what(), "); ignoring request");
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
LMQ_LOG(warn, worker_id, " caught exception when processing command: ", e.what());
|
||||
}
|
||||
catch (...) {
|
||||
LMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
// Signal that we are ready for another job and wait for it. (We do this down here
|
||||
// because our first job gets set up when the thread is started).
|
||||
detail::send_control(sock, "RAN");
|
||||
LMQ_TRACE("worker ", worker_id, " waiting for requests");
|
||||
parts.clear();
|
||||
recv_message_parts(sock, parts);
|
||||
|
||||
if (parts.size() != 1) {
|
||||
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part worker instruction");
|
||||
continue;
|
||||
}
|
||||
auto command = view(parts[0]);
|
||||
if (command == "RUN") {
|
||||
LMQ_LOG(debug, "worker ", worker_id, " running command ", run.command);
|
||||
break; // proxy has set up a command for us, go back and run it.
|
||||
} else if (command == "QUIT") {
|
||||
LMQ_LOG(debug, "worker ", worker_id, " shutting down");
|
||||
detail::send_control(sock, "QUITTING");
|
||||
sock.setsockopt<int>(ZMQ_LINGER, 1000);
|
||||
sock.close();
|
||||
return;
|
||||
} else {
|
||||
LMQ_LOG(error, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
LokiMQ::run_info& LokiMQ::get_idle_worker() {
|
||||
if (idle_workers.empty()) {
|
||||
size_t id = workers.size();
|
||||
assert(workers.capacity() > id);
|
||||
workers.emplace_back();
|
||||
auto& r = workers.back();
|
||||
r.worker_id = id;
|
||||
r.worker_routing_id = "w" + std::to_string(id);
|
||||
return r;
|
||||
}
|
||||
size_t id = idle_workers.back();
|
||||
idle_workers.pop_back();
|
||||
return workers[id];
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_worker_message(std::vector<zmq::message_t>& parts) {
|
||||
// Process messages sent by workers
|
||||
if (parts.size() != 2) {
|
||||
LMQ_LOG(error, "Received send invalid ", parts.size(), "-part message");
|
||||
return;
|
||||
}
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
LMQ_TRACE("worker message from ", route);
|
||||
assert(route.size() >= 2 && route[0] == 'w' && route[1] >= '0' && route[1] <= '9');
|
||||
string_view worker_id_str{&route[1], route.size()-1}; // Chop off the leading "w"
|
||||
unsigned int worker_id = detail::extract_unsigned(worker_id_str);
|
||||
if (!worker_id_str.empty() /* didn't consume everything */ || worker_id >= workers.size()) {
|
||||
LMQ_LOG(error, "Worker id '", route, "' is invalid, unable to process worker command");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = workers[worker_id];
|
||||
|
||||
LMQ_TRACE("received ", cmd, " command from ", route);
|
||||
if (cmd == "RAN") {
|
||||
LMQ_LOG(debug, "Worker ", route, " finished ", run.command);
|
||||
if (run.is_batch_job) {
|
||||
auto& jobs = run.is_reply_job ? reply_jobs : batch_jobs;
|
||||
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
|
||||
assert(active > 0);
|
||||
active--;
|
||||
bool clear_job = false;
|
||||
if (run.batch_jobno == -1) {
|
||||
// Returned from the completion function
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto status = run.batch->job_finished();
|
||||
if (status == detail::BatchStatus::complete) {
|
||||
jobs.emplace(run.batch, -1);
|
||||
} else if (status == detail::BatchStatus::complete_proxy) {
|
||||
try {
|
||||
run.batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
|
||||
} catch (const std::exception &e) {
|
||||
// Raise these to error levels: the caller really shouldn't be doing
|
||||
// anything non-trivial in an in-proxy completion function!
|
||||
LMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
|
||||
} catch (...) {
|
||||
LMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
|
||||
}
|
||||
clear_job = true;
|
||||
} else if (status == detail::BatchStatus::done) {
|
||||
clear_job = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (clear_job) {
|
||||
batches.erase(run.batch);
|
||||
delete run.batch;
|
||||
run.batch = nullptr;
|
||||
}
|
||||
} else {
|
||||
assert(run.cat->active_threads > 0);
|
||||
run.cat->active_threads--;
|
||||
}
|
||||
if (max_workers == 0) { // Shutting down
|
||||
LMQ_TRACE("Telling worker ", route, " to quit");
|
||||
route_control(workers_socket, route, "QUIT");
|
||||
} else {
|
||||
idle_workers.push_back(worker_id);
|
||||
}
|
||||
} else if (cmd == "QUITTING") {
|
||||
workers[worker_id].worker_thread.join();
|
||||
LMQ_LOG(debug, "Worker ", route, " exited normally");
|
||||
} else {
|
||||
LMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_run_worker(run_info& run) {
|
||||
if (!run.worker_thread.joinable())
|
||||
run.worker_thread = std::thread{&LokiMQ::worker_thread, this, run.worker_id};
|
||||
else
|
||||
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& parts) {
|
||||
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
|
||||
|
||||
peer_info tmp_peer;
|
||||
tmp_peer.conn_index = conn_index;
|
||||
if (!outgoing) tmp_peer.route = parts[0].to_string();
|
||||
peer_info* peer = nullptr;
|
||||
if (outgoing) {
|
||||
auto it = peers.find(conn_index_to_id[conn_index]);
|
||||
if (it == peers.end()) {
|
||||
LMQ_LOG(warn, "Internal error: connection index ", conn_index, " not found");
|
||||
return;
|
||||
}
|
||||
peer = &it->second;
|
||||
} else {
|
||||
std::tie(tmp_peer.pubkey, tmp_peer.service_node, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
|
||||
if (tmp_peer.service_node) {
|
||||
// It's a service node so we should have a peer_info entry; see if we can find one with
|
||||
// the same route, and if not, add one.
|
||||
auto pr = peers.equal_range(tmp_peer.pubkey);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
if (it->second.route == tmp_peer.route) {
|
||||
peer = &it->second;
|
||||
// Upgrade permissions in case we have something higher on the socket
|
||||
peer->service_node |= tmp_peer.service_node;
|
||||
if (tmp_peer.auth_level > peer->auth_level)
|
||||
peer->auth_level = tmp_peer.auth_level;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!peer) {
|
||||
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
|
||||
}
|
||||
} else {
|
||||
// Incoming, non-SN connection: we don't store a peer_info for this, so just use the
|
||||
// temporary one
|
||||
peer = &tmp_peer;
|
||||
}
|
||||
}
|
||||
|
||||
size_t command_part_index = outgoing ? 0 : 1;
|
||||
std::string command = parts[command_part_index].to_string();
|
||||
auto cat_call = get_command(command);
|
||||
|
||||
if (!cat_call.first) {
|
||||
if (outgoing)
|
||||
send_direct_message(connections[conn_index], "UNKNOWNCOMMAND", command);
|
||||
else
|
||||
send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND", command);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& category = *cat_call.first;
|
||||
|
||||
if (!proxy_check_auth(conn_index, outgoing, *peer, command, category, parts.back()))
|
||||
return;
|
||||
|
||||
// Steal any data message parts
|
||||
size_t data_part_index = command_part_index + 1;
|
||||
std::vector<zmq::message_t> data_parts;
|
||||
data_parts.reserve(parts.size() - data_part_index);
|
||||
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
|
||||
data_parts.push_back(std::move(*it));
|
||||
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free reserved or general spots, try to queue it for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
LMQ_LOG(warn, "No space to queue incoming command ", command, "; already have ", category.queued,
|
||||
"commands queued in that category (max ", category.max_queue, "); dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
LMQ_LOG(debug, "No available free workers, queuing ", command, " for later");
|
||||
ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey, std::move(tmp_peer.route)};
|
||||
pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second, std::move(conn));
|
||||
category.queued++;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cat_call.second->second /*is_request*/ && data_parts.empty()) {
|
||||
LMQ_LOG(warn, "Received an invalid request command with no reply tag; dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = get_idle_worker();
|
||||
{
|
||||
ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey};
|
||||
c.route = std::move(tmp_peer.route);
|
||||
if (outgoing || peer->service_node)
|
||||
tmp_peer.route.clear();
|
||||
run.load(&category, std::move(command), std::move(c), std::move(data_parts), cat_call.second);
|
||||
}
|
||||
|
||||
if (outgoing)
|
||||
peer->activity(); // outgoing connection activity, pump the activity timer
|
||||
|
||||
LMQ_TRACE("Forwarding incoming ", run.command, " from ", run.conn, " @ ", peer_address(parts[command_part_index]),
|
||||
" to worker ", run.worker_routing_id);
|
||||
|
||||
proxy_run_worker(run);
|
||||
category.active_threads++;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
Subproject commit c94634bbd294204c9ba3f5b267a39582a52e8e5a
|
1
oxen-encoding
Submodule
1
oxen-encoding
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit d6f300d7d250ae0a9708090c0011c0f495377e6a
|
351
oxenmq/address.cpp
Normal file
351
oxenmq/address.cpp
Normal file
|
@ -0,0 +1,351 @@
|
|||
#include "address.h"
|
||||
#include <tuple>
|
||||
#include <limits>
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <ostream>
|
||||
#include <oxenc/hex.h>
|
||||
#include <oxenc/base32z.h>
|
||||
#include <oxenc/base64.h>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
constexpr size_t enc_length(address::encoding enc) {
|
||||
return enc == address::encoding::hex ? 64 :
|
||||
enc == address::encoding::base64 ? 43 : // this can be 44 with a padding byte, but we don't need it
|
||||
52 /*base32z*/;
|
||||
};
|
||||
|
||||
// Parses an encoding pubkey from the given string_view. Advanced the string_view beyond the
|
||||
// consumed pubkey data, and returns the pubkey (as a 32-byte string). Throws if no valid pubkey
|
||||
// was found at the beginning of addr. We look for hex, base32z, or base64 pubkeys *unless* qr is
|
||||
// given: for QR-friendly we only accept hex or base32z (since QR cannot handle base64's alphabet).
|
||||
std::string decode_pubkey(std::string_view& in, bool qr) {
|
||||
std::string pubkey;
|
||||
if (in.size() >= 64 && oxenc::is_hex(in.substr(0, 64))) {
|
||||
pubkey = oxenc::from_hex(in.substr(0, 64));
|
||||
in.remove_prefix(64);
|
||||
} else if (in.size() >= 52 && oxenc::is_base32z(in.substr(0, 52))) {
|
||||
pubkey = oxenc::from_base32z(in.substr(0, 52));
|
||||
in.remove_prefix(52);
|
||||
} else if (!qr && in.size() >= 43 && oxenc::is_base64(in.substr(0, 43))) {
|
||||
pubkey = oxenc::from_base64(in.substr(0, 43));
|
||||
in.remove_prefix(43);
|
||||
if (!in.empty() && in.front() == '=')
|
||||
in.remove_prefix(1); // allow (and eat) a padding byte at the end
|
||||
} else {
|
||||
throw std::invalid_argument{"No pubkey found"};
|
||||
}
|
||||
return pubkey;
|
||||
}
|
||||
|
||||
// Parse the host, port, and optionally pubkey from a string view, mutating it to remove the parsed
|
||||
// sections. qr should be true if we should accept $IPv6$ as a QR-encoding-friendly alternative to
|
||||
// [IPv6] (the returned host will have the $ replaced, i.e. [IPv6]).
|
||||
std::tuple<std::string, uint16_t, std::string> parse_tcp(std::string_view& addr, bool qr, bool expect_pubkey) {
|
||||
std::tuple<std::string, uint16_t, std::string> result;
|
||||
auto& host = std::get<0>(result);
|
||||
if (addr.front() == '[' || (qr && addr.front() == '$')) { // IPv6 addr (though this is far from complete validation)
|
||||
auto pos = addr.find_first_not_of(":.1234567890abcdefABCDEF", 1);
|
||||
if (pos == std::string_view::npos)
|
||||
throw std::invalid_argument("Could not find terminating ] while parsing an IPv6 address");
|
||||
if (!(addr[pos] == ']' || (qr && addr[pos] == '$')))
|
||||
throw std::invalid_argument{"Expected " + (qr ? "$"s : "]"s) + " to close IPv6 address but found " + std::string(1, addr[pos])};
|
||||
host = std::string{addr.substr(0, pos+1)};
|
||||
if (qr) {
|
||||
if (host.front() == '$')
|
||||
host.front() = '[';
|
||||
if (host.back() == '$')
|
||||
host.back() = ']';
|
||||
}
|
||||
addr.remove_prefix(pos+1);
|
||||
} else {
|
||||
auto port_pos = addr.find(':');
|
||||
if (port_pos == std::string_view::npos)
|
||||
throw std::invalid_argument{"Could not determine host (no following ':port' found)"};
|
||||
if (port_pos == 0)
|
||||
throw std::invalid_argument{"Host cannot be empty"};
|
||||
host = std::string{addr.substr(0, port_pos)};
|
||||
addr.remove_prefix(port_pos);
|
||||
}
|
||||
|
||||
if (qr)
|
||||
// Lower-case the host because upper case hostnames are ugly
|
||||
for (char& c : host)
|
||||
if (c >= 'A' && c <= 'Z')
|
||||
c = c - 'A' + 'a';
|
||||
|
||||
if (addr.size() < 2 || addr[0] != ':')
|
||||
throw std::invalid_argument{"Could not find :port in address string"};
|
||||
addr.remove_prefix(1);
|
||||
auto pos = addr.find_first_not_of("1234567890");
|
||||
if (pos == 0)
|
||||
throw std::invalid_argument{"Could not find numeric port in address string"};
|
||||
if (pos == std::string_view::npos)
|
||||
pos = addr.size();
|
||||
size_t processed;
|
||||
int port_int = std::stoi(std::string{addr.substr(0, pos)}, &processed);
|
||||
if (port_int == 0 || processed != pos)
|
||||
throw std::invalid_argument{"Could not parse numeric port in address string"};
|
||||
if (port_int < 0 || port_int > std::numeric_limits<uint16_t>::max())
|
||||
throw std::invalid_argument{"Invalid port: port must be in range 1-65535"};
|
||||
std::get<1>(result) = static_cast<uint16_t>(port_int);
|
||||
addr.remove_prefix(pos);
|
||||
|
||||
if (expect_pubkey) {
|
||||
if (addr.size() < 1 + enc_length(qr ? address::encoding::base32z : address::encoding::base64)
|
||||
|| addr.front() != '/')
|
||||
throw std::invalid_argument{"Invalid address: expected /PUBKEY after port"};
|
||||
addr.remove_prefix(1);
|
||||
|
||||
std::get<2>(result) = decode_pubkey(addr, qr);
|
||||
if (!addr.empty())
|
||||
throw std::invalid_argument{"Invalid address: found unexpected trailing data after pubkey"};
|
||||
} else if (!addr.empty()) {
|
||||
throw std::invalid_argument{"Invalid address: found unexpected trailing data after port"};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Parse the socket path and (possibly) pubkey, mutating it to remove the parsed sections.
|
||||
// Currently the /pubkey *must* be at the end of the string, but this might not always be the case
|
||||
// (e.g. we could in the future support query string-like arguments).
|
||||
std::pair<std::string, std::string> parse_unix(std::string_view& addr, bool expect_pubkey) {
|
||||
std::pair<std::string, std::string> result;
|
||||
if (expect_pubkey) {
|
||||
size_t b64_len = addr.size() > 0 && addr.back() == '=' ? 44 : 43;
|
||||
if (addr.size() > 64 && addr[addr.size() - 65] == '/' && oxenc::is_hex(addr.substr(addr.size() - 64))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - 65)};
|
||||
result.second = oxenc::from_hex(addr.substr(addr.size() - 64));
|
||||
} else if (addr.size() > 52 && addr[addr.size() - 53] == '/' && oxenc::is_base32z(addr.substr(addr.size() - 52))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - 53)};
|
||||
result.second = oxenc::from_base32z(addr.substr(addr.size() - 52));
|
||||
} else if (addr.size() > b64_len && addr[addr.size() - b64_len - 1] == '/' && oxenc::is_base64(addr.substr(addr.size() - b64_len))) {
|
||||
result.first = std::string{addr.substr(0, addr.size() - b64_len - 1)};
|
||||
result.second = oxenc::from_base64(addr.substr(addr.size() - b64_len));
|
||||
} else {
|
||||
throw std::invalid_argument{"icp+curve:// requires a trailing /PUBKEY value, got: " + std::string{addr}};
|
||||
}
|
||||
} else {
|
||||
// Anything goes
|
||||
result.first = std::string{addr};
|
||||
}
|
||||
|
||||
// Any path above consumes everything:
|
||||
addr.remove_prefix(addr.size());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
address::address(std::string_view addr) {
|
||||
auto protoend = addr.find("://"sv);
|
||||
if (protoend == std::string_view::npos || protoend == 0)
|
||||
throw std::invalid_argument("Invalid address: no protocol found");
|
||||
auto pro = addr.substr(0, protoend);
|
||||
addr.remove_prefix(protoend + 3);
|
||||
if (addr.empty())
|
||||
throw std::invalid_argument("Invalid address: no value specified after protocol");
|
||||
bool qr = false;
|
||||
if (pro == "tcp") protocol = proto::tcp;
|
||||
else if (pro == "tcp+curve" || pro == "curve") protocol = proto::tcp_curve;
|
||||
else if (pro == "ipc") protocol = proto::ipc;
|
||||
else if (pro == "ipc+curve") protocol = proto::ipc_curve;
|
||||
else if (pro == "TCP") {
|
||||
protocol = proto::tcp;
|
||||
qr = true;
|
||||
} else if (pro == "CURVE") {
|
||||
protocol = proto::tcp_curve;
|
||||
qr = true;
|
||||
} else {
|
||||
throw std::invalid_argument("Invalid protocol '" + std::string{pro} + "'");
|
||||
}
|
||||
|
||||
if (qr) {
|
||||
// The QR variations only allow QR-alphanumeric characters (upper-case letters, numbers, and
|
||||
// a few symbols):
|
||||
for (char c : addr) {
|
||||
// QR alphanumeric also allows space, %, *, +, but we don't need or allow any of those here.
|
||||
if (!((c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '$' || c == ':' || c == '/' || c == '.' || c == '-'))
|
||||
throw std::invalid_argument("Found non-QR-alphanumeric value in QR TCP:// or CURVE:// address");
|
||||
}
|
||||
}
|
||||
|
||||
if (tcp())
|
||||
std::tie(host, port, pubkey) = parse_tcp(addr, qr, curve());
|
||||
else
|
||||
std::tie(socket, pubkey) = parse_unix(addr, curve());
|
||||
|
||||
if (!addr.empty())
|
||||
throw std::invalid_argument{"Invalid trailing garbage '" + std::string{addr} + "' in address"};
|
||||
}
|
||||
|
||||
address& address::set_pubkey(std::string_view pk) {
|
||||
if (pk.size() == 0) {
|
||||
if (protocol == proto::tcp_curve) protocol = proto::tcp;
|
||||
else if (protocol == proto::ipc_curve) protocol = proto::ipc;
|
||||
} else if (pk.size() == 32) {
|
||||
if (protocol == proto::tcp) protocol = proto::tcp_curve;
|
||||
else if (protocol == proto::ipc) protocol = proto::ipc_curve;
|
||||
} else {
|
||||
throw std::invalid_argument{"Invalid pubkey passed to set_pubkey(): require 0- or 32-byte pubkey"};
|
||||
}
|
||||
pubkey = pk;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string address::encode_pubkey(encoding enc) const {
|
||||
std::string pk;
|
||||
if (enc == encoding::hex)
|
||||
pk = oxenc::to_hex(pubkey);
|
||||
else if (enc == encoding::base32z)
|
||||
pk = oxenc::to_base32z(pubkey);
|
||||
else if (enc == encoding::BASE32Z) {
|
||||
pk = oxenc::to_base32z(pubkey);
|
||||
for (char& c : pk)
|
||||
if (c >= 'a' && c <= 'z')
|
||||
c = c - 'a' + 'A';
|
||||
} else if (enc == encoding::base64) {
|
||||
pk = oxenc::to_base64(pubkey);
|
||||
if (pk.size() == 44 && pk.back() == '=')
|
||||
pk.resize(43);
|
||||
} else {
|
||||
throw std::logic_error{"Invalid encoding"};
|
||||
}
|
||||
return pk;
|
||||
}
|
||||
|
||||
std::string address::full_address(encoding enc) const {
|
||||
std::string result;
|
||||
std::string pk;
|
||||
if (curve())
|
||||
pk = encode_pubkey(enc);
|
||||
|
||||
if (protocol == proto::tcp) {
|
||||
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
|
||||
result += "tcp://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
} else if (protocol == proto::tcp_curve) {
|
||||
result.reserve(8 /*curve:// */ + host.size() + 6 /*:port*/ + 1 /* / */ + pk.size());
|
||||
result += "curve://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
result += '/';
|
||||
result += pk;
|
||||
} else if (protocol == proto::ipc) {
|
||||
result.reserve(6 /*ipc:// */ + socket.size());
|
||||
result += "ipc://";
|
||||
result += socket;
|
||||
} else if (protocol == proto::ipc_curve) {
|
||||
result.reserve(12 /*ipc+curve:// */ + socket.size() + 1 /* / */ + pk.size());
|
||||
result += "ipc+curve://";
|
||||
result += socket;
|
||||
result += '/';
|
||||
result += pk;
|
||||
} else {
|
||||
throw std::logic_error{"Invalid protocol"};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string address::zmq_address() const {
|
||||
std::string result;
|
||||
if (tcp()) {
|
||||
result.reserve(6 /*tcp:// */ + host.size() + 6 /*:port*/);
|
||||
result += "tcp://";
|
||||
result += host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
} else {
|
||||
result.reserve(6 /*ipc:// */ + socket.size());
|
||||
result += "ipc://";
|
||||
result += socket;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string address::qr_address() const {
|
||||
if (protocol != proto::tcp && protocol != proto::tcp_curve)
|
||||
throw std::logic_error("Cannot construct a QR-friendly address for a non-TCP address");
|
||||
if (host.empty())
|
||||
throw std::logic_error("Cannot construct a QR-friendly address with an empty TCP host");
|
||||
std::string result;
|
||||
result.reserve((curve() ? 8 /*CURVE:// */ : 6 /*TCP:// */) + host.size() + 6 /*:port*/ +
|
||||
(curve() ? 1 + enc_length(encoding::BASE32Z) : 0));
|
||||
result += curve() ? "CURVE://" : "TCP://";
|
||||
std::string uc_host = host;
|
||||
for (auto& c : uc_host)
|
||||
if (c >= 'a' && c <= 'z')
|
||||
c = c - 'a' + 'A';
|
||||
|
||||
if (uc_host.front() == '[' && uc_host.back() == ']') {
|
||||
uc_host.front() = '$';
|
||||
uc_host.back() = '$';
|
||||
}
|
||||
result += uc_host;
|
||||
result += ':';
|
||||
result += std::to_string(port);
|
||||
|
||||
if (curve()) {
|
||||
result += '/';
|
||||
result += encode_pubkey(encoding::BASE32Z);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool address::operator==(const address& other) const {
|
||||
if (protocol != other.protocol)
|
||||
return false;
|
||||
if (tcp())
|
||||
if (host != other.host || port != other.port)
|
||||
return false;
|
||||
if (ipc())
|
||||
if (socket != other.socket)
|
||||
return false;
|
||||
if (curve())
|
||||
if (pubkey != other.pubkey)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
address address::tcp(std::string host, uint16_t port) {
|
||||
address a;
|
||||
a.protocol = proto::tcp;
|
||||
a.host = std::move(host);
|
||||
a.port = port;
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::tcp_curve(std::string host, uint16_t port, std::string pubkey) {
|
||||
address a;
|
||||
a.protocol = proto::tcp_curve;
|
||||
a.host = std::move(host);
|
||||
a.port = port;
|
||||
a.pubkey = std::move(pubkey);
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::ipc(std::string path) {
|
||||
address a;
|
||||
a.protocol = proto::ipc;
|
||||
a.socket = std::move(path);
|
||||
return a;
|
||||
}
|
||||
|
||||
address address::ipc_curve(std::string path, std::string pubkey) {
|
||||
address a;
|
||||
a.protocol = proto::ipc_curve;
|
||||
a.socket = std::move(path);
|
||||
a.pubkey = std::move(pubkey);
|
||||
return a;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, const address& a) { return o << a.full_address(); }
|
||||
|
||||
}
|
218
oxenmq/address.h
Normal file
218
oxenmq/address.h
Normal file
|
@ -0,0 +1,218 @@
|
|||
// Copyright (c) 2020-2021, The Oxen Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification, are
|
||||
// permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice, this list of
|
||||
// conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice, this list
|
||||
// of conditions and the following disclaimer in the documentation and/or other
|
||||
// materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its contributors may be
|
||||
// used to endorse or promote products derived from this software without specific
|
||||
// prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
|
||||
// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <cstdint>
|
||||
#include <iosfwd>
|
||||
#include <functional>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
/** OxenMQ address abstraction class. This class uses and extends standard ZMQ addresses allowing
|
||||
* extra parameters to be passed in in a relative standard way.
|
||||
*
|
||||
* External ZMQ addresses generally have two forms that we are concerned with: one for TCP and one
|
||||
* for Unix sockets:
|
||||
*
|
||||
* tcp://HOST:PORT -- HOST can be a hostname, IPv4 address, or IPv6 address in [...]
|
||||
* ipc://PATH -- PATH can be absolute (ipc:///path/to/some.sock) or relative (ipc://some.sock)
|
||||
*
|
||||
* but this doesn't carry enough info: in particular, we can connect with two very different
|
||||
* protocols: curve25519-encrypted, or plaintext, but for curve25519-encrypted we require the
|
||||
* remote's public key as well to verify the connection.
|
||||
*
|
||||
* This class, then, handles this by allowing addresses of:
|
||||
*
|
||||
* Standard ZMQ address: these carry no pubkey and so the connection will be unencrypted:
|
||||
*
|
||||
* tcp://HOSTNAME:PORT
|
||||
* ipc://PATH
|
||||
*
|
||||
* Non-ZMQ address formats that specify that the connection shall be x25519 encrypted:
|
||||
*
|
||||
* curve://HOSTNAME:PORT/PUBKEY -- PUBKEY must be specified in hex (64 characters), base32z (52)
|
||||
* or base64 (43 or 44 with one '=' trailing padding)
|
||||
* ipc+curve:///path/to/my.sock/PUBKEY -- same requirements on PUBKEY as above.
|
||||
* tcp+curve://(whatever) -- alias for curve://(whatever)
|
||||
*
|
||||
* We also accept special upper-case TCP-only variants which *only* accept uppercase characters and
|
||||
* a few required symbols (:, /, $, ., and -) in the string:
|
||||
*
|
||||
* TCP://HOSTNAME:PORT
|
||||
* CURVE://HOSTNAME:PORT/B32ZPUBKEY
|
||||
*
|
||||
* These versions are explicitly meant to be used with QR codes; the upper-case-only requirement
|
||||
* allows a smaller QR code by allowing QR's alphanumeric mode (which allows only [A-Z0-9 $%*+./:-])
|
||||
* to be used. Such a QR-friendly address can be created from the qr_address() method. To support
|
||||
* literal IPv6 addresses we surround the address with $...$ instead of the usual [...].
|
||||
*
|
||||
* Note that this class does very little validate the host argument at all, and no socket path
|
||||
* validation whatsoever. The only constraint on host is when parsing an encoded address: we check
|
||||
* that it contains no : at all, or must be a [bracketed] expression that contains only hex
|
||||
* characters, :'s, or .'s. Otherwise, if you pass broken crap into the hostname, expect broken
|
||||
* crap out.
|
||||
*/
|
||||
struct address {
|
||||
/// Supported address protocols: TCP connections (tcp), or unix sockets (ipc).
|
||||
enum class proto {
|
||||
tcp,
|
||||
tcp_curve,
|
||||
ipc,
|
||||
ipc_curve
|
||||
};
|
||||
/// Supported public key encodings (used when regenerating an augmented address).
|
||||
enum class encoding {
|
||||
hex, ///< hexadecimal encoded
|
||||
base32z, ///< base32z encoded
|
||||
base64, ///< base64 encoded (*without* trailing = padding)
|
||||
BASE32Z ///< upper-case base32z encoding, meant for QR encoding
|
||||
};
|
||||
|
||||
/// The protocol: one of the `protocol` enum values for tcp or ipc (unix sockets), with or
|
||||
/// without _curve encryption.
|
||||
proto protocol = proto::tcp;
|
||||
/// The host for tcp connections; can be a hostname or IP address. If this is an IPv6 it must be surrounded with [ ].
|
||||
std::string host;
|
||||
/// The port (for tcp connections)
|
||||
uint16_t port = 0;
|
||||
/// The socket path (for unix socket connections)
|
||||
std::string socket;
|
||||
/// If a curve connection, this is the required remote public key (in bytes)
|
||||
std::string pubkey;
|
||||
|
||||
/// Default constructor; this gives you an unusable address.
|
||||
address() = default;
|
||||
|
||||
/**
|
||||
* Constructs an address by parsing a string_view containing one of the formats listed in the
|
||||
* class description. This is intentionally implicitly constructible so that you can pass a
|
||||
* string_view into anything expecting an `address`.
|
||||
*
|
||||
* Throw std::invalid_argument if the given address is not parseable.
|
||||
*/
|
||||
address(std::string_view addr);
|
||||
|
||||
/** Constructs an address from a remote string and a separate pubkey. Typically `remote` is a
|
||||
* basic ZMQ connect string, though this is not enforced. Any pubkey information embedded in
|
||||
* the remote string will be discarded and replaced with the given pubkey string. The result
|
||||
* will be curve encrypted if `pubkey` is non-empty, plaintext if `pubkey` is empty.
|
||||
*
|
||||
* Throws an exception if either addr or pubkey is invalid.
|
||||
*
|
||||
* Exactly equivalent to `address a{remote}; a.set_pubkey(pubkey);`
|
||||
*/
|
||||
address(std::string_view addr, std::string_view pubkey) : address(addr) { set_pubkey(pubkey); }
|
||||
|
||||
/// Replaces the address's pubkey (if any) with the given pubkey (or no pubkey if empty). If
|
||||
/// changing from pubkey to no-pubkey or no-pubkey to pubkey then the protocol is update to
|
||||
/// switch to or from curve encryption.
|
||||
///
|
||||
/// pubkey should be the 32-byte binary pubkey, or an empty string to remove an existing pubkey.
|
||||
///
|
||||
/// Returns the object itself, so that you can chain it.
|
||||
address& set_pubkey(std::string_view pubkey);
|
||||
|
||||
/// Constructs and builds the ZMQ connection address from the stored connection details. This
|
||||
/// does not contain any of the curve-related details; those must be specified separately when
|
||||
/// interfacing with ZMQ.
|
||||
std::string zmq_address() const;
|
||||
|
||||
/// Returns true if the connection was specified as a curve-encryption-enabled connection, false
|
||||
/// otherwise.
|
||||
bool curve() const { return protocol == proto::tcp_curve || protocol == proto::ipc_curve; }
|
||||
|
||||
/// True if the protocol is TCP (either with or without curve)
|
||||
bool tcp() const { return protocol == proto::tcp || protocol == proto::tcp_curve; }
|
||||
|
||||
/// True if the protocol is unix socket (either with or without curve)
|
||||
bool ipc() const { return !tcp(); }
|
||||
|
||||
/// Returns the full "augmented" address string (i.e. that could be passed in to the
|
||||
/// constructor). This will be equivalent (but not necessarily identical) to an augmented
|
||||
/// string passed into the constructor. Takes an optional encoding format for the pubkey (if
|
||||
/// any), which defaults to base32z.
|
||||
std::string full_address(encoding enc = encoding::base32z) const;
|
||||
|
||||
/// Returns a QR-code friendly address string. This returns an all-uppercase version of the
|
||||
/// address with "TCP://" or "CURVE://" for the protocol string, and uses upper-case base32z
|
||||
/// encoding for the pubkey (for curve addresses). For literal IPv6 addresses we surround the
|
||||
/// address with $ instead of [...]
|
||||
///
|
||||
/// \throws std::logic_error if called on a unix socket address.
|
||||
std::string qr_address() const;
|
||||
|
||||
/// Returns `.pubkey` but encoded in the given format
|
||||
std::string encode_pubkey(encoding enc) const;
|
||||
|
||||
/// Returns true if two addresses are identical (i.e. same protocol and relevant protocol
|
||||
/// arguments).
|
||||
///
|
||||
/// Note that it is possible for addresses to connect to the same socket without being
|
||||
/// identical: for example, using "foo.sock" and "./foo.sock", or writing IPv6 addresses (or
|
||||
/// even IPv4 addresses) in slightly different ways). Such equivalent but non-equal values will
|
||||
/// result in a false return here.
|
||||
///
|
||||
/// Note also that we ignore irrelevant arguments: for example, we don't care whether pubkeys
|
||||
/// match when comparing two non-curve TCP addresses.
|
||||
bool operator==(const address& other) const;
|
||||
/// Negation of ==
|
||||
bool operator!=(const address& other) const { return !operator==(other); }
|
||||
|
||||
/// Factory function that constructs a TCP address from a host and port. The connection will be
|
||||
/// plaintext. If the host is an IPv6 address it *must* be surrounded with [ and ].
|
||||
static address tcp(std::string host, uint16_t port);
|
||||
|
||||
/// Factory function that constructs a curve-encrypted TCP address from a host, port, and remote
|
||||
/// pubkey. The pubkey must be 32 bytes. As above, IPv6 addresses must be specified as [addr].
|
||||
static address tcp_curve(std::string host, uint16_t, std::string pubkey);
|
||||
|
||||
/// Factory function that constructs a unix socket address from a path. The connection will be
|
||||
/// plaintext (which is usually fine for a socket since unix sockets are local machine).
|
||||
static address ipc(std::string path);
|
||||
|
||||
/// Factory function that constructs a unix socket address from a path and remote pubkey. The
|
||||
/// connection will be curve25519 encrypted; the remote pubkey must be 32 bytes.
|
||||
static address ipc_curve(std::string path, std::string pubkey);
|
||||
};
|
||||
|
||||
// Outputs address.full_address() when sent to an ostream.
|
||||
std::ostream& operator<<(std::ostream& o, const address& a);
|
||||
|
||||
} // namespace oxenmq
|
||||
|
||||
namespace std {
|
||||
template<> struct hash<oxenmq::address> {
|
||||
std::size_t operator()(const oxenmq::address& a) const noexcept {
|
||||
return std::hash<std::string>{}(a.full_address(oxenmq::address::encoding::hex));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
318
oxenmq/auth.cpp
Normal file
318
oxenmq/auth.cpp
Normal file
|
@ -0,0 +1,318 @@
|
|||
#include "oxenmq.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include "oxenmq-internal.h"
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, AuthLevel a) {
|
||||
return o << to_string(a);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Builds a ZMTP metadata key-value pair. These will be available on every message from that peer.
|
||||
// Keys must start with X- and be <= 255 characters.
|
||||
std::string zmtp_metadata(std::string_view key, std::string_view value) {
|
||||
assert(key.size() > 2 && key.size() <= 255 && key[0] == 'X' && key[1] == '-');
|
||||
|
||||
std::string result;
|
||||
result.reserve(1 + key.size() + 4 + value.size());
|
||||
result += static_cast<char>(key.size()); // Size octet of key
|
||||
result.append(&key[0], key.size()); // key data
|
||||
for (int i = 24; i >= 0; i -= 8) // 4-byte size of value in network order
|
||||
result += static_cast<char>((value.size() >> i) & 0xff);
|
||||
result.append(&value[0], value.size()); // value data
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
bool OxenMQ::proxy_check_auth(int64_t conn_id, bool outgoing, const peer_info& peer,
|
||||
zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) {
|
||||
auto command = view(cmd);
|
||||
std::string reply;
|
||||
|
||||
if (!cat_call.first) {
|
||||
OMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
reply = "UNKNOWNCOMMAND";
|
||||
} else if (peer.auth_level < cat_call.first->access.auth) {
|
||||
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": peer auth level ", peer.auth_level, " < ", cat_call.first->access.auth);
|
||||
reply = "FORBIDDEN";
|
||||
} else if (cat_call.first->access.local_sn && !local_service_node) {
|
||||
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": that command is only available when this OxenMQ is running in service node mode");
|
||||
reply = "NOT_A_SERVICE_NODE";
|
||||
} else if (cat_call.first->access.remote_sn && !peer.service_node) {
|
||||
OMQ_LOG(warn, "Access denied to ", command, " for peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": remote is not recognized as a service node");
|
||||
reply = "FORBIDDEN_SN";
|
||||
} else if (cat_call.second->second /*is_request*/ && data.empty()) {
|
||||
OMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [",
|
||||
oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
reply = "NO_REPLY_TAG";
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<zmq::message_t> msgs;
|
||||
msgs.reserve(4);
|
||||
if (!outgoing)
|
||||
msgs.push_back(create_message(peer.route));
|
||||
msgs.push_back(create_message(reply));
|
||||
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
|
||||
msgs.push_back(create_message("REPLY"sv));
|
||||
msgs.push_back(create_message(view(data.front()))); // reply tag
|
||||
} else {
|
||||
msgs.push_back(create_message(view(cmd)));
|
||||
}
|
||||
|
||||
try {
|
||||
send_message_parts(connections.at(conn_id), msgs);
|
||||
} catch (const zmq::error_t& err) {
|
||||
/* can't send: possibly already disconnected. Ignore. */
|
||||
OMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", oxenc::to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void OxenMQ::set_active_sns(pubkey_set pubkeys) {
|
||||
if (proxy_thread.joinable()) {
|
||||
auto data = oxenc::bt_serialize(detail::serialize_object(std::move(pubkeys)));
|
||||
detail::send_control(get_control_socket(), "SET_SNS", data);
|
||||
} else {
|
||||
proxy_set_active_sns(std::move(pubkeys));
|
||||
}
|
||||
}
|
||||
void OxenMQ::proxy_set_active_sns(std::string_view data) {
|
||||
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(oxenc::bt_deserialize<uintptr_t>(data)));
|
||||
}
|
||||
void OxenMQ::proxy_set_active_sns(pubkey_set pubkeys) {
|
||||
pubkey_set added, removed;
|
||||
for (auto it = pubkeys.begin(); it != pubkeys.end(); ) {
|
||||
auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to set_active_sns");
|
||||
it = pubkeys.erase(it);
|
||||
continue;
|
||||
}
|
||||
if (!active_service_nodes.count(pk))
|
||||
added.insert(std::move(pk));
|
||||
++it;
|
||||
}
|
||||
if (added.empty() && active_service_nodes.size() == pubkeys.size()) {
|
||||
OMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update");
|
||||
return;
|
||||
}
|
||||
for (const auto& pk : active_service_nodes) {
|
||||
if (!pubkeys.count(pk))
|
||||
removed.insert(pk);
|
||||
if (active_service_nodes.size() + added.size() - removed.size() == pubkeys.size())
|
||||
break;
|
||||
}
|
||||
proxy_update_active_sns_clean(std::move(added), std::move(removed));
|
||||
}
|
||||
|
||||
void OxenMQ::update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
if (proxy_thread.joinable()) {
|
||||
std::array<uintptr_t, 2> data;
|
||||
data[0] = detail::serialize_object(std::move(added));
|
||||
data[1] = detail::serialize_object(std::move(removed));
|
||||
detail::send_control(get_control_socket(), "UPDATE_SNS", oxenc::bt_serialize(data));
|
||||
} else {
|
||||
proxy_update_active_sns(std::move(added), std::move(removed));
|
||||
}
|
||||
}
|
||||
void OxenMQ::proxy_update_active_sns(oxenc::bt_list_consumer data) {
|
||||
auto added = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
|
||||
auto remed = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
|
||||
proxy_update_active_sns(std::move(added), std::move(remed));
|
||||
}
|
||||
void OxenMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
// We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting
|
||||
// values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists
|
||||
// to the _clean version.
|
||||
|
||||
for (auto it = removed.begin(); it != removed.end(); ) {
|
||||
const auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to update_active_sns (removed)");
|
||||
it = removed.erase(it);
|
||||
} else if (!active_service_nodes.count(pk) || added.count(pk) /* added wins if in both */) {
|
||||
it = removed.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it = added.begin(); it != added.end(); ) {
|
||||
const auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
OMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", oxenc::to_hex(pk), ") passed to update_active_sns (added)");
|
||||
it = added.erase(it);
|
||||
} else if (active_service_nodes.count(pk)) {
|
||||
it = added.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
proxy_update_active_sns_clean(std::move(added), std::move(removed));
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) {
|
||||
OMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys");
|
||||
|
||||
// For anything we remove we want close the connection to the SN (if outgoing), and remove the
|
||||
// stored peer_info (incoming or outgoing).
|
||||
for (const auto& pk : removed) {
|
||||
ConnectionID c{pk};
|
||||
active_service_nodes.erase(pk);
|
||||
auto range = peers.equal_range(c);
|
||||
for (auto it = range.first; it != range.second; ) {
|
||||
bool outgoing = it->second.outgoing();
|
||||
auto conn_id = it->second.conn_id;
|
||||
it = peers.erase(it);
|
||||
if (outgoing) {
|
||||
OMQ_LOG(debug, "Closing outgoing connection to ", c);
|
||||
proxy_close_connection(conn_id, CLOSE_LINGER);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For pubkeys we add there's nothing special to be done beyond adding them to the pubkey set
|
||||
for (auto& pk : added)
|
||||
active_service_nodes.insert(std::move(pk));
|
||||
}
|
||||
|
||||
void OxenMQ::process_zap_requests() {
|
||||
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
|
||||
#ifndef NDEBUG
|
||||
if (log_level() >= LogLevel::trace) {
|
||||
std::ostringstream o;
|
||||
o << "Processing ZAP authentication request:";
|
||||
for (size_t i = 0; i < frames.size(); i++) {
|
||||
o << "\n[" << i << "]: ";
|
||||
auto v = view(frames[i]);
|
||||
if (i == 1 || i == 6)
|
||||
o << oxenc::to_hex(v);
|
||||
else
|
||||
o << v;
|
||||
}
|
||||
log(LogLevel::trace, __FILE__, __LINE__, o.str());
|
||||
} else
|
||||
#endif
|
||||
OMQ_LOG(debug, "Processing ZAP authentication request");
|
||||
|
||||
// https://rfc.zeromq.org/spec:27/ZAP/
|
||||
//
|
||||
// The request message SHALL consist of the following message frames:
|
||||
//
|
||||
// The version frame, which SHALL contain the three octets "1.0".
|
||||
// The request id, which MAY contain an opaque binary blob.
|
||||
// The domain, which SHALL contain a (non-empty) string.
|
||||
// The address, the origin network IP address.
|
||||
// The identity, the connection Identity, if any.
|
||||
// The mechanism, which SHALL contain a string.
|
||||
// The credentials, which SHALL be zero or more opaque frames.
|
||||
//
|
||||
// The reply message SHALL consist of the following message frames:
|
||||
//
|
||||
// The version frame, which SHALL contain the three octets "1.0".
|
||||
// The request id, which MAY contain an opaque binary blob.
|
||||
// The status code, which SHALL contain a string.
|
||||
// The status text, which MAY contain a string.
|
||||
// The user id, which SHALL contain a string.
|
||||
// The metadata, which MAY contain a blob.
|
||||
//
|
||||
// (NB: there are also null address delimiters at the beginning of each mentioned in the
|
||||
// RFC, but those have already been removed through the use of a REP socket)
|
||||
|
||||
std::vector<std::string> response_vals(6);
|
||||
response_vals[0] = "1.0"; // version
|
||||
if (frames.size() >= 2)
|
||||
response_vals[1] = std::string{view(frames[1])}; // unique identifier
|
||||
std::string &status_code = response_vals[2], &status_text = response_vals[3];
|
||||
|
||||
if (frames.size() < 6 || view(frames[0]) != "1.0") {
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: version != 1.0 or invalid ZAP message parts");
|
||||
status_code = "500";
|
||||
status_text = "Internal error: invalid auth request";
|
||||
} else {
|
||||
auto auth_domain = view(frames[2]);
|
||||
size_t bind_id = (size_t) -1;
|
||||
try {
|
||||
bind_id = oxenc::bt_deserialize<size_t>(view(frames[2]));
|
||||
} catch (...) {}
|
||||
|
||||
if (bind_id >= bind.size()) {
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
|
||||
status_code = "400";
|
||||
status_text = "Unknown authentication domain: " + std::string{auth_domain};
|
||||
} else if (bind[bind_id].curve
|
||||
? !(frames.size() == 7 && view(frames[5]) == "CURVE")
|
||||
: !(frames.size() == 6 && view(frames[5]) == "NULL")) {
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: invalid ",
|
||||
bind[bind_id].curve ? "CURVE" : "NULL", " authentication request");
|
||||
status_code = "500";
|
||||
status_text = "Invalid authentication request mechanism";
|
||||
} else if (bind[bind_id].curve && frames[6].size() != 32) {
|
||||
OMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
|
||||
status_code = "500";
|
||||
status_text = "Invalid public key size for CURVE authentication";
|
||||
} else {
|
||||
auto ip = view(frames[3]);
|
||||
// If we're in dual stack mode IPv4 address might be IPv4-mapped IPv6 address (e.g.
|
||||
// ::ffff:192.168.0.1); if so, remove the prefix to get a proper IPv4 address:
|
||||
if (ip.size() >= 14 && ip.substr(0, 7) == "::ffff:"sv && ip.find_last_not_of("0123456789."sv) == 6)
|
||||
ip = ip.substr(7);
|
||||
|
||||
std::string_view pubkey;
|
||||
bool sn = false;
|
||||
if (bind[bind_id].curve) {
|
||||
pubkey = view(frames[6]);
|
||||
sn = active_service_nodes.count(std::string{pubkey});
|
||||
}
|
||||
auto auth = bind[bind_id].allow(ip, pubkey, sn);
|
||||
auto& user_id = response_vals[4];
|
||||
if (bind[bind_id].curve) {
|
||||
user_id.reserve(64);
|
||||
oxenc::to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
|
||||
}
|
||||
|
||||
if (auth <= AuthLevel::denied || auth > AuthLevel::admin) {
|
||||
OMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
|
||||
" with initial auth level ", auth);
|
||||
status_code = "400";
|
||||
status_text = "Access denied";
|
||||
user_id.clear();
|
||||
} else {
|
||||
OMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection with authentication level ", auth,
|
||||
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
|
||||
|
||||
auto& metadata = response_vals[5];
|
||||
metadata += zmtp_metadata("X-AuthLevel", to_string(auth));
|
||||
|
||||
status_code = "200";
|
||||
status_text = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OMQ_TRACE("ZAP request result: ", status_code, " ", status_text);
|
||||
|
||||
std::vector<zmq::message_t> response;
|
||||
response.reserve(response_vals.size());
|
||||
for (auto &r : response_vals) response.push_back(create_message(std::move(r)));
|
||||
send_message_parts(zap_auth, response.begin(), response.end());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
73
oxenmq/auth.h
Normal file
73
oxenmq/auth.h
Normal file
|
@ -0,0 +1,73 @@
|
|||
#pragma once
|
||||
#include <iosfwd>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
/// Authentication levels for command categories and connections
|
||||
enum class AuthLevel {
|
||||
denied, ///< Not actually an auth level, but can be returned by the AllowFunc to deny an incoming connection.
|
||||
none, ///< No authentication at all; any random incoming ZMQ connection can invoke this command.
|
||||
basic, ///< Basic authentication commands require a login, or a node that is specifically configured to be a public node (e.g. for public RPC).
|
||||
admin, ///< Advanced authentication commands require an admin user, either via explicit login or by implicit login from localhost. This typically protects administrative commands like shutting down, starting mining, or access sensitive data.
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, AuthLevel a);
|
||||
|
||||
/// The access level for a command category
|
||||
struct Access {
|
||||
/// Minimum access level required
|
||||
AuthLevel auth;
|
||||
/// If true only remote SNs may call the category commands
|
||||
bool remote_sn;
|
||||
/// If true the category requires that the local node is a SN
|
||||
bool local_sn;
|
||||
|
||||
/// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an
|
||||
/// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both
|
||||
/// remote and local sn set to false).
|
||||
Access(AuthLevel auth = AuthLevel::none, bool remote_sn = false, bool local_sn = false)
|
||||
: auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {}
|
||||
};
|
||||
|
||||
/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey).
|
||||
/// Falls back to std::hash<std::string> if given a string smaller than a size_t.
|
||||
struct already_hashed {
|
||||
size_t operator()(const std::string& s) const {
|
||||
if (s.size() < sizeof(size_t))
|
||||
return std::hash<std::string>{}(s);
|
||||
size_t hash;
|
||||
std::memcpy(&hash, &s[0], sizeof(hash));
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
/// std::unordered_set specialization for specifying pubkeys (used, in particular, by
|
||||
/// OxenMQ::set_active_sns and OxenMQ::update_active_sns); this is a std::string unordered_set that
|
||||
/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the
|
||||
/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like
|
||||
/// pubkeys and a terrible hash choice for anything else).
|
||||
using pubkey_set = std::unordered_set<std::string, already_hashed>;
|
||||
|
||||
inline constexpr std::string_view to_string(AuthLevel a) {
|
||||
switch (a) {
|
||||
case AuthLevel::denied: return "denied";
|
||||
case AuthLevel::none: return "none";
|
||||
case AuthLevel::basic: return "basic";
|
||||
case AuthLevel::admin: return "admin";
|
||||
default: return "(unknown)";
|
||||
}
|
||||
}
|
||||
|
||||
inline AuthLevel auth_from_string(std::string_view a) {
|
||||
if (a == "none") return AuthLevel::none;
|
||||
if (a == "basic") return AuthLevel::basic;
|
||||
if (a == "admin") return AuthLevel::admin;
|
||||
return AuthLevel::denied;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2020, The Loki Project
|
||||
// Copyright (c) 2020-2021, The Oxen Project
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
@ -30,24 +30,31 @@
|
|||
#include <exception>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "lokimq.h"
|
||||
#include "oxenmq.h"
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
namespace detail {
|
||||
|
||||
enum class BatchStatus {
|
||||
enum class BatchState {
|
||||
running, // there are still jobs to run (or running)
|
||||
complete, // the batch is complete but still has a completion job to call
|
||||
complete_proxy, // same as `complete`, but the completion job should be invoked immediately in the proxy thread (be very careful)
|
||||
done // the batch is complete and has no completion function
|
||||
};
|
||||
|
||||
struct BatchStatus {
|
||||
BatchState state;
|
||||
int thread;
|
||||
};
|
||||
|
||||
// Virtual base class for Batch<R>
|
||||
class Batch {
|
||||
public:
|
||||
// Returns the number of jobs in this batch
|
||||
virtual size_t size() const = 0;
|
||||
// Returns the number of jobs in this batch and whether any of them are thread-specific
|
||||
virtual std::pair<size_t, bool> size() const = 0;
|
||||
// Returns a vector of exactly the same length of size().first containing the tagged thread ids
|
||||
// of the batch jobs or 0 for general jobs.
|
||||
virtual std::vector<int> threads() const = 0;
|
||||
// Called in a worker thread to run the job
|
||||
virtual void run_job(int i) = 0;
|
||||
// Called in the main proxy thread when the worker returns from finishing a job. The return
|
||||
|
@ -71,7 +78,7 @@ public:
|
|||
* This is designed to be like a very stripped down version of a std::promise/std::future pair. We
|
||||
* reimplemented it, however, because by ditching all the thread synchronization that promise/future
|
||||
* guarantees we can substantially reduce call overhead (by a factor of ~8 according to benchmarking
|
||||
* code). Since LokiMQ's proxy<->worker communication channel already gives us thread that overhead
|
||||
* code). Since OxenMQ's proxy<->worker communication channel already gives us thread that overhead
|
||||
* would just be wasted.
|
||||
*
|
||||
* @tparam R the value type held by the result; must be default constructible. Note, however, that
|
||||
|
@ -128,13 +135,13 @@ public:
|
|||
void get() { if (exc) std::rethrow_exception(exc); }
|
||||
};
|
||||
|
||||
/// Helper class used to set up batches of jobs to be scheduled via the lokimq job handler.
|
||||
/// Helper class used to set up batches of jobs to be scheduled via the oxenmq job handler.
|
||||
///
|
||||
/// @tparam R - the return type of the individual jobs
|
||||
///
|
||||
template <typename R>
|
||||
class Batch final : private detail::Batch {
|
||||
friend class LokiMQ;
|
||||
friend class OxenMQ;
|
||||
public:
|
||||
/// The completion function type, called after all jobs have finished.
|
||||
using CompletionFunc = std::function<void(std::vector<job_result<R>> results)>;
|
||||
|
@ -151,16 +158,17 @@ public:
|
|||
Batch &operator=(const Batch&) = delete;
|
||||
|
||||
private:
|
||||
std::vector<std::function<R()>> jobs;
|
||||
std::vector<std::pair<std::function<R()>, int>> jobs;
|
||||
std::vector<job_result<R>> results;
|
||||
CompletionFunc complete;
|
||||
std::size_t jobs_outstanding = 0;
|
||||
bool complete_in_proxy = false;
|
||||
int complete_in_thread = 0;
|
||||
bool started = false;
|
||||
bool tagged_thread_jobs = false;
|
||||
|
||||
void check_not_started() {
|
||||
if (started)
|
||||
throw std::logic_error("Cannot add jobs or completion function after starting a lokimq::Batch!");
|
||||
throw std::logic_error("Cannot add jobs or completion function after starting a oxenmq::Batch!");
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -175,39 +183,61 @@ public:
|
|||
/// available. The called function may throw exceptions (which will be propagated to the
|
||||
/// completion function through the job_result values). There is no guarantee on the order of
|
||||
/// invocation of the jobs.
|
||||
void add_job(std::function<R()> job) {
|
||||
///
|
||||
/// \param job the callback
|
||||
/// \param thread an optional TaggedThreadID indicating a thread in which this job must run
|
||||
void add_job(std::function<R()> job, std::optional<TaggedThreadID> thread = std::nullopt) {
|
||||
check_not_started();
|
||||
jobs.emplace_back(std::move(job));
|
||||
results.emplace_back();
|
||||
jobs_outstanding++;
|
||||
if (thread && thread->_id == -1)
|
||||
// There are some special case internal jobs where we allow this, but they use the
|
||||
// private method below that doesn't have this check.
|
||||
throw std::logic_error{"Cannot add a proxy thread batch job -- this makes no sense"};
|
||||
add_job(std::move(job), thread ? thread->_id : 0);
|
||||
}
|
||||
|
||||
/// Sets the completion function to invoke after all jobs have finished. If this is not set
|
||||
/// then jobs simply run and results are discarded.
|
||||
void completion(CompletionFunc comp) {
|
||||
///
|
||||
/// \param comp - function to call when all jobs have finished
|
||||
/// \param thread - optional tagged thread in which to schedule the completion job. If not
|
||||
/// provided then the completion job is scheduled in the pool of batch job threads.
|
||||
///
|
||||
/// `thread` can be provided the value &OxenMQ::run_in_proxy to invoke the completion function
|
||||
/// *IN THE PROXY THREAD* itself after all jobs have finished. Be very, very careful: this
|
||||
/// should be a nearly trivial job that does not require any substantial CPU time and does not
|
||||
/// block for any reason. This is only intended for the case where the completion job is so
|
||||
/// trivial that it will take less time than simply queuing the job to be executed by another
|
||||
/// thread.
|
||||
void completion(CompletionFunc comp, std::optional<TaggedThreadID> thread = std::nullopt) {
|
||||
check_not_started();
|
||||
if (complete)
|
||||
throw std::logic_error("Completion function can only be set once");
|
||||
complete = std::move(comp);
|
||||
}
|
||||
|
||||
/// Sets a completion function to invoke *IN THE PROXY THREAD* after all jobs have finished. Be
|
||||
/// very, very careful: this should not be a job that takes any significant amount of CPU time
|
||||
/// or can block for any reason (NO MUTEXES).
|
||||
void completion_proxy(CompletionFunc comp) {
|
||||
check_not_started();
|
||||
if (complete)
|
||||
throw std::logic_error("Completion function can only be set once");
|
||||
complete = std::move(comp);
|
||||
complete_in_proxy = true;
|
||||
complete_in_thread = thread ? thread->_id : 0;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
std::size_t size() const override {
|
||||
return jobs.size();
|
||||
void add_job(std::function<R()> job, int thread_id) {
|
||||
jobs.emplace_back(std::move(job), thread_id);
|
||||
results.emplace_back();
|
||||
jobs_outstanding++;
|
||||
if (thread_id != 0)
|
||||
tagged_thread_jobs = true;
|
||||
}
|
||||
|
||||
std::pair<std::size_t, bool> size() const override {
|
||||
return {jobs.size(), tagged_thread_jobs};
|
||||
}
|
||||
|
||||
std::vector<int> threads() const override {
|
||||
std::vector<int> t;
|
||||
t.reserve(jobs.size());
|
||||
for (auto& j : jobs)
|
||||
t.push_back(j.second);
|
||||
return t;
|
||||
};
|
||||
|
||||
template <typename S = R>
|
||||
void set_value(job_result<S>& r, std::function<S()>& f) { r.set_value(f()); }
|
||||
void set_value(job_result<void>&, std::function<void()>& f) { f(); }
|
||||
|
@ -216,7 +246,7 @@ private:
|
|||
// called by worker thread
|
||||
auto& r = results[i];
|
||||
try {
|
||||
set_value(r, jobs[i]);
|
||||
set_value(r, jobs[i].first);
|
||||
} catch (...) {
|
||||
r.set_exception(std::current_exception());
|
||||
}
|
||||
|
@ -225,12 +255,10 @@ private:
|
|||
detail::BatchStatus job_finished() override {
|
||||
--jobs_outstanding;
|
||||
if (jobs_outstanding)
|
||||
return detail::BatchStatus::running;
|
||||
return {detail::BatchState::running, 0};
|
||||
if (complete)
|
||||
return complete_in_proxy
|
||||
? detail::BatchStatus::complete_proxy
|
||||
: detail::BatchStatus::complete;
|
||||
return detail::BatchStatus::done;
|
||||
return {detail::BatchState::complete, complete_in_thread};
|
||||
return {detail::BatchState::done, 0};
|
||||
}
|
||||
|
||||
void job_completion() override {
|
||||
|
@ -238,14 +266,60 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
// Similar to Batch<void>, but doesn't support a completion function and only handles a single task.
|
||||
class Job final : private detail::Batch {
|
||||
friend class OxenMQ;
|
||||
public:
|
||||
/// Constructs the Job to run a single task. Takes any callable invokable with no arguments and
|
||||
/// having no return value. The task will be scheduled and run when the next worker thread is
|
||||
/// available. Any exceptions thrown by the job will be caught and squelched (the exception
|
||||
/// terminates/completes the job).
|
||||
|
||||
explicit Job(std::function<void()> f, std::optional<TaggedThreadID> thread = std::nullopt)
|
||||
: Job{std::move(f), thread ? thread->_id : 0}
|
||||
{
|
||||
if (thread && thread->_id == -1)
|
||||
// There are some special case internal jobs where we allow this, but they use the
|
||||
// private ctor below that doesn't have this check.
|
||||
throw std::logic_error{"Cannot add a proxy thread job -- this makes no sense"};
|
||||
}
|
||||
|
||||
// movable
|
||||
Job(Job&&) = default;
|
||||
Job &operator=(Job&&) = default;
|
||||
|
||||
// non-copyable
|
||||
Job(const Job&) = delete;
|
||||
Job &operator=(const Job&) = delete;
|
||||
|
||||
private:
|
||||
explicit Job(std::function<void()> f, int thread_id)
|
||||
: job{std::move(f), thread_id} {}
|
||||
|
||||
std::pair<std::function<void()>, int> job;
|
||||
bool done = false;
|
||||
|
||||
std::pair<size_t, bool> size() const override { return {1, job.second != 0}; }
|
||||
std::vector<int> threads() const override { return {job.second}; }
|
||||
|
||||
void run_job(const int /*i*/) override {
|
||||
try { job.first(); }
|
||||
catch (...) {}
|
||||
}
|
||||
|
||||
detail::BatchStatus job_finished() override { return {detail::BatchState::done, 0}; }
|
||||
|
||||
void job_completion() override {} // Never called because we return ::done (not ::complete) above.
|
||||
|
||||
};
|
||||
|
||||
template <typename R>
|
||||
void LokiMQ::batch(Batch<R>&& batch) {
|
||||
if (batch.size() == 0)
|
||||
void OxenMQ::batch(Batch<R>&& batch) {
|
||||
if (batch.size().first == 0)
|
||||
throw std::logic_error("Cannot batch a a job batch with 0 jobs");
|
||||
// Need to send this over to the proxy thread via the base class pointer. It assumes ownership.
|
||||
auto* baseptr = static_cast<detail::Batch*>(new Batch<R>(std::move(batch)));
|
||||
detail::send_control(get_control_socket(), "BATCH", bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
detail::send_control(get_control_socket(), "BATCH", oxenc::bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
}
|
||||
|
||||
}
|
429
oxenmq/connections.cpp
Normal file
429
oxenmq/connections.cpp
Normal file
|
@ -0,0 +1,429 @@
|
|||
#include "oxenmq.h"
|
||||
#include "oxenmq-internal.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include <optional>
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
extern "C" {
|
||||
#include <sys/epoll.h>
|
||||
#include <unistd.h>
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, const ConnectionID& conn) {
|
||||
return o << conn.to_string();
|
||||
}
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
|
||||
void OxenMQ::rebuild_pollitems() {
|
||||
|
||||
if (epoll_fd != -1)
|
||||
close(epoll_fd);
|
||||
epoll_fd = epoll_create1(0);
|
||||
|
||||
struct epoll_event ev;
|
||||
ev.events = EPOLLIN | EPOLLET;
|
||||
ev.data.u64 = EPOLL_COMMAND_ID;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, command.get(zmq::sockopt::fd), &ev);
|
||||
|
||||
ev.data.u64 = EPOLL_WORKER_ID;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, workers_socket.get(zmq::sockopt::fd), &ev);
|
||||
|
||||
ev.data.u64 = EPOLL_ZAP_ID;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, zap_auth.get(zmq::sockopt::fd), &ev);
|
||||
|
||||
for (auto& [id, s] : connections) {
|
||||
ev.data.u64 = id;
|
||||
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, s.get(zmq::sockopt::fd), &ev);
|
||||
}
|
||||
connections_updated = false;
|
||||
}
|
||||
|
||||
#else // !OXENMQ_USE_EPOLL
|
||||
|
||||
namespace {
|
||||
|
||||
void add_pollitem(std::vector<zmq::pollitem_t>& pollitems, zmq::socket_t& sock) {
|
||||
pollitems.emplace_back();
|
||||
auto &p = pollitems.back();
|
||||
p.socket = static_cast<void *>(sock);
|
||||
p.fd = 0;
|
||||
p.events = ZMQ_POLLIN;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
void OxenMQ::rebuild_pollitems() {
|
||||
pollitems.clear();
|
||||
add_pollitem(pollitems, command);
|
||||
add_pollitem(pollitems, workers_socket);
|
||||
add_pollitem(pollitems, zap_auth);
|
||||
|
||||
for (auto& [id, s] : connections)
|
||||
add_pollitem(pollitems, s);
|
||||
connections_updated = false;
|
||||
}
|
||||
|
||||
#endif // OXENMQ_USE_EPOLL
|
||||
|
||||
void OxenMQ::setup_external_socket(zmq::socket_t& socket) {
|
||||
socket.set(zmq::sockopt::reconnect_ivl, (int) RECONNECT_INTERVAL.count());
|
||||
socket.set(zmq::sockopt::reconnect_ivl_max, (int) RECONNECT_INTERVAL_MAX.count());
|
||||
socket.set(zmq::sockopt::handshake_ivl, (int) HANDSHAKE_TIME.count());
|
||||
socket.set(zmq::sockopt::maxmsgsize, MAX_MSG_SIZE);
|
||||
if (IPV6)
|
||||
socket.set(zmq::sockopt::ipv6, 1);
|
||||
|
||||
if (CONN_HEARTBEAT > 0s) {
|
||||
socket.set(zmq::sockopt::heartbeat_ivl, (int) CONN_HEARTBEAT.count());
|
||||
if (CONN_HEARTBEAT_TIMEOUT > 0s)
|
||||
socket.set(zmq::sockopt::heartbeat_timeout, (int) CONN_HEARTBEAT_TIMEOUT.count());
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey, bool use_ephemeral_routing_id) {
|
||||
|
||||
setup_external_socket(socket);
|
||||
|
||||
if (!remote_pubkey.empty()) {
|
||||
socket.set(zmq::sockopt::curve_serverkey, remote_pubkey);
|
||||
socket.set(zmq::sockopt::curve_publickey, pubkey);
|
||||
socket.set(zmq::sockopt::curve_secretkey, privkey);
|
||||
}
|
||||
|
||||
if (!use_ephemeral_routing_id) {
|
||||
std::string routing_id;
|
||||
routing_id.reserve(33);
|
||||
routing_id += 'L'; // Prefix because routing id's starting with \0 are reserved by zmq (and our pubkey might start with \0)
|
||||
routing_id.append(pubkey.begin(), pubkey.end());
|
||||
socket.set(zmq::sockopt::routing_id, routing_id);
|
||||
}
|
||||
// else let ZMQ pick a random one
|
||||
}
|
||||
|
||||
|
||||
void OxenMQ::setup_incoming_socket(zmq::socket_t& listener, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index) {
|
||||
|
||||
setup_external_socket(listener);
|
||||
|
||||
listener.set(zmq::sockopt::zap_domain, oxenc::bt_serialize(bind_index));
|
||||
if (curve) {
|
||||
listener.set(zmq::sockopt::curve_server, true);
|
||||
listener.set(zmq::sockopt::curve_publickey, pubkey);
|
||||
listener.set(zmq::sockopt::curve_secretkey, privkey);
|
||||
}
|
||||
listener.set(zmq::sockopt::router_handover, true);
|
||||
listener.set(zmq::sockopt::router_mandatory, true);
|
||||
}
|
||||
|
||||
// Deprecated versions:
|
||||
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
|
||||
ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
return connect_remote(address{remote}, std::move(on_connect), std::move(on_failure),
|
||||
auth_level, connect_option::timeout{timeout});
|
||||
}
|
||||
|
||||
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
|
||||
ConnectFailure on_failure, std::string_view pubkey, AuthLevel auth_level,
|
||||
std::chrono::milliseconds timeout) {
|
||||
return connect_remote(address{remote}.set_pubkey(pubkey), std::move(on_connect),
|
||||
std::move(on_failure), auth_level, connect_option::timeout{timeout});
|
||||
}
|
||||
|
||||
void OxenMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
||||
detail::send_control(get_control_socket(), "DISCONNECT", oxenc::bt_serialize<oxenc::bt_dict>({
|
||||
{"conn_id", id.id},
|
||||
{"linger_ms", linger.count()},
|
||||
{"pubkey", id.pk},
|
||||
}));
|
||||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string>
|
||||
OxenMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, bool use_ephemeral_routing_id, std::chrono::milliseconds keep_alive) {
|
||||
ConnectionID remote_cid{remote};
|
||||
auto its = peers.equal_range(remote_cid);
|
||||
peer_info* peer = nullptr;
|
||||
for (auto it = its.first; it != its.second; ++it) {
|
||||
if (incoming_only && it->second.route.empty())
|
||||
continue; // outgoing connection but we were asked to only use incoming connections
|
||||
if (outgoing_only && !it->second.route.empty())
|
||||
continue;
|
||||
peer = &it->second;
|
||||
break;
|
||||
}
|
||||
|
||||
if (peer) {
|
||||
OMQ_TRACE("proxy asked to connect to ", oxenc::to_hex(remote), "; reusing existing connection");
|
||||
if (peer->route.empty() /* == outgoing*/) {
|
||||
if (peer->idle_expiry < keep_alive) {
|
||||
OMQ_LOG(debug, "updating existing outgoing peer connection idle expiry time from ",
|
||||
peer->idle_expiry.count(), "ms to ", keep_alive.count(), "ms");
|
||||
peer->idle_expiry = keep_alive;
|
||||
}
|
||||
peer->activity();
|
||||
}
|
||||
return {&connections[peer->conn_id], peer->route};
|
||||
} else if (optional || incoming_only) {
|
||||
OMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt");
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
|
||||
// No connection so establish a new one
|
||||
OMQ_LOG(debug, "proxy establishing new outbound connection to ", oxenc::to_hex(remote));
|
||||
std::string addr;
|
||||
bool to_self = false && remote == pubkey; // FIXME; need to use a separate listening socket for this, otherwise we can't easily
|
||||
// tell it wasn't from a remote.
|
||||
if (to_self) {
|
||||
// special inproc connection if self that doesn't need any external connection
|
||||
addr = SN_ADDR_SELF;
|
||||
} else {
|
||||
addr = std::string{connect_hint};
|
||||
if (addr.empty())
|
||||
addr = sn_lookup(remote);
|
||||
else
|
||||
OMQ_LOG(debug, "using connection hint ", connect_hint);
|
||||
|
||||
if (addr.empty()) {
|
||||
OMQ_LOG(error, "peer lookup failed for ", oxenc::to_hex(remote));
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
}
|
||||
|
||||
OMQ_LOG(debug, oxenc::to_hex(pubkey), " (me) connecting to ", addr, " to reach ", oxenc::to_hex(remote));
|
||||
std::optional<zmq::socket_t> socket;
|
||||
try {
|
||||
socket.emplace(context, zmq::socket_type::dealer);
|
||||
setup_outgoing_socket(*socket, remote, use_ephemeral_routing_id);
|
||||
socket->connect(addr);
|
||||
} catch (const zmq::error_t& e) {
|
||||
// Note that this failure cases indicates something serious went wrong that means zmq isn't
|
||||
// even going to try connecting (for example an unparseable remote address).
|
||||
OMQ_LOG(error, "Outgoing connection to ", addr, " failed: ", e.what());
|
||||
return {nullptr, ""s};
|
||||
}
|
||||
|
||||
auto& p = peers.emplace(std::move(remote_cid), peer_info{})->second;
|
||||
p.service_node = true;
|
||||
p.pubkey = std::string{remote};
|
||||
p.conn_id = next_conn_id++;
|
||||
p.idle_expiry = keep_alive;
|
||||
p.activity();
|
||||
connections_updated = true;
|
||||
outgoing_sn_conns.emplace_hint(outgoing_sn_conns.end(), p.conn_id, ConnectionID{remote});
|
||||
auto it = connections.emplace_hint(connections.end(), p.conn_id, *std::move(socket));
|
||||
|
||||
return {&it->second, ""s};
|
||||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string> OxenMQ::proxy_connect_sn(oxenc::bt_dict_consumer data) {
|
||||
std::string_view hint, remote_pk;
|
||||
std::chrono::milliseconds keep_alive;
|
||||
bool optional = false, incoming_only = false, outgoing_only = false, ephemeral_rid = EPHEMERAL_ROUTING_ID;
|
||||
|
||||
// Alphabetical order
|
||||
if (data.skip_until("ephemeral_rid"))
|
||||
ephemeral_rid = data.consume_integer<bool>();
|
||||
if (data.skip_until("hint"))
|
||||
hint = data.consume_string_view();
|
||||
if (data.skip_until("incoming"))
|
||||
incoming_only = data.consume_integer<bool>();
|
||||
if (data.skip_until("keep_alive"))
|
||||
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
if (data.skip_until("optional"))
|
||||
optional = data.consume_integer<bool>();
|
||||
if (data.skip_until("outgoing_only"))
|
||||
outgoing_only = data.consume_integer<bool>();
|
||||
if (!data.skip_until("pubkey"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy_connect_sn command; pubkey missing");
|
||||
remote_pk = data.consume_string_view();
|
||||
|
||||
return proxy_connect_sn(remote_pk, hint, optional, incoming_only, outgoing_only, ephemeral_rid, keep_alive);
|
||||
}
|
||||
|
||||
/// Closes outgoing connections and removes all references. Note that this will call `erase()`
|
||||
/// which can invalidate iterators on the various connection containers - if you don't want that,
|
||||
/// delete it first so that the container won't contain the element being deleted.
|
||||
void OxenMQ::proxy_close_connection(int64_t id, std::chrono::milliseconds linger) {
|
||||
auto it = connections.find(id);
|
||||
if (it == connections.end()) {
|
||||
OMQ_LOG(warn, "internal error: connection to close (", id, ") doesn't exist!");
|
||||
return;
|
||||
}
|
||||
OMQ_LOG(debug, "Closing conn ", id);
|
||||
it->second.set(zmq::sockopt::linger, linger > 0ms ? (int) linger.count() : 0);
|
||||
connections.erase(it);
|
||||
connections_updated = true;
|
||||
|
||||
outgoing_sn_conns.erase(id);
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_expire_idle_peers() {
|
||||
for (auto it = peers.begin(); it != peers.end(); ) {
|
||||
auto &info = it->second;
|
||||
if (info.outgoing()) {
|
||||
auto idle = std::chrono::steady_clock::now() - info.last_activity;
|
||||
if (idle > info.idle_expiry) {
|
||||
OMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle time (",
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(), "ms) reached connection timeout (",
|
||||
info.idle_expiry.count(), "ms)");
|
||||
proxy_close_connection(info.conn_id, CLOSE_LINGER);
|
||||
it = peers.erase(it);
|
||||
} else {
|
||||
OMQ_LOG(trace, "Not closing ", it->first, ": ", std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(),
|
||||
"ms <= ", info.idle_expiry.count(), "ms");
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_conn_cleanup() {
|
||||
OMQ_TRACE("starting proxy connections cleanup");
|
||||
|
||||
// Drop idle connections (if we haven't done it in a while)
|
||||
OMQ_TRACE("closing idle connections");
|
||||
proxy_expire_idle_peers();
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
// FIXME - check other outgoing connections to see if they died and if so purge them
|
||||
|
||||
OMQ_TRACE("Timing out pending outgoing connections");
|
||||
// Check any pending outgoing connections for timeout
|
||||
for (auto it = pending_connects.begin(); it != pending_connects.end(); ) {
|
||||
auto& pc = *it;
|
||||
if (std::get<std::chrono::steady_clock::time_point>(pc) < now) {
|
||||
auto id = std::get<int64_t>(pc);
|
||||
job([cid = ConnectionID{id}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); });
|
||||
it = pending_connects.erase(it); // Don't let the below erase it (because it invalidates iterators)
|
||||
proxy_close_connection(id, CLOSE_LINGER);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
OMQ_TRACE("Timing out pending requests");
|
||||
// Remove any expired pending requests and schedule their callback with a failure
|
||||
for (auto it = pending_requests.begin(); it != pending_requests.end(); ) {
|
||||
auto& callback = it->second;
|
||||
if (callback.first < now) {
|
||||
OMQ_LOG(debug, "pending request ", oxenc::to_hex(it->first), " expired, invoking callback with failure status and removing");
|
||||
job([callback = std::move(callback.second)] { callback(false, {{"TIMEOUT"s}}); });
|
||||
it = pending_requests.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
OMQ_TRACE("done proxy connections cleanup");
|
||||
};
|
||||
|
||||
void OxenMQ::proxy_connect_remote(oxenc::bt_dict_consumer data) {
|
||||
AuthLevel auth_level = AuthLevel::none;
|
||||
long long conn_id = -1;
|
||||
ConnectSuccess on_connect;
|
||||
ConnectFailure on_failure;
|
||||
std::string remote;
|
||||
std::string remote_pubkey;
|
||||
std::chrono::milliseconds timeout = REMOTE_CONNECT_TIMEOUT;
|
||||
bool ephemeral_rid = EPHEMERAL_ROUTING_ID;
|
||||
|
||||
if (data.skip_until("auth_level"))
|
||||
auth_level = static_cast<AuthLevel>(data.consume_integer<std::underlying_type_t<AuthLevel>>());
|
||||
if (data.skip_until("conn_id"))
|
||||
conn_id = data.consume_integer<long long>();
|
||||
if (data.skip_until("connect"))
|
||||
on_connect = detail::deserialize_object<ConnectSuccess>(data.consume_integer<uintptr_t>());
|
||||
if (data.skip_until("ephemeral_rid"))
|
||||
ephemeral_rid = data.consume_integer<bool>();
|
||||
if (data.skip_until("failure"))
|
||||
on_failure = detail::deserialize_object<ConnectFailure>(data.consume_integer<uintptr_t>());
|
||||
if (data.skip_until("pubkey")) {
|
||||
remote_pubkey = data.consume_string();
|
||||
assert(remote_pubkey.size() == 32 || remote_pubkey.empty());
|
||||
}
|
||||
if (data.skip_until("remote"))
|
||||
remote = data.consume_string();
|
||||
if (data.skip_until("timeout"))
|
||||
timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
|
||||
if (conn_id == -1 || remote.empty())
|
||||
throw std::runtime_error("Internal error: CONNECT_REMOTE proxy command missing required 'conn_id' and/or 'remote' value");
|
||||
|
||||
OMQ_LOG(debug, "Establishing remote connection to ", remote,
|
||||
remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + oxenc::to_hex(remote_pubkey));
|
||||
|
||||
std::optional<zmq::socket_t> sock;
|
||||
try {
|
||||
sock.emplace(context, zmq::socket_type::dealer);
|
||||
setup_outgoing_socket(*sock, remote_pubkey, ephemeral_rid);
|
||||
sock->connect(remote);
|
||||
} catch (const zmq::error_t &e) {
|
||||
proxy_schedule_reply_job([conn_id, on_failure=std::move(on_failure), what="connect() failed: "s+e.what()] {
|
||||
on_failure(conn_id, std::move(what));
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
auto &s = connections.emplace_hint(connections.end(), conn_id, std::move(*sock))->second;
|
||||
connections_updated = true;
|
||||
OMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI");
|
||||
send_direct_message(s, "HI");
|
||||
pending_connects.emplace_back(conn_id, std::chrono::steady_clock::now() + timeout,
|
||||
std::move(on_connect), std::move(on_failure));
|
||||
auto& peer = peers.emplace(ConnectionID{conn_id, remote_pubkey}, peer_info{})->second;
|
||||
peer.pubkey = std::move(remote_pubkey);
|
||||
peer.service_node = false;
|
||||
peer.auth_level = auth_level;
|
||||
peer.conn_id = conn_id;
|
||||
peer.idle_expiry = 24h * 10 * 365; // "forever"
|
||||
peer.activity();
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_disconnect(oxenc::bt_dict_consumer data) {
|
||||
ConnectionID connid{-1};
|
||||
std::chrono::milliseconds linger = 1s;
|
||||
|
||||
if (data.skip_until("conn_id"))
|
||||
connid.id = data.consume_integer<long long>();
|
||||
if (data.skip_until("linger_ms"))
|
||||
linger = std::chrono::milliseconds(data.consume_integer<long long>());
|
||||
if (data.skip_until("pubkey"))
|
||||
connid.pk = data.consume_string();
|
||||
|
||||
if (connid.sn() && connid.pk.size() != 32)
|
||||
throw std::runtime_error("Error: invalid disconnect of SN without a valid pubkey");
|
||||
|
||||
proxy_disconnect(std::move(connid), linger);
|
||||
}
|
||||
void OxenMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linger) {
|
||||
OMQ_TRACE("Disconnecting outgoing connection to ", conn);
|
||||
auto pr = peers.equal_range(conn);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
auto& peer = it->second;
|
||||
if (peer.outgoing()) {
|
||||
OMQ_LOG(debug, "Closing outgoing connection to ", conn);
|
||||
proxy_close_connection(peer.conn_id, linger);
|
||||
peers.erase(it);
|
||||
return;
|
||||
}
|
||||
}
|
||||
OMQ_LOG(warn, "Failed to disconnect ", conn, ": no such outgoing connection");
|
||||
}
|
||||
|
||||
std::string ConnectionID::to_string() const {
|
||||
if (!pk.empty())
|
||||
return (sn() ? std::string("SN ") : std::string("non-SN authenticated remote ")) + oxenc::to_hex(pk);
|
||||
else
|
||||
return std::string("unauthenticated remote [") + std::to_string(id) + "]";
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,14 +1,20 @@
|
|||
#pragma once
|
||||
#include "string_view.h"
|
||||
#include "auth.h"
|
||||
#include <oxenc/bt_value.h>
|
||||
#include <string_view>
|
||||
#include <iosfwd>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
class bt_dict;
|
||||
struct ConnectionID;
|
||||
|
||||
namespace detail {
|
||||
template <typename... T>
|
||||
bt_dict build_send(ConnectionID to, string_view cmd, const T&... opts);
|
||||
oxenc::bt_dict build_send(ConnectionID to, std::string_view cmd, T&&... opts);
|
||||
}
|
||||
|
||||
/// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For
|
||||
|
@ -16,11 +22,17 @@ bt_dict build_send(ConnectionID to, string_view cmd, const T&... opts);
|
|||
/// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy
|
||||
/// of the ConnectionID returned by connect_remote().
|
||||
struct ConnectionID {
|
||||
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
|
||||
// an actual connection.
|
||||
ConnectionID() : ConnectionID(0) {}
|
||||
// Construction from a service node pubkey
|
||||
ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} {
|
||||
if (pk.size() != 32)
|
||||
throw std::runtime_error{"Invalid pubkey: expected 32 bytes"};
|
||||
}
|
||||
ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
||||
// Construction from a service node pubkey
|
||||
ConnectionID(std::string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
||||
|
||||
ConnectionID(const ConnectionID&) = default;
|
||||
ConnectionID(ConnectionID&&) = default;
|
||||
ConnectionID& operator=(const ConnectionID&) = default;
|
||||
|
@ -32,52 +44,55 @@ struct ConnectionID {
|
|||
}
|
||||
|
||||
// Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both
|
||||
// not SNs and have matching internal IDs. (Pubkeys do not have to match for non-SNs, and
|
||||
// routes are not considered for equality at all).
|
||||
// not SNs and have matching internal IDs and routes. (Pubkeys do not have to match for
|
||||
// non-SNs).
|
||||
bool operator==(const ConnectionID &o) const {
|
||||
if (id == SN_ID && o.id == SN_ID)
|
||||
if (sn() && o.sn())
|
||||
return pk == o.pk;
|
||||
return id == o.id;
|
||||
return id == o.id && route == o.route;
|
||||
}
|
||||
bool operator!=(const ConnectionID &o) const { return !(*this == o); }
|
||||
bool operator<(const ConnectionID &o) const {
|
||||
if (id == SN_ID && o.id == SN_ID)
|
||||
if (sn() && o.sn())
|
||||
return pk < o.pk;
|
||||
return id < o.id;
|
||||
return id < o.id || (id == o.id && route < o.route);
|
||||
}
|
||||
|
||||
// Returns true if this ConnectionID represents a SN connection
|
||||
bool sn() const { return id == SN_ID; }
|
||||
|
||||
// Returns this connection's pubkey, if any. (Note that it is possible to have a pubkey and not
|
||||
// be a SN when connecting to secure remotes: having a non-empty pubkey does not imply that
|
||||
// `sn()` is true).
|
||||
// Returns this connection's pubkey, if any. (Note that all curve connections have pubkeys, not
|
||||
// only SNs).
|
||||
const std::string& pubkey() const { return pk; }
|
||||
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
|
||||
// an actual connection.
|
||||
ConnectionID() : ConnectionID(0) {}
|
||||
|
||||
// Returns a copy of the ConnectionID with the route set to empty.
|
||||
ConnectionID unrouted() { return ConnectionID{id, pk, ""}; }
|
||||
|
||||
std::string to_string() const;
|
||||
|
||||
|
||||
private:
|
||||
ConnectionID(long long id) : id{id} {}
|
||||
ConnectionID(long long id, std::string pubkey, std::string route = "")
|
||||
ConnectionID(int64_t id) : id{id} {}
|
||||
ConnectionID(int64_t id, std::string pubkey, std::string route = "")
|
||||
: id{id}, pk{std::move(pubkey)}, route{std::move(route)} {}
|
||||
|
||||
constexpr static long long SN_ID = -1;
|
||||
long long id = 0;
|
||||
constexpr static int64_t SN_ID = -1;
|
||||
int64_t id = 0;
|
||||
std::string pk;
|
||||
std::string route;
|
||||
friend class LokiMQ;
|
||||
friend class OxenMQ;
|
||||
friend struct std::hash<ConnectionID>;
|
||||
template <typename... T>
|
||||
friend bt_dict detail::build_send(ConnectionID to, string_view cmd, const T&... opts);
|
||||
friend oxenc::bt_dict detail::build_send(ConnectionID to, std::string_view cmd, T&&... opts);
|
||||
friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn);
|
||||
};
|
||||
|
||||
} // namespace lokimq
|
||||
} // namespace oxenmq
|
||||
namespace std {
|
||||
template <> struct hash<lokimq::ConnectionID> {
|
||||
size_t operator()(const lokimq::ConnectionID &c) const {
|
||||
return c.sn() ? std::hash<std::string>{}(c.pk) :
|
||||
std::hash<long long>{}(c.id);
|
||||
template <> struct hash<oxenmq::ConnectionID> {
|
||||
size_t operator()(const oxenmq::ConnectionID &c) const {
|
||||
return c.sn() ? oxenmq::already_hashed{}(c.pk) :
|
||||
std::hash<int64_t>{}(c.id) + std::hash<std::string>{}(c.route);
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
28
oxenmq/fmt.h
Normal file
28
oxenmq/fmt.h
Normal file
|
@ -0,0 +1,28 @@
|
|||
#pragma once
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include "connections.h"
|
||||
#include "auth.h"
|
||||
#include "address.h"
|
||||
|
||||
template <>
|
||||
struct fmt::formatter<oxenmq::AuthLevel> : fmt::formatter<std::string> {
|
||||
auto format(oxenmq::AuthLevel v, format_context& ctx) {
|
||||
return formatter<std::string>::format(
|
||||
fmt::format("{}", to_string(v)), ctx);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct fmt::formatter<oxenmq::ConnectionID> : fmt::formatter<std::string> {
|
||||
auto format(oxenmq::ConnectionID conn, format_context& ctx) {
|
||||
return formatter<std::string>::format(
|
||||
fmt::format("{}", conn.to_string()), ctx);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct fmt::formatter<oxenmq::address> : fmt::formatter<std::string> {
|
||||
auto format(oxenmq::address addr, format_context& ctx) {
|
||||
return formatter<std::string>::format(
|
||||
fmt::format("{}", addr.full_address()), ctx);
|
||||
}
|
||||
};
|
182
oxenmq/jobs.cpp
Normal file
182
oxenmq/jobs.cpp
Normal file
|
@ -0,0 +1,182 @@
|
|||
#include "oxenmq.h"
|
||||
#include "batch.h"
|
||||
#include "oxenmq-internal.h"
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
void OxenMQ::proxy_batch(detail::Batch* batch) {
|
||||
const auto [jobs, tagged_threads] = batch->size();
|
||||
OMQ_TRACE("proxy queuing batch job with ", jobs, " jobs", tagged_threads ? " (job uses tagged thread(s))" : "");
|
||||
if (!tagged_threads) {
|
||||
for (size_t i = 0; i < jobs; i++)
|
||||
batch_jobs.emplace_back(batch, i);
|
||||
} else {
|
||||
// Some (or all) jobs have a specific thread target so queue any such jobs in the tagged
|
||||
// worker queue.
|
||||
auto threads = batch->threads();
|
||||
for (size_t i = 0; i < jobs; i++) {
|
||||
auto& jobs = threads[i] > 0
|
||||
? std::get<batch_queue>(tagged_workers[threads[i] - 1])
|
||||
: batch_jobs;
|
||||
jobs.emplace_back(batch, i);
|
||||
}
|
||||
}
|
||||
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void OxenMQ::job(std::function<void()> f, std::optional<TaggedThreadID> thread) {
|
||||
if (thread && thread->_id == -1)
|
||||
throw std::logic_error{"job() cannot be used to queue an in-proxy job"};
|
||||
auto* j = new Job(std::move(f), thread);
|
||||
auto* baseptr = static_cast<detail::Batch*>(j);
|
||||
detail::send_control(get_control_socket(), "BATCH", oxenc::bt_serialize(reinterpret_cast<uintptr_t>(baseptr)));
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_schedule_reply_job(std::function<void()> f) {
|
||||
auto* j = new Job(std::move(f));
|
||||
reply_jobs.emplace_back(static_cast<detail::Batch*>(j), 0);
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_run_batch_jobs(batch_queue& jobs, const int reserved, int& active, bool reply) {
|
||||
while (!jobs.empty() && active_workers() < max_workers &&
|
||||
(active < reserved || active_workers() < general_workers)) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(jobs.front()), reply));
|
||||
jobs.pop_front();
|
||||
active++;
|
||||
}
|
||||
}
|
||||
|
||||
// Called either within the proxy thread, or before the proxy thread has been created; actually adds
|
||||
// the timer. If the timer object hasn't been set up yet it gets set up here.
|
||||
void OxenMQ::proxy_timer(int id, std::function<void()> job, std::chrono::milliseconds interval, bool squelch, int thread) {
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
int zmq_timer_id = zmq_timers_add(timers.get(),
|
||||
interval.count(),
|
||||
[](int timer_id, void* self) { static_cast<OxenMQ*>(self)->_queue_timer_job(timer_id); },
|
||||
this);
|
||||
if (zmq_timer_id == -1)
|
||||
throw zmq::error_t{};
|
||||
timer_jobs[zmq_timer_id] = { std::move(job), squelch, false, thread };
|
||||
timer_zmq_id[id] = zmq_timer_id;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_timer(oxenc::bt_list_consumer timer_data) {
|
||||
auto timer_id = timer_data.consume_integer<int>();
|
||||
std::unique_ptr<std::function<void()>> func{reinterpret_cast<std::function<void()>*>(timer_data.consume_integer<uintptr_t>())};
|
||||
auto interval = std::chrono::milliseconds{timer_data.consume_integer<uint64_t>()};
|
||||
auto squelch = timer_data.consume_integer<bool>();
|
||||
auto thread = timer_data.consume_integer<int>();
|
||||
if (!timer_data.is_finished())
|
||||
throw std::runtime_error("Internal error: proxied timer request contains unexpected data");
|
||||
proxy_timer(timer_id, std::move(*func), interval, squelch, thread);
|
||||
}
|
||||
|
||||
void OxenMQ::_queue_timer_job(int timer_id) {
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it == timer_jobs.end()) {
|
||||
OMQ_LOG(warn, "Could not find timer job ", timer_id);
|
||||
return;
|
||||
}
|
||||
auto& [func, squelch, running, thread] = it->second;
|
||||
if (squelch && running) {
|
||||
OMQ_LOG(debug, "Not running timer job ", timer_id, " because a job for that timer is still running");
|
||||
return;
|
||||
}
|
||||
|
||||
if (thread == -1) { // Run directly in proxy thread
|
||||
try { func(); }
|
||||
catch (const std::exception &e) { OMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { OMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
return;
|
||||
}
|
||||
|
||||
detail::Batch* b;
|
||||
if (squelch) {
|
||||
auto* bv = new Batch<void>;
|
||||
bv->add_job(func, thread);
|
||||
running = true;
|
||||
bv->completion([this,timer_id](auto results) {
|
||||
try { results[0].get(); }
|
||||
catch (const std::exception &e) { OMQ_LOG(warn, "timer job ", timer_id, " raised an exception: ", e.what()); }
|
||||
catch (...) { OMQ_LOG(warn, "timer job ", timer_id, " raised a non-std exception"); }
|
||||
auto it = timer_jobs.find(timer_id);
|
||||
if (it != timer_jobs.end())
|
||||
it->second.running = false;
|
||||
}, OxenMQ::run_in_proxy);
|
||||
b = bv;
|
||||
} else {
|
||||
b = new Job(func, thread);
|
||||
}
|
||||
OMQ_TRACE("b: ", b->size().first, ", ", b->size().second, "; thread = ", thread);
|
||||
assert(b->size() == std::make_pair(size_t{1}, thread > 0));
|
||||
auto& queue = thread > 0
|
||||
? std::get<batch_queue>(tagged_workers[thread - 1])
|
||||
: batch_jobs;
|
||||
queue.emplace_back(static_cast<detail::Batch*>(b), 0);
|
||||
}
|
||||
|
||||
void OxenMQ::add_timer(TimerID& timer, std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
|
||||
int th_id = thread ? thread->_id : 0;
|
||||
timer._id = next_timer_id++;
|
||||
if (proxy_thread.joinable()) {
|
||||
detail::send_control(get_control_socket(), "TIMER", oxenc::bt_serialize(oxenc::bt_list{{
|
||||
timer._id,
|
||||
detail::serialize_object(std::move(job)),
|
||||
interval.count(),
|
||||
squelch,
|
||||
th_id}}));
|
||||
} else {
|
||||
proxy_timer(timer._id, std::move(job), interval, squelch, th_id);
|
||||
}
|
||||
}
|
||||
|
||||
TimerID OxenMQ::add_timer(std::function<void()> job, std::chrono::milliseconds interval, bool squelch, std::optional<TaggedThreadID> thread) {
|
||||
TimerID tid;
|
||||
add_timer(tid, std::move(job), interval, squelch, std::move(thread));
|
||||
return tid;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_timer_del(int id) {
|
||||
if (!timers)
|
||||
return;
|
||||
auto it = timer_zmq_id.find(id);
|
||||
if (it == timer_zmq_id.end())
|
||||
return;
|
||||
zmq_timers_cancel(timers.get(), it->second);
|
||||
timer_zmq_id.erase(it);
|
||||
}
|
||||
|
||||
void OxenMQ::cancel_timer(TimerID timer_id) {
|
||||
if (proxy_thread.joinable()) {
|
||||
detail::send_control(get_control_socket(), "TIMER_DEL", oxenc::bt_serialize(timer_id._id));
|
||||
} else {
|
||||
proxy_timer_del(timer_id._id);
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::TimersDeleter::operator()(void* timers) { zmq_timers_destroy(&timers); }
|
||||
|
||||
TaggedThreadID OxenMQ::add_tagged_thread(std::string name, std::function<void()> start) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error{"Cannot add tagged threads after calling `start()`"};
|
||||
|
||||
if (name == "_proxy"sv || name.empty() || name.find('\0') != std::string::npos)
|
||||
throw std::logic_error{"Invalid tagged thread name `" + name + "'"};
|
||||
|
||||
auto& [run, busy, queue] = tagged_workers.emplace_back();
|
||||
busy = false;
|
||||
run.worker_id = tagged_workers.size(); // We want index + 1 (b/c 0 is used for non-tagged jobs)
|
||||
run.worker_routing_name = "t" + std::to_string(run.worker_id);
|
||||
run.worker_routing_id = "t" + std::string{reinterpret_cast<const char*>(&run.worker_id), sizeof(run.worker_id)};
|
||||
OMQ_TRACE("Created new tagged thread ", name, " with routing id ", run.worker_routing_name);
|
||||
|
||||
run.worker_thread = std::thread{&OxenMQ::worker_thread, this, run.worker_id, name, std::move(start)};
|
||||
|
||||
return TaggedThreadID{static_cast<int>(run.worker_id)};
|
||||
}
|
||||
|
||||
}
|
106
oxenmq/message.h
Normal file
106
oxenmq/message.h
Normal file
|
@ -0,0 +1,106 @@
|
|||
#pragma once
|
||||
#include <vector>
|
||||
#include "connections.h"
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
class OxenMQ;
|
||||
|
||||
/// Encapsulates an incoming message from a remote connection with message details plus extra
|
||||
/// info need to send a reply back through the proxy thread via the `reply()` method. Note that
|
||||
/// this object gets reused: callbacks should use but not store any reference beyond the callback.
|
||||
class Message {
|
||||
public:
|
||||
OxenMQ& oxenmq; ///< The owning OxenMQ object
|
||||
std::vector<std::string_view> data; ///< The provided command data parts, if any.
|
||||
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status.
|
||||
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `send_reply()`.
|
||||
Access access; ///< The access level of the invoker. This can be higher than the access level of the command, for example for an admin invoking a basic command.
|
||||
std::string remote; ///< Some sort of remote address from which the request came. Often "IP" for TCP connections and "localhost:UID:GID:PID" for unix socket connections.
|
||||
|
||||
/// Constructor
|
||||
Message(OxenMQ& omq, ConnectionID cid, Access access, std::string remote)
|
||||
: oxenmq{omq}, conn{std::move(cid)}, access{std::move(access)}, remote{std::move(remote)} {}
|
||||
|
||||
// Non-copyable
|
||||
Message(const Message&) = delete;
|
||||
Message& operator=(const Message&) = delete;
|
||||
|
||||
/// Sends a command back to whomever sent this message. Arguments are forwarded to send() but
|
||||
/// with send_option::optional{} added if the originator is not a SN. For SN messages (i.e.
|
||||
/// where `sn` is true) this is a "strong" reply by default in that the proxy will attempt to
|
||||
/// establish a new connection to the SN if no longer connected. For non-SN messages the reply
|
||||
/// will be attempted using the available routing information, but if the connection has already
|
||||
/// been closed the reply will be dropped.
|
||||
///
|
||||
/// If you want to send a non-strong reply even when the remote is a service node then add
|
||||
/// an explicit `send_option::optional()` argument.
|
||||
template <typename... Args>
|
||||
void send_back(std::string_view command, Args&&... args);
|
||||
|
||||
/// Sends a reply to a request. This takes no command: the command is always the built-in
|
||||
/// "REPLY" command, followed by the unique reply tag, then any reply data parts. All other
|
||||
/// arguments are as in `send_back()`. You should only send one reply for a command expecting
|
||||
/// replies, though this is not enforced: attempting to send multiple replies will simply be
|
||||
/// dropped when received by the remote. (Note, however, that it is possible to send multiple
|
||||
/// messages -- e.g. you could send a reply and then also call send_back() and/or send_request()
|
||||
/// to send more requests back to the sender).
|
||||
template <typename... Args>
|
||||
void send_reply(Args&&... args);
|
||||
|
||||
/// Sends a request back to whomever sent this message. This is effectively a wrapper around
|
||||
/// omq.request() that takes care of setting up the recipient arguments.
|
||||
template <typename ReplyCallback, typename... Args>
|
||||
void send_request(std::string_view command, ReplyCallback&& callback, Args&&... args);
|
||||
|
||||
/** Class returned by `send_later()` that can be used to call `back()`, `reply()`, or
|
||||
* `request()` beyond the lifetime of the Message instance as if calling `msg.send_back()`,
|
||||
* `msg.send_reply()`, or `msg.send_request()`. For example:
|
||||
*
|
||||
* auto send = msg.send_later();
|
||||
* // ... later, perhaps in a lambda or scheduled job:
|
||||
* send.reply("content");
|
||||
*
|
||||
* is equivalent to
|
||||
*
|
||||
* msg.send_reply("content");
|
||||
*
|
||||
* except that it is valid even if `msg` is no longer valid.
|
||||
*/
|
||||
class DeferredSend {
|
||||
public:
|
||||
OxenMQ& oxenmq; ///< The owning OxenMQ object
|
||||
ConnectionID conn; ///< The connection info for routing a reply; also contains the pubkey/sn status
|
||||
std::string reply_tag; ///< If the invoked command is a request command this is the required reply tag that will be prepended by `reply()`.
|
||||
|
||||
explicit DeferredSend(Message& m) : oxenmq{m.oxenmq}, conn{m.conn}, reply_tag{m.reply_tag} {}
|
||||
|
||||
template <typename... Args>
|
||||
void operator()(Args &&...args) const {
|
||||
if (reply_tag.empty())
|
||||
back(std::forward<Args>(args)...);
|
||||
else
|
||||
reply(std::forward<Args>(args)...);
|
||||
};
|
||||
|
||||
|
||||
/// Equivalent to msg.send_back(...), but can be invoked later.
|
||||
template <typename... Args>
|
||||
void back(std::string_view command, Args&&... args) const;
|
||||
|
||||
/// Equivalent to msg.send_reply(...), but can be invoked later.
|
||||
template <typename... Args>
|
||||
void reply(Args&&... args) const;
|
||||
|
||||
/// Equivalent to msg.send_request(...), but can be invoked later.
|
||||
template <typename ReplyCallback, typename... Args>
|
||||
void request(std::string_view command, ReplyCallback&& callback, Args&&... args) const;
|
||||
};
|
||||
|
||||
/// Returns a DeferredSend object that can be used to send replies to this message even if the
|
||||
/// message expires. Typically this is used when sending a reply requires waiting on another
|
||||
/// task to complete without needing to block the handler thread.
|
||||
DeferredSend send_later() { return DeferredSend{*this}; }
|
||||
};
|
||||
|
||||
}
|
|
@ -1,82 +1,91 @@
|
|||
#pragma once
|
||||
#include "lokimq.h"
|
||||
#include <limits>
|
||||
#include "oxenmq.h"
|
||||
|
||||
// Inside some method:
|
||||
// LMQ_LOG(warn, "bad ", 42, " stuff");
|
||||
// OMQ_LOG(warn, "bad ", 42, " stuff");
|
||||
//
|
||||
// (The "this->" is here to work around gcc 5 bugginess when called in a `this`-capturing lambda.)
|
||||
#define LMQ_LOG(level, ...) this->log_(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define OMQ_LOG(level, ...) log(LogLevel::level, __FILE__, __LINE__, __VA_ARGS__)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Same as LMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
|
||||
# define LMQ_TRACE(...) this->log_(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
|
||||
// Same as OMQ_LOG(trace, ...) when not doing a release build; nothing under a release build.
|
||||
# define OMQ_TRACE(...) log(LogLevel::trace, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
# define LMQ_TRACE(...)
|
||||
# define OMQ_TRACE(...)
|
||||
#endif
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
constexpr char SN_ADDR_COMMAND[] = "inproc://sn-command";
|
||||
constexpr char SN_ADDR_WORKERS[] = "inproc://sn-workers";
|
||||
constexpr char SN_ADDR_SELF[] = "inproc://sn-self";
|
||||
constexpr char ZMQ_ADDR_ZAP[] = "inproc://zeromq.zap.01";
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
|
||||
constexpr auto EPOLL_COMMAND_ID = std::numeric_limits<uint64_t>::max();
|
||||
constexpr auto EPOLL_WORKER_ID = std::numeric_limits<uint64_t>::max() - 1;
|
||||
constexpr auto EPOLL_ZAP_ID = std::numeric_limits<uint64_t>::max() - 2;
|
||||
|
||||
#endif
|
||||
|
||||
/// Destructor for create_message(std::string&&) that zmq calls when it's done with the message.
|
||||
extern "C" inline void message_buffer_destroy(void*, void* hint) {
|
||||
delete reinterpret_cast<std::string*>(hint);
|
||||
}
|
||||
|
||||
/// Creates a message without needing to reallocate the provided string data
|
||||
inline zmq::message_t create_message(std::string &&data) {
|
||||
inline zmq::message_t create_message(std::string&& data) {
|
||||
auto *buffer = new std::string(std::move(data));
|
||||
return zmq::message_t{&(*buffer)[0], buffer->size(), message_buffer_destroy, buffer};
|
||||
};
|
||||
|
||||
/// Create a message copying from a string_view
|
||||
inline zmq::message_t create_message(string_view data) {
|
||||
inline zmq::message_t create_message(std::string_view data) {
|
||||
return zmq::message_t{data.begin(), data.end()};
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
void send_message_parts(zmq::socket_t &sock, It begin, It end) {
|
||||
bool send_message_parts(zmq::socket_t& sock, It begin, It end) {
|
||||
while (begin != end) {
|
||||
// FIXME: for outgoing connections on ZMQ_DEALER we want to use ZMQ_DONTWAIT and handle
|
||||
// EAGAIN error (which either means the peer HWM is hit -- probably indicating a connection
|
||||
// failure -- or the underlying connect() system call failed). Assuming it's an outgoing
|
||||
// connection, we should destroy it.
|
||||
zmq::message_t &msg = *begin++;
|
||||
sock.send(msg, begin == end ? zmq::send_flags::none : zmq::send_flags::sndmore);
|
||||
if (!sock.send(msg, begin == end ? zmq::send_flags::dontwait : zmq::send_flags::dontwait | zmq::send_flags::sndmore))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
void send_message_parts(zmq::socket_t &sock, Container &&c) {
|
||||
send_message_parts(sock, c.begin(), c.end());
|
||||
bool send_message_parts(zmq::socket_t& sock, Container&& c) {
|
||||
return send_message_parts(sock, c.begin(), c.end());
|
||||
}
|
||||
|
||||
/// Sends a message with an initial route. `msg` and `data` can be empty: if `msg` is empty then
|
||||
/// the msg frame will be an empty message; if `data` is empty then the data frame will be omitted.
|
||||
inline void send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}) {
|
||||
/// `flags` is passed through to zmq: typically given `zmq::send_flags::dontwait` to throw rather
|
||||
/// than block if a message can't be queued.
|
||||
inline bool send_routed_message(zmq::socket_t& socket, std::string route, std::string msg = {}, std::string data = {}) {
|
||||
assert(!route.empty());
|
||||
std::array<zmq::message_t, 3> msgs{{create_message(std::move(route))}};
|
||||
if (!msg.empty())
|
||||
msgs[1] = create_message(std::move(msg));
|
||||
if (!data.empty())
|
||||
msgs[2] = create_message(std::move(data));
|
||||
send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
|
||||
return send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
|
||||
}
|
||||
|
||||
// Sends some stuff to a socket directly.
|
||||
inline void send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}) {
|
||||
// Sends some stuff to a socket directly. If dontwait is true then we throw instead of blocking if
|
||||
// the message cannot be accepted by zmq (i.e. because the outgoing buffer is full).
|
||||
inline bool send_direct_message(zmq::socket_t& socket, std::string msg, std::string data = {}) {
|
||||
std::array<zmq::message_t, 2> msgs{{create_message(std::move(msg))}};
|
||||
if (!data.empty())
|
||||
msgs[1] = create_message(std::move(data));
|
||||
send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
|
||||
return send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end());
|
||||
}
|
||||
|
||||
// Receive all the parts of a single message from the given socket. Returns true if a message was
|
||||
// received, false if called with flags=zmq::recv_flags::dontwait and no message was available.
|
||||
inline bool recv_message_parts(zmq::socket_t &sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
|
||||
inline bool recv_message_parts(zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
|
||||
do {
|
||||
zmq::message_t msg;
|
||||
if (!sock.recv(msg, flags))
|
||||
|
@ -86,6 +95,21 @@ inline bool recv_message_parts(zmq::socket_t &sock, std::vector<zmq::message_t>&
|
|||
return true;
|
||||
}
|
||||
|
||||
// Same as above, but using a fixed sized array; this is only used for internal jobs (e.g. control
|
||||
// messages) where we know the message parts should never exceed a given size (this function does
|
||||
// not bounds check except in debug builds). Returns the number of message parts received, or 0 on
|
||||
// read error.
|
||||
template <size_t N>
|
||||
inline size_t recv_message_parts(zmq::socket_t& sock, std::array<zmq::message_t, N>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
|
||||
for (size_t count = 0; ; count++) {
|
||||
assert(count < N);
|
||||
if (!sock.recv(parts[count], flags))
|
||||
return 0;
|
||||
if (!parts[count].more())
|
||||
return count + 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline const char* peer_address(zmq::message_t& msg) {
|
||||
try { return msg.gets("Peer-Address"); } catch (...) {}
|
||||
return "(unknown)";
|
||||
|
@ -93,29 +117,12 @@ inline const char* peer_address(zmq::message_t& msg) {
|
|||
|
||||
// Returns a string view of the given message data. It's the caller's responsibility to keep the
|
||||
// referenced message alive. If you want a std::string instead just call `m.to_string()`
|
||||
inline string_view view(const zmq::message_t& m) {
|
||||
inline std::string_view view(const zmq::message_t& m) {
|
||||
return {m.data<char>(), m.size()};
|
||||
}
|
||||
|
||||
inline std::string to_string(AuthLevel a) {
|
||||
switch (a) {
|
||||
case AuthLevel::denied: return "denied";
|
||||
case AuthLevel::none: return "none";
|
||||
case AuthLevel::basic: return "basic";
|
||||
case AuthLevel::admin: return "admin";
|
||||
default: return "(unknown)";
|
||||
}
|
||||
}
|
||||
|
||||
inline AuthLevel auth_from_string(string_view a) {
|
||||
if (a == "none") return AuthLevel::none;
|
||||
if (a == "basic") return AuthLevel::basic;
|
||||
if (a == "admin") return AuthLevel::admin;
|
||||
return AuthLevel::denied;
|
||||
}
|
||||
|
||||
// Extracts and builds the "send" part of a message for proxy_send/proxy_reply
|
||||
inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_view route) {
|
||||
inline std::list<zmq::message_t> build_send_parts(oxenc::bt_list_consumer send, std::string_view route) {
|
||||
std::list<zmq::message_t> parts;
|
||||
if (!route.empty())
|
||||
parts.push_back(create_message(route));
|
||||
|
@ -127,7 +134,7 @@ inline std::list<zmq::message_t> build_send_parts(bt_list_consumer send, string_
|
|||
/// Sends a control message to a specific destination by prefixing the worker name (or identity)
|
||||
/// then appending the command and optional data (if non-empty). (This is needed when sending the control message
|
||||
/// to a router socket, i.e. inside the proxy thread).
|
||||
inline void route_control(zmq::socket_t& sock, string_view identity, string_view cmd, const std::string& data = {}) {
|
||||
inline void route_control(zmq::socket_t& sock, std::string_view identity, std::string_view cmd, const std::string& data = {}) {
|
||||
sock.send(create_message(identity), zmq::send_flags::sndmore);
|
||||
detail::send_control(sock, cmd, data);
|
||||
}
|
|
@ -1,21 +1,29 @@
|
|||
#include "lokimq.h"
|
||||
#include "lokimq-internal.h"
|
||||
#include "oxenmq.h"
|
||||
#include "oxenmq-internal.h"
|
||||
#include "zmq.hpp"
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <random>
|
||||
#include <ostream>
|
||||
#include <thread>
|
||||
#include <future>
|
||||
|
||||
extern "C" {
|
||||
#include <sodium.h>
|
||||
#include <sodium/core.h>
|
||||
#include <sodium/crypto_box.h>
|
||||
#include <sodium/crypto_scalarmult.h>
|
||||
}
|
||||
#include "hex.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include <oxenc/variant.h>
|
||||
|
||||
namespace lokimq {
|
||||
namespace oxenmq {
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
/// Creates a message by bt-serializing the given value (string, number, list, or dict)
|
||||
template <typename T>
|
||||
zmq::message_t create_bt_message(T&& data) { return create_message(bt_serialize(std::forward<T>(data))); }
|
||||
zmq::message_t create_bt_message(T&& data) { return create_message(oxenc::bt_serialize(std::forward<T>(data))); }
|
||||
|
||||
template <typename MessageContainer>
|
||||
std::vector<std::string> as_strings(const MessageContainer& msgs) {
|
||||
|
@ -26,11 +34,6 @@ std::vector<std::string> as_strings(const MessageContainer& msgs) {
|
|||
return result;
|
||||
}
|
||||
|
||||
void check_started(const std::thread& proxy_thread, const std::string &verb) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot " + verb + " before calling `start()`");
|
||||
}
|
||||
|
||||
void check_not_started(const std::thread& proxy_thread, const std::string &verb) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot " + verb + " after calling `start()`");
|
||||
|
@ -43,7 +46,7 @@ namespace detail {
|
|||
|
||||
// Sends a control messages between proxy and threads or between proxy and workers consisting of a
|
||||
// single command codes with an optional data part (the data frame is omitted if empty).
|
||||
void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
|
||||
void send_control(zmq::socket_t& sock, std::string_view cmd, std::string data) {
|
||||
auto c = create_message(std::move(cmd));
|
||||
if (data.empty()) {
|
||||
sock.send(c, zmq::send_flags::none);
|
||||
|
@ -54,29 +57,20 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
|
||||
/// socket.
|
||||
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
||||
auto result = std::make_tuple(""s, false, AuthLevel::none);
|
||||
/// Extracts a pubkey and and auth level from a zmq message received on a *listening* socket.
|
||||
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
||||
auto result = std::make_pair(""s, AuthLevel::none);
|
||||
try {
|
||||
string_view pubkey_hex{msg.gets("User-Id")};
|
||||
std::string_view pubkey_hex{msg.gets("User-Id")};
|
||||
if (pubkey_hex.size() != 64)
|
||||
throw std::logic_error("bad user-id");
|
||||
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
|
||||
auto& pubkey = std::get<std::string>(result);
|
||||
pubkey.resize(32, 0);
|
||||
from_hex(pubkey_hex.begin(), pubkey_hex.end(), pubkey.begin());
|
||||
assert(oxenc::is_hex(pubkey_hex.begin(), pubkey_hex.end()));
|
||||
result.first.resize(32, 0);
|
||||
oxenc::from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin());
|
||||
} catch (...) {}
|
||||
|
||||
try {
|
||||
string_view is_sn{msg.gets("X-SN")};
|
||||
if (is_sn.size() == 1 && is_sn[0] == '1')
|
||||
std::get<bool>(result) = true;
|
||||
} catch (...) {}
|
||||
|
||||
try {
|
||||
string_view auth_level{msg.gets("X-AuthLevel")};
|
||||
std::get<AuthLevel>(result) = auth_from_string(auth_level);
|
||||
result.second = auth_from_string(msg.gets("X-AuthLevel"));
|
||||
} catch (...) {}
|
||||
|
||||
return result;
|
||||
|
@ -85,16 +79,20 @@ std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
|||
|
||||
} // namespace detail
|
||||
|
||||
void LokiMQ::log_level(LogLevel level) {
|
||||
void OxenMQ::set_zmq_context_option(zmq::ctxopt option, int value) {
|
||||
context.set(option, value);
|
||||
}
|
||||
|
||||
void OxenMQ::log_level(LogLevel level) {
|
||||
log_lvl.store(level, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
LogLevel LokiMQ::log_level() const {
|
||||
LogLevel OxenMQ::log_level() const {
|
||||
return log_lvl.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
|
||||
CatHelper LokiMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) {
|
||||
CatHelper OxenMQ::add_category(std::string name, Access access_level, unsigned int reserved_threads, int max_queue) {
|
||||
check_not_started(proxy_thread, "add a category");
|
||||
|
||||
if (name.size() > MAX_CATEGORY_LENGTH)
|
||||
|
@ -112,7 +110,7 @@ CatHelper LokiMQ::add_category(std::string name, Access access_level, unsigned i
|
|||
return ret;
|
||||
}
|
||||
|
||||
void LokiMQ::add_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
void OxenMQ::add_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
check_not_started(proxy_thread, "add a command");
|
||||
|
||||
if (name.size() > MAX_COMMAND_LENGTH)
|
||||
|
@ -131,12 +129,12 @@ void LokiMQ::add_command(const std::string& category, std::string name, CommandC
|
|||
throw std::runtime_error("Cannot add command `" + fullname + "': that command already exists");
|
||||
}
|
||||
|
||||
void LokiMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
void OxenMQ::add_request_command(const std::string& category, std::string name, CommandCallback callback) {
|
||||
add_command(category, name, std::move(callback));
|
||||
categories.at(category).commands.at(name).second = true;
|
||||
}
|
||||
|
||||
void LokiMQ::add_command_alias(std::string from, std::string to) {
|
||||
void OxenMQ::add_command_alias(std::string from, std::string to) {
|
||||
check_not_started(proxy_thread, "add a command alias");
|
||||
|
||||
if (from.empty())
|
||||
|
@ -165,57 +163,56 @@ std::atomic<int> next_id{1};
|
|||
/// Accesses a thread-local command socket connected to the proxy's command socket used to issue
|
||||
/// commands in a thread-safe manner. A mutex is only required here the first time a thread
|
||||
/// accesses the control socket.
|
||||
zmq::socket_t& LokiMQ::get_control_socket() {
|
||||
zmq::socket_t& OxenMQ::get_control_socket() {
|
||||
assert(proxy_thread.joinable());
|
||||
|
||||
// Maps the LokiMQ unique ID to a local thread command socket.
|
||||
static thread_local std::map<int, std::shared_ptr<zmq::socket_t>> control_sockets;
|
||||
static thread_local std::pair<int, std::shared_ptr<zmq::socket_t>> last{-1, nullptr};
|
||||
|
||||
// Optimize by caching the last value; LokiMQ is often a singleton and in that case we're
|
||||
// Optimize by caching the last value; OxenMQ is often a singleton and in that case we're
|
||||
// going to *always* hit this optimization. Even if it isn't, we're probably likely to need the
|
||||
// same control socket from the same thread multiple times sequentially so this may still help.
|
||||
if (object_id == last.first)
|
||||
return *last.second;
|
||||
static thread_local int last_id = -1;
|
||||
static thread_local zmq::socket_t* last_socket = nullptr;
|
||||
if (object_id == last_id)
|
||||
return *last_socket;
|
||||
|
||||
auto it = control_sockets.find(object_id);
|
||||
if (it != control_sockets.end()) {
|
||||
last = *it;
|
||||
return *last.second;
|
||||
}
|
||||
std::lock_guard lock{control_sockets_mutex};
|
||||
|
||||
std::lock_guard<std::mutex> lock{control_sockets_mutex};
|
||||
if (proxy_shutting_down)
|
||||
throw std::runtime_error("Unable to obtain LokiMQ control socket: proxy thread is shutting down");
|
||||
auto control = std::make_shared<zmq::socket_t>(context, zmq::socket_type::dealer);
|
||||
control->setsockopt<int>(ZMQ_LINGER, 0);
|
||||
control->connect(SN_ADDR_COMMAND);
|
||||
thread_control_sockets.push_back(control);
|
||||
control_sockets.emplace(object_id, control);
|
||||
last.first = object_id;
|
||||
last.second = std::move(control);
|
||||
return *last.second;
|
||||
throw std::runtime_error("Unable to obtain OxenMQ control socket: proxy thread is shutting down");
|
||||
|
||||
auto& socket = control_sockets[std::this_thread::get_id()];
|
||||
if (!socket) {
|
||||
socket = std::make_unique<zmq::socket_t>(context, zmq::socket_type::dealer);
|
||||
socket->set(zmq::sockopt::linger, 0);
|
||||
socket->connect(SN_ADDR_COMMAND);
|
||||
}
|
||||
last_id = object_id;
|
||||
last_socket = socket.get();
|
||||
return *last_socket;
|
||||
}
|
||||
|
||||
|
||||
LokiMQ::LokiMQ(
|
||||
OxenMQ::OxenMQ(
|
||||
std::string pubkey_,
|
||||
std::string privkey_,
|
||||
bool service_node,
|
||||
SNRemoteAddress lookup,
|
||||
Logger logger)
|
||||
Logger logger,
|
||||
LogLevel level)
|
||||
: object_id{next_id++}, pubkey{std::move(pubkey_)}, privkey{std::move(privkey_)}, local_service_node{service_node},
|
||||
sn_lookup{std::move(lookup)}, logger{std::move(logger)}
|
||||
sn_lookup{std::move(lookup)}, log_lvl{level}, logger{std::move(logger)}
|
||||
{
|
||||
|
||||
LMQ_TRACE("Constructing listening LokiMQ, id=", object_id, ", this=", this);
|
||||
OMQ_TRACE("Constructing OxenMQ, id=", object_id, ", this=", this);
|
||||
|
||||
if (sodium_init() == -1)
|
||||
throw std::runtime_error{"libsodium initialization failed"};
|
||||
|
||||
if (pubkey.empty() != privkey.empty()) {
|
||||
throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
|
||||
throw std::invalid_argument("OxenMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
|
||||
} else if (pubkey.empty()) {
|
||||
if (service_node)
|
||||
throw std::invalid_argument("Cannot construct a service node mode LokiMQ without a keypair");
|
||||
LMQ_LOG(debug, "generating x25519 keypair for remote-only LokiMQ instance");
|
||||
throw std::invalid_argument("Cannot construct a service node mode OxenMQ without a keypair");
|
||||
OMQ_LOG(debug, "generating x25519 keypair for remote-only OxenMQ instance");
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
|
@ -230,63 +227,101 @@ LokiMQ::LokiMQ(
|
|||
std::string verify_pubkey(crypto_box_PUBLICKEYBYTES, 0);
|
||||
crypto_scalarmult_base(reinterpret_cast<unsigned char*>(&verify_pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
if (verify_pubkey != pubkey)
|
||||
throw std::invalid_argument("Invalid pubkey/privkey values given to LokiMQ construction: pubkey verification failed");
|
||||
throw std::invalid_argument("Invalid pubkey/privkey values given to OxenMQ construction: pubkey verification failed");
|
||||
}
|
||||
}
|
||||
|
||||
void LokiMQ::start() {
|
||||
void OxenMQ::start() {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call start() multiple times!");
|
||||
|
||||
// If we're not binding to anything then we don't listen, i.e. we can only establish outbound
|
||||
// connections. Don't allow this if we are in service_node mode because, if we aren't
|
||||
// listening, we are useless as a service node.
|
||||
if (bind.empty() && local_service_node)
|
||||
throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"};
|
||||
OMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", oxenc::to_hex(pubkey));
|
||||
|
||||
LMQ_LOG(info, "Initializing LokiMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey));
|
||||
assert(general_workers > 0);
|
||||
if (batch_jobs_reserved < 0)
|
||||
batch_jobs_reserved = (general_workers + 1) / 2;
|
||||
if (reply_jobs_reserved < 0)
|
||||
reply_jobs_reserved = (general_workers + 7) / 8;
|
||||
|
||||
max_workers = general_workers + batch_jobs_reserved + reply_jobs_reserved;
|
||||
for (const auto& cat : categories) {
|
||||
max_workers += cat.second.reserved_threads;
|
||||
}
|
||||
|
||||
if (log_level() >= LogLevel::debug) {
|
||||
OMQ_LOG(debug, "Reserving space for ", max_workers, " max workers = ", general_workers, " general plus reservations for:");
|
||||
for (const auto& cat : categories)
|
||||
OMQ_LOG(debug, " - ", cat.first, ": ", cat.second.reserved_threads);
|
||||
OMQ_LOG(debug, " - (batch jobs): ", batch_jobs_reserved);
|
||||
OMQ_LOG(debug, " - (reply jobs): ", reply_jobs_reserved);
|
||||
OMQ_LOG(debug, "Plus ", tagged_workers.size(), " tagged worker threads");
|
||||
}
|
||||
|
||||
if (MAX_SOCKETS != 0) {
|
||||
// The max sockets setting we apply to the context here is used during zmq context
|
||||
// initialization, which happens when the first socket is constructed using this context:
|
||||
// hence we set this *before* constructing any socket_t on the context.
|
||||
int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit);
|
||||
int want_sockets = MAX_SOCKETS < 0 ? zmq_socket_limit :
|
||||
std::min<int>(zmq_socket_limit,
|
||||
MAX_SOCKETS + max_workers + tagged_workers.size()
|
||||
+ 4 /* zap_auth, workers_socket, command, inproc_listener */);
|
||||
context.set(zmq::ctxopt::max_sockets, want_sockets);
|
||||
}
|
||||
|
||||
// We bind `command` here so that the `get_control_socket()` below is always connecting to a
|
||||
// bound socket, but we do nothing else here: the proxy thread is responsible for everything
|
||||
// except binding it.
|
||||
command = zmq::socket_t{context, zmq::socket_type::router};
|
||||
command.bind(SN_ADDR_COMMAND);
|
||||
proxy_thread = std::thread{&LokiMQ::proxy_loop, this};
|
||||
std::promise<void> startup_prom;
|
||||
auto proxy_startup = startup_prom.get_future();
|
||||
proxy_thread = std::thread{&OxenMQ::proxy_loop, this, std::move(startup_prom)};
|
||||
|
||||
LMQ_LOG(debug, "Waiting for proxy thread to get ready...");
|
||||
OMQ_LOG(debug, "Waiting for proxy thread to initialize...");
|
||||
proxy_startup.get(); // Rethrows exceptions from the proxy startup (e.g. failure to bind)
|
||||
|
||||
OMQ_LOG(debug, "Waiting for proxy thread to get ready...");
|
||||
auto &control = get_control_socket();
|
||||
detail::send_control(control, "START");
|
||||
LMQ_TRACE("Sent START command");
|
||||
OMQ_TRACE("Sent START command");
|
||||
|
||||
zmq::message_t ready_msg;
|
||||
std::vector<zmq::message_t> parts;
|
||||
try { recv_message_parts(control, parts); }
|
||||
catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from LokiMQ::Proxy thread: "s + e.what()); }
|
||||
catch (const zmq::error_t &e) { throw std::runtime_error("Failure reading from OxenMQ::Proxy thread: "s + e.what()); }
|
||||
|
||||
if (!(parts.size() == 1 && view(parts.front()) == "READY"))
|
||||
throw std::runtime_error("Invalid startup message from proxy thread (didn't get expected READY message)");
|
||||
LMQ_LOG(debug, "Proxy thread is ready");
|
||||
OMQ_LOG(debug, "Proxy thread is ready");
|
||||
}
|
||||
|
||||
void LokiMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) {
|
||||
// TODO: there's no particular reason we can't start listening after starting up; just needs to
|
||||
// be implemented. (But if we can start we'll probably also want to be able to stop, so it's
|
||||
// more than just binding that needs implementing).
|
||||
check_not_started(proxy_thread, "start listening");
|
||||
|
||||
bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)});
|
||||
void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
|
||||
if (std::string_view{bind_addr}.substr(0, 9) == "inproc://")
|
||||
throw std::logic_error{"inproc:// cannot be used with listen_curve"};
|
||||
if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; };
|
||||
bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)};
|
||||
if (proxy_thread.joinable())
|
||||
detail::send_control(get_control_socket(), "BIND", oxenc::bt_serialize(detail::serialize_object(std::move(d))));
|
||||
else
|
||||
bind.push_back(std::move(d));
|
||||
}
|
||||
|
||||
void LokiMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) {
|
||||
// TODO: As above.
|
||||
check_not_started(proxy_thread, "start listening");
|
||||
|
||||
bind.emplace_back(std::move(bind_addr), bind_data{false, std::move(allow_connection)});
|
||||
void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
|
||||
if (std::string_view{bind_addr}.substr(0, 9) == "inproc://")
|
||||
throw std::logic_error{"inproc:// cannot be used with listen_plain"};
|
||||
if (!allow_connection) allow_connection = [](auto&&...) { return AuthLevel::none; };
|
||||
bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)};
|
||||
if (proxy_thread.joinable())
|
||||
detail::send_control(get_control_socket(), "BIND", oxenc::bt_serialize(detail::serialize_object(std::move(d))));
|
||||
else
|
||||
bind.push_back(std::move(d));
|
||||
}
|
||||
|
||||
|
||||
std::pair<LokiMQ::category*, const std::pair<LokiMQ::CommandCallback, bool>*> LokiMQ::get_command(std::string& command) {
|
||||
std::pair<OxenMQ::category*, const std::pair<OxenMQ::CommandCallback, bool>*> OxenMQ::get_command(std::string& command) {
|
||||
if (command.size() > MAX_CATEGORY_LENGTH + 1 + MAX_COMMAND_LENGTH) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "': command too long");
|
||||
OMQ_LOG(warn, "Invalid command '", command, "': command too long");
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -298,7 +333,7 @@ std::pair<LokiMQ::category*, const std::pair<LokiMQ::CommandCallback, bool>*> Lo
|
|||
|
||||
auto dot = command.find('.');
|
||||
if (dot == 0 || dot == std::string::npos) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "': expected <category>.<command>");
|
||||
OMQ_LOG(warn, "Invalid command '", command, "': expected <category>.<command>");
|
||||
return {};
|
||||
}
|
||||
std::string catname = command.substr(0, dot);
|
||||
|
@ -306,21 +341,21 @@ std::pair<LokiMQ::category*, const std::pair<LokiMQ::CommandCallback, bool>*> Lo
|
|||
|
||||
auto catit = categories.find(catname);
|
||||
if (catit == categories.end()) {
|
||||
LMQ_LOG(warn, "Invalid command category '", catname, "'");
|
||||
OMQ_LOG(warn, "Invalid command category '", catname, "'");
|
||||
return {};
|
||||
}
|
||||
|
||||
const auto& category = catit->second;
|
||||
auto callback_it = category.commands.find(cmd);
|
||||
if (callback_it == category.commands.end()) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "'");
|
||||
OMQ_LOG(warn, "Invalid command '", command, "'");
|
||||
return {};
|
||||
}
|
||||
|
||||
return {&catit->second, &callback_it->second};
|
||||
}
|
||||
|
||||
void LokiMQ::set_batch_threads(int threads) {
|
||||
void OxenMQ::set_batch_threads(int threads) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot change reserved batch threads after calling `start()`");
|
||||
if (threads < -1) // -1 is the default which is based on general threads
|
||||
|
@ -328,7 +363,7 @@ void LokiMQ::set_batch_threads(int threads) {
|
|||
batch_jobs_reserved = threads;
|
||||
}
|
||||
|
||||
void LokiMQ::set_reply_threads(int threads) {
|
||||
void OxenMQ::set_reply_threads(int threads) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot change reserved reply threads after calling `start()`");
|
||||
if (threads < -1) // -1 is the default which is based on general threads
|
||||
|
@ -336,7 +371,7 @@ void LokiMQ::set_reply_threads(int threads) {
|
|||
reply_jobs_reserved = threads;
|
||||
}
|
||||
|
||||
void LokiMQ::set_general_threads(int threads) {
|
||||
void OxenMQ::set_general_threads(int threads) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot change general thread count after calling `start()`");
|
||||
if (threads < 1)
|
||||
|
@ -344,80 +379,72 @@ void LokiMQ::set_general_threads(int threads) {
|
|||
general_workers = threads;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_,
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_,
|
||||
std::vector<zmq::message_t> data_parts_, const std::pair<CommandCallback, bool>* callback_) {
|
||||
is_batch_job = false;
|
||||
is_reply_job = false;
|
||||
reset();
|
||||
cat = cat_;
|
||||
command = std::move(command_);
|
||||
conn = std::move(conn_);
|
||||
access = std::move(access_);
|
||||
remote = std::move(remote_);
|
||||
data_parts = std::move(data_parts_);
|
||||
callback = callback_;
|
||||
to_run = callback_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) {
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.conn),
|
||||
std::move(pending.data_parts), pending.callback);
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function<void()> callback) {
|
||||
reset();
|
||||
is_injected = true;
|
||||
cat = cat_;
|
||||
command = std::move(command_);
|
||||
conn = {};
|
||||
access = {};
|
||||
remote = std::move(remote_);
|
||||
to_run = std::move(callback);
|
||||
return *this;
|
||||
}
|
||||
|
||||
LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job) {
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(pending_command&& pending) {
|
||||
if (auto *f = std::get_if<std::function<void()>>(&pending.callback))
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.remote), std::move(*f));
|
||||
|
||||
assert(pending.callback.index() == 0);
|
||||
return load(&pending.cat, std::move(pending.command), std::move(pending.conn), std::move(pending.access),
|
||||
std::move(pending.remote), std::move(pending.data_parts), var::get<0>(pending.callback));
|
||||
}
|
||||
|
||||
OxenMQ::run_info& OxenMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) {
|
||||
reset();
|
||||
is_batch_job = true;
|
||||
is_reply_job = reply_job;
|
||||
is_tagged_thread_job = tagged_thread > 0;
|
||||
batch_jobno = bj.second;
|
||||
batch = bj.first;
|
||||
to_run = bj.first;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
LokiMQ::~LokiMQ() {
|
||||
if (!proxy_thread.joinable())
|
||||
return;
|
||||
OxenMQ::~OxenMQ() {
|
||||
if (!proxy_thread.joinable()) {
|
||||
if (!tagged_workers.empty()) {
|
||||
// We have tagged workers that are waiting on a signal for startup, but we didn't start
|
||||
// up, so signal them so that they can end themselves.
|
||||
{
|
||||
std::lock_guard lock{tagged_startup_mutex};
|
||||
tagged_go = tagged_go_mode::SHUTDOWN;
|
||||
}
|
||||
tagged_cv.notify_all();
|
||||
for (auto& [run, busy, queue] : tagged_workers)
|
||||
run.worker_thread.join();
|
||||
}
|
||||
|
||||
LMQ_LOG(info, "LokiMQ shutting down proxy thread");
|
||||
return;
|
||||
}
|
||||
|
||||
OMQ_LOG(info, "OxenMQ shutting down proxy thread");
|
||||
detail::send_control(get_control_socket(), "QUIT");
|
||||
proxy_thread.join();
|
||||
LMQ_LOG(info, "LokiMQ proxy thread has stopped");
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
|
||||
check_started(proxy_thread, "connect");
|
||||
|
||||
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
|
||||
|
||||
return pubkey;
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
if (!proxy_thread.joinable())
|
||||
LMQ_LOG(warn, "connect_remote() called before start(); this won't take effect until start() is called");
|
||||
|
||||
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
|
||||
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
|
||||
|
||||
auto id = next_conn_id++;
|
||||
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
|
||||
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
|
||||
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
|
||||
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
|
||||
{"conn_id", id},
|
||||
{"connect", reinterpret_cast<uintptr_t>(new ConnectSuccess{std::move(on_connect)})},
|
||||
{"failure", reinterpret_cast<uintptr_t>(new ConnectFailure{std::move(on_failure)})},
|
||||
{"pubkey", pubkey},
|
||||
{"remote", remote},
|
||||
{"timeout", timeout.count()},
|
||||
}));
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
||||
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
|
||||
{"conn_id", id.id},
|
||||
{"linger_ms", linger.count()},
|
||||
{"pubkey", id.pk},
|
||||
}));
|
||||
OMQ_LOG(info, "OxenMQ proxy thread has stopped");
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
|
||||
|
@ -433,13 +460,14 @@ std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
|
|||
|
||||
std::string make_random_string(size_t size) {
|
||||
static thread_local std::mt19937_64 rng{std::random_device{}()};
|
||||
static thread_local std::uniform_int_distribution<char> dist{std::numeric_limits<char>::min(), std::numeric_limits<char>::max()};
|
||||
std::string rando;
|
||||
rando.reserve(size);
|
||||
for (size_t i = 0; i < size; i++)
|
||||
rando += dist(rng);
|
||||
while (rando.size() < size) {
|
||||
uint64_t x = rng();
|
||||
rando.append(reinterpret_cast<const char*>(&x), std::min<size_t>(size - rando.size(), 8));
|
||||
}
|
||||
return rando;
|
||||
}
|
||||
|
||||
} // namespace lokimq
|
||||
} // namespace oxenmq
|
||||
// vim:sw=4:et
|
1835
oxenmq/oxenmq.h
Normal file
1835
oxenmq/oxenmq.h
Normal file
File diff suppressed because it is too large
Load diff
843
oxenmq/proxy.cpp
Normal file
843
oxenmq/proxy.cpp
Normal file
|
@ -0,0 +1,843 @@
|
|||
#include "oxenmq.h"
|
||||
#include "oxenmq-internal.h"
|
||||
#include <oxenc/hex.h>
|
||||
#include <exception>
|
||||
#include <future>
|
||||
|
||||
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
extern "C" {
|
||||
#include <pthread.h>
|
||||
#include <pthread_np.h>
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
#include <sys/epoll.h>
|
||||
#endif
|
||||
|
||||
#ifndef _WIN32
|
||||
extern "C" {
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
void OxenMQ::proxy_quit() {
|
||||
OMQ_LOG(debug, "Received quit command, shutting down proxy thread");
|
||||
|
||||
assert(std::none_of(workers.begin(), workers.end(), [](auto& worker) { return worker.worker_thread.joinable(); }));
|
||||
assert(std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto& worker) { return std::get<0>(worker).worker_thread.joinable(); }));
|
||||
|
||||
command.set(zmq::sockopt::linger, 0);
|
||||
command.close();
|
||||
{
|
||||
std::lock_guard lock{control_sockets_mutex};
|
||||
proxy_shutting_down = true; // To prevent threads from opening new control sockets
|
||||
}
|
||||
workers_socket.close();
|
||||
int linger = std::chrono::milliseconds{CLOSE_LINGER}.count();
|
||||
for (auto& [id, s] : connections)
|
||||
s.set(zmq::sockopt::linger, linger);
|
||||
connections.clear();
|
||||
peers.clear();
|
||||
|
||||
OMQ_LOG(debug, "Proxy thread teardown complete");
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_send(oxenc::bt_dict_consumer data) {
|
||||
// NB: bt_dict_consumer goes in alphabetical order
|
||||
std::string_view hint;
|
||||
std::chrono::milliseconds keep_alive{DEFAULT_SEND_KEEP_ALIVE};
|
||||
std::chrono::milliseconds request_timeout{DEFAULT_REQUEST_TIMEOUT};
|
||||
bool optional = false;
|
||||
bool outgoing = false;
|
||||
bool incoming = false;
|
||||
bool request = false;
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id;
|
||||
|
||||
std::string request_tag;
|
||||
ReplyCallback request_callback;
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; conn_pubkey or conn_id missing");
|
||||
if (data.skip_until("conn_route"))
|
||||
conn_id.route = data.consume_string();
|
||||
if (data.skip_until("hint"))
|
||||
hint = data.consume_string_view();
|
||||
if (data.skip_until("incoming"))
|
||||
incoming = data.consume_integer<bool>();
|
||||
if (data.skip_until("keep_alive"))
|
||||
keep_alive = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
if (data.skip_until("optional"))
|
||||
optional = data.consume_integer<bool>();
|
||||
if (data.skip_until("outgoing"))
|
||||
outgoing = data.consume_integer<bool>();
|
||||
|
||||
if (data.skip_until("request"))
|
||||
request = data.consume_integer<bool>();
|
||||
if (request) {
|
||||
if (!data.skip_until("request_callback"))
|
||||
throw std::runtime_error("Internal error: received request without request_callback");
|
||||
|
||||
request_callback = detail::deserialize_object<ReplyCallback>(data.consume_integer<uintptr_t>());
|
||||
|
||||
if (!data.skip_until("request_tag"))
|
||||
throw std::runtime_error("Internal error: received request without request_name");
|
||||
request_tag = data.consume_string();
|
||||
if (data.skip_until("request_timeout"))
|
||||
request_timeout = std::chrono::milliseconds{data.consume_integer<uint64_t>()};
|
||||
}
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy send command; send parts missing");
|
||||
oxenc::bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
send_option::queue_failure::callback_t callback_nosend;
|
||||
if (data.skip_until("send_fail"))
|
||||
callback_nosend = detail::deserialize_object<decltype(callback_nosend)>(data.consume_integer<uintptr_t>());
|
||||
|
||||
send_option::queue_full::callback_t callback_noqueue;
|
||||
if (data.skip_until("send_full_q"))
|
||||
callback_noqueue = detail::deserialize_object<decltype(callback_noqueue)>(data.consume_integer<uintptr_t>());
|
||||
|
||||
// Now figure out which socket to send to and do the actual sending. We can repeat this loop
|
||||
// multiple times, if we're sending to a SN, because it's possible that we have multiple
|
||||
// connections open to that SN (e.g. one out + one in) so if one fails we can clean up that
|
||||
// connection and try the next one.
|
||||
bool retry = true, sent = false, nowarn = false;
|
||||
while (retry) {
|
||||
retry = false;
|
||||
zmq::socket_t *send_to;
|
||||
if (conn_id.sn()) {
|
||||
auto sock_route = proxy_connect_sn(conn_id.pk, hint, optional, incoming, outgoing, EPHEMERAL_ROUTING_ID, keep_alive);
|
||||
if (!sock_route.first) {
|
||||
nowarn = true;
|
||||
if (optional)
|
||||
OMQ_LOG(debug, "Not sending: send is optional and no connection to ",
|
||||
oxenc::to_hex(conn_id.pk), " is currently established");
|
||||
else
|
||||
OMQ_LOG(error, "Unable to send to ", oxenc::to_hex(conn_id.pk), ": no valid connection address found");
|
||||
break;
|
||||
}
|
||||
send_to = sock_route.first;
|
||||
conn_id.route = std::move(sock_route.second);
|
||||
} else if (!conn_id.route.empty()) { // incoming non-SN connection
|
||||
auto it = connections.find(conn_id.id);
|
||||
if (it == connections.end()) {
|
||||
OMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
|
||||
break;
|
||||
}
|
||||
send_to = &it->second;
|
||||
} else {
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == peers.end()) {
|
||||
OMQ_LOG(warn, "Unable to send: connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
|
||||
break;
|
||||
}
|
||||
auto& peer = pr.first->second;
|
||||
auto it = connections.find(peer.conn_id);
|
||||
if (it == connections.end()) {
|
||||
OMQ_LOG(warn, "Unable to send: peer connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
|
||||
break;
|
||||
}
|
||||
send_to = &it->second;
|
||||
}
|
||||
|
||||
try {
|
||||
sent = send_message_parts(*send_to, build_send_parts(send, conn_id.route));
|
||||
} catch (const zmq::error_t &e) {
|
||||
if (e.num() == EHOSTUNREACH && !conn_id.route.empty() /*= incoming conn*/) {
|
||||
|
||||
OMQ_LOG(debug, "Incoming connection is no longer valid; removing peer details");
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first != peers.end()) {
|
||||
if (!conn_id.sn()) {
|
||||
peers.erase(pr.first);
|
||||
} else {
|
||||
bool removed;
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
auto& peer = it->second;
|
||||
if (peer.route == conn_id.route) {
|
||||
peers.erase(it);
|
||||
removed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The incoming connection to the SN is no longer good, but we can retry because
|
||||
// we may have another active connection with the SN (or may want to open one).
|
||||
if (removed) {
|
||||
OMQ_LOG(debug, "Retrying sending to SN ", oxenc::to_hex(conn_id.pk), " using other sockets");
|
||||
retry = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!retry) {
|
||||
if (!conn_id.sn() && !conn_id.route.empty()) { // incoming non-SN connection
|
||||
OMQ_LOG(debug, "Unable to send message to incoming connection ", conn_id, ": ", e.what(),
|
||||
"; remote has probably disconnected");
|
||||
} else {
|
||||
OMQ_LOG(warn, "Unable to send message to ", conn_id, ": ", e.what());
|
||||
}
|
||||
nowarn = true;
|
||||
if (callback_nosend) {
|
||||
job([callback = std::move(callback_nosend), error = e] { callback(&error); });
|
||||
callback_nosend = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (request) {
|
||||
if (sent) {
|
||||
OMQ_LOG(debug, "Added new pending request ", oxenc::to_hex(request_tag));
|
||||
pending_requests.insert({ request_tag, {
|
||||
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
|
||||
} else {
|
||||
OMQ_LOG(debug, "Could not send request, scheduling request callback failure");
|
||||
job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); });
|
||||
}
|
||||
}
|
||||
if (!sent) {
|
||||
if (callback_nosend)
|
||||
job([callback = std::move(callback_nosend)] { callback(nullptr); });
|
||||
else if (callback_noqueue)
|
||||
job(std::move(callback_noqueue));
|
||||
else if (!nowarn)
|
||||
OMQ_LOG(warn, "Unable to send message to ", conn_id, ": sending would block");
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_reply(oxenc::bt_dict_consumer data) {
|
||||
bool have_conn_id = false;
|
||||
ConnectionID conn_id{0};
|
||||
if (data.skip_until("conn_id")) {
|
||||
conn_id.id = data.consume_integer<long long>();
|
||||
if (conn_id.id == -1)
|
||||
throw std::runtime_error("Invalid error: invalid conn_id value (-1)");
|
||||
have_conn_id = true;
|
||||
}
|
||||
if (data.skip_until("conn_pubkey")) {
|
||||
if (have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_id and conn_pubkey are exclusive");
|
||||
conn_id.pk = data.consume_string();
|
||||
conn_id.id = ConnectionID::SN_ID;
|
||||
} else if (!have_conn_id)
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; conn_pubkey or conn_id missing");
|
||||
if (!data.skip_until("send"))
|
||||
throw std::runtime_error("Internal error: Invalid proxy reply command; send parts missing");
|
||||
|
||||
oxenc::bt_list_consumer send = data.consume_list_consumer();
|
||||
|
||||
auto pr = peers.equal_range(conn_id);
|
||||
if (pr.first == pr.second) {
|
||||
OMQ_LOG(warn, "Unable to send tagged reply: the connection is no longer valid");
|
||||
return;
|
||||
}
|
||||
|
||||
// We try any connections until one works (for ordinary remotes there will be just one, but for
|
||||
// SNs there might be one incoming and one outgoing).
|
||||
for (auto it = pr.first; it != pr.second; ) {
|
||||
try {
|
||||
send_message_parts(connections[it->second.conn_id], build_send_parts(send, it->second.route));
|
||||
break;
|
||||
} catch (const zmq::error_t &err) {
|
||||
if (err.num() == EHOSTUNREACH) {
|
||||
if (it->second.outgoing()) {
|
||||
OMQ_LOG(debug, "Unable to send reply to non-SN request on outgoing socket: "
|
||||
"remote is no longer connected; closing connection");
|
||||
proxy_close_connection(it->second.conn_id, CLOSE_LINGER);
|
||||
it = peers.erase(it);
|
||||
++it;
|
||||
} else {
|
||||
OMQ_LOG(debug, "Unable to send reply to non-SN request on incoming socket: "
|
||||
"remote is no longer connected; removing peer details");
|
||||
it = peers.erase(it);
|
||||
}
|
||||
} else {
|
||||
OMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what());
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_control_message(OxenMQ::control_message_array& parts, size_t len) {
|
||||
// We throw an uncaught exception here because we only generate control messages internally in
|
||||
// oxenmq code: if one of these condition fail it's a oxenmq bug.
|
||||
if (len < 2)
|
||||
throw std::logic_error("OxenMQ bug: Expected 2-3 message parts for a proxy control message");
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
OMQ_TRACE("control message: ", cmd);
|
||||
if (len == 3) {
|
||||
OMQ_TRACE("...: ", parts[2]);
|
||||
auto data = view(parts[2]);
|
||||
if (cmd == "SEND") {
|
||||
OMQ_TRACE("proxying message");
|
||||
return proxy_send(data);
|
||||
} else if (cmd == "REPLY") {
|
||||
OMQ_TRACE("proxying reply to non-SN incoming message");
|
||||
return proxy_reply(data);
|
||||
} else if (cmd == "BATCH") {
|
||||
OMQ_TRACE("proxy batch jobs");
|
||||
auto ptrval = oxenc::bt_deserialize<uintptr_t>(data);
|
||||
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
|
||||
} else if (cmd == "INJECT") {
|
||||
OMQ_TRACE("proxy inject");
|
||||
return proxy_inject_task(detail::deserialize_object<injected_task>(oxenc::bt_deserialize<uintptr_t>(data)));
|
||||
} else if (cmd == "SET_SNS") {
|
||||
return proxy_set_active_sns(data);
|
||||
} else if (cmd == "UPDATE_SNS") {
|
||||
return proxy_update_active_sns(data);
|
||||
} else if (cmd == "CONNECT_SN") {
|
||||
proxy_connect_sn(data);
|
||||
return;
|
||||
} else if (cmd == "CONNECT_REMOTE") {
|
||||
return proxy_connect_remote(data);
|
||||
} else if (cmd == "DISCONNECT") {
|
||||
return proxy_disconnect(data);
|
||||
} else if (cmd == "TIMER") {
|
||||
return proxy_timer(data);
|
||||
} else if (cmd == "TIMER_DEL") {
|
||||
return proxy_timer_del(oxenc::bt_deserialize<int>(data));
|
||||
} else if (cmd == "BIND") {
|
||||
auto b = detail::deserialize_object<bind_data>(oxenc::bt_deserialize<uintptr_t>(data));
|
||||
if (proxy_bind(b, bind.size()))
|
||||
bind.push_back(std::move(b));
|
||||
return;
|
||||
}
|
||||
} else if (len == 2) {
|
||||
if (cmd == "START") {
|
||||
// Command send by the owning thread during startup; we send back a simple READY reply to
|
||||
// let it know we are running.
|
||||
return route_control(command, route, "READY");
|
||||
} else if (cmd == "QUIT") {
|
||||
// Asked to quit: set max_workers to zero and tell any idle ones to quit. We will
|
||||
// close workers as they come back to READY status, and then close external
|
||||
// connections once all workers are done.
|
||||
max_workers = 0;
|
||||
for (size_t i = 0; i < idle_worker_count; i++)
|
||||
route_control(workers_socket, workers[idle_workers[i]].worker_routing_id, "QUIT");
|
||||
idle_worker_count = 0;
|
||||
for (auto& [run, busy, queue] : tagged_workers)
|
||||
if (!busy)
|
||||
route_control(workers_socket, run.worker_routing_id, "QUIT");
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("OxenMQ bug: Proxy received invalid control command: " +
|
||||
std::string{cmd} + " (" + std::to_string(len) + ")");
|
||||
}
|
||||
|
||||
bool OxenMQ::proxy_bind(bind_data& b, size_t bind_index) {
|
||||
zmq::socket_t listener{context, zmq::socket_type::router};
|
||||
setup_incoming_socket(listener, b.curve, pubkey, privkey, bind_index);
|
||||
|
||||
bool good = true;
|
||||
try {
|
||||
listener.bind(b.address);
|
||||
} catch (const zmq::error_t&) {
|
||||
good = false;
|
||||
}
|
||||
if (b.on_bind) {
|
||||
b.on_bind(good);
|
||||
b.on_bind = nullptr;
|
||||
}
|
||||
if (!good) {
|
||||
OMQ_LOG(warn, "OxenMQ failed to listen on ", b.address);
|
||||
return false;
|
||||
}
|
||||
|
||||
OMQ_LOG(info, "OxenMQ listening on ", b.address);
|
||||
|
||||
b.conn_id = next_conn_id++;
|
||||
connections.emplace_hint(connections.end(), b.conn_id, std::move(listener));
|
||||
|
||||
connections_updated = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_loop_init() {
|
||||
|
||||
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
|
||||
pthread_setname_np(pthread_self(), "omq-proxy");
|
||||
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
pthread_set_name_np(pthread_self(), "omq-proxy");
|
||||
#elif defined(__MACH__)
|
||||
pthread_setname_np("omq-proxy");
|
||||
#endif
|
||||
|
||||
zap_auth = zmq::socket_t{context, zmq::socket_type::rep};
|
||||
zap_auth.set(zmq::sockopt::linger, 0);
|
||||
zap_auth.bind(ZMQ_ADDR_ZAP);
|
||||
|
||||
workers_socket = zmq::socket_t{context, zmq::socket_type::router};
|
||||
workers_socket.set(zmq::sockopt::router_mandatory, true);
|
||||
workers_socket.bind(SN_ADDR_WORKERS);
|
||||
|
||||
workers.reserve(max_workers);
|
||||
idle_workers.resize(max_workers);
|
||||
if (!workers.empty() || !worker_sockets.empty())
|
||||
throw std::logic_error("Internal error: proxy thread started with active worker threads");
|
||||
worker_sockets.reserve(max_workers);
|
||||
// Pre-initialize these worker sockets rather than creating during thread initialization so that
|
||||
// we can't hit the zmq socket limit during worker thread startup.
|
||||
for (int i = 0; i < max_workers; i++)
|
||||
worker_sockets.emplace_back(context, zmq::socket_type::dealer);
|
||||
|
||||
#ifndef _WIN32
|
||||
int saved_umask = -1;
|
||||
if (STARTUP_UMASK >= 0)
|
||||
saved_umask = umask(STARTUP_UMASK);
|
||||
#endif
|
||||
|
||||
{
|
||||
zmq::socket_t inproc_listener{context, zmq::socket_type::router};
|
||||
inproc_listener.bind(SN_ADDR_SELF);
|
||||
inproc_listener_connid = next_conn_id++;
|
||||
connections.emplace_hint(connections.end(), inproc_listener_connid, std::move(inproc_listener));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < bind.size(); i++) {
|
||||
if (!proxy_bind(bind[i], i)) {
|
||||
OMQ_LOG(fatal, "OxenMQ failed to listen on ", bind[i].address);
|
||||
throw zmq::error_t{};
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef _WIN32
|
||||
if (saved_umask != -1)
|
||||
umask(saved_umask);
|
||||
|
||||
// set socket gid / uid if it is provided
|
||||
if (SOCKET_GID != -1 or SOCKET_UID != -1) {
|
||||
for (auto& b : bind) {
|
||||
const address addr(b.address);
|
||||
if (addr.ipc())
|
||||
if (chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1)
|
||||
throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
connections_updated = true;
|
||||
|
||||
// Also add an internal connection to self so that calling code can avoid needing to
|
||||
// special-case rare situations where we are supposed to talk to a quorum member that happens to
|
||||
// be ourselves (which can happen, for example, with cross-quoum Blink communication)
|
||||
// FIXME: not working
|
||||
//listener.bind(SN_ADDR_SELF);
|
||||
|
||||
if (!timers)
|
||||
timers.reset(zmq_timers_new());
|
||||
|
||||
if (-1 == zmq_timers_add(timers.get(),
|
||||
std::chrono::milliseconds{CONN_CHECK_INTERVAL}.count(),
|
||||
[](int /*timer_id*/, void* self) { static_cast<OxenMQ*>(self)->proxy_conn_cleanup(); },
|
||||
this)) {
|
||||
throw zmq::error_t{};
|
||||
}
|
||||
|
||||
// Wait for tagged worker threads to get ready and connect to us (we get a "STARTING" message)
|
||||
// and send them back a "START" to let them know to go ahead with startup. We need this
|
||||
// synchronization dance to guarantee that the workers are routable before we can proceed.
|
||||
if (!tagged_workers.empty()) {
|
||||
OMQ_LOG(debug, "Waiting for tagged workers");
|
||||
{
|
||||
std::unique_lock lock{tagged_startup_mutex};
|
||||
tagged_go = tagged_go_mode::GO;
|
||||
}
|
||||
tagged_cv.notify_all();
|
||||
std::unordered_set<std::string_view> waiting_on;
|
||||
for (auto& w : tagged_workers)
|
||||
waiting_on.emplace(std::get<run_info>(w).worker_routing_id);
|
||||
for (std::vector<zmq::message_t> parts; !waiting_on.empty(); parts.clear()) {
|
||||
recv_message_parts(workers_socket, parts);
|
||||
if (parts.size() != 2 || view(parts[1]) != "STARTING"sv) {
|
||||
OMQ_LOG(error, "Received invalid message on worker socket while waiting for tagged thread startup");
|
||||
continue;
|
||||
}
|
||||
OMQ_LOG(debug, "Received STARTING message from ", view(parts[0]));
|
||||
if (auto it = waiting_on.find(view(parts[0])); it != waiting_on.end())
|
||||
waiting_on.erase(it);
|
||||
else
|
||||
OMQ_LOG(error, "Received STARTING message from unknown worker ", view(parts[0]));
|
||||
}
|
||||
|
||||
for (auto&w : tagged_workers) {
|
||||
OMQ_LOG(debug, "Telling tagged thread worker ", std::get<run_info>(w).worker_routing_name, " to finish startup");
|
||||
route_control(workers_socket, std::get<run_info>(w).worker_routing_id, "START");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_loop(std::promise<void> startup) {
|
||||
try {
|
||||
proxy_loop_init();
|
||||
} catch (...) {
|
||||
startup.set_exception(std::current_exception());
|
||||
return;
|
||||
}
|
||||
startup.set_value();
|
||||
|
||||
// Fixed array used for worker and control messages: these are never longer than 3 parts:
|
||||
std::array<zmq::message_t, 3> control_parts;
|
||||
|
||||
// General vector for handling incoming messages:
|
||||
std::vector<zmq::message_t> parts;
|
||||
|
||||
std::vector<std::pair<const int64_t, zmq::socket_t>*> queue; // Used as a circular buffer
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
std::vector<struct epoll_event> evs;
|
||||
#endif
|
||||
|
||||
while (true) {
|
||||
std::chrono::milliseconds poll_timeout;
|
||||
if (max_workers == 0) { // Will be 0 only if we are quitting
|
||||
if (std::none_of(workers.begin(), workers.end(), [](auto &w) { return w.worker_thread.joinable(); }) &&
|
||||
std::none_of(tagged_workers.begin(), tagged_workers.end(), [](auto &w) { return std::get<0>(w).worker_thread.joinable(); })) {
|
||||
// All the workers have finished, so we can finish shutting down
|
||||
return proxy_quit();
|
||||
}
|
||||
poll_timeout = 1s; // We don't keep running timers when we're quitting, so don't have a timer to check
|
||||
} else {
|
||||
poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())};
|
||||
}
|
||||
|
||||
if (connections_updated) {
|
||||
rebuild_pollitems();
|
||||
// If we just rebuilt the queue then do a full check of everything, because we might
|
||||
// have sockets that already edge-triggered that we need to fully drain before we start
|
||||
// polling.
|
||||
proxy_skip_one_poll = true;
|
||||
}
|
||||
|
||||
// We round-robin connections when pulling off pending messages one-by-one rather than
|
||||
// pulling off all messages from one connection before moving to the next; thus in cases of
|
||||
// contention we end up fairly distributing.
|
||||
queue.reserve(connections.size() + 1);
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
bool process_command = false, process_worker = false, process_zap = false, process_all = false;
|
||||
|
||||
if (proxy_skip_one_poll) {
|
||||
proxy_skip_one_poll = false;
|
||||
|
||||
process_command = command.get(zmq::sockopt::events) & ZMQ_POLLIN;
|
||||
process_worker = workers_socket.get(zmq::sockopt::events) & ZMQ_POLLIN;
|
||||
process_zap = zap_auth.get(zmq::sockopt::events) & ZMQ_POLLIN;
|
||||
process_all = true;
|
||||
}
|
||||
else {
|
||||
OMQ_TRACE("polling for new messages via epoll");
|
||||
|
||||
evs.resize(3 + connections.size());
|
||||
const int max = epoll_wait(epoll_fd, evs.data(), evs.size(), poll_timeout.count());
|
||||
|
||||
queue.clear();
|
||||
for (int i = 0; i < max; i++) {
|
||||
const auto conn_id = evs[i].data.u64;
|
||||
if (conn_id == EPOLL_COMMAND_ID)
|
||||
process_command = true;
|
||||
else if (conn_id == EPOLL_WORKER_ID)
|
||||
process_worker = true;
|
||||
else if (conn_id == EPOLL_ZAP_ID)
|
||||
process_zap = true;
|
||||
else if (auto it = connections.find(conn_id); it != connections.end())
|
||||
queue.push_back(&*it);
|
||||
}
|
||||
queue.push_back(nullptr);
|
||||
}
|
||||
|
||||
#else
|
||||
if (proxy_skip_one_poll)
|
||||
proxy_skip_one_poll = false;
|
||||
else {
|
||||
OMQ_TRACE("polling for new messages");
|
||||
|
||||
// We poll the control socket and worker socket for any incoming messages. If we have
|
||||
// available worker room then also poll incoming connections and outgoing connections
|
||||
// for messages to forward to a worker. Otherwise, we just look for a control message
|
||||
// or a worker coming back with a ready message.
|
||||
zmq::poll(pollitems.data(), pollitems.size(), poll_timeout);
|
||||
}
|
||||
|
||||
constexpr bool process_command = true, process_worker = true, process_zap = true, process_all = true;
|
||||
#endif
|
||||
|
||||
if (process_command) {
|
||||
OMQ_TRACE("processing control messages");
|
||||
while (size_t len = recv_message_parts(command, control_parts, zmq::recv_flags::dontwait))
|
||||
proxy_control_message(control_parts, len);
|
||||
}
|
||||
|
||||
if (process_worker) {
|
||||
OMQ_TRACE("processing worker messages");
|
||||
while (size_t len = recv_message_parts(workers_socket, control_parts, zmq::recv_flags::dontwait))
|
||||
proxy_worker_message(control_parts, len);
|
||||
}
|
||||
|
||||
OMQ_TRACE("processing timers");
|
||||
zmq_timers_execute(timers.get());
|
||||
|
||||
if (process_zap) {
|
||||
// Handle any zap authentication
|
||||
OMQ_TRACE("processing zap requests");
|
||||
process_zap_requests();
|
||||
}
|
||||
|
||||
// See if we can drain anything from the current queue before we potentially add to it
|
||||
// below.
|
||||
OMQ_TRACE("processing queued jobs and messages");
|
||||
proxy_process_queue();
|
||||
|
||||
OMQ_TRACE("processing new incoming messages");
|
||||
if (process_all) {
|
||||
queue.clear();
|
||||
for (auto& id_sock : connections)
|
||||
if (id_sock.second.get(zmq::sockopt::events) & ZMQ_POLLIN)
|
||||
queue.push_back(&id_sock);
|
||||
queue.push_back(nullptr);
|
||||
}
|
||||
|
||||
size_t end = queue.size() - 1;
|
||||
|
||||
for (size_t pos = 0; pos != end; ++pos %= queue.size()) {
|
||||
parts.clear();
|
||||
auto& [id, sock] = *queue[pos];
|
||||
|
||||
if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait))
|
||||
continue;
|
||||
|
||||
// We only pull this one message now but then requeue the socket so that after we check
|
||||
// all other sockets we come back to this one to check again.
|
||||
queue[end] = queue[pos];
|
||||
++end %= queue.size();
|
||||
|
||||
if (parts.empty()) {
|
||||
OMQ_LOG(warn, "Ignoring empty (0-part) incoming message");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!proxy_handle_builtin(id, sock, parts))
|
||||
proxy_to_worker(id, sock, parts);
|
||||
|
||||
if (connections_updated) {
|
||||
// If connections got updated then our points are stale, so restart the proxy loop;
|
||||
// we'll immediately end up right back here at least once before we resume polling.
|
||||
OMQ_TRACE("connections became stale; short-circuiting incoming message loop");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef OXENMQ_USE_EPOLL
|
||||
// If any socket still has ZMQ_POLLIN (which is possible if something we did above changed
|
||||
// state on another socket, perhaps by writing to it) then we need to repeat the loop
|
||||
// *without* going back to epoll again, until we get through everything without any
|
||||
// ZMQ_POLLIN sockets. If we didn't, we could miss it and might end up deadlocked because
|
||||
// of ZMQ's edge-triggered notifications on zmq fd's.
|
||||
//
|
||||
// More info on the complexities here at https://github.com/zeromq/libzmq/issues/3641 and
|
||||
// https://funcptr.net/2012/09/10/zeromq---edge-triggered-notification/
|
||||
if (!connections_updated && !proxy_skip_one_poll) {
|
||||
for (auto* s : {&command, &workers_socket, &zap_auth}) {
|
||||
if (s->get(zmq::sockopt::events) & ZMQ_POLLIN) {
|
||||
proxy_skip_one_poll = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!proxy_skip_one_poll) {
|
||||
for (auto& [id, sock] : connections) {
|
||||
if (sock.get(zmq::sockopt::events) & ZMQ_POLLIN) {
|
||||
proxy_skip_one_poll = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
OMQ_TRACE("done proxy loop");
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_error_response(std::string_view cmd) {
|
||||
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
|
||||
}
|
||||
|
||||
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
|
||||
// reason)
|
||||
bool OxenMQ::proxy_handle_builtin(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
|
||||
// Doubling as a bool and an offset:
|
||||
size_t incoming = sock.get(zmq::sockopt::type) == ZMQ_ROUTER;
|
||||
|
||||
std::string_view route, cmd;
|
||||
if (parts.size() < 1 + incoming) {
|
||||
OMQ_LOG(warn, "Received empty message; ignoring");
|
||||
return true;
|
||||
}
|
||||
if (incoming) {
|
||||
route = view(parts[0]);
|
||||
cmd = view(parts[1]);
|
||||
} else {
|
||||
cmd = view(parts[0]);
|
||||
}
|
||||
OMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back()));
|
||||
|
||||
if (cmd == "REPLY") {
|
||||
size_t tag_pos = 1 + incoming;
|
||||
if (parts.size() <= tag_pos) {
|
||||
OMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
|
||||
return true;
|
||||
}
|
||||
std::string reply_tag{view(parts[tag_pos])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
OMQ_LOG(debug, "Received REPLY for pending command ", oxenc::to_hex(reply_tag), "; scheduling callback");
|
||||
std::vector<std::string> data;
|
||||
data.reserve(parts.size() - (tag_pos + 1));
|
||||
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
|
||||
data.emplace_back(view(*it));
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), data=std::move(data)] {
|
||||
callback(true, std::move(data));
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
OMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", oxenc::to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
return true;
|
||||
} else if (cmd == "HI") {
|
||||
if (!incoming) {
|
||||
OMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
OMQ_LOG(debug, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO");
|
||||
try {
|
||||
send_routed_message(sock, std::string{route}, "HELLO");
|
||||
} catch (const std::exception &e) { OMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
|
||||
return true;
|
||||
} else if (cmd == "HELLO") {
|
||||
if (incoming) {
|
||||
OMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto it = std::find_if(pending_connects.begin(), pending_connects.end(),
|
||||
[&](auto& pc) { return std::get<int64_t>(pc) == conn_id; });
|
||||
if (it == pending_connects.end()) {
|
||||
OMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
auto& pc = *it;
|
||||
auto pit = peers.find(std::get<int64_t>(pc));
|
||||
if (pit == peers.end()) {
|
||||
OMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring");
|
||||
return true;
|
||||
}
|
||||
|
||||
OMQ_LOG(debug, "Got initial HELLO server response from ", peer_address(parts.back()));
|
||||
proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)),
|
||||
conn=pit->first] {
|
||||
on_success(conn);
|
||||
});
|
||||
pending_connects.erase(it);
|
||||
return true;
|
||||
} else if (cmd == "BYE") {
|
||||
if (!incoming) {
|
||||
OMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
|
||||
proxy_close_connection(conn_id, 0s);
|
||||
} else {
|
||||
OMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
else if (is_error_response(cmd)) {
|
||||
// These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to
|
||||
// invoke something that doesn't exist or we don't have permission to access. These have
|
||||
// two forms (the latter is only sent by remotes running 1.1.0+).
|
||||
// - ["XXX", "whatever.command"]
|
||||
// - ["XXX", "REPLY", replytag]
|
||||
// (ignoring the routing prefix on incoming commands).
|
||||
// For the former, we log; for the latter we trigger the reply callback with a failure
|
||||
|
||||
if (parts.size() == (1 + incoming) && cmd == "UNKNOWNCOMMAND") {
|
||||
// pre-1.1.0 sent just a plain UNKNOWNCOMMAND (without the actual command); this was not
|
||||
// useful, but also this response is *expected* for things 1.0.5 didn't understand, like
|
||||
// FORBIDDEN_SN: so log it only at debug level and move on.
|
||||
OMQ_LOG(debug, "Received plain UNKNOWNCOMMAND; remote is probably an older oxenmq. Ignoring.");
|
||||
return true;
|
||||
}
|
||||
|
||||
if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") {
|
||||
std::string reply_tag{view(parts[2 + incoming])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
OMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", oxenc::to_hex(reply_tag), "; scheduling failure callback");
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] {
|
||||
callback(false, {{std::move(cmd)}});
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
OMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", oxenc::to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
} else {
|
||||
OMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"sv),
|
||||
" from ", peer_address(parts.back()));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_process_queue() {
|
||||
if (max_workers == 0) // shutting down
|
||||
return;
|
||||
|
||||
// First: send any tagged thread tasks to the tagged threads, if idle
|
||||
for (auto& [run, busy, queue] : tagged_workers) {
|
||||
if (!busy && !queue.empty()) {
|
||||
busy = true;
|
||||
proxy_run_worker(run.load(std::move(queue.front()), false, run.worker_id));
|
||||
queue.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// Second: process any batch jobs; since these are internal they are given higher priority.
|
||||
proxy_run_batch_jobs(batch_jobs, batch_jobs_reserved, batch_jobs_active, false);
|
||||
|
||||
// Next any reply batch jobs (which are a bit different from the above, since they are
|
||||
// externally triggered but for things we initiated locally).
|
||||
proxy_run_batch_jobs(reply_jobs, reply_jobs_reserved, reply_jobs_active, true);
|
||||
|
||||
// Finally general incoming commands
|
||||
for (auto it = pending_commands.begin(); it != pending_commands.end() && active_workers() < max_workers; ) {
|
||||
auto& pending = *it;
|
||||
if (pending.cat.active_threads < pending.cat.reserved_threads
|
||||
|| active_workers() < general_workers) {
|
||||
proxy_run_worker(get_idle_worker().load(std::move(pending)));
|
||||
pending.cat.queued--;
|
||||
pending.cat.active_threads++;
|
||||
assert(pending.cat.queued >= 0);
|
||||
it = pending_commands.erase(it);
|
||||
} else {
|
||||
++it; // no available general or reserved worker spots for this job right now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
172
oxenmq/pubsub.h
Normal file
172
oxenmq/pubsub.h
Normal file
|
@ -0,0 +1,172 @@
|
|||
#pragma once
|
||||
|
||||
#include "connections.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
namespace detail {
|
||||
struct no_data_t final {};
|
||||
inline constexpr no_data_t no_data{};
|
||||
|
||||
template <typename UserData>
|
||||
struct SubData {
|
||||
std::chrono::steady_clock::time_point expiry;
|
||||
UserData user_data;
|
||||
explicit SubData(std::chrono::steady_clock::time_point _exp)
|
||||
: expiry{_exp}, user_data{} {}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SubData<void> {
|
||||
std::chrono::steady_clock::time_point expiry;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* OMQ Subscription class. Handles pub/sub connections such that the user only needs to call
|
||||
* methods to subscribe and publish.
|
||||
*
|
||||
* FIXME: do we want an unsubscribe, or is expiry / conn management sufficient?
|
||||
*
|
||||
* Type UserData can contain whatever information the user may need at publish time, for example if
|
||||
* the subscription is for logs the subscriber can specify log levels or categories, and the
|
||||
* publisher can choose to send or not based on those. The UserData type, if provided and non-void,
|
||||
* must be default constructible, must be comparable with ==, and must be movable.
|
||||
*/
|
||||
template <typename UserData = void>
|
||||
class Subscription {
|
||||
static constexpr bool have_user_data = !std::is_void_v<UserData>;
|
||||
using UserData_if_present = std::conditional_t<have_user_data, UserData, detail::no_data_t>;
|
||||
using subdata_t = detail::SubData<UserData>;
|
||||
|
||||
std::unordered_map<ConnectionID, subdata_t> subs;
|
||||
std::shared_mutex _mutex;
|
||||
const std::string description; // description of the sub for logging
|
||||
const std::chrono::milliseconds sub_duration; // extended by re-subscribe
|
||||
|
||||
public:
|
||||
|
||||
Subscription() = delete;
|
||||
Subscription(std::string description, std::chrono::milliseconds sub_duration = 30min)
|
||||
: description{std::move(description)}, sub_duration{sub_duration} {}
|
||||
|
||||
// returns true if new sub, false if refresh sub. throws on error. `data` will be checked
|
||||
// against the existing data: if there is existing data and it compares `==` to the given value,
|
||||
// false is returned (and the existing data is not replaced). Otherwise the given data gets
|
||||
// stored for this connection (replacing existing data, if present), and true is returned.
|
||||
bool subscribe(const ConnectionID& conn, UserData_if_present data) {
|
||||
std::unique_lock lock{_mutex};
|
||||
auto expiry = std::chrono::steady_clock::now() + sub_duration;
|
||||
auto [value, added] = subs.emplace(conn, subdata_t{expiry});
|
||||
if (added) {
|
||||
if constexpr (have_user_data)
|
||||
value->second.user_data = std::move(data);
|
||||
return true;
|
||||
}
|
||||
|
||||
value->second.expiry = expiry;
|
||||
|
||||
if constexpr (have_user_data) {
|
||||
// if user_data changed, consider it a new sub rather than refresh, and update
|
||||
// user_data in the mapped value.
|
||||
if (!(value->second.user_data == data)) {
|
||||
value->second.user_data = std::move(data);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// no-user-data version, only available for Subscription<void> (== Subscription without a
|
||||
// UserData type).
|
||||
template <bool enable = !have_user_data, std::enable_if_t<enable, int> = 0>
|
||||
bool subscribe(const ConnectionID& conn) {
|
||||
return subscribe(conn, detail::no_data);
|
||||
}
|
||||
|
||||
// unsubscribe a connection ID. return the user data, if a sub was present.
|
||||
template <bool enable = have_user_data, std::enable_if_t<enable, int> = 0>
|
||||
std::optional<UserData> unsubscribe(const ConnectionID& conn) {
|
||||
std::unique_lock lock{_mutex};
|
||||
|
||||
auto node = subs.extract(conn);
|
||||
if (!node.empty())
|
||||
return node.mapped().user_data;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// no-user-data version, only available for Subscription<void> (== Subscription without a
|
||||
// UserData type).
|
||||
template <bool enable = !have_user_data, std::enable_if_t<enable, int> = 0>
|
||||
bool unsubscribe(const ConnectionID& conn) {
|
||||
std::unique_lock lock{_mutex};
|
||||
auto node = subs.extract(conn);
|
||||
return !node.empty(); // true if removed, false if wasn't present
|
||||
}
|
||||
|
||||
// force removal of expired subscriptions. removal will otherwise only happen on publish
|
||||
void remove_expired() {
|
||||
std::unique_lock lock{_mutex};
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
for (auto itr = subs.begin(); itr != subs.end();) {
|
||||
if (itr->second.expiry < now)
|
||||
itr = subs.erase(itr);
|
||||
else
|
||||
itr++;
|
||||
}
|
||||
}
|
||||
|
||||
// Func is any callable which takes:
|
||||
// - (const ConnectionID&, const UserData&) for Subscription<UserData> with non-void UserData
|
||||
// - (const ConnectionID&) for Subscription<void>.
|
||||
template <typename Func>
|
||||
void publish(Func&& func) {
|
||||
std::vector<ConnectionID> to_remove;
|
||||
{
|
||||
std::shared_lock lock(_mutex);
|
||||
if (subs.empty())
|
||||
return;
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
for (const auto& [conn, sub] : subs) {
|
||||
if (sub.expiry < now)
|
||||
to_remove.push_back(conn);
|
||||
else if constexpr (have_user_data)
|
||||
func(conn, sub.user_data);
|
||||
else
|
||||
func(conn);
|
||||
}
|
||||
}
|
||||
|
||||
if (to_remove.empty())
|
||||
return;
|
||||
|
||||
std::unique_lock lock{_mutex};
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
for (auto& conn : to_remove) {
|
||||
auto it = subs.find(conn);
|
||||
if (it != subs.end() && it->second.expiry < now /* recheck: client might have resubscribed in between locks */) {
|
||||
subs.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // namespace oxenmq
|
||||
|
||||
// vim:sw=4:et
|
5
oxenmq/version.h.in
Normal file
5
oxenmq/version.h.in
Normal file
|
@ -0,0 +1,5 @@
|
|||
namespace oxenmq {
|
||||
constexpr int VERSION_MAJOR = @PROJECT_VERSION_MAJOR@;
|
||||
constexpr int VERSION_MINOR = @PROJECT_VERSION_MINOR@;
|
||||
constexpr int VERSION_PATCH = @PROJECT_VERSION_PATCH@;
|
||||
}
|
431
oxenmq/worker.cpp
Normal file
431
oxenmq/worker.cpp
Normal file
|
@ -0,0 +1,431 @@
|
|||
#include "oxenmq.h"
|
||||
#include "batch.h"
|
||||
#include "oxenmq-internal.h"
|
||||
|
||||
#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
extern "C" {
|
||||
#include <pthread.h>
|
||||
#include <pthread_np.h>
|
||||
}
|
||||
#endif
|
||||
#include <oxenc/variant.h>
|
||||
|
||||
namespace oxenmq {
|
||||
|
||||
namespace {
|
||||
|
||||
// Waits for a specific command or "QUIT" on the given socket. Returns true if the command was
|
||||
// received. If "QUIT" was received, replies with "QUITTING" on the socket and closes it, then
|
||||
// returns false.
|
||||
[[gnu::always_inline]] inline
|
||||
bool worker_wait_for(OxenMQ& omq, zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const std::string_view worker_id, const std::string_view expect) {
|
||||
while (true) {
|
||||
omq.log(LogLevel::trace, __FILE__, __LINE__, "worker ", worker_id, " waiting for ", expect);
|
||||
parts.clear();
|
||||
recv_message_parts(sock, parts);
|
||||
if (parts.size() != 1) {
|
||||
omq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid ", parts.size(), "-part control msg");
|
||||
continue;
|
||||
}
|
||||
auto command = view(parts[0]);
|
||||
if (command == expect) {
|
||||
#ifndef NDEBUG
|
||||
omq.log(LogLevel::trace, __FILE__, __LINE__, "Worker ", worker_id, " received waited-for ", expect, " command");
|
||||
#endif
|
||||
return true;
|
||||
} else if (command == "QUIT"sv) {
|
||||
omq.log(LogLevel::debug, __FILE__, __LINE__, "Worker ", worker_id, " received QUIT command, shutting down");
|
||||
detail::send_control(sock, "QUITTING");
|
||||
sock.set(zmq::sockopt::linger, 1000);
|
||||
sock.close();
|
||||
return false;
|
||||
} else {
|
||||
omq.log(LogLevel::error, __FILE__, __LINE__, "Internal error: worker ", worker_id, " received invalid command: `", command, "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void OxenMQ::worker_thread(unsigned int index, std::optional<std::string> tagged, std::function<void()> start) {
|
||||
std::string routing_id = (tagged ? "t" : "w") +
|
||||
std::string(reinterpret_cast<const char*>(&index), sizeof(index)); // for routing
|
||||
std::string worker_id{tagged ? *tagged : "w" + std::to_string(index)}; // for debug
|
||||
|
||||
[[maybe_unused]] std::string thread_name = tagged.value_or("omq-" + worker_id);
|
||||
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
|
||||
if (thread_name.size() > 15) thread_name.resize(15);
|
||||
pthread_setname_np(pthread_self(), thread_name.c_str());
|
||||
#elif defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
|
||||
pthread_set_name_np(pthread_self(), thread_name.c_str());
|
||||
#elif defined(__MACH__)
|
||||
pthread_setname_np(thread_name.c_str());
|
||||
#endif
|
||||
|
||||
std::optional<zmq::socket_t> tagged_socket;
|
||||
if (tagged) {
|
||||
// If we're a tagged worker then we got started up before OxenMQ started, so we need to wait
|
||||
// for an all-clear signal from OxenMQ first, then we fire our `start` callback, then we can
|
||||
// start waiting for commands in the main loop further down. (We also can't get the
|
||||
// reference to our `tagged_workers` element or create a socket until the main proxy thread
|
||||
// is running).
|
||||
{
|
||||
std::unique_lock lock{tagged_startup_mutex};
|
||||
tagged_cv.wait(lock, [this] { return tagged_go != tagged_go_mode::WAIT; });
|
||||
}
|
||||
if (tagged_go == tagged_go_mode::SHUTDOWN) // OxenMQ destroyed without starting
|
||||
return;
|
||||
tagged_socket.emplace(context, zmq::socket_type::dealer);
|
||||
}
|
||||
auto& sock = tagged ? *tagged_socket : worker_sockets[index];
|
||||
sock.set(zmq::sockopt::routing_id, routing_id);
|
||||
OMQ_LOG(debug, "New worker thread ", worker_id, " (", routing_id, ") started");
|
||||
sock.connect(SN_ADDR_WORKERS);
|
||||
if (tagged)
|
||||
detail::send_control(sock, "STARTING");
|
||||
|
||||
Message message{*this, 0, AuthLevel::none, ""s};
|
||||
std::vector<zmq::message_t> parts;
|
||||
|
||||
bool waiting_for_command;
|
||||
if (tagged) {
|
||||
waiting_for_command = true;
|
||||
|
||||
if (!worker_wait_for(*this, sock, parts, worker_id, "START"sv))
|
||||
return;
|
||||
if (start) start();
|
||||
} else {
|
||||
// Otherwise for a regular worker we can only be started by an active main proxy thread
|
||||
// which will have preloaded our first job so we can start off right away.
|
||||
waiting_for_command = false;
|
||||
}
|
||||
|
||||
// This will always contains the current job, and is guaranteed to never be invalidated.
|
||||
run_info& run = tagged ? std::get<run_info>(tagged_workers[index - 1]) : workers[index];
|
||||
|
||||
while (true) {
|
||||
if (waiting_for_command) {
|
||||
if (!worker_wait_for(*this, sock, parts, worker_id, "RUN"sv))
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
if (run.is_batch_job) {
|
||||
auto* batch = var::get<detail::Batch*>(run.to_run);
|
||||
if (run.batch_jobno >= 0) {
|
||||
OMQ_TRACE("worker thread ", worker_id, " running batch ", batch, "#", run.batch_jobno);
|
||||
batch->run_job(run.batch_jobno);
|
||||
} else if (run.batch_jobno == -1) {
|
||||
OMQ_TRACE("worker thread ", worker_id, " running batch ", batch, " completion");
|
||||
batch->job_completion();
|
||||
}
|
||||
} else if (run.is_injected) {
|
||||
auto& func = var::get<std::function<void()>>(run.to_run);
|
||||
OMQ_TRACE("worker thread ", worker_id, " invoking injected command ", run.command);
|
||||
func();
|
||||
func = nullptr;
|
||||
} else {
|
||||
message.conn = run.conn;
|
||||
message.access = run.access;
|
||||
message.remote = std::move(run.remote);
|
||||
message.data.clear();
|
||||
|
||||
OMQ_TRACE("Got incoming command from ", message.remote, "/", message.conn, message.conn.route.empty() ? " (outgoing)" : " (incoming)");
|
||||
|
||||
auto& [callback, is_request] = *var::get<const std::pair<CommandCallback, bool>*>(run.to_run);
|
||||
if (is_request) {
|
||||
message.reply_tag = {run.data_parts[0].data<char>(), run.data_parts[0].size()};
|
||||
for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it)
|
||||
message.data.emplace_back(it->data<char>(), it->size());
|
||||
} else {
|
||||
for (auto& m : run.data_parts)
|
||||
message.data.emplace_back(m.data<char>(), m.size());
|
||||
}
|
||||
|
||||
OMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts");
|
||||
callback(message);
|
||||
}
|
||||
}
|
||||
catch (const oxenc::bt_deserialize_invalid& e) {
|
||||
OMQ_LOG(warn, worker_id, " deserialization failed: ", e.what(), "; ignoring request");
|
||||
}
|
||||
#ifndef BROKEN_APPLE_VARIANT
|
||||
catch (const std::bad_variant_access& e) {
|
||||
OMQ_LOG(warn, worker_id, " deserialization failed: found unexpected serialized type (", e.what(), "); ignoring request");
|
||||
}
|
||||
#endif
|
||||
catch (const std::out_of_range& e) {
|
||||
OMQ_LOG(warn, worker_id, " deserialization failed: invalid data - required field missing (", e.what(), "); ignoring request");
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
OMQ_LOG(warn, worker_id, " caught exception when processing command: ", e.what());
|
||||
}
|
||||
catch (...) {
|
||||
OMQ_LOG(warn, worker_id, " caught non-standard exception when processing command");
|
||||
}
|
||||
|
||||
// Tell the proxy thread that we are ready for another job
|
||||
detail::send_control(sock, "RAN");
|
||||
waiting_for_command = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
OxenMQ::run_info& OxenMQ::get_idle_worker() {
|
||||
if (idle_worker_count == 0) {
|
||||
uint32_t id = workers.size();
|
||||
workers.emplace_back();
|
||||
auto& r = workers.back();
|
||||
r.worker_id = id;
|
||||
r.worker_routing_id = "w" + std::string(reinterpret_cast<const char*>(&id), sizeof(id));
|
||||
r.worker_routing_name = "w" + std::to_string(id);
|
||||
return r;
|
||||
}
|
||||
size_t id = idle_workers[--idle_worker_count];
|
||||
return workers[id];
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_worker_message(OxenMQ::control_message_array& parts, size_t len) {
|
||||
// Process messages sent by workers
|
||||
if (len != 2) {
|
||||
OMQ_LOG(error, "Received send invalid ", len, "-part message");
|
||||
return;
|
||||
}
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
if (route.size() != 5 || (route[0] != 'w' && route[0] != 't')) {
|
||||
OMQ_LOG(error, "Received malformed worker id in worker message; unable to process worker command");
|
||||
return;
|
||||
}
|
||||
bool tagged_worker = route[0] == 't';
|
||||
uint32_t worker_id;
|
||||
std::memcpy(&worker_id, route.data() + 1, 4);
|
||||
if (tagged_worker
|
||||
? 0 == worker_id || worker_id > tagged_workers.size() // tagged worker ids are indexed from 1 to N (0 means untagged)
|
||||
: worker_id >= workers.size()) { // regular worker ids are indexed from 0 to N-1
|
||||
OMQ_LOG(error, "Received invalid worker id w" + std::to_string(worker_id) + " in worker message; unable to process worker command");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = tagged_worker ? std::get<run_info>(tagged_workers[worker_id - 1]) : workers[worker_id];
|
||||
|
||||
OMQ_TRACE("received ", cmd, " command from ", route);
|
||||
if (cmd == "RAN"sv) {
|
||||
OMQ_TRACE("Worker ", route, " finished ", run.is_batch_job ? "batch job" : run.command);
|
||||
if (run.is_batch_job) {
|
||||
if (tagged_worker) {
|
||||
std::get<bool>(tagged_workers[worker_id - 1]) = false;
|
||||
} else {
|
||||
auto& active = run.is_reply_job ? reply_jobs_active : batch_jobs_active;
|
||||
assert(active > 0);
|
||||
active--;
|
||||
}
|
||||
bool clear_job = false;
|
||||
auto* batch = var::get<detail::Batch*>(run.to_run);
|
||||
if (run.batch_jobno == -1) {
|
||||
// Returned from the completion function
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto [state, thread] = batch->job_finished();
|
||||
if (state == detail::BatchState::complete) {
|
||||
if (thread == -1) { // run directly in proxy
|
||||
OMQ_TRACE("Completion job running directly in proxy");
|
||||
try {
|
||||
batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD
|
||||
} catch (const std::exception &e) {
|
||||
// Raise these to error levels: the caller really shouldn't be doing
|
||||
// anything non-trivial in an in-proxy completion function!
|
||||
OMQ_LOG(error, "proxy thread caught exception when processing in-proxy completion command: ", e.what());
|
||||
} catch (...) {
|
||||
OMQ_LOG(error, "proxy thread caught non-standard exception when processing in-proxy completion command");
|
||||
}
|
||||
clear_job = true;
|
||||
} else {
|
||||
auto& jobs =
|
||||
thread > 0
|
||||
? std::get<batch_queue>(tagged_workers[thread - 1]) // run in tagged thread
|
||||
: run.is_reply_job
|
||||
? reply_jobs
|
||||
: batch_jobs;
|
||||
jobs.emplace_back(batch, -1);
|
||||
}
|
||||
} else if (state == detail::BatchState::done) {
|
||||
// No completion job
|
||||
clear_job = true;
|
||||
}
|
||||
// else the job is still running
|
||||
}
|
||||
|
||||
if (clear_job) {
|
||||
delete batch;
|
||||
}
|
||||
} else {
|
||||
assert(run.cat->active_threads > 0);
|
||||
run.cat->active_threads--;
|
||||
}
|
||||
if (max_workers == 0) { // Shutting down
|
||||
OMQ_TRACE("Telling worker ", route, " to quit");
|
||||
route_control(workers_socket, route, "QUIT");
|
||||
} else if (!tagged_worker) {
|
||||
idle_workers[idle_worker_count++] = worker_id;
|
||||
}
|
||||
} else if (cmd == "QUITTING"sv) {
|
||||
run.worker_thread.join();
|
||||
OMQ_LOG(debug, "Worker ", route, " exited normally");
|
||||
} else {
|
||||
OMQ_LOG(error, "Worker ", route, " sent unknown control message: `", cmd, "'");
|
||||
}
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_run_worker(run_info& run) {
|
||||
if (!run.worker_thread.joinable())
|
||||
run.worker_thread = std::thread{&OxenMQ::worker_thread, this, run.worker_id, std::nullopt, nullptr};
|
||||
else
|
||||
send_routed_message(workers_socket, run.worker_routing_id, "RUN");
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_to_worker(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
|
||||
bool outgoing = sock.get(zmq::sockopt::type) == ZMQ_DEALER;
|
||||
|
||||
peer_info tmp_peer;
|
||||
tmp_peer.conn_id = conn_id;
|
||||
if (!outgoing) tmp_peer.route = parts[0].to_string();
|
||||
peer_info* peer = nullptr;
|
||||
if (outgoing) {
|
||||
auto snit = outgoing_sn_conns.find(conn_id);
|
||||
auto it = snit != outgoing_sn_conns.end()
|
||||
? peers.find(snit->second)
|
||||
: peers.find(conn_id);
|
||||
|
||||
if (it == peers.end()) {
|
||||
OMQ_LOG(warn, "Internal error: connection id ", conn_id, " not found");
|
||||
return;
|
||||
}
|
||||
peer = &it->second;
|
||||
} else if (conn_id == inproc_listener_connid) {
|
||||
tmp_peer.auth_level = AuthLevel::admin;
|
||||
tmp_peer.pubkey = pubkey;
|
||||
tmp_peer.service_node = active_service_nodes.count(pubkey);
|
||||
peer = &tmp_peer;
|
||||
} else {
|
||||
std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
|
||||
tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey);
|
||||
|
||||
if (tmp_peer.service_node) {
|
||||
// It's a service node so we should have a peer_info entry; see if we can find one with
|
||||
// the same route, and if not, add one.
|
||||
auto pr = peers.equal_range(tmp_peer.pubkey);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
if (it->second.conn_id == tmp_peer.conn_id && it->second.route == tmp_peer.route) {
|
||||
peer = &it->second;
|
||||
// Update the stored auth level just in case the peer reconnected
|
||||
peer->auth_level = tmp_peer.auth_level;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!peer) {
|
||||
// We don't have a record: this is either a new SN connection or a new message on a
|
||||
// connection that recently gained SN status.
|
||||
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
|
||||
}
|
||||
} else {
|
||||
// Incoming, non-SN connection: we don't store a peer_info for this, so just use the
|
||||
// temporary one
|
||||
peer = &tmp_peer;
|
||||
}
|
||||
}
|
||||
|
||||
size_t command_part_index = outgoing ? 0 : 1;
|
||||
std::string command = parts[command_part_index].to_string();
|
||||
|
||||
// Steal any data message parts
|
||||
size_t data_part_index = command_part_index + 1;
|
||||
std::vector<zmq::message_t> data_parts;
|
||||
data_parts.reserve(parts.size() - data_part_index);
|
||||
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
|
||||
data_parts.push_back(std::move(*it));
|
||||
|
||||
auto cat_call = get_command(command);
|
||||
|
||||
// Check that command is valid, that we have permission, etc.
|
||||
if (!proxy_check_auth(conn_id, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
|
||||
return;
|
||||
|
||||
auto& category = *cat_call.first;
|
||||
Access access{peer->auth_level, peer->service_node, local_service_node};
|
||||
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free reserved or general spots, try to queue it for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
OMQ_LOG(warn, "No space to queue incoming command ", command, "; already have ", category.queued,
|
||||
"commands queued in that category (max ", category.max_queue, "); dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
OMQ_LOG(debug, "No available free workers, queuing ", command, " for later");
|
||||
ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey, std::move(tmp_peer.route)};
|
||||
pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second,
|
||||
std::move(conn), std::move(access), peer_address(parts[command_part_index]));
|
||||
category.queued++;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cat_call.second->second /*is_request*/ && data_parts.empty()) {
|
||||
OMQ_LOG(warn, "Received an invalid request command with no reply tag; dropping message");
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = get_idle_worker();
|
||||
{
|
||||
ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey};
|
||||
c.route = std::move(tmp_peer.route);
|
||||
if (outgoing || peer->service_node)
|
||||
tmp_peer.route.clear();
|
||||
run.load(&category, std::move(command), std::move(c), std::move(access), peer_address(parts[command_part_index]),
|
||||
std::move(data_parts), cat_call.second);
|
||||
}
|
||||
|
||||
if (outgoing)
|
||||
peer->activity(); // outgoing connection activity, pump the activity timer
|
||||
|
||||
OMQ_TRACE("Forwarding incoming ", run.command, " from ", run.conn, " @ ", peer_address(parts[command_part_index]),
|
||||
" to worker ", run.worker_routing_name);
|
||||
|
||||
proxy_run_worker(run);
|
||||
category.active_threads++;
|
||||
}
|
||||
|
||||
void OxenMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function<void()> callback) {
|
||||
if (!callback) return;
|
||||
auto it = categories.find(category);
|
||||
if (it == categories.end())
|
||||
throw std::out_of_range{"Invalid category `" + category + "': category does not exist"};
|
||||
detail::send_control(get_control_socket(), "INJECT", oxenc::bt_serialize(detail::serialize_object(
|
||||
injected_task{it->second, std::move(command), std::move(remote), std::move(callback)})));
|
||||
}
|
||||
|
||||
void OxenMQ::proxy_inject_task(injected_task task) {
|
||||
auto& category = task.cat;
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free worker slot, queue for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
OMQ_LOG(warn, "No space to queue injected task ", task.command, "; already have ", category.queued,
|
||||
"commands queued in that category (max ", category.max_queue, "); dropping task");
|
||||
return;
|
||||
}
|
||||
OMQ_LOG(debug, "No available free workers for injected task ", task.command, "; queuing for later");
|
||||
pending_commands.emplace_back(category, std::move(task.command), std::move(task.callback), std::move(task.remote));
|
||||
category.queued++;
|
||||
return;
|
||||
}
|
||||
|
||||
auto& run = get_idle_worker();
|
||||
OMQ_TRACE("Forwarding incoming injected task ", task.command, " from ", task.remote, " to worker ", run.worker_routing_name);
|
||||
run.load(&category, std::move(task.command), std::move(task.remote), std::move(task.callback));
|
||||
|
||||
proxy_run_worker(run);
|
||||
category.active_threads++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,25 +1,27 @@
|
|||
|
||||
add_subdirectory(Catch2)
|
||||
|
||||
set(LMQ_TEST_SRC
|
||||
add_executable(tests
|
||||
main.cpp
|
||||
test_address.cpp
|
||||
test_batch.cpp
|
||||
test_connect.cpp
|
||||
test_commands.cpp
|
||||
test_failures.cpp
|
||||
test_inject.cpp
|
||||
test_pubsub.cpp
|
||||
test_requests.cpp
|
||||
test_string_view.cpp
|
||||
)
|
||||
|
||||
add_executable(tests ${LMQ_TEST_SRC})
|
||||
test_socket_limit.cpp
|
||||
test_tagged_threads.cpp
|
||||
test_timer.cpp
|
||||
)
|
||||
|
||||
find_package(Threads)
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(SODIUM REQUIRED libsodium)
|
||||
|
||||
target_link_libraries(tests Catch2::Catch2 lokimq ${SODIUM_LIBRARIES} Threads::Threads)
|
||||
target_link_libraries(tests Catch2::Catch2 oxenmq Threads::Threads)
|
||||
|
||||
set_target_properties(tests PROPERTIES
|
||||
CXX_STANDARD 14
|
||||
CXX_STANDARD 17
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CXX_EXTENSIONS OFF
|
||||
)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit b3b07215d1ca2224aea6ff3e21d87ad0f7750df2
|
||||
Subproject commit dba29b60d639bf8d206a9a12c223e6ed4284fb13
|
|
@ -1,18 +1,61 @@
|
|||
#pragma once
|
||||
#include "lokimq/lokimq.h"
|
||||
#include "oxenmq/oxenmq.h"
|
||||
#include <catch2/catch.hpp>
|
||||
#include <chrono>
|
||||
|
||||
using namespace lokimq;
|
||||
using namespace oxenmq;
|
||||
|
||||
// Apple's mutexes, thread scheduling, and IO handling are garbage and it shows up with lots of
|
||||
// spurious failures in this test suite (because it expects a system to not suck that badly), so we
|
||||
// multiply the time-sensitive bits by this factor as a hack to make the test suite work.
|
||||
constexpr int TIME_DILATION =
|
||||
#ifdef __APPLE__
|
||||
5;
|
||||
#else
|
||||
1;
|
||||
#endif
|
||||
|
||||
static auto startup = std::chrono::steady_clock::now();
|
||||
|
||||
/// Returns a localhost connection string to listen on. It can be considered random, though in
|
||||
/// practice in the current implementation is sequential starting at 25432.
|
||||
inline std::string random_localhost() {
|
||||
static std::atomic<uint16_t> last = 25432;
|
||||
last++;
|
||||
assert(last); // We should never call this enough to overflow
|
||||
return "tcp://127.0.0.1:" + std::to_string(last);
|
||||
}
|
||||
|
||||
|
||||
// Catch2 macros aren't thread safe, so guard with a mutex
|
||||
inline std::unique_lock<std::mutex> catch_lock() {
|
||||
static std::mutex mutex;
|
||||
return std::unique_lock<std::mutex>{mutex};
|
||||
}
|
||||
|
||||
inline LokiMQ::Logger get_logger(std::string prefix = "") {
|
||||
/// Waits up to 200ms for something to happen.
|
||||
template <typename Func>
|
||||
inline void wait_for(Func f) {
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
for (int i = 0; i < 20; i++) {
|
||||
if (f())
|
||||
break;
|
||||
std::this_thread::sleep_for(10ms * TIME_DILATION);
|
||||
}
|
||||
auto lock = catch_lock();
|
||||
UNSCOPED_INFO("done waiting after " << (std::chrono::steady_clock::now() - start).count() << "ns");
|
||||
}
|
||||
|
||||
/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough
|
||||
/// time for an initial connection + request.
|
||||
inline void wait_for_conn(std::atomic<bool> &c) {
|
||||
wait_for([&c] { return c.load(); });
|
||||
}
|
||||
|
||||
/// Waits enough time for us to receive a reply from a localhost remote.
|
||||
inline void reply_sleep() { std::this_thread::sleep_for(10ms * TIME_DILATION); }
|
||||
|
||||
inline OxenMQ::Logger get_logger(std::string prefix = "") {
|
||||
std::string me = "tests/common.h";
|
||||
std::string strip = __FILE__;
|
||||
if (strip.substr(strip.size() - me.size()) == me)
|
||||
|
|
162
tests/test_address.cpp
Normal file
162
tests/test_address.cpp
Normal file
|
@ -0,0 +1,162 @@
|
|||
#include "oxenmq/address.h"
|
||||
#include "common.h"
|
||||
|
||||
const std::string pk = "\xf1\x6b\xa5\x59\x10\x39\xf0\x89\xb4\x2a\x83\x41\x75\x09\x30\x94\x07\x4d\x0d\x93\x7a\x79\xe5\x3e\x5c\xe7\x30\xf9\x46\xe1\x4b\x88";
|
||||
const std::string pk_hex = "f16ba5591039f089b42a834175093094074d0d937a79e53e5ce730f946e14b88";
|
||||
const std::string pk_HEX = "F16BA5591039F089B42A834175093094074D0D937A79E53E5CE730F946E14B88";
|
||||
const std::string pk_b32z = "6fi4kseo88aeupbkopyzknjo1odw4dcuxjh6kx1hhhax1tzbjqry";
|
||||
const std::string pk_B32Z = "6FI4KSEO88AEUPBKOPYZKNJO1ODW4DCUXJH6KX1HHHAX1TZBJQRY";
|
||||
const std::string pk_b64 = "8WulWRA58Im0KoNBdQkwlAdNDZN6eeU+XOcw+UbhS4g"; // NB: padding '=' omitted
|
||||
|
||||
TEST_CASE("tcp addresses", "[address][tcp]") {
|
||||
address a{"tcp://1.2.3.4:5678"};
|
||||
REQUIRE( a.host == "1.2.3.4" );
|
||||
REQUIRE( a.port == 5678 );
|
||||
REQUIRE_FALSE( a.curve() );
|
||||
REQUIRE( a.tcp() );
|
||||
REQUIRE( a.zmq_address() == "tcp://1.2.3.4:5678" );
|
||||
REQUIRE( a.full_address() == "tcp://1.2.3.4:5678" );
|
||||
REQUIRE( a.qr_address() == "TCP://1.2.3.4:5678" );
|
||||
|
||||
REQUIRE_THROWS_AS( address{"tcp://1:1:1"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://localhost:123"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcp://abc"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://localhost:0"}, std::invalid_argument );
|
||||
REQUIRE_THROWS_AS( address{"tcpz://[::1:1080"}, std::invalid_argument );
|
||||
|
||||
address b = address::tcp("example.com", 80);
|
||||
REQUIRE( b.host == "example.com" );
|
||||
REQUIRE( b.port == 80 );
|
||||
REQUIRE_FALSE( b.curve() );
|
||||
REQUIRE( b.tcp() );
|
||||
REQUIRE( b.zmq_address() == "tcp://example.com:80" );
|
||||
REQUIRE( b.full_address() == "tcp://example.com:80" );
|
||||
REQUIRE( b.qr_address() == "TCP://EXAMPLE.COM:80" );
|
||||
|
||||
address c{"tcp://[::1]:1111"};
|
||||
REQUIRE( c.host == "[::1]" );
|
||||
REQUIRE( c.port == 1111 );
|
||||
}
|
||||
|
||||
TEST_CASE("unix sockets", "[address][ipc]") {
|
||||
address a{"ipc:///path/to/foo"};
|
||||
REQUIRE( a.socket == "/path/to/foo" );
|
||||
REQUIRE_FALSE( a.curve() );
|
||||
REQUIRE_FALSE( a.tcp() );
|
||||
REQUIRE( a.zmq_address() == "ipc:///path/to/foo" );
|
||||
REQUIRE( a.full_address() == "ipc:///path/to/foo" );
|
||||
|
||||
address b = address::ipc("../foo");
|
||||
REQUIRE( b.socket == "../foo" );
|
||||
REQUIRE_FALSE( b.curve() );
|
||||
REQUIRE_FALSE( b.tcp() );
|
||||
REQUIRE( b.zmq_address() == "ipc://../foo" );
|
||||
REQUIRE( b.full_address() == "ipc://../foo" );
|
||||
}
|
||||
|
||||
TEST_CASE("pubkey formats", "[address][curve][pubkey]") {
|
||||
address a{"tcp+curve://a:1/" + pk_hex};
|
||||
address b{"curve://a:1/" + pk_b32z};
|
||||
address c{"curve://a:1/" + pk_b64};
|
||||
address d{"CURVE://A:1/" + pk_B32Z};
|
||||
REQUIRE( a.curve() );
|
||||
REQUIRE( a.host == "a" );
|
||||
REQUIRE( a.port == 1 );
|
||||
REQUIRE((b.curve() && c.curve() && d.curve()));
|
||||
REQUIRE( a.pubkey == pk );
|
||||
REQUIRE( b.pubkey == pk );
|
||||
REQUIRE( c.pubkey == pk );
|
||||
REQUIRE( d.pubkey == pk );
|
||||
|
||||
address e{"ipc+curve://my.sock/" + pk_hex};
|
||||
address f{"ipc+curve://../my.sock/" + pk_b32z};
|
||||
address g{"ipc+curve:///my.sock/" + pk_B32Z};
|
||||
address h{"ipc+curve://./my.sock/" + pk_b64};
|
||||
REQUIRE( e.curve() );
|
||||
REQUIRE( e.ipc() );
|
||||
REQUIRE_FALSE( e.tcp() );
|
||||
REQUIRE((f.curve() && g.curve() && h.curve()));
|
||||
REQUIRE( e.socket == "my.sock" );
|
||||
REQUIRE( f.socket == "../my.sock" );
|
||||
REQUIRE( g.socket == "/my.sock" );
|
||||
REQUIRE( h.socket == "./my.sock" );
|
||||
REQUIRE( e.pubkey == pk );
|
||||
REQUIRE( f.pubkey == pk );
|
||||
REQUIRE( g.pubkey == pk );
|
||||
REQUIRE( h.pubkey == pk );
|
||||
|
||||
REQUIRE( d.full_address(address::encoding::hex) == "curve://a:1/" + pk_hex );
|
||||
REQUIRE( c.full_address(address::encoding::base32z) == "curve://a:1/" + pk_b32z );
|
||||
REQUIRE( b.full_address(address::encoding::BASE32Z) == "curve://a:1/" + pk_B32Z );
|
||||
REQUIRE( a.full_address(address::encoding::base64) == "curve://a:1/" + pk_b64 );
|
||||
|
||||
REQUIRE( h.full_address(address::encoding::hex) == "ipc+curve://./my.sock/" + pk_hex );
|
||||
REQUIRE( g.full_address(address::encoding::base32z) == "ipc+curve:///my.sock/" + pk_b32z );
|
||||
REQUIRE( f.full_address(address::encoding::BASE32Z) == "ipc+curve://../my.sock/" + pk_B32Z );
|
||||
REQUIRE( e.full_address(address::encoding::base64) == "ipc+curve://my.sock/" + pk_b64 );
|
||||
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_hex.substr(0, 63)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b32z.substr(0, 51)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_B32Z.substr(0, 51)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/" + pk_b64.substr(0, 42)}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock"}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"ipc+curve://my.sock/"}, std::invalid_argument);
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("tcp QR-code friendly addresses", "[address][tcp][qr]") {
|
||||
address a{"tcp://public.loki.foundation:12345"};
|
||||
address a_qr{"TCP://PUBLIC.LOKI.FOUNDATION:12345"};
|
||||
address b{"tcp://PUBLIC.LOKI.FOUNDATION:12345"};
|
||||
REQUIRE( a == a_qr );
|
||||
REQUIRE( a != b );
|
||||
REQUIRE( a.host == "public.loki.foundation" );
|
||||
REQUIRE( a.qr_address() == "TCP://PUBLIC.LOKI.FOUNDATION:12345" );
|
||||
|
||||
address c = address::tcp_curve("public.loki.foundation", 12345, pk);
|
||||
REQUIRE( c.qr_address() == "CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z );
|
||||
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_B32Z} == c );
|
||||
// We don't produce with upper-case hex, but we accept it:
|
||||
REQUIRE( address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_HEX} == c );
|
||||
|
||||
// lower case not permitted: ▾
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATiON:12345/" + pk_B32Z}, std::invalid_argument);
|
||||
// also only accept upper-base base32z and hex:
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_b32z}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/" + pk_hex}, std::invalid_argument);
|
||||
// don't accept base64 even if it's upper-case (because case-converting it changes the value)
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}, std::invalid_argument);
|
||||
REQUIRE_THROWS_AS(address{"CURVE://PUBLIC.LOKI.FOUNDATION:12345/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}, std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("address hashing", "[address][hash]") {
|
||||
address a{"tcp://public.loki.foundation:12345"};
|
||||
address b{"tcp+curve://public.loki.foundation:12345/" + pk_hex};
|
||||
address c{"ipc:///tmp/some.sock"};
|
||||
address d{"ipc:///tmp/some.other.sock"};
|
||||
|
||||
std::hash<oxenmq::address> hasher{};
|
||||
REQUIRE( hasher(a) != hasher(b) );
|
||||
REQUIRE( hasher(a) != hasher(c) );
|
||||
REQUIRE( hasher(a) != hasher(d) );
|
||||
REQUIRE( hasher(b) != hasher(c) );
|
||||
REQUIRE( hasher(b) != hasher(d) );
|
||||
REQUIRE( hasher(c) != hasher(d) );
|
||||
|
||||
std::unordered_set<oxenmq::address> set;
|
||||
set.insert(a);
|
||||
set.insert(b);
|
||||
set.insert(c);
|
||||
set.insert(d);
|
||||
|
||||
CHECK( set.size() == 4 );
|
||||
std::unordered_map<oxenmq::address, int> count;
|
||||
for (const auto& addr : set)
|
||||
count[addr]++;
|
||||
|
||||
REQUIRE( count.size() == 4 );
|
||||
CHECK( count[a] == 1 );
|
||||
CHECK( count[b] == 1 );
|
||||
CHECK( count[c] == 1 );
|
||||
CHECK( count[d] == 1 );
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
#include "lokimq/batch.h"
|
||||
#include "oxenmq/batch.h"
|
||||
#include "common.h"
|
||||
#include <future>
|
||||
|
||||
|
@ -12,7 +12,7 @@ double do_my_task(int input) {
|
|||
|
||||
std::promise<std::pair<double, int>> done;
|
||||
|
||||
void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
||||
void continue_big_task(std::vector<oxenmq::job_result<double>> results) {
|
||||
double sum = 0;
|
||||
int exc_count = 0;
|
||||
for (auto& r : results) {
|
||||
|
@ -25,10 +25,10 @@ void continue_big_task(std::vector<lokimq::job_result<double>> results) {
|
|||
done.set_value({sum, exc_count});
|
||||
}
|
||||
|
||||
void start_big_task(lokimq::LokiMQ& lmq) {
|
||||
void start_big_task(oxenmq::OxenMQ& omq) {
|
||||
size_t num_jobs = 32;
|
||||
|
||||
lokimq::Batch<double /*return type*/> batch;
|
||||
oxenmq::Batch<double /*return type*/> batch;
|
||||
batch.reserve(num_jobs);
|
||||
|
||||
for (size_t i = 0; i < num_jobs; i++)
|
||||
|
@ -36,21 +36,21 @@ void start_big_task(lokimq::LokiMQ& lmq) {
|
|||
|
||||
batch.completion(&continue_big_task);
|
||||
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("batching many small jobs", "[batch-many]") {
|
||||
lokimq::LokiMQ lmq{
|
||||
oxenmq::OxenMQ omq{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
lmq.set_general_threads(4);
|
||||
lmq.set_batch_threads(4);
|
||||
lmq.start();
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
omq.start();
|
||||
|
||||
start_big_task(lmq);
|
||||
start_big_task(omq);
|
||||
auto sum = done.get_future().get();
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( sum.first == 1337.0 );
|
||||
|
@ -58,14 +58,14 @@ TEST_CASE("batching many small jobs", "[batch-many]") {
|
|||
}
|
||||
|
||||
TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
||||
lokimq::LokiMQ lmq{
|
||||
oxenmq::OxenMQ omq{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
lmq.set_general_threads(4);
|
||||
lmq.set_batch_threads(4);
|
||||
lmq.start();
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
omq.start();
|
||||
|
||||
std::promise<void> done_promise;
|
||||
std::future<void> done_future = done_promise.get_future();
|
||||
|
@ -73,7 +73,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
using Catch::Matchers::Message;
|
||||
|
||||
SECTION( "value return" ) {
|
||||
lokimq::Batch<int> batch;
|
||||
oxenmq::Batch<int> batch;
|
||||
for (int i : {1, 2})
|
||||
batch.add_job([i]() { if (i == 1) return 42; throw std::domain_error("bad value " + std::to_string(i)); });
|
||||
batch.completion([&done_promise](auto results) {
|
||||
|
@ -83,12 +83,12 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
REQUIRE_THROWS_MATCHES( results[1].get() == 0, std::domain_error, Message("bad value 2") );
|
||||
done_promise.set_value();
|
||||
});
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
done_future.get();
|
||||
}
|
||||
|
||||
SECTION( "lvalue return" ) {
|
||||
lokimq::Batch<int&> batch;
|
||||
oxenmq::Batch<int&> batch;
|
||||
int forty_two = 42;
|
||||
for (int i : {1, 2})
|
||||
batch.add_job([i,&forty_two]() -> int& {
|
||||
|
@ -105,12 +105,12 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
REQUIRE_THROWS_MATCHES( results[1].get(), std::domain_error, Message("bad value 2") );
|
||||
done_promise.set_value();
|
||||
});
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
done_future.get();
|
||||
}
|
||||
|
||||
SECTION( "void return" ) {
|
||||
lokimq::Batch<void> batch;
|
||||
oxenmq::Batch<void> batch;
|
||||
for (int i : {1, 2})
|
||||
batch.add_job([i]() { if (i != 1) throw std::domain_error("bad value " + std::to_string(i)); });
|
||||
batch.completion([&done_promise](auto results) {
|
||||
|
@ -120,7 +120,7 @@ TEST_CASE("batch exception propagation", "[batch-exceptions]") {
|
|||
REQUIRE_THROWS_MATCHES( results[1].get(), std::domain_error, Message("bad value 2") );
|
||||
done_promise.set_value();
|
||||
});
|
||||
lmq.batch(std::move(batch));
|
||||
omq.batch(std::move(batch));
|
||||
done_future.get();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,21 +1,20 @@
|
|||
#include "common.h"
|
||||
#include <future>
|
||||
#include <lokimq/hex.h>
|
||||
#include <oxenc/hex.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
using namespace lokimq;
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("basic commands", "[commands]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos{0}, his{0};
|
||||
|
||||
|
@ -32,52 +31,43 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{
|
||||
get_logger("C» ")
|
||||
};
|
||||
client.log_level(LogLevel::trace);
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.add_category("public", Access{AuthLevel::none});
|
||||
client.add_command("public", "hi", [&](auto&) { his++; });
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false, failed = false;
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto conn, string_view) { failed = true; },
|
||||
server.get_pubkey());
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
|
||||
[&](auto conn, std::string_view) { failed = true; got = true; });
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( i <= 1 ); // should be fast
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.client.pubkey");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 1 );
|
||||
REQUIRE( his == 1 );
|
||||
REQUIRE( to_hex(client_pubkey) == to_hex(client.get_pubkey()) );
|
||||
REQUIRE( oxenc::to_hex(client_pubkey) == oxenc::to_hex(client.get_pubkey()) );
|
||||
}
|
||||
|
||||
for (int i = 0; i < 50; i++)
|
||||
client.send(c, "public.hello");
|
||||
|
||||
std::this_thread::sleep_for(100ms);
|
||||
wait_for([&] { return his == 26; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 51 );
|
||||
|
@ -86,15 +76,15 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
}
|
||||
|
||||
TEST_CASE("outgoing auth level", "[commands][auth]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos{0};
|
||||
|
||||
|
@ -103,10 +93,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{
|
||||
get_logger("C» ")
|
||||
};
|
||||
client.log_level(LogLevel::trace);
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
std::atomic<int> public_hi{0}, basic_hi{0}, admin_hi{0};
|
||||
client.add_category("public", Access{AuthLevel::none});
|
||||
|
@ -117,15 +104,15 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
client.add_command("admin", "hi", [&](auto&) { admin_hi++; });
|
||||
client.start();
|
||||
|
||||
client.PUBKEY_BASED_ROUTING_ID = false; // establishing multiple connections below, so we need unique routing ids
|
||||
client.EPHEMERAL_ROUTING_ID = true; // establishing multiple connections below, so we need unique routing ids
|
||||
|
||||
auto public_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey());
|
||||
auto basic_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::basic);
|
||||
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
|
||||
address server_addr{listen, server.get_pubkey()};
|
||||
auto public_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {});
|
||||
auto basic_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::basic);
|
||||
auto admin_c = client.connect_remote(server_addr, [](auto&&...) {}, [](auto&&...) {}, AuthLevel::admin);
|
||||
|
||||
client.send(public_c, "public.reflect", "public.hi");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
wait_for([&] { return public_hi == 1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( public_hi == 1 );
|
||||
|
@ -138,7 +125,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
client.send(admin_c, "public.reflect", "admin.hi");
|
||||
client.send(basic_c, "public.reflect", "basic.hi");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
wait_for([&] { return basic_hi == 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( admin_hi == 3 );
|
||||
|
@ -160,7 +147,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
client.send(admin_c, "public.reflect", "basic.hi");
|
||||
client.send(admin_c, "public.reflect", "public.hi");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
wait_for([&] { return public_hi == 3; });
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( admin_hi == 1 );
|
||||
REQUIRE( basic_hi == 2 );
|
||||
|
@ -171,33 +158,41 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
// Tests that the ConnectionID from a Message can be stored and reused later to contact the
|
||||
// original node.
|
||||
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::vector<std::pair<ConnectionID, std::string>> subscribers;
|
||||
ConnectionID backdoor;
|
||||
|
||||
server.add_category("hey google", Access{AuthLevel::none});
|
||||
server.add_request_command("hey google", "remember", [&](Message& m) {
|
||||
auto l = catch_lock();
|
||||
subscribers.emplace_back(m.conn, std::string{m.data[0]});
|
||||
bool bd;
|
||||
{
|
||||
auto l = catch_lock();
|
||||
subscribers.emplace_back(m.conn, std::string{m.data[0]});
|
||||
bd = (bool) backdoor;
|
||||
}
|
||||
m.send_reply("Okay, I'll remember that.");
|
||||
|
||||
if (backdoor)
|
||||
m.lokimq.send(backdoor, "backdoor.data", m.data[0]);
|
||||
if (bd)
|
||||
m.oxenmq.send(backdoor, "backdoor.data", m.data[0]);
|
||||
});
|
||||
server.add_command("hey google", "recall", [&](Message& m) {
|
||||
auto l = catch_lock();
|
||||
for (auto& s : subscribers) {
|
||||
server.send(s.first, "personal.detail", s.second);
|
||||
decltype(subscribers) subs;
|
||||
{
|
||||
auto l = catch_lock();
|
||||
subs = subscribers;
|
||||
}
|
||||
|
||||
for (auto& s : subs)
|
||||
server.send(s.first, "personal.detail", s.second);
|
||||
});
|
||||
server.add_command("hey google", "install backdoor", [&](Message& m) {
|
||||
auto l = catch_lock();
|
||||
|
@ -206,29 +201,29 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
|
||||
server.start();
|
||||
|
||||
auto connect_success = [&](...) { auto l = catch_lock(); REQUIRE(true); };
|
||||
auto connect_failure = [&](...) { auto l = catch_lock(); REQUIRE(false); };
|
||||
auto connect_success = [&](auto&&...) { auto l = catch_lock(); REQUIRE(true); };
|
||||
auto connect_failure = [&](auto&&...) { auto l = catch_lock(); REQUIRE(false); };
|
||||
|
||||
|
||||
std::set<std::string> backdoor_details;
|
||||
|
||||
LokiMQ nsa{get_logger("NSA» ")};
|
||||
OxenMQ nsa{get_logger("NSA» ")};
|
||||
nsa.add_category("backdoor", Access{AuthLevel::admin});
|
||||
nsa.add_command("backdoor", "data", [&](Message& m) {
|
||||
backdoor_details.emplace(m.data[0]);
|
||||
auto l = catch_lock();
|
||||
backdoor_details.emplace(m.data[0]);
|
||||
});
|
||||
nsa.start();
|
||||
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
|
||||
auto nsa_c = nsa.connect_remote(address{listen, server.get_pubkey()}, connect_success, connect_failure, AuthLevel::admin);
|
||||
nsa.send(nsa_c, "hey google.install backdoor");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( backdoor );
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<LokiMQ>> clients;
|
||||
std::vector<std::unique_ptr<OxenMQ>> clients;
|
||||
std::vector<ConnectionID> conns;
|
||||
std::map<int, std::set<std::string>> personal_details{
|
||||
{0, {"Loretta"s, "photos"s}},
|
||||
|
@ -240,12 +235,14 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
std::set<std::string> all_the_things;
|
||||
for (auto& pd : personal_details) all_the_things.insert(pd.second.begin(), pd.second.end());
|
||||
|
||||
address server_addr{listen, server.get_pubkey()};
|
||||
std::map<int, std::set<std::string>> google_knows;
|
||||
int things_remembered{0};
|
||||
for (int i = 0; i < 5; i++) {
|
||||
clients.push_back(std::make_unique<LokiMQ>(get_logger("C" + std::to_string(i) + "» ")));
|
||||
clients.push_back(std::make_unique<OxenMQ>(
|
||||
get_logger("C" + std::to_string(i) + "» "), LogLevel::trace
|
||||
));
|
||||
auto& c = clients.back();
|
||||
c->log_level(LogLevel::trace);
|
||||
c->add_category("personal", Access{AuthLevel::basic});
|
||||
c->add_command("personal", "detail", [&,i](Message& m) {
|
||||
auto l = catch_lock();
|
||||
|
@ -253,7 +250,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
});
|
||||
c->start();
|
||||
conns.push_back(
|
||||
c->connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::basic));
|
||||
c->connect_remote(server_addr, connect_success, connect_failure, AuthLevel::basic));
|
||||
for (auto& personal_detail : personal_details[i])
|
||||
c->request(conns.back(), "hey google.remember",
|
||||
[&](bool success, std::vector<std::string> data) {
|
||||
|
@ -265,7 +262,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
},
|
||||
personal_detail);
|
||||
}
|
||||
std::this_thread::sleep_for(50ms);
|
||||
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size() && things_remembered == backdoor_details.size(); });
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( things_remembered == all_the_things.size() );
|
||||
|
@ -273,9 +270,256 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
}
|
||||
|
||||
clients[0]->send(conns[0], "hey google.recall");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
reply_sleep();
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( google_knows == personal_details );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::debug // This test traces so much that it takes 2.5-3s of CPU time at trace level, so don't do that.
|
||||
};
|
||||
server.listen_plain(listen);
|
||||
|
||||
std::atomic<int> send_attempts{0};
|
||||
std::atomic<int> send_failures{0};
|
||||
// ZMQ TCP sockets' HWM is complicated and OS dependent; sender and receiver (probably) each
|
||||
// have 1000 message queues, but there is also the TCP queue to worry about which means we can
|
||||
// have more queued before we fill up, so we send 4kiB of null with each message so that we
|
||||
// don't get too much TCP queuing.
|
||||
std::string junk(4096, '0');
|
||||
server.add_category("x", Access{AuthLevel::none})
|
||||
.add_command("x", [&](Message& m) {
|
||||
for (int x = 0; x < 500; x++) {
|
||||
++send_attempts;
|
||||
m.send_back("y.y", junk, send_option::queue_full{[&]() { ++send_failures; }});
|
||||
}
|
||||
});
|
||||
|
||||
server.start();
|
||||
|
||||
// Use a raw socket here because I want to stall it by not reading from it at all, and that is
|
||||
// hard with OxenMQ.
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
zmq::message_t hello;
|
||||
auto recvd = client.recv(hello);
|
||||
std::string_view hello_sv{hello.data<char>(), hello.size()};
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello_sv == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
// Tell the remote to queue up a batch of messages
|
||||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 20; i++) {
|
||||
if (send_attempts.load() >= 500)
|
||||
break;
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( i < 20 ); // should be not too slow
|
||||
// We have two buffers here: 1000 on the receiver, and 1000 on the client, which means we
|
||||
// should be able to get 2000 out before we hit HWM. We should only have been sent 501 so
|
||||
// far (the "HELLO" handshake + 500 "y.y" messages).
|
||||
REQUIRE( send_attempts.load() == 500 );
|
||||
REQUIRE( send_failures.load() == 0 );
|
||||
}
|
||||
|
||||
// Now we want to tell the server to send enough to fill the outgoing queue and start stalling.
|
||||
// This is complicated as it depends on ZMQ internals *and* OS-level TCP buffers, so we really
|
||||
// don't know precisely where this will start failing.
|
||||
//
|
||||
// In practice, I seem to reach HWM (for this test, with this amount of data being sent, on my
|
||||
// Debian desktop) after 2499 messages (that is, queuing 2500 gives 1 failure).
|
||||
int expected_attempts = 500;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
expected_attempts += 500;
|
||||
if (i >= 4) {
|
||||
if (send_failures.load() > 0)
|
||||
break;
|
||||
std::this_thread::sleep_for(25ms);
|
||||
}
|
||||
}
|
||||
|
||||
for (i = 0; i < 100; i++) {
|
||||
if (send_attempts.load() >= expected_attempts)
|
||||
break;
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( i < 100 );
|
||||
REQUIRE( send_attempts.load() == expected_attempts );
|
||||
REQUIRE( send_failures.load() > 0 );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("data parts", "[commands][send][data_parts]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::mutex mut;
|
||||
std::vector<std::string> r;
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_command("public", "hello", [&](Message& m) {
|
||||
std::lock_guard l{mut};
|
||||
for (const auto& s : m.data)
|
||||
r.emplace_back(s);
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false, failed = false;
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
|
||||
[&](auto conn, std::string_view) { failed = true; got = true; });
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::vector some_data{{"abc"s, "def"s, "omg123\0zzz"s}};
|
||||
client.send(c, "public.hello", oxenmq::send_option::data_parts(some_data.begin(), some_data.end()));
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
std::lock_guard l{mut};
|
||||
REQUIRE( r == some_data );
|
||||
r.clear();
|
||||
}
|
||||
|
||||
std::optional<std::string_view> opt1, opt2;
|
||||
std::optional<std::string> opt3, opt4;
|
||||
opt1 = "o1"sv;
|
||||
opt4 = "o4"s;
|
||||
std::vector some_data2{{"a"sv, "b"sv, "\0"sv}};
|
||||
client.send(c, "public.hello",
|
||||
"hi",
|
||||
oxenmq::send_option::data_parts(some_data2.begin(), some_data2.end()),
|
||||
"another",
|
||||
"string"sv,
|
||||
oxenmq::send_option::data_parts(some_data.begin(), some_data.end()),
|
||||
opt1, opt2, opt3, opt4
|
||||
);
|
||||
|
||||
std::vector<std::string> expected;
|
||||
expected.push_back("hi");
|
||||
expected.insert(expected.end(), some_data2.begin(), some_data2.end());
|
||||
expected.push_back("another");
|
||||
expected.push_back("string");
|
||||
expected.insert(expected.end(), some_data.begin(), some_data.end());
|
||||
expected.push_back("o1");
|
||||
expected.push_back("o4");
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
std::lock_guard l{mut};
|
||||
REQUIRE( r == expected );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("deferred replies", "[commands][send][deferred]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos{0}, his{0};
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "echo", [&](Message& m) {
|
||||
std::string msg = m.data.empty() ? ""s : std::string{m.data.front()};
|
||||
std::thread t{[send=m.send_later(), msg=std::move(msg)] {
|
||||
{ auto lock = catch_lock(); UNSCOPED_INFO("sleeping"); }
|
||||
std::this_thread::sleep_for(50ms * TIME_DILATION);
|
||||
{ auto lock = catch_lock(); UNSCOPED_INFO("sending"); }
|
||||
send(msg);
|
||||
}};
|
||||
t.detach();
|
||||
});
|
||||
server.set_general_threads(1);
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
get_logger("C» "),
|
||||
LogLevel::trace);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> replies;
|
||||
std::mutex reply_mut;
|
||||
std::vector<std::string> data;
|
||||
for (auto str : {"hello", "world", "omg"})
|
||||
client.request(c, "public.echo", [&](bool ok, std::vector<std::string> data_) {
|
||||
std::lock_guard lock{reply_mut};
|
||||
replies.insert(std::move(data_[0]));
|
||||
}, str);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
std::lock_guard lq{reply_mut};
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( replies.size() == 0 ); // The server waits 50ms before sending, so we shouldn't have any reply yet
|
||||
}
|
||||
std::this_thread::sleep_for(60ms * TIME_DILATION); // We're at least 70ms in now so the 50ms-delayed server responses should have arrived
|
||||
{
|
||||
std::lock_guard lq{reply_mut};
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( replies == std::unordered_set<std::string>{{"hello", "world", "omg"}} );
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,74 +1,75 @@
|
|||
#include "common.h"
|
||||
#include <future>
|
||||
#include <lokimq/hex.h>
|
||||
extern "C" {
|
||||
#include <sodium.h>
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("connections with curve authentication", "[curve][connect]") {
|
||||
std::string listen = "tcp://127.0.0.1:4455";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» ")};
|
||||
client.log_level(LogLevel::trace);
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.start();
|
||||
|
||||
auto pubkey = server.get_pubkey();
|
||||
std::atomic<int> connected{0};
|
||||
auto server_conn = client.connect_remote(listen,
|
||||
[&](auto conn) { connected = 1; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); },
|
||||
pubkey);
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
auto server_conn = client.connect_remote(address{listen, pubkey},
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; });
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
success = false;
|
||||
std::vector<std::string> parts;
|
||||
client.request(server_conn, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
||||
std::string pubkey, privkey;
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
REQUIRE(sodium_init() != -1);
|
||||
auto listen_addr = random_localhost();
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
LokiMQ sn{
|
||||
OxenMQ sn{
|
||||
pubkey, privkey,
|
||||
true,
|
||||
[&](auto pk) { if (pk == pubkey) return "tcp://127.0.0.1:5544"; else return ""; },
|
||||
get_logger("S» ")
|
||||
[&](auto pk) { if (pk == pubkey) return listen_addr; else return ""s; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
|
||||
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk) { REQUIRE(ip == "127.0.0.1"); return Allow{AuthLevel::none, pk == pubkey}; });
|
||||
sn.listen_curve(listen_addr, [&](auto ip, auto pk, auto sn) {
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(ip == "127.0.0.1");
|
||||
REQUIRE(sn == (pk == pubkey));
|
||||
return AuthLevel::none;
|
||||
});
|
||||
sn.add_category("a", Access{AuthLevel::none});
|
||||
bool invoked = false;
|
||||
std::atomic<bool> invoked{false};
|
||||
sn.add_command("a", "b", [&](const Message& m) {
|
||||
invoked = true;
|
||||
auto lock = catch_lock();
|
||||
|
@ -77,98 +78,532 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
|||
REQUIRE(!m.data.empty());
|
||||
REQUIRE(m.data[0] == "my data");
|
||||
});
|
||||
sn.log_level(LogLevel::trace);
|
||||
sn.set_active_sns({{pubkey}});
|
||||
|
||||
sn.start();
|
||||
std::this_thread::sleep_for(50ms);
|
||||
sn.send(pubkey, "a.b", "my data");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(invoked);
|
||||
wait_for_conn(invoked);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(invoked);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("plain-text connections", "[plaintext][connect]") {
|
||||
std::string listen = "tcp://127.0.0.1:4455";
|
||||
LokiMQ server{get_logger("S» ")};
|
||||
server.log_level(LogLevel::trace);
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{get_logger("S» "), LogLevel::trace};
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
|
||||
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_plain(listen);
|
||||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» ")};
|
||||
client.log_level(LogLevel::trace);
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<int> connected{0};
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { connected = 1; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
auto c = client.connect_remote(address{listen},
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
|
||||
);
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
success = false;
|
||||
std::vector<std::string> parts;
|
||||
client.request(c, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("post-start listening", "[connect][listen]") {
|
||||
OxenMQ server{get_logger("S» "), LogLevel::trace};
|
||||
server.add_category("x", AuthLevel::none)
|
||||
.add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); });
|
||||
server.start();
|
||||
std::atomic<int> listens = 0;
|
||||
auto listen_curve = random_localhost();
|
||||
server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; });
|
||||
auto listen_plain = random_localhost();
|
||||
server.listen_plain(listen_plain, nullptr, [&](bool success) { if (success) listens += 10; });
|
||||
|
||||
wait_for([&] { return listens.load() >= 11; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( listens == 11 );
|
||||
}
|
||||
|
||||
// This should fail since we're already listening on it:
|
||||
server.listen_curve(listen_plain, nullptr, [&](bool success) { if (!success) listens++; });
|
||||
|
||||
wait_for([&] { return listens.load() >= 12; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( listens == 12 );
|
||||
}
|
||||
|
||||
|
||||
OxenMQ client{get_logger("C1» "), LogLevel::trace};
|
||||
client.start();
|
||||
std::atomic<int> conns = 0;
|
||||
auto c1 = client.connect_remote(address{listen_curve, server.get_pubkey()},
|
||||
[&](auto) { conns++; },
|
||||
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
|
||||
auto c2 = client.connect_remote(address{listen_plain},
|
||||
[&](auto) { conns += 10; },
|
||||
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
|
||||
|
||||
|
||||
wait_for([&] { return conns.load() >= 11; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( conns == 11 );
|
||||
}
|
||||
|
||||
std::atomic<int> replies = 0;
|
||||
std::string reply1, reply2;
|
||||
client.request(c1, "x.y", [&](auto success, auto parts) { replies++; for (auto& p : parts) reply1 += p; }, " world");
|
||||
client.request(c2, "x.y", [&](auto success, auto parts) { replies += 10; for (auto& p : parts) reply2 += p; }, " cat");
|
||||
|
||||
wait_for([&] { return replies.load() >= 11; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( replies == 11 );
|
||||
REQUIRE( reply1 == "hi world" );
|
||||
REQUIRE( reply2 == "hi cat" );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("unique connection IDs", "[connect][id]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{get_logger("S» "), LogLevel::trace};
|
||||
|
||||
ConnectionID first, second;
|
||||
server.add_category("x", Access{AuthLevel::none})
|
||||
.add_request_command("x", [&](Message& m) { first = m.conn; m.send_reply("hi"); })
|
||||
.add_request_command("y", [&](Message& m) { second = m.conn; m.send_reply("hi"); })
|
||||
;
|
||||
|
||||
server.listen_plain(listen);
|
||||
|
||||
server.start();
|
||||
|
||||
OxenMQ client1{get_logger("C1» "), LogLevel::trace};
|
||||
OxenMQ client2{get_logger("C2» "), LogLevel::trace};
|
||||
client1.start();
|
||||
client2.start();
|
||||
|
||||
std::atomic<bool> good1{false}, good2{false};
|
||||
auto r1 = client1.connect_remote(address{listen},
|
||||
[&](auto conn) { good1 = true; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
);
|
||||
auto r2 = client2.connect_remote(address{listen},
|
||||
[&](auto conn) { good2 = true; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
);
|
||||
|
||||
wait_for_conn(good1);
|
||||
wait_for_conn(good2);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( good1 );
|
||||
REQUIRE( good2 );
|
||||
REQUIRE( first == second );
|
||||
REQUIRE_FALSE( first );
|
||||
REQUIRE_FALSE( second );
|
||||
}
|
||||
|
||||
good1 = false;
|
||||
good2 = false;
|
||||
client1.request(r1, "x.x", [&](auto success_, auto parts_) { good1 = true; });
|
||||
client2.request(r2, "x.y", [&](auto success_, auto parts_) { good2 = true; });
|
||||
reply_sleep();
|
||||
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( good1 );
|
||||
REQUIRE( good2 );
|
||||
REQUIRE_FALSE( first == second );
|
||||
REQUIRE_FALSE( std::hash<ConnectionID>{}(first) == std::hash<ConnectionID>{}(second) );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
||||
std::vector<std::unique_ptr<LokiMQ>> lmq;
|
||||
std::vector<std::unique_ptr<OxenMQ>> omq;
|
||||
std::vector<std::string> pubkey, privkey;
|
||||
std::unordered_map<std::string, std::string> conn;
|
||||
REQUIRE(sodium_init() != -1);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
pubkey.emplace_back();
|
||||
privkey.emplace_back();
|
||||
pubkey[i].resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey[i].resize(crypto_box_SECRETKEYBYTES);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[i][0]), reinterpret_cast<unsigned char*>(&privkey[i][0]));
|
||||
conn.emplace(pubkey[i], "tcp://127.0.0.1:" + std::to_string(4450 + i));
|
||||
conn.emplace(pubkey[i], random_localhost());
|
||||
}
|
||||
std::atomic<int> his{0};
|
||||
for (int i = 0; i < pubkey.size(); i++) {
|
||||
lmq.push_back(std::make_unique<LokiMQ>(
|
||||
omq.push_back(std::make_unique<OxenMQ>(
|
||||
pubkey[i], privkey[i], true,
|
||||
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
|
||||
get_logger("S" + std::to_string(i) + "» ")
|
||||
get_logger("S" + std::to_string(i) + "» "),
|
||||
LogLevel::trace
|
||||
));
|
||||
auto& server = *lmq.back();
|
||||
server.log_level(LogLevel::debug);
|
||||
auto& server = *omq.back();
|
||||
|
||||
server.listen_curve(conn[pubkey[i]], [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, true}; });
|
||||
server.listen_curve(conn[pubkey[i]]);
|
||||
server.add_category("sn", Access{AuthLevel::none, true})
|
||||
.add_command("hi", [&](Message& m) { his++; });
|
||||
server.set_active_sns({pubkey.begin(), pubkey.end()});
|
||||
server.start();
|
||||
}
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
lmq[0]->send(pubkey[1], "sn.hi");
|
||||
lmq[0]->send(pubkey[2], "sn.hi");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
lmq[2]->send(pubkey[0], "sn.hi");
|
||||
lmq[2]->send(pubkey[1], "sn.hi");
|
||||
lmq[1]->send(pubkey[0], "BYE");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
lmq[0]->send(pubkey[2], "sn.hi");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
omq[0]->send(pubkey[1], "sn.hi");
|
||||
omq[0]->send(pubkey[2], "sn.hi");
|
||||
omq[2]->send(pubkey[0], "sn.hi");
|
||||
omq[2]->send(pubkey[1], "sn.hi");
|
||||
omq[1]->send(pubkey[0], "BYE");
|
||||
omq[0]->send(pubkey[2], "sn.hi");
|
||||
std::this_thread::sleep_for(50ms * TIME_DILATION);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(his == 5);
|
||||
}
|
||||
|
||||
TEST_CASE("SN auth checks", "[sandwich][auth]") {
|
||||
// When a remote connects, we check its authentication level; if at the time of connection it
|
||||
// isn't recognized as a SN but tries to invoke a SN command it'll be told to disconnect; if it
|
||||
// tries to send again it should reconnect and reauthenticate. This test is meant to test this
|
||||
// pattern where the reconnection/reauthentication now authenticates it as a SN.
|
||||
std::string listen = random_localhost();
|
||||
std::string pubkey, privkey;
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
REQUIRE(sodium_init() != -1);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
OxenMQ server{
|
||||
pubkey, privkey,
|
||||
true, // service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("A» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
|
||||
std::atomic<bool> incoming_is_sn{false};
|
||||
server.listen_curve(listen);
|
||||
server.add_category("public", Access{AuthLevel::none})
|
||||
.add_request_command("hello", [&](Message& m) { m.send_reply("hi"); })
|
||||
.add_request_command("sudo", [&](Message& m) {
|
||||
server.update_active_sns({{m.conn.pubkey()}}, {});
|
||||
m.send_reply("making sandwiches");
|
||||
})
|
||||
.add_request_command("nosudo", [&](Message& m) {
|
||||
// Send the reply *first* because if we do it the other way we'll have just removed
|
||||
// ourselves from the list of SNs and thus would try to open an outbound connection
|
||||
// to deliver it since it's still queued as a message to a SN.
|
||||
m.send_reply("make them yourself");
|
||||
server.update_active_sns({}, {{m.conn.pubkey()}});
|
||||
});
|
||||
server.add_category("sandwich", Access{AuthLevel::none, true})
|
||||
.add_request_command("make", [&](Message& m) { m.send_reply("okay"); });
|
||||
server.start();
|
||||
|
||||
OxenMQ client{
|
||||
"", "", false,
|
||||
[&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; },
|
||||
get_logger("B» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success;
|
||||
client.request(pubkey, "public.hello", [&](auto success_, auto) { success = success_; got = true; });
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
got = false;
|
||||
using dvec = std::vector<std::string>;
|
||||
dvec data;
|
||||
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
|
||||
success = success_;
|
||||
data = std::move(data_);
|
||||
got = true;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE_FALSE( success );
|
||||
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
|
||||
}
|
||||
|
||||
// Somebody set up us the bomb. Main sudo turn on.
|
||||
got = false;
|
||||
client.request(pubkey, "public.sudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == dvec{{"making sandwiches"}} );
|
||||
}
|
||||
|
||||
got = false;
|
||||
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
|
||||
success = success_;
|
||||
data = std::move(data_);
|
||||
got = true;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == dvec{{"okay"}} );
|
||||
}
|
||||
|
||||
// Take off every 'SUDO', You [not] know what you doing
|
||||
got = false;
|
||||
client.request(pubkey, "public.nosudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == dvec{{"make them yourself"}} );
|
||||
}
|
||||
|
||||
got = false;
|
||||
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
|
||||
success = success_;
|
||||
data = std::move(data_);
|
||||
got = true;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE_FALSE( success );
|
||||
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("SN single worker test", "[connect][worker]") {
|
||||
// Tests a failure case that could trigger when all workers are allocated (here we make that
|
||||
// simpler by just having one worker).
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "",
|
||||
false, // service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.set_general_threads(1);
|
||||
server.set_batch_threads(0);
|
||||
server.set_reply_threads(0);
|
||||
server.listen_plain(listen);
|
||||
server.add_category("c", Access{AuthLevel::none})
|
||||
.add_request_command("x", [&](Message& m) { m.send_reply(); })
|
||||
;
|
||||
server.start();
|
||||
|
||||
OxenMQ client{get_logger("B» "), LogLevel::trace};
|
||||
client.start();
|
||||
auto conn = client.connect_remote(address{listen}, [](auto) {}, [](auto, auto) {});
|
||||
|
||||
std::atomic<int> got{0};
|
||||
std::atomic<int> success{0};
|
||||
client.request(conn, "c.x", [&](auto success_, auto) { if (success_) ++success; ++got; });
|
||||
wait_for([&] { return got.load() >= 1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success == 1 );
|
||||
}
|
||||
client.request(conn, "c.x", [&](auto success_, auto) { if (success_) ++success; ++got; });
|
||||
wait_for([&] { return got.load() >= 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success == 2 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("SN backchatter", "[connect][sn]") {
|
||||
// When we have a SN connection A -> B and then B sends a message to A on that existing
|
||||
// connection, A should see it as coming from B.
|
||||
std::vector<std::unique_ptr<OxenMQ>> omq;
|
||||
std::vector<std::string> pubkey, privkey;
|
||||
std::unordered_map<std::string, std::string> conn;
|
||||
REQUIRE(sodium_init() != -1);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
pubkey.emplace_back();
|
||||
privkey.emplace_back();
|
||||
pubkey[i].resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey[i].resize(crypto_box_SECRETKEYBYTES);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[i][0]), reinterpret_cast<unsigned char*>(&privkey[i][0]));
|
||||
conn.emplace(pubkey[i], random_localhost());
|
||||
}
|
||||
|
||||
for (int i = 0; i < pubkey.size(); i++) {
|
||||
omq.push_back(std::make_unique<OxenMQ>(
|
||||
pubkey[i], privkey[i], true,
|
||||
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
|
||||
get_logger("S" + std::to_string(i) + "» "),
|
||||
LogLevel::trace
|
||||
));
|
||||
auto& server = *omq.back();
|
||||
|
||||
server.listen_curve(conn[pubkey[i]]);
|
||||
server.set_active_sns({pubkey.begin(), pubkey.end()});
|
||||
}
|
||||
std::string f;
|
||||
omq[0]->add_category("a", Access{AuthLevel::none, true})
|
||||
.add_command("a", [&](Message& m) {
|
||||
m.oxenmq.send(m.conn, "b.b", "abc");
|
||||
//m.send_back("b.b", "abc");
|
||||
})
|
||||
.add_command("z", [&](Message& m) {
|
||||
auto lock = catch_lock();
|
||||
f = m.data[0];
|
||||
});
|
||||
omq[1]->add_category("b", Access{AuthLevel::none, true})
|
||||
.add_command("b", [&](Message& m) {
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
UNSCOPED_INFO("b.b from conn " << m.conn);
|
||||
}
|
||||
m.send_back("a.z", m.data[0]);
|
||||
});
|
||||
|
||||
for (auto& server : omq)
|
||||
server->start();
|
||||
|
||||
auto c = omq[1]->connect_sn(pubkey[0]);
|
||||
omq[1]->send(c, "a.a");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(f == "abc");
|
||||
}
|
||||
|
||||
TEST_CASE("inproc connections", "[connect][inproc]") {
|
||||
std::string inproc_name = "foo";
|
||||
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
|
||||
|
||||
omq.add_category("public", Access{AuthLevel::none});
|
||||
omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
|
||||
omq.start();
|
||||
|
||||
std::atomic<int> got{0};
|
||||
bool success = false;
|
||||
auto c_inproc = omq.connect_inproc(
|
||||
[&](auto conn) { success = true; got++; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; }
|
||||
);
|
||||
|
||||
wait_for([&got] { return got.load() > 0; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
REQUIRE( got == 1 );
|
||||
}
|
||||
|
||||
got = 0;
|
||||
success = false;
|
||||
omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) {
|
||||
success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got == 1 );
|
||||
REQUIRE( success );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("no explicit inproc listening", "[connect][inproc]") {
|
||||
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
|
||||
REQUIRE_THROWS_AS(omq.listen_plain("inproc://foo"), std::logic_error);
|
||||
REQUIRE_THROWS_AS(omq.listen_curve("inproc://foo"), std::logic_error);
|
||||
}
|
||||
|
||||
TEST_CASE("inproc connection permissions", "[connect][inproc]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ omq{get_logger("OMQ» "), LogLevel::trace};
|
||||
|
||||
omq.add_category("public", Access{AuthLevel::none});
|
||||
omq.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
omq.add_category("private", Access{AuthLevel::admin});
|
||||
omq.add_request_command("private", "handshake", [&](Message& m) { m.send_reply("yo dude"); });
|
||||
|
||||
omq.listen_plain(listen);
|
||||
|
||||
omq.start();
|
||||
|
||||
std::atomic<int> got{0};
|
||||
bool success = false;
|
||||
auto c_inproc = omq.connect_inproc(
|
||||
[&](auto conn) { success = true; got++; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("inproc connection failed: " << reason); got++; }
|
||||
);
|
||||
|
||||
bool pub_success = false;
|
||||
auto c_pub = omq.connect_remote(address{listen},
|
||||
[&](auto conn) { pub_success = true; got++; },
|
||||
[&](auto conn, std::string_view reason) { auto lock = catch_lock(); INFO("tcp connection failed: " << reason); got++; }
|
||||
);
|
||||
|
||||
wait_for([&got] { return got.load() == 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got == 2 );
|
||||
REQUIRE( success );
|
||||
REQUIRE( pub_success );
|
||||
}
|
||||
|
||||
got = 0;
|
||||
success = false;
|
||||
pub_success = false;
|
||||
bool success_private = false;
|
||||
bool pub_success_private = false;
|
||||
omq.request(c_inproc, "public.hello", [&](auto success_, auto parts_) {
|
||||
success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
|
||||
});
|
||||
omq.request(c_pub, "public.hello", [&](auto success_, auto parts_) {
|
||||
pub_success = success_ && parts_.size() == 1 && parts_.front() == "hi"; got++;
|
||||
});
|
||||
omq.request(c_inproc, "private.handshake", [&](auto success_, auto parts_) {
|
||||
success_private = success_ && parts_.size() == 1 && parts_.front() == "yo dude"; got++;
|
||||
});
|
||||
omq.request(c_pub, "private.handshake", [&](auto success_, auto parts_) {
|
||||
pub_success_private = success_; got++;
|
||||
});
|
||||
wait_for([&got] { return got.load() == 4; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got == 4 );
|
||||
REQUIRE( success );
|
||||
REQUIRE( pub_success );
|
||||
REQUIRE( success_private );
|
||||
REQUIRE_FALSE( pub_success_private );
|
||||
}
|
||||
}
|
||||
|
|
324
tests/test_failures.cpp
Normal file
324
tests/test_failures.cpp
Normal file
|
@ -0,0 +1,324 @@
|
|||
#include "common.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen);
|
||||
server.start();
|
||||
|
||||
// Use a raw socket here because I want to see the raw commands coming on the wire
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
auto recvd = client.recv(resp);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "a.a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen);
|
||||
server.add_category("x", AuthLevel::none)
|
||||
.add_request_command("r", [] (auto& m) { m.send_reply("a"); });
|
||||
server.start();
|
||||
|
||||
// Use a raw socket here because I want to see the raw commands coming on the wire
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
auto recvd = client.recv(resp);
|
||||
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.r" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "foo" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen, [](auto, auto, auto) {
|
||||
static int count = 0;
|
||||
++count;
|
||||
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
|
||||
});
|
||||
server.add_category("x", AuthLevel::basic)
|
||||
.add_command("x", [] (auto& m) { m.send_back("a"); });
|
||||
server.add_category("y", AuthLevel::admin)
|
||||
.add_command("x", [] (auto& m) { m.send_back("b"); });
|
||||
server.start();
|
||||
|
||||
zmq::context_t client_ctx;
|
||||
std::array<zmq::socket_t, 3> clients;
|
||||
// Client 0 should get none auth level, client 1 should get basic, client 2 should get admin
|
||||
for (auto& client : clients) {
|
||||
client = {client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& c : clients)
|
||||
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
auto recvd = clients[0].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( clients[0].recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
for (int i : {1, 2}) {
|
||||
recvd = clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
for (auto& c : clients)
|
||||
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
|
||||
|
||||
for (int i : {0, 1}) {
|
||||
recvd = clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( clients[i].recv(resp) );
|
||||
REQUIRE( resp.to_string() == "y.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
recvd = clients[2].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "b" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen, [](auto, auto, auto) {
|
||||
static int count = 0;
|
||||
++count;
|
||||
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
|
||||
});
|
||||
server.add_category("x", Access{AuthLevel::none, false, true})
|
||||
.add_command("x", [] (auto&) {})
|
||||
.add_request_command("r", [] (auto& m) { m.send_reply(); })
|
||||
;
|
||||
server.start();
|
||||
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
auto recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen, [](auto, auto, auto) {
|
||||
static int count = 0;
|
||||
++count;
|
||||
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
|
||||
});
|
||||
server.add_category("x", Access{AuthLevel::none, true, false})
|
||||
.add_command("x", [] (auto&) {})
|
||||
.add_request_command("r", [] (auto& m) { m.send_reply(); })
|
||||
;
|
||||
server.start();
|
||||
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
auto recvd = client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
auto recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
recvd = client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( recvd );
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
REQUIRE( client.recv(resp) );
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
96
tests/test_inject.cpp
Normal file
96
tests/test_inject.cpp
Normal file
|
@ -0,0 +1,96 @@
|
|||
#include "common.h"
|
||||
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("injected external commands", "[injected]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.set_general_threads(1);
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos = 0;
|
||||
std::atomic<bool> done = false;
|
||||
server.add_category("public", AuthLevel::none, 3);
|
||||
server.add_command("public", "hello", [&](Message& m) {
|
||||
hellos++;
|
||||
while (!done) std::this_thread::sleep_for(10ms);
|
||||
});
|
||||
|
||||
server.start();
|
||||
|
||||
OxenMQ client{get_logger("C» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
|
||||
// Deliberately using a deprecated command here, disable -Wdeprecated-declarations
|
||||
#ifdef __GNUG__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, std::string_view) { got = true; },
|
||||
server.get_pubkey());
|
||||
|
||||
#ifdef __GNUG__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
// First make sure that basic message respects the 3 thread limit
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 3; });
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 3 );
|
||||
}
|
||||
done = true;
|
||||
wait_for([&] { return hellos >= 4; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 4 );
|
||||
}
|
||||
|
||||
// Now try injecting external commands
|
||||
done = false;
|
||||
hellos = 0;
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 1; });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
wait_for([&] { return hellos >= 11; });
|
||||
client.send(c, "public.hello");
|
||||
wait_for([&] { return hellos >= 12; });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); });
|
||||
wait_for([&] { return hellos >= 12; });
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 12 );
|
||||
}
|
||||
done = true;
|
||||
wait_for([&] { return hellos >= 42; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 42 );
|
||||
}
|
||||
}
|
611
tests/test_pubsub.cpp
Normal file
611
tests/test_pubsub.cpp
Normal file
|
@ -0,0 +1,611 @@
|
|||
#include "common.h"
|
||||
#include "oxenmq/pubsub.h"
|
||||
|
||||
#include <oxenc/hex.h>
|
||||
|
||||
using namespace oxenmq;
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
TEST_CASE("sub OK", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> is_new{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
is_new = greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::atomic<int> reply_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
bool success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 2 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("user data", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<std::string> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> is_new{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
is_new = greetings.subscribe(m.conn, std::string{m.data[0]});
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::string response{"foo"};
|
||||
std::atomic<int> reply_count{0};
|
||||
std::atomic<int> foo_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == response)
|
||||
reply_count++;
|
||||
if (data[0] == "foo")
|
||||
foo_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
std::atomic<bool> success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
}, response);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( is_new );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
}, response);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( is_new );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn, std::string user) {
|
||||
server.send(conn, "notify.greetings", user);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
REQUIRE( foo_count == 1 );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
response = "bar";
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
}, response);
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( is_new );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn, std::string user) {
|
||||
server.send(conn, "notify.greetings", user);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 2 );
|
||||
REQUIRE( foo_count == 1 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("unsubscribe", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> was_subbed{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.add_request_command("public", "goodbye", [&](Message& m) {
|
||||
was_subbed = greetings.unsubscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::atomic<int> reply_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
std::atomic<bool> success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
REQUIRE( was_subbed );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count == 1 );
|
||||
}
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
REQUIRE( was_subbed == false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("expire", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings", 250ms};
|
||||
|
||||
std::atomic<bool> was_subbed{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.add_request_command("public", "goodbye", [&](Message& m) {
|
||||
was_subbed = greetings.unsubscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
std::atomic<int> reply_count{0};
|
||||
client.add_category("notify", Access{AuthLevel::none});
|
||||
client.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count++;
|
||||
});
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
bool success;
|
||||
std::vector<std::string> data;
|
||||
client.request(c, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
|
||||
// should be expired by now
|
||||
std::this_thread::sleep_for(500ms);
|
||||
|
||||
greetings.remove_expired();
|
||||
|
||||
got_reply = false;
|
||||
success = false;
|
||||
client.request(c, "public.goodbye", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply = true;
|
||||
success = ok;
|
||||
data = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"OK"}} );
|
||||
REQUIRE( was_subbed == false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("multiple subs", "[pubsub]") {
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_curve(listen);
|
||||
|
||||
Subscription<> greetings{"greetings"};
|
||||
|
||||
std::atomic<bool> is_new{false};
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "greetings", [&](Message& m) {
|
||||
is_new = greetings.subscribe(m.conn);
|
||||
m.send_reply("OK");
|
||||
});
|
||||
server.start();
|
||||
|
||||
/* client 1 */
|
||||
std::atomic<int> reply_count_c1{0};
|
||||
std::atomic<bool> connected_c1{false}, failed_c1{false};
|
||||
std::atomic<bool> got_reply_c1{false};
|
||||
bool success_c1;
|
||||
std::vector<std::string> data_c1;
|
||||
std::string pubkey_c1;
|
||||
OxenMQ client1(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
client1.add_category("notify", Access{AuthLevel::none});
|
||||
client1.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count_c1++;
|
||||
});
|
||||
|
||||
client1.start();
|
||||
|
||||
auto c1 = client1.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey_c1 = conn.pubkey(); connected_c1 = true; },
|
||||
[&](auto, auto) { failed_c1 = true; });
|
||||
|
||||
wait_for([&] { return connected_c1 || failed_c1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected_c1 );
|
||||
REQUIRE_FALSE( failed_c1 );
|
||||
REQUIRE( oxenc::to_hex(pubkey_c1) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
client1.request(c1, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply_c1 = true;
|
||||
success_c1 = ok;
|
||||
data_c1 = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply_c1.load() );
|
||||
REQUIRE( success_c1 );
|
||||
REQUIRE( data_c1 == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
/* end client 1 */
|
||||
|
||||
/* client 2 */
|
||||
std::atomic<int> reply_count_c2{0};
|
||||
std::atomic<bool> connected_c2{false}, failed_c2{false};
|
||||
std::atomic<bool> got_reply_c2{false};
|
||||
bool success_c2;
|
||||
std::vector<std::string> data_c2;
|
||||
std::string pubkey_c2;
|
||||
OxenMQ client2(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
|
||||
client2.add_category("notify", Access{AuthLevel::none});
|
||||
client2.add_command("notify", "greetings", [&](Message& m) {
|
||||
const auto& data = m.data;
|
||||
if (!data.size())
|
||||
{
|
||||
std::cerr << "client received public.greetings with empty data\n";
|
||||
return;
|
||||
}
|
||||
if (data[0] == "hello")
|
||||
reply_count_c2++;
|
||||
});
|
||||
|
||||
client2.start();
|
||||
|
||||
auto c2 = client2.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey_c2 = conn.pubkey(); connected_c2 = true; },
|
||||
[&](auto, auto) { failed_c2 = true; });
|
||||
|
||||
wait_for([&] { return connected_c2 || failed_c2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected_c2 );
|
||||
REQUIRE_FALSE( failed_c2 );
|
||||
REQUIRE( oxenc::to_hex(pubkey_c2) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
client2.request(c2, "public.greetings", [&](bool ok, std::vector<std::string> data_) {
|
||||
got_reply_c2 = true;
|
||||
success_c2 = ok;
|
||||
data_c2 = std::move(data_);
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply_c2.load() );
|
||||
REQUIRE( success_c2 );
|
||||
REQUIRE( data_c2 == std::vector<std::string>{{"OK"}} );
|
||||
}
|
||||
/* end client2 */
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count_c1 == 1 );
|
||||
REQUIRE( reply_count_c2 == 1 );
|
||||
}
|
||||
|
||||
greetings.publish([&](auto& conn) {
|
||||
server.send(conn, "notify.greetings", "hello");
|
||||
});
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( reply_count_c1 == 2 );
|
||||
REQUIRE( reply_count_c2 == 2 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// vim:sw=4:et
|
|
@ -1,12 +1,11 @@
|
|||
#include "common.h"
|
||||
#include <future>
|
||||
#include <lokimq/hex.h>
|
||||
#include <oxenc/hex.h>
|
||||
|
||||
using namespace lokimq;
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("basic requests", "[requests]") {
|
||||
std::string listen = "tcp://127.0.0.1:5678";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -21,7 +20,7 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
});
|
||||
server.start();
|
||||
|
||||
LokiMQ client(
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
@ -31,23 +30,16 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
|
@ -59,16 +51,18 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
data = std::move(data_);
|
||||
});
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"123"}} );
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"123"}} );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("request from server to client", "[requests]") {
|
||||
std::string listen = "tcp://127.0.0.1:5678";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -83,7 +77,7 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
});
|
||||
server.start();
|
||||
|
||||
LokiMQ client(
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
@ -93,10 +87,9 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
|
@ -109,7 +102,7 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
REQUIRE( connected.load() );
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
std::atomic<bool> got_reply{false};
|
||||
|
@ -131,8 +124,8 @@ TEST_CASE("request from server to client", "[requests]") {
|
|||
}
|
||||
|
||||
TEST_CASE("request timeouts", "[requests][timeout]") {
|
||||
std::string listen = "tcp://127.0.0.1:5678";
|
||||
LokiMQ server{
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
|
@ -145,7 +138,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
server.add_request_command("public", "blackhole", [&](Message& m) { /* doesn't reply */ });
|
||||
server.start();
|
||||
|
||||
LokiMQ client(
|
||||
OxenMQ client(
|
||||
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
|
||||
);
|
||||
//client.log_level(LogLevel::trace);
|
||||
|
@ -156,21 +149,15 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
auto c = client.connect_remote(address{listen, server.get_pubkey()},
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
[&](auto, auto) { failed = true; });
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
wait_for([&] { return connected || failed; });
|
||||
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) );
|
||||
|
||||
std::atomic<bool> got_triggered{false};
|
||||
bool success;
|
||||
|
@ -180,7 +167,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
success = ok;
|
||||
data = std::move(data_);
|
||||
},
|
||||
lokimq::send_option::request_timeout{30ms}
|
||||
oxenmq::send_option::request_timeout{10ms}
|
||||
);
|
||||
|
||||
std::atomic<bool> got_triggered2{false};
|
||||
|
@ -189,13 +176,13 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
success = ok;
|
||||
data = std::move(data_);
|
||||
},
|
||||
lokimq::send_option::request_timeout{100ms}
|
||||
oxenmq::send_option::request_timeout{200ms}
|
||||
);
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
REQUIRE( got_triggered.load() );
|
||||
std::this_thread::sleep_for(100ms);
|
||||
REQUIRE( got_triggered );
|
||||
REQUIRE_FALSE( got_triggered2 );
|
||||
REQUIRE_FALSE( success );
|
||||
REQUIRE( data.size() == 0 );
|
||||
REQUIRE( data == std::vector<std::string>{{"TIMEOUT"}} );
|
||||
|
||||
REQUIRE_FALSE( got_triggered2.load() );
|
||||
}
|
||||
|
|
43
tests/test_socket_limit.cpp
Normal file
43
tests/test_socket_limit.cpp
Normal file
|
@ -0,0 +1,43 @@
|
|||
#include "common.h"
|
||||
#include <oxenc/hex.h>
|
||||
|
||||
using namespace oxenmq;
|
||||
|
||||
TEST_CASE("zmq socket limit", "[zmq][socket-limit]") {
|
||||
// Make sure setting .MAX_SOCKETS works as expected. (This test was added when a bug was fixed
|
||||
// that was causing it not to be applied).
|
||||
std::string listen = random_localhost();
|
||||
OxenMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
};
|
||||
server.listen_plain(listen);
|
||||
server.start();
|
||||
|
||||
std::atomic<int> failed = 0, good = 0, failed_toomany = 0;
|
||||
OxenMQ client;
|
||||
client.MAX_SOCKETS = 15;
|
||||
client.start();
|
||||
|
||||
std::vector<ConnectionID> conns;
|
||||
address server_addr{listen};
|
||||
for (int i = 0; i < 16; i++)
|
||||
client.connect_remote(server_addr,
|
||||
[&](auto) { good++; },
|
||||
[&](auto cid, auto msg) {
|
||||
if (msg == "connect() failed: Too many open files")
|
||||
failed_toomany++;
|
||||
else
|
||||
failed++;
|
||||
});
|
||||
|
||||
|
||||
wait_for([&] { return good > 0 && failed_toomany > 0; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( good > 0 );
|
||||
REQUIRE( failed == 0 );
|
||||
REQUIRE( failed_toomany > 0 );
|
||||
}
|
||||
}
|
|
@ -1,187 +0,0 @@
|
|||
#include <catch2/catch.hpp>
|
||||
#include "lokimq/string_view.h"
|
||||
#include <future>
|
||||
|
||||
using namespace lokimq;
|
||||
|
||||
using namespace std::literals;
|
||||
|
||||
TEST_CASE("string view", "[string_view]") {
|
||||
std::string foo = "abc 123 xyz";
|
||||
string_view f1{foo};
|
||||
string_view f2{"def 789 uvw"};
|
||||
string_view f3{"nu\0ll", 5};
|
||||
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
REQUIRE( f2 == "def 789 uvw" );
|
||||
REQUIRE( f3.size() == 5 );
|
||||
REQUIRE( f3 == std::string{"nu\0ll", 5} );
|
||||
REQUIRE( f3 != "nu" );
|
||||
REQUIRE( f3.data() == "nu"s );
|
||||
REQUIRE( string_view(f3) == f3 );
|
||||
|
||||
auto f4 = f3;
|
||||
REQUIRE( f4 == f3 );
|
||||
f4 = f2;
|
||||
REQUIRE( f4 == "def 789 uvw" );
|
||||
|
||||
REQUIRE( f1.size() == 11 );
|
||||
REQUIRE( f3.length() == 5 );
|
||||
|
||||
string_view f5{""};
|
||||
REQUIRE( !f3.empty() );
|
||||
REQUIRE( f5.empty() );
|
||||
|
||||
REQUIRE( f1[5] == '2' );
|
||||
size_t i = 0;
|
||||
for (auto c : f3)
|
||||
REQUIRE(c == f3[i++]);
|
||||
|
||||
std::string backwards;
|
||||
for (auto it = std::rbegin(f2); it != f2.crend(); ++it)
|
||||
backwards += *it;
|
||||
|
||||
REQUIRE( backwards == "wvu 987 fed" );
|
||||
|
||||
REQUIRE( f1.at(10) == 'z' );
|
||||
REQUIRE_THROWS_AS( f1.at(15), std::out_of_range );
|
||||
REQUIRE_THROWS_AS( f1.at(11), std::out_of_range );
|
||||
|
||||
f4 = f1;
|
||||
f4.remove_prefix(2);
|
||||
REQUIRE( f4 == "c 123 xyz" );
|
||||
f4.remove_prefix(2);
|
||||
f4.remove_suffix(4);
|
||||
REQUIRE( f4 == "123" );
|
||||
f4.remove_prefix(1);
|
||||
REQUIRE( f4 == "23" );
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
f4.swap(f1);
|
||||
REQUIRE( f1 == "23" );
|
||||
REQUIRE( f4 == "abc 123 xyz" );
|
||||
f1.remove_suffix(2);
|
||||
REQUIRE( f1.empty() );
|
||||
REQUIRE( f4 == "abc 123 xyz" );
|
||||
f1.swap(f4);
|
||||
REQUIRE( f4.empty() );
|
||||
REQUIRE( f1 == "abc 123 xyz" );
|
||||
|
||||
REQUIRE( f1.front() == 'a' );
|
||||
REQUIRE( f1.back() == 'z' );
|
||||
REQUIRE( f1.compare("abc") > 0 );
|
||||
REQUIRE( f1.compare("abd") < 0 );
|
||||
REQUIRE( f1.compare("abc 123 xyz") == 0 );
|
||||
REQUIRE( f1.compare("abc 123 xyza") < 0 );
|
||||
REQUIRE( f1.compare("abc 123 xy") > 0 );
|
||||
|
||||
std::string buf;
|
||||
buf.resize(5);
|
||||
f1.copy(&buf[0], 5, 2);
|
||||
REQUIRE( buf == "c 123" );
|
||||
buf.resize(100, 'X');
|
||||
REQUIRE( f1.copy(&buf[0], 100) == 11 );
|
||||
REQUIRE( buf.substr(0, 11) == f1 );
|
||||
REQUIRE( buf.substr(11) == std::string(89, 'X') );
|
||||
REQUIRE( f1.substr(4) == "123 xyz" );
|
||||
REQUIRE( f1.substr(4, 3) == "123" );
|
||||
REQUIRE_THROWS_AS( f1.substr(500, 3), std::out_of_range );
|
||||
REQUIRE( f1.substr(11, 2) == "" );
|
||||
REQUIRE( f1.substr(8, 500) == "xyz" );
|
||||
REQUIRE( f1.find("123") == 4 );
|
||||
REQUIRE( f1.find("abc") == 0 );
|
||||
REQUIRE( f1.find("xyz") == 8 );
|
||||
REQUIRE( f1.find("abc 123 xyz 7") == string_view::npos );
|
||||
REQUIRE( f1.find("23") == 5 );
|
||||
REQUIRE( f1.find("234") == string_view::npos );
|
||||
|
||||
string_view f6{"zz abc abcd abcde abcdef"};
|
||||
REQUIRE( f6.find("abc") == 3 );
|
||||
REQUIRE( f6.find("abc", 3) == 3 );
|
||||
REQUIRE( f6.find("abc", 4) == 7 );
|
||||
REQUIRE( f6.find("abc", 7) == 7 );
|
||||
REQUIRE( f6.find("abc", 8) == 12 );
|
||||
REQUIRE( f6.find("abc", 18) == 18 );
|
||||
REQUIRE( f6.find("abc", 19) == string_view::npos );
|
||||
REQUIRE( f6.find("abcd") == 7 );
|
||||
REQUIRE( f6.rfind("abc") == 18 );
|
||||
REQUIRE( f6.rfind("abcd") == 18 );
|
||||
REQUIRE( f6.rfind("bcd") == 19 );
|
||||
REQUIRE( f6.rfind("abc", 19) == 18 );
|
||||
REQUIRE( f6.rfind("abc", 18) == 18 );
|
||||
REQUIRE( f6.rfind("abc", 17) == 12 );
|
||||
REQUIRE( f6.rfind("abc", 17) == 12 );
|
||||
REQUIRE( f6.rfind("abc", 8) == 7 );
|
||||
REQUIRE( f6.rfind("abc", 7) == 7 );
|
||||
REQUIRE( f6.rfind("abc", 6) == 3 );
|
||||
REQUIRE( f6.rfind("abc", 3) == 3 );
|
||||
REQUIRE( f6.rfind("abc", 2) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.find('a') == 3 );
|
||||
REQUIRE( f6.find('a', 17) == 18 );
|
||||
REQUIRE( f6.find('a', 20) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.rfind('a') == 18 );
|
||||
REQUIRE( f6.rfind('a', 17) == 12 );
|
||||
REQUIRE( f6.rfind('a', 2) == string_view::npos );
|
||||
|
||||
string_view f7{"abc\0def", 7};
|
||||
REQUIRE( f7.find("c\0d", 0, 3) == 2 );
|
||||
REQUIRE( f7.find("c\0e", 0, 3) == string_view::npos );
|
||||
REQUIRE( f7.rfind("c\0d", string_view::npos, 3) == 2 );
|
||||
REQUIRE( f7.rfind("c\0e", 0, 3) == string_view::npos );
|
||||
|
||||
REQUIRE( f6.find_first_of("c789b") == 4 );
|
||||
REQUIRE( f6.find_first_of("c789") == 5 );
|
||||
REQUIRE( f2.find_first_of("c789b") == 4 );
|
||||
REQUIRE( f6.find_first_of("c789b", 6) == 8 );
|
||||
|
||||
REQUIRE( f6.find_last_of("c789b") == 20 );
|
||||
REQUIRE( f6.find_last_of("789b") == 19 );
|
||||
REQUIRE( f2.find_last_of("c789b") == 6 );
|
||||
REQUIRE( f6.find_last_of("c789b", 6) == 5 );
|
||||
REQUIRE( f6.find_last_of("c789b", 5) == 5 );
|
||||
REQUIRE( f6.find_last_of("c789b", 4) == 4 );
|
||||
REQUIRE( f6.find_last_of("c789b", 3) == string_view::npos );
|
||||
|
||||
REQUIRE( f2.find_first_of(f7) == 0 );
|
||||
REQUIRE( f3.find_first_of(f7) == 2 );
|
||||
REQUIRE( f3.find_first_of('\0') == 2 );
|
||||
REQUIRE( f3.find_first_of("jk\0", 0, 3) == 2 );
|
||||
|
||||
REQUIRE( f1.find_first_not_of("abc") == 3 );
|
||||
REQUIRE( f1.find_first_not_of("abc ", 3) == 4 );
|
||||
REQUIRE( f1.find_first_not_of(" 123", 3) == 8 );
|
||||
REQUIRE( f1.find_last_not_of("abc") == 10 );
|
||||
REQUIRE( f1.find_last_not_of("xyz") == 7 );
|
||||
REQUIRE( f1.find_last_not_of("xyz 321") == 2 );
|
||||
REQUIRE( f1.find_last_not_of("xay z1b2c3") == string_view::npos );
|
||||
REQUIRE( f6.find_last_not_of("def") == 20 );
|
||||
REQUIRE( f6.find_last_not_of("abcdef") == 17 );
|
||||
REQUIRE( f6.find_last_not_of("abcdef ") == 1 );
|
||||
REQUIRE( f6.find_first_not_of('z') == 2 );
|
||||
REQUIRE( f6.find_first_not_of("z ") == 3 );
|
||||
REQUIRE( f6.find_first_not_of("a ", 2) == 4 );
|
||||
REQUIRE( f6.find_last_not_of("abc ", 9) == 1 );
|
||||
|
||||
REQUIRE( string_view{"abc"} == string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} == string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} == string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} != string_view{"abd"} );
|
||||
REQUIRE( string_view{"abc"} != string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} < string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} < string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abd"} < string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abcd"} < string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} < string_view{"abc"} );
|
||||
REQUIRE( string_view{"abd"} > string_view{"abc"} );
|
||||
REQUIRE( string_view{"abcd"} > string_view{"abc"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abcd"} );
|
||||
REQUIRE_FALSE( string_view{"abc"} > string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abcd"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} <= string_view{"abd"} );
|
||||
REQUIRE( string_view{"abd"} >= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abc"} >= string_view{"abc"} );
|
||||
REQUIRE( string_view{"abcd"} >= string_view{"abc"} );
|
||||
}
|
165
tests/test_tagged_threads.cpp
Normal file
165
tests/test_tagged_threads.cpp
Normal file
|
@ -0,0 +1,165 @@
|
|||
#include "oxenmq/batch.h"
|
||||
#include "common.h"
|
||||
#include <future>
|
||||
|
||||
TEST_CASE("tagged thread start functions", "[tagged][start]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(2);
|
||||
omq.set_batch_threads(2);
|
||||
auto t_abc = omq.add_tagged_thread("abc");
|
||||
std::atomic<bool> start_called = false;
|
||||
auto t_def = omq.add_tagged_thread("def", [&] { start_called = true; });
|
||||
|
||||
std::this_thread::sleep_for(20ms);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE_FALSE( start_called );
|
||||
}
|
||||
|
||||
omq.start();
|
||||
wait_for([&] { return start_called.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( start_called );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("tagged threads quit-before-start", "[tagged][quit]") {
|
||||
auto omq = std::make_unique<oxenmq::OxenMQ>(get_logger(""), LogLevel::trace);
|
||||
auto t_abc = omq->add_tagged_thread("abc");
|
||||
REQUIRE_NOTHROW(omq.reset());
|
||||
}
|
||||
|
||||
TEST_CASE("batch jobs to tagged threads", "[tagged][batch]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(2);
|
||||
omq.set_batch_threads(2);
|
||||
std::thread::id id_abc, id_def;
|
||||
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
auto t_def = omq.add_tagged_thread("def", [&] { id_def = std::this_thread::get_id(); });
|
||||
omq.start();
|
||||
|
||||
std::atomic<bool> done = false;
|
||||
std::thread::id id;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; });
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id != id_abc );
|
||||
REQUIRE( id != id_def );
|
||||
}
|
||||
|
||||
done = false;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id == id_abc );
|
||||
}
|
||||
|
||||
done = false;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_def);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( id == id_def );
|
||||
}
|
||||
|
||||
std::atomic<bool> sleep = true;
|
||||
auto sleeper = [&] { for (int i = 0; sleep && i < 10; i++) { std::this_thread::sleep_for(25ms); } };
|
||||
omq.job(sleeper);
|
||||
omq.job(sleeper);
|
||||
// This one should stall:
|
||||
std::atomic<bool> bad = false;
|
||||
omq.job([&] { bad = true; });
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
done = false;
|
||||
omq.job([&] { id = std::this_thread::get_id(); done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE_FALSE( bad.load() );
|
||||
}
|
||||
|
||||
done = false;
|
||||
// We can queue up a bunch of jobs which should all happen in order, and all on the abc thread.
|
||||
std::vector<int> v;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
omq.job([&] { if (std::this_thread::get_id() == id_abc) v.push_back(v.size()); }, t_abc);
|
||||
}
|
||||
omq.job([&] { done = true; }, t_abc);
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE_FALSE( bad.load() );
|
||||
REQUIRE( v.size() == 100 );
|
||||
for (int i = 0; i < 100; i++)
|
||||
REQUIRE( v[i] == i );
|
||||
}
|
||||
sleep = false;
|
||||
wait_for([&] { return bad.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( bad.load() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("batch job completion on tagged threads", "[tagged][batch-completion]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
std::thread::id id_abc;
|
||||
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
omq.start();
|
||||
|
||||
oxenmq::Batch<int> batch;
|
||||
for (int i = 1; i < 10; i++)
|
||||
batch.add_job([i, &id_abc]() { if (std::this_thread::get_id() == id_abc) return 0; return i; });
|
||||
|
||||
std::atomic<int> result_sum = -1;
|
||||
batch.completion([&](auto result) {
|
||||
int sum = 0;
|
||||
for (auto& r : result)
|
||||
sum += r.get();
|
||||
result_sum = std::this_thread::get_id() == id_abc ? sum : -sum;
|
||||
}, t_abc);
|
||||
omq.batch(std::move(batch));
|
||||
wait_for([&] { return result_sum.load() != -1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( result_sum == 45 );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("timer job completion on tagged threads", "[tagged][timer]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(4);
|
||||
omq.set_batch_threads(4);
|
||||
|
||||
std::thread::id id_abc;
|
||||
auto t_abc = omq.add_tagged_thread("abc", [&] { id_abc = std::this_thread::get_id(); });
|
||||
omq.start();
|
||||
|
||||
std::atomic<int> ticks = 0;
|
||||
std::atomic<int> abc_ticks = 0;
|
||||
omq.add_timer([&] { ticks++; }, 10ms);
|
||||
omq.add_timer([&] { if (std::this_thread::get_id() == id_abc) abc_ticks++; }, 10ms, true, t_abc);
|
||||
|
||||
wait_for([&] { return ticks.load() > 2 && abc_ticks > 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks.load() > 2 );
|
||||
REQUIRE( abc_ticks.load() > 2 );
|
||||
}
|
||||
}
|
||||
|
||||
|
132
tests/test_timer.cpp
Normal file
132
tests/test_timer.cpp
Normal file
|
@ -0,0 +1,132 @@
|
|||
#include "oxenmq/oxenmq.h"
|
||||
#include "common.h"
|
||||
#include <chrono>
|
||||
#include <future>
|
||||
|
||||
TEST_CASE("timer test", "[timer][basic]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(1);
|
||||
omq.set_batch_threads(1);
|
||||
|
||||
std::atomic<int> ticks = 0;
|
||||
auto timer = omq.add_timer([&] { ticks++; }, 5ms);
|
||||
omq.start();
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
wait_for([&] { return ticks.load() > 3; });
|
||||
{
|
||||
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start).count();
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks.load() > 3 );
|
||||
REQUIRE( elapsed_ms < 50 * TIME_DILATION );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("timer squelch", "[timer][squelch]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(3);
|
||||
omq.set_batch_threads(3);
|
||||
|
||||
std::atomic<bool> first = true;
|
||||
std::atomic<bool> done = false;
|
||||
std::atomic<int> ticks = 0;
|
||||
|
||||
// Set up a timer with squelch on; the job shouldn't get rescheduled until the first call
|
||||
// finishes, by which point we set `done` and so should get exactly 1 tick.
|
||||
auto timer = omq.add_timer([&] {
|
||||
if (first.exchange(false)) {
|
||||
std::this_thread::sleep_for(30ms * TIME_DILATION);
|
||||
ticks++;
|
||||
done = true;
|
||||
} else if (!done) {
|
||||
ticks++;
|
||||
}
|
||||
}, 5ms * TIME_DILATION, true /* squelch */);
|
||||
omq.start();
|
||||
|
||||
wait_for([&] { return done.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( done.load() );
|
||||
REQUIRE( ticks.load() == 1 );
|
||||
}
|
||||
|
||||
// Start another timer with squelch *off*; the subsequent jobs should get scheduled even while
|
||||
// the first one blocks
|
||||
std::atomic<bool> first2 = true;
|
||||
std::atomic<bool> done2 = false;
|
||||
std::atomic<int> ticks2 = 0;
|
||||
auto timer2 = omq.add_timer([&] {
|
||||
if (first2.exchange(false)) {
|
||||
std::this_thread::sleep_for(40ms * TIME_DILATION);
|
||||
done2 = true;
|
||||
} else if (!done2) {
|
||||
ticks2++;
|
||||
}
|
||||
}, 5ms, false /* squelch */);
|
||||
|
||||
wait_for([&] { return done2.load(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks2.load() > 2 );
|
||||
REQUIRE( done2.load() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("timer cancel", "[timer][cancel]") {
|
||||
oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace};
|
||||
|
||||
omq.set_general_threads(1);
|
||||
omq.set_batch_threads(1);
|
||||
|
||||
std::atomic<int> ticks = 0;
|
||||
|
||||
// We set up *and cancel* this timer before omq starts, so it should never fire
|
||||
auto notimer = omq.add_timer([&] { ticks += 1000; }, 5ms * TIME_DILATION);
|
||||
omq.cancel_timer(notimer);
|
||||
|
||||
TimerID timer = omq.add_timer([&] {
|
||||
if (++ticks == 3)
|
||||
omq.cancel_timer(timer);
|
||||
}, 5ms * TIME_DILATION);
|
||||
|
||||
omq.start();
|
||||
|
||||
wait_for([&] { return ticks.load() >= 3; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks.load() == 3 );
|
||||
}
|
||||
|
||||
// Test the alternative taking an lvalue reference instead of returning by value (see oxenmq.h
|
||||
// for why this is sometimes needed).
|
||||
std::atomic<int> ticks3 = 0;
|
||||
std::weak_ptr<TimerID> w_timer3;
|
||||
{
|
||||
auto timer3 = std::make_shared<TimerID>();
|
||||
auto& t3ref = *timer3; // Get this reference *before* we move the shared pointer into the lambda
|
||||
omq.add_timer(t3ref, [&ticks3, &omq, timer3=std::move(timer3)] {
|
||||
if (ticks3 == 0)
|
||||
ticks3++;
|
||||
else if (ticks3 > 1) {
|
||||
omq.cancel_timer(*timer3);
|
||||
ticks3++;
|
||||
}
|
||||
}, 1ms);
|
||||
}
|
||||
|
||||
wait_for([&] { return ticks3.load() >= 1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks3.load() == 1 );
|
||||
}
|
||||
ticks3++;
|
||||
wait_for([&] { return ticks3.load() >= 3 && w_timer3.expired(); });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( ticks3.load() == 3 );
|
||||
REQUIRE( w_timer3.expired() );
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in a new issue