Merge remote-tracking branch 'origin/master' into ubuntu/bionic

This commit is contained in:
Jason Rhinelander 2020-04-14 13:06:44 -03:00
commit 1662279a20
16 changed files with 981 additions and 325 deletions

View file

@ -5,8 +5,8 @@ project(liblokimq CXX)
include(GNUInstallDirs)
set(LOKIMQ_VERSION_MAJOR 1)
set(LOKIMQ_VERSION_MINOR 0)
set(LOKIMQ_VERSION_PATCH 5)
set(LOKIMQ_VERSION_MINOR 1)
set(LOKIMQ_VERSION_PATCH 0)
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
message(STATUS "lokimq v${LOKIMQ_VERSION}")

View file

@ -30,45 +30,166 @@ std::string zmtp_metadata(string_view key, string_view value) {
bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
const std::string& command, const category& cat, zmq::message_t& msg) {
zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) {
auto command = view(cmd);
std::string reply;
if (peer.auth_level < cat.access.auth) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
": peer auth level ", peer.auth_level, " < ", cat.access.auth);
if (!cat_call.first) {
LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer.pubkey), "]/", peer_address(cmd));
reply = "UNKNOWNCOMMAND";
} else if (peer.auth_level < cat_call.first->access.auth) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
": peer auth level ", peer.auth_level, " < ", cat_call.first->access.auth);
reply = "FORBIDDEN";
}
else if (cat.access.local_sn && !local_service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
} else if (cat_call.first->access.local_sn && !local_service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
": that command is only available when this LokiMQ is running in service node mode");
reply = "NOT_A_SERVICE_NODE";
}
else if (cat.access.remote_sn && !peer.service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
} else if (cat_call.first->access.remote_sn && !peer.service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
": remote is not recognized as a service node");
// Disconnect: we don't think the remote is a SN, but it issued a command only SNs should be
// issuing. Drop the connection; if the remote has something important to relay it will
// reconnect, at which point we will reassess the SN status on the new incoming connection.
if (outgoing) {
proxy_disconnect(peer.service_node ? ConnectionID{peer.pubkey} : conn_index_to_id[conn_index], 1s);
return false;
}
else
reply = "BYE";
reply = "FORBIDDEN_SN";
} else if (cat_call.second->second /*is_request*/ && data.empty()) {
LMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [",
to_hex(peer.pubkey), "]/", peer_address(cmd));
reply = "NO_REPLY_TAG";
} else {
return true;
}
if (reply.empty())
return true;
std::vector<zmq::message_t> msgs;
msgs.reserve(4);
if (!outgoing)
msgs.push_back(create_message(peer.route));
msgs.push_back(create_message(reply));
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
msgs.push_back(create_message("REPLY"_sv));
msgs.push_back(create_message(view(data.front()))); // reply tag
} else {
msgs.push_back(create_message(view(cmd)));
}
try {
if (outgoing)
send_direct_message(connections[conn_index], std::move(reply), command);
else
send_routed_message(connections[conn_index], peer.route, std::move(reply), command);
} catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ }
send_message_parts(connections[conn_index], msgs);
} catch (const zmq::error_t& err) {
/* can't send: possibly already disconnected. Ignore. */
LMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
}
return false;
}
void LokiMQ::set_active_sns(pubkey_set pubkeys) {
if (proxy_thread.joinable()) {
auto data = bt_serialize(detail::serialize_object(std::move(pubkeys)));
detail::send_control(get_control_socket(), "SET_SNS", data);
} else {
proxy_set_active_sns(std::move(pubkeys));
}
}
void LokiMQ::proxy_set_active_sns(string_view data) {
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(bt_deserialize<uintptr_t>(data)));
}
void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
pubkey_set added, removed;
for (auto it = pubkeys.begin(); it != pubkeys.end(); ) {
auto& pk = *it;
if (pk.size() != 32) {
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to set_active_sns");
it = pubkeys.erase(it);
continue;
}
if (!active_service_nodes.count(pk))
added.insert(std::move(pk));
++it;
}
if (added.empty() && active_service_nodes.size() == pubkeys.size()) {
LMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update");
return;
}
for (const auto& pk : active_service_nodes) {
if (!pubkeys.count(pk))
removed.insert(pk);
if (active_service_nodes.size() + added.size() - removed.size() == pubkeys.size())
break;
}
proxy_update_active_sns_clean(std::move(added), std::move(removed));
}
void LokiMQ::update_active_sns(pubkey_set added, pubkey_set removed) {
LMQ_LOG(info, "uh, ", added.size());
if (proxy_thread.joinable()) {
std::array<uintptr_t, 2> data;
data[0] = detail::serialize_object(std::move(added));
data[1] = detail::serialize_object(std::move(removed));
detail::send_control(get_control_socket(), "UPDATE_SNS", bt_serialize(data));
} else {
proxy_update_active_sns(std::move(added), std::move(removed));
}
}
void LokiMQ::proxy_update_active_sns(bt_list_consumer data) {
auto added = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
auto remed = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
proxy_update_active_sns(std::move(added), std::move(remed));
}
void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
// We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting
// values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists
// to the _clean version.
LMQ_LOG(info, "uh, ", added.size(), ", ", removed.size());
for (auto it = removed.begin(); it != removed.end(); ) {
const auto& pk = *it;
if (pk.size() != 32) {
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (removed)");
it = removed.erase(it);
} else if (!active_service_nodes.count(pk) || added.count(pk) /* added wins if in both */) {
it = removed.erase(it);
} else {
++it;
}
}
for (auto it = added.begin(); it != added.end(); ) {
const auto& pk = *it;
if (pk.size() != 32) {
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (added)");
it = added.erase(it);
} else if (active_service_nodes.count(pk)) {
it = added.erase(it);
} else {
++it;
}
}
proxy_update_active_sns_clean(std::move(added), std::move(removed));
}
void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) {
LMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys");
// For anything we remove we want close the connection to the SN (if outgoing), and remove the
// stored peer_info (incoming or outgoing).
for (const auto& pk : removed) {
ConnectionID c{pk};
active_service_nodes.erase(pk);
auto range = peers.equal_range(c);
for (auto it = range.first; it != range.second; ) {
bool outgoing = it->second.outgoing();
size_t conn_index = it->second.conn_index;
it = peers.erase(it);
if (outgoing) {
LMQ_LOG(debug, "Closing outgoing connection to ", c);
proxy_close_connection(conn_index, CLOSE_LINGER);
}
}
}
// For pubkeys we add there's nothing special to be done beyond adding them to the pubkey set
for (auto& pk : added)
active_service_nodes.insert(std::move(pk));
}
void LokiMQ::process_zap_requests() {
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
#ifndef NDEBUG
@ -147,31 +268,32 @@ void LokiMQ::process_zap_requests() {
} else {
auto ip = view(frames[3]);
string_view pubkey;
if (bind[bind_id].second.curve)
bool sn = false;
if (bind[bind_id].second.curve) {
pubkey = view(frames[6]);
auto result = bind[bind_id].second.allow(ip, pubkey);
bool sn = result.remote_sn;
sn = active_service_nodes.count(std::string{pubkey});
}
auto auth = bind[bind_id].second.allow(ip, pubkey, sn);
auto& user_id = response_vals[4];
if (bind[bind_id].second.curve) {
user_id.reserve(64);
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
}
if (result.auth <= AuthLevel::denied || result.auth > AuthLevel::admin) {
if (auth <= AuthLevel::denied || auth > AuthLevel::admin) {
LMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
" with initial auth level ", result.auth);
" with initial auth level ", auth);
status_code = "400";
status_text = "Access denied";
user_id.clear();
} else {
LMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection with authentication level ", result.auth,
" connection with authentication level ", auth,
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
auto& metadata = response_vals[5];
metadata += zmtp_metadata("X-SN", result.remote_sn ? "1" : "0");
metadata += zmtp_metadata("X-AuthLevel", to_string(result.auth));
metadata += zmtp_metadata("X-AuthLevel", to_string(auth));
status_code = "200";
status_text = "";

View file

@ -1,5 +1,8 @@
#pragma once
#include <iostream>
#include <string>
#include <cstring>
#include <unordered_set>
namespace lokimq {
@ -21,13 +24,32 @@ struct Access {
bool remote_sn = false;
/// If true the category requires that the local node is a SN
bool local_sn = false;
/// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an
/// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both
/// remote and local sn set to false).
Access(AuthLevel auth, bool remote_sn = false, bool local_sn = false)
: auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {}
};
/// Return type of the AllowFunc: this determines whether we allow the connection at all, and if so,
/// sets the initial authentication level and tells LokiMQ whether the other end is an active SN.
struct Allow {
AuthLevel auth = AuthLevel::none;
bool remote_sn = false;
/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey).
/// Falls back to std::hash<std::string> if given a string smaller than a size_t.
struct already_hashed {
size_t operator()(const std::string& s) const {
if (s.size() < sizeof(size_t))
return std::hash<std::string>{}(s);
size_t hash;
std::memcpy(&hash, &s[0], sizeof(hash));
return hash;
}
};
/// std::unordered_set specialization for specifying pubkeys (used, in particular, by
/// LokiMQ::set_active_sns and LokiMQ::update_active_sns); this is a std::string unordered_set that
/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the
/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like
/// pubkeys and a terrible hash choice for anything else).
using pubkey_set = std::unordered_set<std::string, already_hashed>;
}

View file

@ -53,6 +53,47 @@ void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pub
// else let ZMQ pick a random one
}
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot call connect_sn() before calling `start()`");
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
return pubkey;
}
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot call connect_remote() before calling `start()`");
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
auto id = next_conn_id++;
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
{"conn_id", id},
{"connect", detail::serialize_object(std::move(on_connect))},
{"failure", detail::serialize_object(std::move(on_failure))},
{"pubkey", pubkey},
{"remote", remote},
{"timeout", timeout.count()},
}));
return id;
}
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
{"conn_id", id.id},
{"linger_ms", linger.count()},
{"pubkey", id.pk},
}));
}
std::pair<zmq::socket_t *, std::string>
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
ConnectionID remote_cid{remote};
@ -166,9 +207,9 @@ void update_connection_indices(Container& c, size_t index, AccessIndex get_index
}
}
/// Closes outgoing connections and removes all references. Note that this will invalidate
/// iterators on the various connection containers - if you don't want that, delete it first so that
/// the container won't contain the element being deleted.
/// Closes outgoing connections and removes all references. Note that this will call `erase()`
/// which can invalidate iterators on the various connection containers - if you don't want that,
/// delete it first so that the container won't contain the element being deleted.
void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) {
connections[index].setsockopt<int>(ZMQ_LINGER, linger > 0ms ? linger.count() : 0);
pollitems_stale = true;
@ -197,6 +238,7 @@ void LokiMQ::proxy_expire_idle_peers() {
continue;
}
LMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle timeout reached");
++it; // The below is going to delete our current element
proxy_close_connection(info.conn_index, CLOSE_LINGER);
} else {
++it;
@ -238,7 +280,7 @@ void LokiMQ::proxy_conn_cleanup() {
auto& callback = it->second;
if (callback.first < now) {
LMQ_LOG(debug, "pending request ", to_hex(it->first), " expired, invoking callback with failure status and removing");
job([callback = std::move(callback.second)] { callback(false, {}); });
job([callback = std::move(callback.second)] { callback(false, {{"TIMEOUT"s}}); });
it = pending_requests.erase(it);
} else {
++it;
@ -262,14 +304,10 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
if (data.skip_until("conn_id"))
conn_id = data.consume_integer<long long>();
if (data.skip_until("connect")) {
auto* ptr = reinterpret_cast<ConnectSuccess*>(data.consume_integer<uintptr_t>());
on_connect = std::move(*ptr);
delete ptr;
on_connect = detail::deserialize_object<ConnectSuccess>(data.consume_integer<uintptr_t>());
}
if (data.skip_until("failure")) {
auto* ptr = reinterpret_cast<ConnectFailure*>(data.consume_integer<uintptr_t>());
on_failure = std::move(*ptr);
delete ptr;
on_failure = detail::deserialize_object<ConnectFailure>(data.consume_integer<uintptr_t>());
}
if (data.skip_until("pubkey")) {
remote_pubkey = data.consume_string();

View file

@ -1,4 +1,5 @@
#pragma once
#include "auth.h"
#include "string_view.h"
namespace lokimq {
@ -75,7 +76,7 @@ private:
namespace std {
template <> struct hash<lokimq::ConnectionID> {
size_t operator()(const lokimq::ConnectionID &c) const {
return c.sn() ? std::hash<std::string>{}(c.pk) :
return c.sn() ? lokimq::already_hashed{}(c.pk) :
std::hash<long long>{}(c.id);
}
};

View file

@ -27,7 +27,7 @@ extern "C" inline void message_buffer_destroy(void*, void* hint) {
}
/// Creates a message without needing to reallocate the provided string data
inline zmq::message_t create_message(std::string &&data) {
inline zmq::message_t create_message(std::string&& data) {
auto *buffer = new std::string(std::move(data));
return zmq::message_t{&(*buffer)[0], buffer->size(), message_buffer_destroy, buffer};
};
@ -38,7 +38,7 @@ inline zmq::message_t create_message(string_view data) {
}
template <typename It>
bool send_message_parts(zmq::socket_t &sock, It begin, It end) {
bool send_message_parts(zmq::socket_t& sock, It begin, It end) {
while (begin != end) {
zmq::message_t &msg = *begin++;
if (!sock.send(msg, begin == end ? zmq::send_flags::dontwait : zmq::send_flags::dontwait | zmq::send_flags::sndmore))
@ -48,7 +48,7 @@ bool send_message_parts(zmq::socket_t &sock, It begin, It end) {
}
template <typename Container>
bool send_message_parts(zmq::socket_t &sock, Container &&c) {
bool send_message_parts(zmq::socket_t& sock, Container&& c) {
return send_message_parts(sock, c.begin(), c.end());
}
@ -56,7 +56,7 @@ bool send_message_parts(zmq::socket_t &sock, Container &&c) {
/// the msg frame will be an empty message; if `data` is empty then the data frame will be omitted.
/// `flags` is passed through to zmq: typically given `zmq::send_flags::dontwait` to throw rather
/// than block if a message can't be queued.
inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}) {
inline bool send_routed_message(zmq::socket_t& socket, std::string route, std::string msg = {}, std::string data = {}) {
assert(!route.empty());
std::array<zmq::message_t, 3> msgs{{create_message(std::move(route))}};
if (!msg.empty())
@ -68,7 +68,7 @@ inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::s
// Sends some stuff to a socket directly. If dontwait is true then we throw instead of blocking if
// the message cannot be accepted by zmq (i.e. because the outgoing buffer is full).
inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}) {
inline bool send_direct_message(zmq::socket_t& socket, std::string msg, std::string data = {}) {
std::array<zmq::message_t, 2> msgs{{create_message(std::move(msg))}};
if (!data.empty())
msgs[1] = create_message(std::move(data));
@ -77,7 +77,7 @@ inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::str
// Receive all the parts of a single message from the given socket. Returns true if a message was
// received, false if called with flags=zmq::recv_flags::dontwait and no message was available.
inline bool recv_message_parts(zmq::socket_t &sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
inline bool recv_message_parts(zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
do {
zmq::message_t msg;
if (!sock.recv(msg, flags))

View file

@ -26,11 +26,6 @@ std::vector<std::string> as_strings(const MessageContainer& msgs) {
return result;
}
void check_started(const std::thread& proxy_thread, const std::string &verb) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot " + verb + " before calling `start()`");
}
void check_not_started(const std::thread& proxy_thread, const std::string &verb) {
if (proxy_thread.joinable())
throw std::logic_error("Cannot " + verb + " after calling `start()`");
@ -54,29 +49,20 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
}
}
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
/// socket.
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg) {
auto result = std::make_tuple(""s, false, AuthLevel::none);
/// Extracts a pubkey and and auth level from a zmq message received on a *listening* socket.
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
auto result = std::make_pair(""s, AuthLevel::none);
try {
string_view pubkey_hex{msg.gets("User-Id")};
if (pubkey_hex.size() != 64)
throw std::logic_error("bad user-id");
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
auto& pubkey = std::get<std::string>(result);
pubkey.resize(32, 0);
from_hex(pubkey_hex.begin(), pubkey_hex.end(), pubkey.begin());
result.first.resize(32, 0);
from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin());
} catch (...) {}
try {
string_view is_sn{msg.gets("X-SN")};
if (is_sn.size() == 1 && is_sn[0] == '1')
std::get<bool>(result) = true;
} catch (...) {}
try {
string_view auth_level{msg.gets("X-AuthLevel")};
std::get<AuthLevel>(result) = auth_from_string(auth_level);
result.second = auth_from_string(msg.gets("X-AuthLevel"));
} catch (...) {}
return result;
@ -385,46 +371,6 @@ LokiMQ::~LokiMQ() {
LMQ_LOG(info, "LokiMQ proxy thread has stopped");
}
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
check_started(proxy_thread, "connect");
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
return pubkey;
}
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
if (!proxy_thread.joinable())
LMQ_LOG(warn, "connect_remote() called before start(); this won't take effect until start() is called");
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
auto id = next_conn_id++;
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
{"conn_id", id},
{"connect", reinterpret_cast<uintptr_t>(new ConnectSuccess{std::move(on_connect)})},
{"failure", reinterpret_cast<uintptr_t>(new ConnectFailure{std::move(on_failure)})},
{"pubkey", pubkey},
{"remote", remote},
{"timeout", timeout.count()},
}));
return id;
}
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
{"conn_id", id.id},
{"linger_ms", linger.count()},
{"pubkey", id.pk},
}));
}
std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
os << (lvl == LogLevel::trace ? "trace" :
lvl == LogLevel::debug ? "debug" :

View file

@ -134,16 +134,18 @@ private:
public:
/// Callback type invoked to determine whether the given new incoming connection is allowed to
/// connect to us and to set its initial authentication level.
/// connect to us and to set its authentication level.
///
/// @param ip - the ip address of the incoming connection
/// @param pubkey - the x25519 pubkey of the connecting client (32 byte string). Note that this
/// will only be non-empty for incoming connections on `listen_curve` sockets; `listen_plain`
/// sockets do not have a pubkey.
/// @param service_node - will be true if the `pubkey` is in the set of known active service
/// nodes.
///
/// @returns an `AuthLevel` enum value indicating the default auth level for the incoming
/// connection, or AuthLevel::denied if the connection should be refused.
using AllowFunc = std::function<Allow(string_view ip, string_view pubkey)>;
using AllowFunc = std::function<AuthLevel(string_view ip, string_view pubkey, bool service_node)>;
/// Callback that is invoked when we need to send a "strong" message to a SN that we aren't
/// already connected to and need to establish a connection. This callback returns the ZMQ
@ -244,7 +246,9 @@ private:
std::vector<std::pair<std::string, bind_data>> bind;
/// Info about a peer's established connection with us. Note that "established" means both
/// connected and authenticated.
/// connected and authenticated. Note that we only store peer info data for SN connections (in
/// or out), and outgoing non-SN connections. Incoming non-SN connections are handled on the
/// fly.
struct peer_info {
/// Pubkey of the remote, if this connection is a curve25519 connection; empty otherwise.
std::string pubkey;
@ -521,15 +525,26 @@ private:
/// done).
std::unordered_map<std::string, std::string> command_aliases;
using cat_call_t = std::pair<category*, const std::pair<CommandCallback, bool>*>;
/// Retrieve category and callback from a command name, including alias mapping. Warns on
/// invalid commands and returns nullptrs. The command name will be updated in place if it is
/// aliased to another command.
std::pair<category*, const std::pair<CommandCallback, bool>*> get_command(std::string& command);
cat_call_t get_command(std::string& command);
/// Checks a peer's authentication level. Returns true if allowed, warns and returns false if
/// not.
bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
const std::string& command, const category& cat, zmq::message_t& msg);
zmq::message_t& command, const cat_call_t& cat_call, std::vector<zmq::message_t>& data);
/// Set of active service nodes.
pubkey_set active_service_nodes;
/// Resets or updates the stored set of active SN pubkeys
void proxy_set_active_sns(string_view data);
void proxy_set_active_sns(pubkey_set pubkeys);
void proxy_update_active_sns(bt_list_consumer data);
void proxy_update_active_sns(pubkey_set added, pubkey_set removed);
void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed);
/// Details for a pending command; such a command already has authenticated access and is just
/// waiting for a thread to become available to handle it.
@ -775,8 +790,15 @@ public:
* Finish starting up: binds to the bind locations given in the constructor and launches the
* proxy thread to handle message dispatching between remote nodes and worker threads.
*
* You will need to call `add_category` and `add_command` to register commands before calling
* `start()`; once start() is called commands cannot be changed.
* Things you want to do before calling this:
* - Use `add_category`/`add_command` to set up any commands remote connections can invoke.
* - If any commands require SN authentication, specify a list of currently active service node
* pubkeys via `set_active_sns()` (and make sure this gets updated when things change by
* another `set_active_sns()` or a `update_active_sns()` call). It *is* possible to make the
* initial call after calling `start()`, but that creates a window during which incoming
* remote SN connections will be erroneously treated as non-SN connections.
* - If this LMQ instance should accept incoming connections, set up any listening ports via
* `listen_curve()` and/or `listen_plain()`.
*/
void start();
@ -789,10 +811,10 @@ public:
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns non-service node,
* AuthLevel::none access.
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
*/
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; });
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
/** Start listening on the given bind address in unauthenticated plain text mode. Incoming
* connections can come from anywhere. `allow_connection` is invoked for any incoming
@ -803,10 +825,10 @@ public:
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns non-service node,
* AuthLevel::none access.
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
*/
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; });
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
/**
* Try to initiate a connection to the given SN in anticipation of needing a connection in the
@ -939,9 +961,23 @@ public:
* @param to - the pubkey string or ConnectionID to send this request to
* @param cmd - the command name
* @param callback - the callback to invoke when we get a reply. Called with a true value and
* the data strings when a reply is received, or false and an empty vector of data parts if we
* get no reply in the timeout interval.
* the data strings when a reply is received, or false with error string(s) indicating the
* failure reason upon failure or timeout.
* @param opts - anything else (i.e. strings, send_options) is forwarded to send().
*
* Possible error data values:
* - ["TIMEOUT"] - we got no reply within the timeout window
* - ["UNKNOWNCOMMAND"] - the remote did not recognize the given request command
* - ["NO_REPLY_TAG"] - the invoked command is a request command but no reply tag was included
* - ["FORBIDDEN"] - the command requires an authorization level (e.g. Basic or Admin) that we
* do not have.
* - ["FORBIDDEN_SN"] - the command requires service node authentication, but the remote did not
* recognize us as a service node. You *may* want to retry the request a limited number of
* times (but do not retry indefinitely as that can be an infinite loop!) because this is
* typically also followed by a disconnection; a retried message would reconnect and
* reauthenticate which *may* result in picking up the SN authentication.
* - ["NOT_A_SERVICE_NODE"] - this command is only invokable on service nodes, and the remote is
* not running as a service node.
*/
template <typename... T>
void request(ConnectionID to, string_view cmd, ReplyCallback callback, const T&... opts);
@ -951,6 +987,41 @@ public:
const std::string& get_pubkey() const { return pubkey; }
const std::string& get_privkey() const { return privkey; }
/** Updates (or initially sets) LokiMQ's list of service node pubkeys with the given list.
*
* This has two main effects:
*
* - All commands processed after the update will have SN status determined by the new list.
* - All outgoing connections to service nodes that are no longer on the list will be closed.
* This includes both explicit connections (established by `connect_sn()`) and implicit ones
* (established by sending to a SN that wasn't connected).
*
* As this update is potentially quite heavy it is recommended that this be called only when
* necessary--i.e. when the list has changed (or potentially changed), but *not* on a short
* periodic timer.
*
* This method may (and should!) be called before start() to load an initial set of SNs.
*
* Once a full list has been set, updates on changes can either call this again with the new
* list, or use the more efficient update_active_sns() call if incremental results are
* available.
*/
void set_active_sns(pubkey_set pubkeys);
/** Updates the list of active pubkeys by adding or removing the given pubkeys from the existing
* list. This is more efficient when the incremental information is already available; if it
* isn't, simply call set_active_sns with a new list to have LokiMQ figure out what was added or
* removed.
*
* \param added new pubkeys that were added since the last set_active_sns or update_active_sns
* call.
*
* \param removed pubkeys that were removed from active SN status since the last call. If a
* pubkey is in both `added` and `removed` for some reason then its presence in `removed` will
* be ignored.
*/
void update_active_sns(pubkey_set added, pubkey_set removed);
/**
* Batches a set of jobs to be executed by workers, optionally followed by a completion function.
*
@ -1121,7 +1192,9 @@ namespace detail {
/// containing the pointer to be serialized to pass (via lokimq queues) from one thread to another.
/// Must be matched with a deserializer_pointer on the other side to reconstitute the object and
/// destroy the intermediate pointer.
template <typename T> uintptr_t serialize_object(T&& obj) {
template <typename T>
uintptr_t serialize_object(T&& obj) {
static_assert(std::is_rvalue_reference<decltype(obj)>::value, "serialize_object must be given an rvalue reference");
auto* ptr = new T{std::forward<T>(obj)};
return reinterpret_cast<uintptr_t>(ptr);
}
@ -1192,9 +1265,8 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queu
control_data["send_full_q"] = serialize_object(std::move(f.callback));
}
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
/// socket.
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg);
/// Extracts a pubkey and auth level from a zmq message received on a *listening* socket.
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg);
template <typename... T>
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) {

View file

@ -178,7 +178,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) {
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
} else {
LMQ_LOG(debug, "Could not send request, scheduling request callback failure");
job([callback = std::move(request_callback)] { callback(false, {}); });
job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); });
}
}
if (!sent) {
@ -237,31 +237,38 @@ void LokiMQ::proxy_reply(bt_dict_consumer data) {
}
void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
// We throw an uncaught exception here because we only generate control messages internally in
// lokimq code: if one of these condition fail it's a lokimq bug.
if (parts.size() < 2)
throw std::logic_error("Expected 2-3 message parts for a proxy control message");
throw std::logic_error("LokiMQ bug: Expected 2-3 message parts for a proxy control message");
auto route = view(parts[0]), cmd = view(parts[1]);
LMQ_TRACE("control message: ", cmd);
if (parts.size() == 3) {
LMQ_TRACE("...: ", parts[2]);
auto data = view(parts[2]);
if (cmd == "SEND") {
LMQ_TRACE("proxying message");
return proxy_send(view(parts[2]));
return proxy_send(data);
} else if (cmd == "REPLY") {
LMQ_TRACE("proxying reply to non-SN incoming message");
return proxy_reply(view(parts[2]));
return proxy_reply(data);
} else if (cmd == "BATCH") {
LMQ_TRACE("proxy batch jobs");
auto ptrval = bt_deserialize<uintptr_t>(view(parts[2]));
auto ptrval = bt_deserialize<uintptr_t>(data);
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
} else if (cmd == "SET_SNS") {
return proxy_set_active_sns(data);
} else if (cmd == "UPDATE_SNS") {
return proxy_update_active_sns(data);
} else if (cmd == "CONNECT_SN") {
proxy_connect_sn(view(parts[2]));
proxy_connect_sn(data);
return;
} else if (cmd == "CONNECT_REMOTE") {
return proxy_connect_remote(view(parts[2]));
return proxy_connect_remote(data);
} else if (cmd == "DISCONNECT") {
return proxy_disconnect(view(parts[2]));
return proxy_disconnect(data);
} else if (cmd == "TIMER") {
return proxy_timer(view(parts[2]));
return proxy_timer(data);
}
} else if (parts.size() == 2) {
if (cmd == "START") {
@ -279,8 +286,8 @@ void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
return;
}
}
throw std::runtime_error("Proxy received invalid control command: " + std::string{cmd} +
" (" + std::to_string(parts.size()) + ")");
throw std::runtime_error("LokiMQ bug: Proxy received invalid control command: " +
std::string{cmd} + " (" + std::to_string(parts.size()) + ")");
}
void LokiMQ::proxy_loop() {
@ -455,26 +462,31 @@ void LokiMQ::proxy_loop() {
}
}
static bool is_error_response(string_view cmd) {
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
}
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
// reason)
bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts) {
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
// Doubling as a bool and an offset:
size_t incoming = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_ROUTER;
string_view route, cmd;
if (parts.size() < (outgoing ? 1 : 2)) {
if (parts.size() < 1 + incoming) {
LMQ_LOG(warn, "Received empty message; ignoring");
return true;
}
if (outgoing) {
cmd = view(parts[0]);
} else {
if (incoming) {
route = view(parts[0]);
cmd = view(parts[1]);
} else {
cmd = view(parts[0]);
}
LMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back()));
if (cmd == "REPLY") {
size_t tag_pos = (outgoing ? 1 : 2);
size_t tag_pos = 1 + incoming;
if (parts.size() <= tag_pos) {
LMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
return true;
@ -482,7 +494,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
std::string reply_tag{view(parts[tag_pos])};
auto it = pending_requests.find(reply_tag);
if (it != pending_requests.end()) {
LMQ_LOG(debug, "Received REPLY for pending command", to_hex(reply_tag), "; scheduling callback");
LMQ_LOG(debug, "Received REPLY for pending command ", to_hex(reply_tag), "; scheduling callback");
std::vector<std::string> data;
data.reserve(parts.size() - (tag_pos + 1));
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
@ -496,7 +508,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
}
return true;
} else if (cmd == "HI") {
if (outgoing) {
if (!incoming) {
LMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
return true;
}
@ -506,7 +518,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
} catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
return true;
} else if (cmd == "HELLO") {
if (!outgoing) {
if (incoming) {
LMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
return true;
}
@ -531,22 +543,50 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
pending_connects.erase(it);
return true;
} else if (cmd == "BYE") {
if (outgoing) {
std::string pk;
bool sn;
AuthLevel a;
std::tie(pk, sn, a) = detail::extract_metadata(parts.back());
ConnectionID conn = sn ? ConnectionID{std::move(pk)} : conn_index_to_id[conn_index];
LMQ_LOG(debug, "BYE command received; disconnecting from ", conn);
proxy_disconnect(conn, 1s);
if (!incoming) {
LMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
proxy_close_connection(conn_index, 0s);
} else {
LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
}
return true;
}
else if (cmd == "FORBIDDEN" || cmd == "NOT_A_SERVICE_NODE") {
return true; // FIXME - ignore these? Log?
else if (is_error_response(cmd)) {
// These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to
// invoke something that doesn't exist or we don't have permission to access. These have
// two forms (the latter is only sent by remotes running 1.0.6+).
// - ["XXX", "whatever.command"]
// - ["XXX", "REPLY", replytag]
// (ignoring the routing prefix on incoming commands).
// For the former, we log; for the latter we trigger the reply callback with a failure
if (parts.size() == (2 + incoming) && is_error_response(view(parts[1 + incoming]))) {
// Something like ["UNKNOWNCOMMAND", "FORBIDDEN_SN"] which can happen because the remote
// is running an older version that didn't understand the FORBIDDEN_SN (or whatever)
// error reply that we sent them. We just ignore it because anything else could trigger
// an infinite cycle.
LMQ_LOG(debug, "Received [", cmd, ",", view(parts[1 + incoming]), "]; remote is probably an older lokimq. Ignoring.");
return true;
}
if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") {
std::string reply_tag{view(parts[2 + incoming])};
auto it = pending_requests.find(reply_tag);
if (it != pending_requests.end()) {
LMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", to_hex(reply_tag), "; scheduling failure callback");
proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] {
callback(false, {{std::move(cmd)}});
});
pending_requests.erase(it);
} else {
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
}
} else {
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"_sv),
" from ", peer_address(parts.back()));
}
return true;
}
return false;
}

View file

@ -201,22 +201,24 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
}
peer = &it->second;
} else {
std::tie(tmp_peer.pubkey, tmp_peer.service_node, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey);
if (tmp_peer.service_node) {
// It's a service node so we should have a peer_info entry; see if we can find one with
// the same route, and if not, add one.
auto pr = peers.equal_range(tmp_peer.pubkey);
for (auto it = pr.first; it != pr.second; ++it) {
if (it->second.route == tmp_peer.route) {
if (it->second.conn_index == tmp_peer.conn_index && it->second.route == tmp_peer.route) {
peer = &it->second;
// Upgrade permissions in case we have something higher on the socket
peer->service_node |= tmp_peer.service_node;
if (tmp_peer.auth_level > peer->auth_level)
peer->auth_level = tmp_peer.auth_level;
// Update the stored auth level just in case the peer reconnected
peer->auth_level = tmp_peer.auth_level;
break;
}
}
if (!peer) {
// We don't have a record: this is either a new SN connection or a new message on a
// connection that recently gained SN status.
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
}
} else {
@ -228,23 +230,6 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
size_t command_part_index = outgoing ? 0 : 1;
std::string command = parts[command_part_index].to_string();
auto cat_call = get_command(command);
if (!cat_call.first) {
LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer->pubkey), "]/", peer_address(parts.back()));
try {
if (outgoing)
send_direct_message(connections[conn_index], "UNKNOWNCOMMAND");
else
send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND");
} catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ }
return;
}
auto& category = *cat_call.first;
if (!proxy_check_auth(conn_index, outgoing, *peer, command, category, parts.back()))
return;
// Steal any data message parts
size_t data_part_index = command_part_index + 1;
@ -253,6 +238,14 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
data_parts.push_back(std::move(*it));
auto cat_call = get_command(command);
// Check that command is valid, that we have permission, etc.
if (!proxy_check_auth(conn_index, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
return;
auto& category = *cat_call.first;
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
// No free reserved or general spots, try to queue it for later
if (category.max_queue >= 0 && category.queued >= category.max_queue) {

View file

@ -6,6 +6,7 @@ set(LMQ_TEST_SRC
test_batch.cpp
test_connect.cpp
test_commands.cpp
test_failures.cpp
test_requests.cpp
test_string_view.cpp
)

View file

@ -6,6 +6,25 @@ using namespace lokimq;
static auto startup = std::chrono::steady_clock::now();
/// Waits up to 100ms for something to happen.
template <typename Func>
inline void wait_for(Func f) {
for (int i = 0; i < 10; i++) {
if (f())
break;
std::this_thread::sleep_for(10ms);
}
}
/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough
/// time for an initial connection + request.
inline void wait_for_conn(std::atomic<bool> &c) {
wait_for([&c] { return c.load(); });
}
/// Waits enough time for us to receive a reply from a localhost remote.
inline void reply_sleep() { std::this_thread::sleep_for(10ms); }
// Catch2 macros aren't thread safe, so guard with a mutex
inline std::unique_lock<std::mutex> catch_lock() {
static std::mutex mutex;

View file

@ -1,5 +1,4 @@
#include "common.h"
#include <future>
#include <lokimq/hex.h>
#include <map>
#include <set>
@ -12,10 +11,10 @@ TEST_CASE("basic commands", "[commands]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
std::atomic<int> hellos{0}, his{0};
@ -32,41 +31,33 @@ TEST_CASE("basic commands", "[commands]") {
server.start();
LokiMQ client{
get_logger("")
};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
client.add_category("public", Access{AuthLevel::none});
client.add_command("public", "hi", [&](auto&) { his++; });
client.start();
std::atomic<bool> connected{false}, failed{false};
std::atomic<bool> got{false};
bool success = false, failed = false;
std::string pubkey;
auto c = client.connect_remote(listen,
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto conn, string_view) { failed = true; },
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
[&](auto conn, string_view) { failed = true; got = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( connected.load() );
REQUIRE( i <= 1 ); // should be fast
REQUIRE( !failed.load() );
REQUIRE( got );
REQUIRE( success );
REQUIRE_FALSE( failed );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
client.send(c, "public.hello");
client.send(c, "public.client.pubkey");
std::this_thread::sleep_for(50ms);
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( hellos == 1 );
@ -77,7 +68,7 @@ TEST_CASE("basic commands", "[commands]") {
for (int i = 0; i < 50; i++)
client.send(c, "public.hello");
std::this_thread::sleep_for(100ms);
wait_for([&] { return his == 26; });
{
auto lock = catch_lock();
REQUIRE( hellos == 51 );
@ -91,10 +82,10 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
std::atomic<int> hellos{0};
@ -103,10 +94,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
server.start();
LokiMQ client{
get_logger("")
};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
std::atomic<int> public_hi{0}, basic_hi{0}, admin_hi{0};
client.add_category("public", Access{AuthLevel::none});
@ -124,8 +112,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
client.send(public_c, "public.reflect", "public.hi");
std::this_thread::sleep_for(50ms);
wait_for([&] { return public_hi == 1; });
{
auto lock = catch_lock();
REQUIRE( public_hi == 1 );
@ -138,7 +125,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.send(admin_c, "public.reflect", "admin.hi");
client.send(basic_c, "public.reflect", "basic.hi");
std::this_thread::sleep_for(50ms);
wait_for([&] { return public_hi == 2; });
{
auto lock = catch_lock();
REQUIRE( admin_hi == 3 );
@ -160,7 +147,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.send(admin_c, "public.reflect", "basic.hi");
client.send(admin_c, "public.reflect", "public.hi");
std::this_thread::sleep_for(50ms);
wait_for([&] { return public_hi == 3; });
auto lock = catch_lock();
REQUIRE( admin_hi == 1 );
REQUIRE( basic_hi == 2 );
@ -176,10 +163,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
std::vector<std::pair<ConnectionID, std::string>> subscribers;
ConnectionID backdoor;
@ -221,8 +208,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
nsa.send(nsa_c, "hey google.install backdoor");
std::this_thread::sleep_for(50ms);
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
{
auto l = catch_lock();
REQUIRE( backdoor );
@ -243,9 +229,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
std::map<int, std::set<std::string>> google_knows;
int things_remembered{0};
for (int i = 0; i < 5; i++) {
clients.push_back(std::make_unique<LokiMQ>(get_logger("C" + std::to_string(i) + "» ")));
clients.push_back(std::make_unique<LokiMQ>(
get_logger("C" + std::to_string(i) + "» "), LogLevel::trace
));
auto& c = clients.back();
c->log_level(LogLevel::trace);
c->add_category("personal", Access{AuthLevel::basic});
c->add_command("personal", "detail", [&,i](Message& m) {
auto l = catch_lock();
@ -265,7 +252,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
},
personal_detail);
}
std::this_thread::sleep_for(50ms);
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size(); });
{
auto l = catch_lock();
REQUIRE( things_remembered == all_the_things.size() );
@ -273,7 +260,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
}
clients[0]->send(conns[0], "hey google.recall");
std::this_thread::sleep_for(50ms);
reply_sleep();
{
auto l = catch_lock();
REQUIRE( google_knows == personal_details );
@ -286,10 +273,10 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::debug // This test traces so much that it takes 2.5-3s of CPU time at trace level, so don't do that.
};
server.log_level(LogLevel::debug);
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_plain(listen);
std::atomic<int> send_attempts{0};
std::atomic<int> send_failures{0};
@ -331,7 +318,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
for (i = 0; i < 20; i++) {
if (send_attempts.load() >= 500)
break;
std::this_thread::sleep_for(25ms);
std::this_thread::sleep_for(10ms);
}
{
auto lock = catch_lock();
@ -354,20 +341,20 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
expected_attempts += 500;
if (i >= 4) {
std::this_thread::sleep_for(25ms);
if (send_failures.load() > 0)
break;
std::this_thread::sleep_for(25ms);
}
}
for (i = 0; i < 10; i++) {
for (i = 0; i < 20; i++) {
if (send_attempts.load() >= expected_attempts)
break;
std::this_thread::sleep_for(25ms);
std::this_thread::sleep_for(10ms);
}
{
auto lock = catch_lock();
REQUIRE( i <= 8 );
REQUIRE( i < 20 );
REQUIRE( send_attempts.load() == expected_attempts );
REQUIRE( send_failures.load() > 0 );
}

View file

@ -1,5 +1,4 @@
#include "common.h"
#include <future>
#include <lokimq/hex.h>
extern "C" {
#include <sodium.h>
@ -12,46 +11,42 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
server.start();
LokiMQ client{get_logger("")};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
client.start();
auto pubkey = server.get_pubkey();
std::atomic<int> connected{0};
std::atomic<bool> got{false};
bool success = false;
auto server_conn = client.connect_remote(listen,
[&](auto conn) { connected = 1; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); },
[&](auto conn) { success = true; got = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; },
pubkey);
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( i <= 1 );
REQUIRE( connected.load() );
REQUIRE( got );
REQUIRE( success );
}
bool success = false;
success = false;
std::vector<std::string> parts;
client.request(server_conn, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( success );
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( success );
}
}
TEST_CASE("self-connection SN optimization", "[connect][self]") {
@ -63,10 +58,16 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
pubkey, privkey,
true,
[&](auto pk) { if (pk == pubkey) return "tcp://127.0.0.1:5544"; else return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk) { REQUIRE(ip == "127.0.0.1"); return Allow{AuthLevel::none, pk == pubkey}; });
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk, auto sn) {
auto lock = catch_lock();
REQUIRE(ip == "127.0.0.1");
REQUIRE(sn == (pk == pubkey));
return AuthLevel::none;
});
sn.add_category("a", Access{AuthLevel::none});
bool invoked = false;
sn.add_command("a", "b", [&](const Message& m) {
@ -77,57 +78,54 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
REQUIRE(!m.data.empty());
REQUIRE(m.data[0] == "my data");
});
sn.log_level(LogLevel::trace);
sn.set_active_sns({{pubkey}});
sn.start();
std::this_thread::sleep_for(50ms);
sn.send(pubkey, "a.b", "my data");
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE(invoked);
reply_sleep();
{
auto lock = catch_lock();
REQUIRE(invoked);
}
}
TEST_CASE("plain-text connections", "[plaintext][connect]") {
std::string listen = "tcp://127.0.0.1:4455";
LokiMQ server{get_logger("")};
server.log_level(LogLevel::trace);
LokiMQ server{get_logger(""), LogLevel::trace};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_plain(listen);
server.start();
LokiMQ client{get_logger("")};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
client.start();
std::atomic<int> connected{0};
std::atomic<bool> got{false};
bool success = false;
auto c = client.connect_remote(listen,
[&](auto conn) { connected = 1; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
[&](auto conn) { success = true; got = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
);
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( i <= 1 );
REQUIRE( connected.load() );
REQUIRE( got );
REQUIRE( success );
}
bool success = false;
success = false;
std::vector<std::string> parts;
client.request(c, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( success );
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( success );
}
}
TEST_CASE("SN disconnections", "[connect][disconnect]") {
@ -147,28 +145,146 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") {
lmq.push_back(std::make_unique<LokiMQ>(
pubkey[i], privkey[i], true,
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
get_logger("S" + std::to_string(i) + "» ")
get_logger("S" + std::to_string(i) + "» "),
LogLevel::trace
));
auto& server = *lmq.back();
server.log_level(LogLevel::debug);
server.listen_curve(conn[pubkey[i]], [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, true}; });
server.listen_curve(conn[pubkey[i]]);
server.add_category("sn", Access{AuthLevel::none, true})
.add_command("hi", [&](Message& m) { his++; });
server.set_active_sns({pubkey.begin(), pubkey.end()});
server.start();
}
std::this_thread::sleep_for(50ms);
lmq[0]->send(pubkey[1], "sn.hi");
lmq[0]->send(pubkey[2], "sn.hi");
std::this_thread::sleep_for(50ms);
lmq[2]->send(pubkey[0], "sn.hi");
lmq[2]->send(pubkey[1], "sn.hi");
lmq[1]->send(pubkey[0], "BYE");
std::this_thread::sleep_for(50ms);
lmq[0]->send(pubkey[2], "sn.hi");
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE(his == 5);
}
TEST_CASE("SN auth checks", "[sandwich][auth]") {
// When a remote connects, we check its authentication level; if at the time of connection it
// isn't recognized as a SN but tries to invoke a SN command it'll be told to disconnect; if it
// tries to send again it should reconnect and reauthenticate. This test is meant to test this
// pattern where the reconnection/reauthentication now authenticates it as a SN.
std::string listen = "tcp://127.0.0.1:4455";
std::string pubkey, privkey;
pubkey.resize(crypto_box_PUBLICKEYBYTES);
privkey.resize(crypto_box_SECRETKEYBYTES);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
LokiMQ server{
pubkey, privkey,
true, // service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
std::atomic<bool> incoming_is_sn{false};
server.listen_curve(listen);
server.add_category("public", Access{AuthLevel::none})
.add_request_command("hello", [&](Message& m) { m.send_reply("hi"); })
.add_request_command("sudo", [&](Message& m) {
server.update_active_sns({{m.conn.pubkey()}}, {});
m.send_reply("making sandwiches");
})
.add_request_command("nosudo", [&](Message& m) {
// Send the reply *first* because if we do it the other way we'll have just removed
// ourselves from the list of SNs and thus would try to open an outbound connection
// to deliver it since it's still queued as a message to a SN.
m.send_reply("make them yourself");
server.update_active_sns({}, {{m.conn.pubkey()}});
});
server.add_category("sandwich", Access{AuthLevel::none, true})
.add_request_command("make", [&](Message& m) { m.send_reply("okay"); });
server.start();
LokiMQ client{
"", "", false,
[&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; },
get_logger(""), LogLevel::trace};
client.start();
std::atomic<bool> got{false};
bool success;
client.request(pubkey, "public.hello", [&](auto success_, auto) { success = success_; got = true; });
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
}
got = false;
using dvec = std::vector<std::string>;
dvec data;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE_FALSE( success );
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
}
// Somebody set up us the bomb. Main sudo turn on.
got = false;
client.request(pubkey, "public.sudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"making sandwiches"}} );
}
got = false;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"okay"}} );
}
// Take off every 'SUDO', You [not] know what you doing
got = false;
client.request(pubkey, "public.nosudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"make them yourself"}} );
}
got = false;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE_FALSE( success );
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
}
}

309
tests/test_failures.cpp Normal file
View file

@ -0,0 +1,309 @@
#include "common.h"
#include <lokimq/hex.h>
#include <map>
#include <set>
using namespace lokimq;
TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen);
server.start();
// Use a raw socket here because I want to see the raw commands coming on the wire
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
auto lock = catch_lock();
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "a.a" );
REQUIRE_FALSE( resp.more() );
}
TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen);
server.add_category("x", AuthLevel::none)
.add_request_command("r", [] (auto& m) { m.send_reply("a"); });
server.start();
// Use a raw socket here because I want to see the raw commands coming on the wire
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "x.r" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "foo" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", AuthLevel::basic)
.add_command("x", [] (auto& m) { m.send_back("a"); });
server.add_category("y", AuthLevel::admin)
.add_command("x", [] (auto& m) { m.send_back("b"); });
server.start();
zmq::context_t client_ctx;
std::array<zmq::socket_t, 3> clients;
// Client 0 should get none auth level, client 1 should get basic, client 2 should get admin
for (auto& client : clients) {
client = {client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
}
for (auto& c : clients)
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
clients[0].recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
clients[0].recv(resp);
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
for (int i : {1, 2}) {
clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
for (auto& c : clients)
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
for (int i : {0, 1}) {
clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
clients[i].recv(resp);
REQUIRE( resp.to_string() == "y.x" );
REQUIRE_FALSE( resp.more() );
}
clients[2].recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "b" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", Access{AuthLevel::none, false, true})
.add_command("x", [] (auto&) {})
.add_request_command("r", [] (auto& m) { m.send_reply(); })
;
server.start();
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", Access{AuthLevel::none, true, false})
.add_command("x", [] (auto&) {})
.add_request_command("r", [] (auto& m) { m.send_reply(); })
;
server.start();
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}
}

View file

@ -1,5 +1,4 @@
#include "common.h"
#include <future>
#include <lokimq/hex.h>
using namespace lokimq;
@ -36,17 +35,11 @@ TEST_CASE("basic requests", "[requests]") {
[&](auto, auto) { failed = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected.load() );
REQUIRE( !failed.load() );
REQUIRE( i <= 1 );
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
@ -59,11 +52,13 @@ TEST_CASE("basic requests", "[requests]") {
data = std::move(data_);
});
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"123"}} );
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"123"}} );
}
}
TEST_CASE("request from server to client", "[requests]") {
@ -161,15 +156,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
[&](auto, auto) { failed = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
REQUIRE( connected.load() );
REQUIRE( !failed.load() );
REQUIRE( i <= 1 );
wait_for([&] { return connected || failed; });
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
std::atomic<bool> got_triggered{false};
@ -180,7 +170,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
success = ok;
data = std::move(data_);
},
lokimq::send_option::request_timeout{30ms}
lokimq::send_option::request_timeout{20ms}
);
std::atomic<bool> got_triggered2{false};
@ -192,10 +182,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
lokimq::send_option::request_timeout{100ms}
);
std::this_thread::sleep_for(50ms);
REQUIRE( got_triggered.load() );
std::this_thread::sleep_for(30ms);
REQUIRE( got_triggered );
REQUIRE_FALSE( got_triggered2 );
REQUIRE_FALSE( success );
REQUIRE( data.size() == 0 );
REQUIRE( data == std::vector<std::string>{{"TIMEOUT"}} );
REQUIRE_FALSE( got_triggered2.load() );
}