1.1.0: invocation-time SN auth; failure responses

This replaces the recognition of SN status to be checked per-command
invocation rather than on connection.  As this breaks the API quite
substantially, though doesn't really affect the functionality, it seems
suitable to bump the minor version.

This requires a fundamental shift in how the calling application tells
LokiMQ about service nodes: rather than using a callback invoked on
connection, the application now has to call set_active_sns() (or the
more efficient update_active_sns(), if changes are readily available) to
update the list whenever it changes.  LokiMQ then keeps this list
internally and uses it when determining whether to invoke.

This release also brings better request responses on errors: when a
request fails, the data argument will now be set to the failure reason,
one of:

- TIMEOUT
- UNKNOWNCOMMAND
- NOT_A_SERVICE_NODE (the remote isn't running in SN mode)
- FORBIDDEN (auth level denies the request)
- FORBIDDEN_SN (SN required and the remote doesn't see us as a SN)

Some of these (UNKNOWNCOMMAND, NOT_A_SERVICE_NODE, FORBIDDEN) were
already sent by remotes, but there was no connection to a request and so
they would log a warning, but the request would have to time out.

These errors (minus TIMEOUT, plus NO_REPLY_TAG signalling that a command
is a request but didn't include a reply tag) are also sent in response
to regular commands, but they simply result in a log warning showing the
error type and the command that caused the failure when received.
This commit is contained in:
Jason Rhinelander 2020-04-12 19:57:19 -03:00
parent fb3bf9bd1f
commit 3b86eb1341
16 changed files with 978 additions and 326 deletions

View File

@ -5,8 +5,8 @@ project(liblokimq CXX)
include(GNUInstallDirs)
set(LOKIMQ_VERSION_MAJOR 1)
set(LOKIMQ_VERSION_MINOR 0)
set(LOKIMQ_VERSION_PATCH 5)
set(LOKIMQ_VERSION_MINOR 1)
set(LOKIMQ_VERSION_PATCH 0)
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
message(STATUS "lokimq v${LOKIMQ_VERSION}")

View File

@ -30,45 +30,166 @@ std::string zmtp_metadata(string_view key, string_view value) {
bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
const std::string& command, const category& cat, zmq::message_t& msg) {
zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) {
auto command = view(cmd);
std::string reply;
if (peer.auth_level < cat.access.auth) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
": peer auth level ", peer.auth_level, " < ", cat.access.auth);
if (!cat_call.first) {
LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer.pubkey), "]/", peer_address(cmd));
reply = "UNKNOWNCOMMAND";
} else if (peer.auth_level < cat_call.first->access.auth) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
": peer auth level ", peer.auth_level, " < ", cat_call.first->access.auth);
reply = "FORBIDDEN";
}
else if (cat.access.local_sn && !local_service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
} else if (cat_call.first->access.local_sn && !local_service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
": that command is only available when this LokiMQ is running in service node mode");
reply = "NOT_A_SERVICE_NODE";
}
else if (cat.access.remote_sn && !peer.service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
} else if (cat_call.first->access.remote_sn && !peer.service_node) {
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
": remote is not recognized as a service node");
// Disconnect: we don't think the remote is a SN, but it issued a command only SNs should be
// issuing. Drop the connection; if the remote has something important to relay it will
// reconnect, at which point we will reassess the SN status on the new incoming connection.
if (outgoing) {
proxy_disconnect(peer.service_node ? ConnectionID{peer.pubkey} : conn_index_to_id[conn_index], 1s);
return false;
}
else
reply = "BYE";
reply = "FORBIDDEN_SN";
} else if (cat_call.second->second /*is_request*/ && data.empty()) {
LMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [",
to_hex(peer.pubkey), "]/", peer_address(cmd));
reply = "NO_REPLY_TAG";
} else {
return true;
}
if (reply.empty())
return true;
std::vector<zmq::message_t> msgs;
msgs.reserve(4);
if (!outgoing)
msgs.push_back(create_message(peer.route));
msgs.push_back(create_message(reply));
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
msgs.push_back(create_message("REPLY"_sv));
msgs.push_back(create_message(view(data.front()))); // reply tag
} else {
msgs.push_back(create_message(view(cmd)));
}
try {
if (outgoing)
send_direct_message(connections[conn_index], std::move(reply), command);
else
send_routed_message(connections[conn_index], peer.route, std::move(reply), command);
} catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ }
send_message_parts(connections[conn_index], msgs);
} catch (const zmq::error_t& err) {
/* can't send: possibly already disconnected. Ignore. */
LMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
}
return false;
}
void LokiMQ::set_active_sns(pubkey_set pubkeys) {
if (proxy_thread.joinable()) {
auto data = bt_serialize(detail::serialize_object(std::move(pubkeys)));
detail::send_control(get_control_socket(), "SET_SNS", data);
} else {
proxy_set_active_sns(std::move(pubkeys));
}
}
void LokiMQ::proxy_set_active_sns(string_view data) {
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(bt_deserialize<uintptr_t>(data)));
}
void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
pubkey_set added, removed;
for (auto it = pubkeys.begin(); it != pubkeys.end(); ) {
auto& pk = *it;
if (pk.size() != 32) {
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to set_active_sns");
it = pubkeys.erase(it);
continue;
}
if (!active_service_nodes.count(pk))
added.insert(std::move(pk));
++it;
}
if (added.empty() && active_service_nodes.size() == pubkeys.size()) {
LMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update");
return;
}
for (const auto& pk : active_service_nodes) {
if (!pubkeys.count(pk))
removed.insert(pk);
if (active_service_nodes.size() + added.size() - removed.size() == pubkeys.size())
break;
}
proxy_update_active_sns_clean(std::move(added), std::move(removed));
}
void LokiMQ::update_active_sns(pubkey_set added, pubkey_set removed) {
LMQ_LOG(info, "uh, ", added.size());
if (proxy_thread.joinable()) {
std::array<uintptr_t, 2> data;
data[0] = detail::serialize_object(std::move(added));
data[1] = detail::serialize_object(std::move(removed));
detail::send_control(get_control_socket(), "UPDATE_SNS", bt_serialize(data));
} else {
proxy_update_active_sns(std::move(added), std::move(removed));
}
}
void LokiMQ::proxy_update_active_sns(bt_list_consumer data) {
auto added = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
auto remed = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
proxy_update_active_sns(std::move(added), std::move(remed));
}
void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
// We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting
// values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists
// to the _clean version.
LMQ_LOG(info, "uh, ", added.size(), ", ", removed.size());
for (auto it = removed.begin(); it != removed.end(); ) {
const auto& pk = *it;
if (pk.size() != 32) {
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (removed)");
it = removed.erase(it);
} else if (!active_service_nodes.count(pk) || added.count(pk) /* added wins if in both */) {
it = removed.erase(it);
} else {
++it;
}
}
for (auto it = added.begin(); it != added.end(); ) {
const auto& pk = *it;
if (pk.size() != 32) {
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (added)");
it = added.erase(it);
} else if (active_service_nodes.count(pk)) {
it = added.erase(it);
} else {
++it;
}
}
proxy_update_active_sns_clean(std::move(added), std::move(removed));
}
void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) {
LMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys");
// For anything we remove we want close the connection to the SN (if outgoing), and remove the
// stored peer_info (incoming or outgoing).
for (const auto& pk : removed) {
ConnectionID c{pk};
active_service_nodes.erase(pk);
auto range = peers.equal_range(c);
for (auto it = range.first; it != range.second; ) {
bool outgoing = it->second.outgoing();
size_t conn_index = it->second.conn_index;
it = peers.erase(it);
if (outgoing) {
LMQ_LOG(debug, "Closing outgoing connection to ", c);
proxy_close_connection(conn_index, CLOSE_LINGER);
}
}
}
// For pubkeys we add there's nothing special to be done beyond adding them to the pubkey set
for (auto& pk : added)
active_service_nodes.insert(std::move(pk));
}
void LokiMQ::process_zap_requests() {
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
#ifndef NDEBUG
@ -147,31 +268,32 @@ void LokiMQ::process_zap_requests() {
} else {
auto ip = view(frames[3]);
string_view pubkey;
if (bind[bind_id].second.curve)
bool sn = false;
if (bind[bind_id].second.curve) {
pubkey = view(frames[6]);
auto result = bind[bind_id].second.allow(ip, pubkey);
bool sn = result.remote_sn;
sn = active_service_nodes.count(std::string{pubkey});
}
auto auth = bind[bind_id].second.allow(ip, pubkey, sn);
auto& user_id = response_vals[4];
if (bind[bind_id].second.curve) {
user_id.reserve(64);
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
}
if (result.auth <= AuthLevel::denied || result.auth > AuthLevel::admin) {
if (auth <= AuthLevel::denied || auth > AuthLevel::admin) {
LMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
" with initial auth level ", result.auth);
" with initial auth level ", auth);
status_code = "400";
status_text = "Access denied";
user_id.clear();
} else {
LMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
" connection with authentication level ", result.auth,
" connection with authentication level ", auth,
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
auto& metadata = response_vals[5];
metadata += zmtp_metadata("X-SN", result.remote_sn ? "1" : "0");
metadata += zmtp_metadata("X-AuthLevel", to_string(result.auth));
metadata += zmtp_metadata("X-AuthLevel", to_string(auth));
status_code = "200";
status_text = "";

View File

@ -21,13 +21,12 @@ struct Access {
bool remote_sn = false;
/// If true the category requires that the local node is a SN
bool local_sn = false;
};
/// Return type of the AllowFunc: this determines whether we allow the connection at all, and if so,
/// sets the initial authentication level and tells LokiMQ whether the other end is an active SN.
struct Allow {
AuthLevel auth = AuthLevel::none;
bool remote_sn = false;
/// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an
/// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both
/// remote and local sn set to false).
Access(AuthLevel auth, bool remote_sn = false, bool local_sn = false)
: auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {}
};
}

View File

@ -53,6 +53,47 @@ void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pub
// else let ZMQ pick a random one
}
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot call connect_sn() before calling `start()`");
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
return pubkey;
}
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot call connect_remote() before calling `start()`");
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
auto id = next_conn_id++;
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
{"conn_id", id},
{"connect", detail::serialize_object(std::move(on_connect))},
{"failure", detail::serialize_object(std::move(on_failure))},
{"pubkey", pubkey},
{"remote", remote},
{"timeout", timeout.count()},
}));
return id;
}
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
{"conn_id", id.id},
{"linger_ms", linger.count()},
{"pubkey", id.pk},
}));
}
std::pair<zmq::socket_t *, std::string>
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
ConnectionID remote_cid{remote};
@ -166,9 +207,9 @@ void update_connection_indices(Container& c, size_t index, AccessIndex get_index
}
}
/// Closes outgoing connections and removes all references. Note that this will invalidate
/// iterators on the various connection containers - if you don't want that, delete it first so that
/// the container won't contain the element being deleted.
/// Closes outgoing connections and removes all references. Note that this will call `erase()`
/// which can invalidate iterators on the various connection containers - if you don't want that,
/// delete it first so that the container won't contain the element being deleted.
void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) {
connections[index].setsockopt<int>(ZMQ_LINGER, linger > 0ms ? linger.count() : 0);
pollitems_stale = true;
@ -197,6 +238,7 @@ void LokiMQ::proxy_expire_idle_peers() {
continue;
}
LMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle timeout reached");
++it; // The below is going to delete our current element
proxy_close_connection(info.conn_index, CLOSE_LINGER);
} else {
++it;
@ -238,7 +280,7 @@ void LokiMQ::proxy_conn_cleanup() {
auto& callback = it->second;
if (callback.first < now) {
LMQ_LOG(debug, "pending request ", to_hex(it->first), " expired, invoking callback with failure status and removing");
job([callback = std::move(callback.second)] { callback(false, {}); });
job([callback = std::move(callback.second)] { callback(false, {{"TIMEOUT"s}}); });
it = pending_requests.erase(it);
} else {
++it;
@ -262,14 +304,10 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
if (data.skip_until("conn_id"))
conn_id = data.consume_integer<long long>();
if (data.skip_until("connect")) {
auto* ptr = reinterpret_cast<ConnectSuccess*>(data.consume_integer<uintptr_t>());
on_connect = std::move(*ptr);
delete ptr;
on_connect = detail::deserialize_object<ConnectSuccess>(data.consume_integer<uintptr_t>());
}
if (data.skip_until("failure")) {
auto* ptr = reinterpret_cast<ConnectFailure*>(data.consume_integer<uintptr_t>());
on_failure = std::move(*ptr);
delete ptr;
on_failure = detail::deserialize_object<ConnectFailure>(data.consume_integer<uintptr_t>());
}
if (data.skip_until("pubkey")) {
remote_pubkey = data.consume_string();

View File

@ -1,5 +1,6 @@
#pragma once
#include "string_view.h"
#include <cstring>
namespace lokimq {
@ -71,11 +72,24 @@ private:
friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn);
};
/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey).
/// Falls back to std::hash<std::string> if given a string smaller than a size_t.
struct already_hashed {
size_t operator()(const std::string& s) const {
if (s.size() < sizeof(size_t))
return std::hash<std::string>{}(s);
size_t hash;
std::memcpy(&hash, &s[0], sizeof(hash));
return hash;
}
};
} // namespace lokimq
namespace std {
template <> struct hash<lokimq::ConnectionID> {
size_t operator()(const lokimq::ConnectionID &c) const {
return c.sn() ? std::hash<std::string>{}(c.pk) :
return c.sn() ? lokimq::already_hashed{}(c.pk) :
std::hash<long long>{}(c.id);
}
};

View File

@ -27,7 +27,7 @@ extern "C" inline void message_buffer_destroy(void*, void* hint) {
}
/// Creates a message without needing to reallocate the provided string data
inline zmq::message_t create_message(std::string &&data) {
inline zmq::message_t create_message(std::string&& data) {
auto *buffer = new std::string(std::move(data));
return zmq::message_t{&(*buffer)[0], buffer->size(), message_buffer_destroy, buffer};
};
@ -38,7 +38,7 @@ inline zmq::message_t create_message(string_view data) {
}
template <typename It>
bool send_message_parts(zmq::socket_t &sock, It begin, It end) {
bool send_message_parts(zmq::socket_t& sock, It begin, It end) {
while (begin != end) {
zmq::message_t &msg = *begin++;
if (!sock.send(msg, begin == end ? zmq::send_flags::dontwait : zmq::send_flags::dontwait | zmq::send_flags::sndmore))
@ -48,7 +48,7 @@ bool send_message_parts(zmq::socket_t &sock, It begin, It end) {
}
template <typename Container>
bool send_message_parts(zmq::socket_t &sock, Container &&c) {
bool send_message_parts(zmq::socket_t& sock, Container&& c) {
return send_message_parts(sock, c.begin(), c.end());
}
@ -56,7 +56,7 @@ bool send_message_parts(zmq::socket_t &sock, Container &&c) {
/// the msg frame will be an empty message; if `data` is empty then the data frame will be omitted.
/// `flags` is passed through to zmq: typically given `zmq::send_flags::dontwait` to throw rather
/// than block if a message can't be queued.
inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}) {
inline bool send_routed_message(zmq::socket_t& socket, std::string route, std::string msg = {}, std::string data = {}) {
assert(!route.empty());
std::array<zmq::message_t, 3> msgs{{create_message(std::move(route))}};
if (!msg.empty())
@ -68,7 +68,7 @@ inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::s
// Sends some stuff to a socket directly. If dontwait is true then we throw instead of blocking if
// the message cannot be accepted by zmq (i.e. because the outgoing buffer is full).
inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}) {
inline bool send_direct_message(zmq::socket_t& socket, std::string msg, std::string data = {}) {
std::array<zmq::message_t, 2> msgs{{create_message(std::move(msg))}};
if (!data.empty())
msgs[1] = create_message(std::move(data));
@ -77,7 +77,7 @@ inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::str
// Receive all the parts of a single message from the given socket. Returns true if a message was
// received, false if called with flags=zmq::recv_flags::dontwait and no message was available.
inline bool recv_message_parts(zmq::socket_t &sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
inline bool recv_message_parts(zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
do {
zmq::message_t msg;
if (!sock.recv(msg, flags))

View File

@ -26,11 +26,6 @@ std::vector<std::string> as_strings(const MessageContainer& msgs) {
return result;
}
void check_started(const std::thread& proxy_thread, const std::string &verb) {
if (!proxy_thread.joinable())
throw std::logic_error("Cannot " + verb + " before calling `start()`");
}
void check_not_started(const std::thread& proxy_thread, const std::string &verb) {
if (proxy_thread.joinable())
throw std::logic_error("Cannot " + verb + " after calling `start()`");
@ -54,29 +49,20 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
}
}
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
/// socket.
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg) {
auto result = std::make_tuple(""s, false, AuthLevel::none);
/// Extracts a pubkey and and auth level from a zmq message received on a *listening* socket.
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
auto result = std::make_pair(""s, AuthLevel::none);
try {
string_view pubkey_hex{msg.gets("User-Id")};
if (pubkey_hex.size() != 64)
throw std::logic_error("bad user-id");
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
auto& pubkey = std::get<std::string>(result);
pubkey.resize(32, 0);
from_hex(pubkey_hex.begin(), pubkey_hex.end(), pubkey.begin());
result.first.resize(32, 0);
from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin());
} catch (...) {}
try {
string_view is_sn{msg.gets("X-SN")};
if (is_sn.size() == 1 && is_sn[0] == '1')
std::get<bool>(result) = true;
} catch (...) {}
try {
string_view auth_level{msg.gets("X-AuthLevel")};
std::get<AuthLevel>(result) = auth_from_string(auth_level);
result.second = auth_from_string(msg.gets("X-AuthLevel"));
} catch (...) {}
return result;
@ -385,46 +371,6 @@ LokiMQ::~LokiMQ() {
LMQ_LOG(info, "LokiMQ proxy thread has stopped");
}
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
check_started(proxy_thread, "connect");
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
return pubkey;
}
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
if (!proxy_thread.joinable())
LMQ_LOG(warn, "connect_remote() called before start(); this won't take effect until start() is called");
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
auto id = next_conn_id++;
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
{"conn_id", id},
{"connect", reinterpret_cast<uintptr_t>(new ConnectSuccess{std::move(on_connect)})},
{"failure", reinterpret_cast<uintptr_t>(new ConnectFailure{std::move(on_failure)})},
{"pubkey", pubkey},
{"remote", remote},
{"timeout", timeout.count()},
}));
return id;
}
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
{"conn_id", id.id},
{"linger_ms", linger.count()},
{"pubkey", id.pk},
}));
}
std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
os << (lvl == LogLevel::trace ? "trace" :
lvl == LogLevel::debug ? "debug" :

View File

@ -86,6 +86,14 @@ static constexpr size_t MAX_COMMAND_LENGTH = 200;
class CatHelper;
/// std::unordered_set specialization for specifying pubkeys (used, in particular, by
/// LokiMQ::set_active_sns and LokiMQ::update_active_sns); this is a std::string unordered_set that
/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the
/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like
/// pubkeys and a terrible hash choice for anything else).
using pubkey_set = std::unordered_set<std::string, already_hashed>;
/**
* Class that handles LokiMQ listeners, connections, proxying, and workers. An application
* typically has just one instance of this class.
@ -134,16 +142,18 @@ private:
public:
/// Callback type invoked to determine whether the given new incoming connection is allowed to
/// connect to us and to set its initial authentication level.
/// connect to us and to set its authentication level.
///
/// @param ip - the ip address of the incoming connection
/// @param pubkey - the x25519 pubkey of the connecting client (32 byte string). Note that this
/// will only be non-empty for incoming connections on `listen_curve` sockets; `listen_plain`
/// sockets do not have a pubkey.
/// @param service_node - will be true if the `pubkey` is in the set of known active service
/// nodes.
///
/// @returns an `AuthLevel` enum value indicating the default auth level for the incoming
/// connection, or AuthLevel::denied if the connection should be refused.
using AllowFunc = std::function<Allow(string_view ip, string_view pubkey)>;
using AllowFunc = std::function<AuthLevel(string_view ip, string_view pubkey, bool service_node)>;
/// Callback that is invoked when we need to send a "strong" message to a SN that we aren't
/// already connected to and need to establish a connection. This callback returns the ZMQ
@ -244,7 +254,9 @@ private:
std::vector<std::pair<std::string, bind_data>> bind;
/// Info about a peer's established connection with us. Note that "established" means both
/// connected and authenticated.
/// connected and authenticated. Note that we only store peer info data for SN connections (in
/// or out), and outgoing non-SN connections. Incoming non-SN connections are handled on the
/// fly.
struct peer_info {
/// Pubkey of the remote, if this connection is a curve25519 connection; empty otherwise.
std::string pubkey;
@ -521,15 +533,26 @@ private:
/// done).
std::unordered_map<std::string, std::string> command_aliases;
using cat_call_t = std::pair<category*, const std::pair<CommandCallback, bool>*>;
/// Retrieve category and callback from a command name, including alias mapping. Warns on
/// invalid commands and returns nullptrs. The command name will be updated in place if it is
/// aliased to another command.
std::pair<category*, const std::pair<CommandCallback, bool>*> get_command(std::string& command);
cat_call_t get_command(std::string& command);
/// Checks a peer's authentication level. Returns true if allowed, warns and returns false if
/// not.
bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
const std::string& command, const category& cat, zmq::message_t& msg);
zmq::message_t& command, const cat_call_t& cat_call, std::vector<zmq::message_t>& data);
/// Set of active service nodes.
pubkey_set active_service_nodes;
/// Resets or updates the stored set of active SN pubkeys
void proxy_set_active_sns(string_view data);
void proxy_set_active_sns(pubkey_set pubkeys);
void proxy_update_active_sns(bt_list_consumer data);
void proxy_update_active_sns(pubkey_set added, pubkey_set removed);
void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed);
/// Details for a pending command; such a command already has authenticated access and is just
/// waiting for a thread to become available to handle it.
@ -775,8 +798,15 @@ public:
* Finish starting up: binds to the bind locations given in the constructor and launches the
* proxy thread to handle message dispatching between remote nodes and worker threads.
*
* You will need to call `add_category` and `add_command` to register commands before calling
* `start()`; once start() is called commands cannot be changed.
* Things you want to do before calling this:
* - Use `add_category`/`add_command` to set up any commands remote connections can invoke.
* - If any commands require SN authentication, specify a list of currently active service node
* pubkeys via `set_active_sns()` (and make sure this gets updated when things change by
* another `set_active_sns()` or a `update_active_sns()` call). It *is* possible to make the
* initial call after calling `start()`, but that creates a window during which incoming
* remote SN connections will be erroneously treated as non-SN connections.
* - If this LMQ instance should accept incoming connections, set up any listening ports via
* `listen_curve()` and/or `listen_plain()`.
*/
void start();
@ -789,10 +819,10 @@ public:
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns non-service node,
* AuthLevel::none access.
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
*/
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; });
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
/** Start listening on the given bind address in unauthenticated plain text mode. Incoming
* connections can come from anywhere. `allow_connection` is invoked for any incoming
@ -803,10 +833,10 @@ public:
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns non-service node,
* AuthLevel::none access.
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
*/
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; });
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
/**
* Try to initiate a connection to the given SN in anticipation of needing a connection in the
@ -939,9 +969,23 @@ public:
* @param to - the pubkey string or ConnectionID to send this request to
* @param cmd - the command name
* @param callback - the callback to invoke when we get a reply. Called with a true value and
* the data strings when a reply is received, or false and an empty vector of data parts if we
* get no reply in the timeout interval.
* the data strings when a reply is received, or false with error string(s) indicating the
* failure reason upon failure or timeout.
* @param opts - anything else (i.e. strings, send_options) is forwarded to send().
*
* Possible error data values:
* - ["TIMEOUT"] - we got no reply within the timeout window
* - ["UNKNOWNCOMMAND"] - the remote did not recognize the given request command
* - ["NO_REPLY_TAG"] - the invoked command is a request command but no reply tag was included
* - ["FORBIDDEN"] - the command requires an authorization level (e.g. Basic or Admin) that we
* do not have.
* - ["FORBIDDEN_SN"] - the command requires service node authentication, but the remote did not
* recognize us as a service node. You *may* want to retry the request a limited number of
* times (but do not retry indefinitely as that can be an infinite loop!) because this is
* typically also followed by a disconnection; a retried message would reconnect and
* reauthenticate which *may* result in picking up the SN authentication.
* - ["NOT_A_SERVICE_NODE"] - this command is only invokable on service nodes, and the remote is
* not running as a service node.
*/
template <typename... T>
void request(ConnectionID to, string_view cmd, ReplyCallback callback, const T&... opts);
@ -951,6 +995,41 @@ public:
const std::string& get_pubkey() const { return pubkey; }
const std::string& get_privkey() const { return privkey; }
/** Updates (or initially sets) LokiMQ's list of service node pubkeys with the given list.
*
* This has two main effects:
*
* - All commands processed after the update will have SN status determined by the new list.
* - All outgoing connections to service nodes that are no longer on the list will be closed.
* This includes both explicit connections (established by `connect_sn()`) and implicit ones
* (established by sending to a SN that wasn't connected).
*
* As this update is potentially quite heavy it is recommended that this be called only when
* necessary--i.e. when the list has changed (or potentially changed), but *not* on a short
* periodic timer.
*
* This method may (and should!) be called before start() to load an initial set of SNs.
*
* Once a full list has been set, updates on changes can either call this again with the new
* list, or use the more efficient update_active_sns() call if incremental results are
* available.
*/
void set_active_sns(pubkey_set pubkeys);
/** Updates the list of active pubkeys by adding or removing the given pubkeys from the existing
* list. This is more efficient when the incremental information is already available; if it
* isn't, simply call set_active_sns with a new list to have LokiMQ figure out what was added or
* removed.
*
* \param added new pubkeys that were added since the last set_active_sns or update_active_sns
* call.
*
* \param removed pubkeys that were removed from active SN status since the last call. If a
* pubkey is in both `added` and `removed` for some reason then its presence in `removed` will
* be ignored.
*/
void update_active_sns(pubkey_set added, pubkey_set removed);
/**
* Batches a set of jobs to be executed by workers, optionally followed by a completion function.
*
@ -1121,7 +1200,9 @@ namespace detail {
/// containing the pointer to be serialized to pass (via lokimq queues) from one thread to another.
/// Must be matched with a deserializer_pointer on the other side to reconstitute the object and
/// destroy the intermediate pointer.
template <typename T> uintptr_t serialize_object(T&& obj) {
template <typename T>
uintptr_t serialize_object(T&& obj) {
static_assert(std::is_rvalue_reference<decltype(obj)>::value, "serialize_object must be given an rvalue reference");
auto* ptr = new T{std::forward<T>(obj)};
return reinterpret_cast<uintptr_t>(ptr);
}
@ -1192,9 +1273,8 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queu
control_data["send_full_q"] = serialize_object(std::move(f.callback));
}
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
/// socket.
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg);
/// Extracts a pubkey and auth level from a zmq message received on a *listening* socket.
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg);
template <typename... T>
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) {

View File

@ -178,7 +178,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) {
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
} else {
LMQ_LOG(debug, "Could not send request, scheduling request callback failure");
job([callback = std::move(request_callback)] { callback(false, {}); });
job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); });
}
}
if (!sent) {
@ -237,31 +237,36 @@ void LokiMQ::proxy_reply(bt_dict_consumer data) {
}
void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
// We throw an uncaught exception here because we only generate control messages internally in
// lokimq code: if one of these condition fail it's a lokimq bug.
if (parts.size() < 2)
throw std::logic_error("Expected 2-3 message parts for a proxy control message");
throw std::logic_error("LokiMQ bug: Expected 2-3 message parts for a proxy control message");
auto route = view(parts[0]), cmd = view(parts[1]);
LMQ_TRACE("control message: ", cmd);
if (parts.size() == 3) {
LMQ_TRACE("...: ", parts[2]);
auto data = view(parts[2]);
if (cmd == "SEND") {
LMQ_TRACE("proxying message");
return proxy_send(view(parts[2]));
return proxy_send(data);
} else if (cmd == "REPLY") {
LMQ_TRACE("proxying reply to non-SN incoming message");
return proxy_reply(view(parts[2]));
return proxy_reply(data);
} else if (cmd == "BATCH") {
LMQ_TRACE("proxy batch jobs");
auto ptrval = bt_deserialize<uintptr_t>(view(parts[2]));
auto ptrval = bt_deserialize<uintptr_t>(data);
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
} else if (cmd == "UPDATE_SNS") {
return proxy_update_active_sns(data);
} else if (cmd == "CONNECT_SN") {
proxy_connect_sn(view(parts[2]));
proxy_connect_sn(data);
return;
} else if (cmd == "CONNECT_REMOTE") {
return proxy_connect_remote(view(parts[2]));
return proxy_connect_remote(data);
} else if (cmd == "DISCONNECT") {
return proxy_disconnect(view(parts[2]));
return proxy_disconnect(data);
} else if (cmd == "TIMER") {
return proxy_timer(view(parts[2]));
return proxy_timer(data);
}
} else if (parts.size() == 2) {
if (cmd == "START") {
@ -279,8 +284,8 @@ void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
return;
}
}
throw std::runtime_error("Proxy received invalid control command: " + std::string{cmd} +
" (" + std::to_string(parts.size()) + ")");
throw std::runtime_error("LokiMQ bug: Proxy received invalid control command: " +
std::string{cmd} + " (" + std::to_string(parts.size()) + ")");
}
void LokiMQ::proxy_loop() {
@ -455,26 +460,31 @@ void LokiMQ::proxy_loop() {
}
}
static bool is_error_response(string_view cmd) {
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
}
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
// reason)
bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts) {
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
// Doubling as a bool and an offset:
size_t incoming = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_ROUTER;
string_view route, cmd;
if (parts.size() < (outgoing ? 1 : 2)) {
if (parts.size() < 1 + incoming) {
LMQ_LOG(warn, "Received empty message; ignoring");
return true;
}
if (outgoing) {
cmd = view(parts[0]);
} else {
if (incoming) {
route = view(parts[0]);
cmd = view(parts[1]);
} else {
cmd = view(parts[0]);
}
LMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back()));
if (cmd == "REPLY") {
size_t tag_pos = (outgoing ? 1 : 2);
size_t tag_pos = 1 + incoming;
if (parts.size() <= tag_pos) {
LMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
return true;
@ -482,7 +492,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
std::string reply_tag{view(parts[tag_pos])};
auto it = pending_requests.find(reply_tag);
if (it != pending_requests.end()) {
LMQ_LOG(debug, "Received REPLY for pending command", to_hex(reply_tag), "; scheduling callback");
LMQ_LOG(debug, "Received REPLY for pending command ", to_hex(reply_tag), "; scheduling callback");
std::vector<std::string> data;
data.reserve(parts.size() - (tag_pos + 1));
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
@ -496,7 +506,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
}
return true;
} else if (cmd == "HI") {
if (outgoing) {
if (!incoming) {
LMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
return true;
}
@ -506,7 +516,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
} catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
return true;
} else if (cmd == "HELLO") {
if (!outgoing) {
if (incoming) {
LMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
return true;
}
@ -531,22 +541,50 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
pending_connects.erase(it);
return true;
} else if (cmd == "BYE") {
if (outgoing) {
std::string pk;
bool sn;
AuthLevel a;
std::tie(pk, sn, a) = detail::extract_metadata(parts.back());
ConnectionID conn = sn ? ConnectionID{std::move(pk)} : conn_index_to_id[conn_index];
LMQ_LOG(debug, "BYE command received; disconnecting from ", conn);
proxy_disconnect(conn, 1s);
if (!incoming) {
LMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
proxy_close_connection(conn_index, 0s);
} else {
LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
}
return true;
}
else if (cmd == "FORBIDDEN" || cmd == "NOT_A_SERVICE_NODE") {
return true; // FIXME - ignore these? Log?
else if (is_error_response(cmd)) {
// These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to
// invoke something that doesn't exist or we don't have permission to access. These have
// two forms (the latter is only sent by remotes running 1.0.6+).
// - ["XXX", "whatever.command"]
// - ["XXX", "REPLY", replytag]
// (ignoring the routing prefix on incoming commands).
// For the former, we log; for the latter we trigger the reply callback with a failure
if (parts.size() == (2 + incoming) && is_error_response(view(parts[1 + incoming]))) {
// Something like ["UNKNOWNCOMMAND", "FORBIDDEN_SN"] which can happen because the remote
// is running an older version that didn't understand the FORBIDDEN_SN (or whatever)
// error reply that we sent them. We just ignore it because anything else could trigger
// an infinite cycle.
LMQ_LOG(debug, "Received [", cmd, ",", view(parts[1 + incoming]), "]; remote is probably an older lokimq. Ignoring.");
return true;
}
if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") {
std::string reply_tag{view(parts[2 + incoming])};
auto it = pending_requests.find(reply_tag);
if (it != pending_requests.end()) {
LMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", to_hex(reply_tag), "; scheduling failure callback");
proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] {
callback(false, {{std::move(cmd)}});
});
pending_requests.erase(it);
} else {
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
}
} else {
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"_sv),
" from ", peer_address(parts.back()));
}
return true;
}
return false;
}

View File

@ -201,22 +201,24 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
}
peer = &it->second;
} else {
std::tie(tmp_peer.pubkey, tmp_peer.service_node, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey);
if (tmp_peer.service_node) {
// It's a service node so we should have a peer_info entry; see if we can find one with
// the same route, and if not, add one.
auto pr = peers.equal_range(tmp_peer.pubkey);
for (auto it = pr.first; it != pr.second; ++it) {
if (it->second.route == tmp_peer.route) {
if (it->second.conn_index == tmp_peer.conn_index && it->second.route == tmp_peer.route) {
peer = &it->second;
// Upgrade permissions in case we have something higher on the socket
peer->service_node |= tmp_peer.service_node;
if (tmp_peer.auth_level > peer->auth_level)
peer->auth_level = tmp_peer.auth_level;
// Update the stored auth level just in case the peer reconnected
peer->auth_level = tmp_peer.auth_level;
break;
}
}
if (!peer) {
// We don't have a record: this is either a new SN connection or a new message on a
// connection that recently gained SN status.
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
}
} else {
@ -228,23 +230,6 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
size_t command_part_index = outgoing ? 0 : 1;
std::string command = parts[command_part_index].to_string();
auto cat_call = get_command(command);
if (!cat_call.first) {
LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer->pubkey), "]/", peer_address(parts.back()));
try {
if (outgoing)
send_direct_message(connections[conn_index], "UNKNOWNCOMMAND");
else
send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND");
} catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ }
return;
}
auto& category = *cat_call.first;
if (!proxy_check_auth(conn_index, outgoing, *peer, command, category, parts.back()))
return;
// Steal any data message parts
size_t data_part_index = command_part_index + 1;
@ -253,6 +238,14 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
data_parts.push_back(std::move(*it));
auto cat_call = get_command(command);
// Check that command is valid, that we have permission, etc.
if (!proxy_check_auth(conn_index, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
return;
auto& category = *cat_call.first;
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
// No free reserved or general spots, try to queue it for later
if (category.max_queue >= 0 && category.queued >= category.max_queue) {

View File

@ -6,6 +6,7 @@ set(LMQ_TEST_SRC
test_batch.cpp
test_connect.cpp
test_commands.cpp
test_failures.cpp
test_requests.cpp
test_string_view.cpp
)

View File

@ -6,6 +6,25 @@ using namespace lokimq;
static auto startup = std::chrono::steady_clock::now();
/// Waits up to 100ms for something to happen.
template <typename Func>
inline void wait_for(Func f) {
for (int i = 0; i < 10; i++) {
if (f())
break;
std::this_thread::sleep_for(10ms);
}
}
/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough
/// time for an initial connection + request.
inline void wait_for_conn(std::atomic<bool> &c) {
wait_for([&c] { return c.load(); });
}
/// Waits enough time for us to receive a reply from a localhost remote.
inline void reply_sleep() { std::this_thread::sleep_for(10ms); }
// Catch2 macros aren't thread safe, so guard with a mutex
inline std::unique_lock<std::mutex> catch_lock() {
static std::mutex mutex;

View File

@ -1,5 +1,4 @@
#include "common.h"
#include <future>
#include <lokimq/hex.h>
#include <map>
#include <set>
@ -12,10 +11,10 @@ TEST_CASE("basic commands", "[commands]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
std::atomic<int> hellos{0}, his{0};
@ -32,41 +31,33 @@ TEST_CASE("basic commands", "[commands]") {
server.start();
LokiMQ client{
get_logger("")
};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
client.add_category("public", Access{AuthLevel::none});
client.add_command("public", "hi", [&](auto&) { his++; });
client.start();
std::atomic<bool> connected{false}, failed{false};
std::atomic<bool> got{false};
bool success = false, failed = false;
std::string pubkey;
auto c = client.connect_remote(listen,
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
[&](auto conn, string_view) { failed = true; },
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
[&](auto conn, string_view) { failed = true; got = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( connected.load() );
REQUIRE( i <= 1 ); // should be fast
REQUIRE( !failed.load() );
REQUIRE( got );
REQUIRE( success );
REQUIRE_FALSE( failed );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
client.send(c, "public.hello");
client.send(c, "public.client.pubkey");
std::this_thread::sleep_for(50ms);
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( hellos == 1 );
@ -77,7 +68,7 @@ TEST_CASE("basic commands", "[commands]") {
for (int i = 0; i < 50; i++)
client.send(c, "public.hello");
std::this_thread::sleep_for(100ms);
wait_for([&] { return his == 26; });
{
auto lock = catch_lock();
REQUIRE( hellos == 51 );
@ -91,10 +82,10 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
std::atomic<int> hellos{0};
@ -103,10 +94,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
server.start();
LokiMQ client{
get_logger("")
};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
std::atomic<int> public_hi{0}, basic_hi{0}, admin_hi{0};
client.add_category("public", Access{AuthLevel::none});
@ -124,8 +112,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
client.send(public_c, "public.reflect", "public.hi");
std::this_thread::sleep_for(50ms);
wait_for([&] { return public_hi == 1; });
{
auto lock = catch_lock();
REQUIRE( public_hi == 1 );
@ -138,7 +125,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.send(admin_c, "public.reflect", "admin.hi");
client.send(basic_c, "public.reflect", "basic.hi");
std::this_thread::sleep_for(50ms);
wait_for([&] { return public_hi == 2; });
{
auto lock = catch_lock();
REQUIRE( admin_hi == 3 );
@ -160,7 +147,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
client.send(admin_c, "public.reflect", "basic.hi");
client.send(admin_c, "public.reflect", "public.hi");
std::this_thread::sleep_for(50ms);
wait_for([&] { return public_hi == 3; });
auto lock = catch_lock();
REQUIRE( admin_hi == 1 );
REQUIRE( basic_hi == 2 );
@ -176,10 +163,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
std::vector<std::pair<ConnectionID, std::string>> subscribers;
ConnectionID backdoor;
@ -221,8 +208,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
nsa.send(nsa_c, "hey google.install backdoor");
std::this_thread::sleep_for(50ms);
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
{
auto l = catch_lock();
REQUIRE( backdoor );
@ -243,9 +229,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
std::map<int, std::set<std::string>> google_knows;
int things_remembered{0};
for (int i = 0; i < 5; i++) {
clients.push_back(std::make_unique<LokiMQ>(get_logger("C" + std::to_string(i) + "» ")));
clients.push_back(std::make_unique<LokiMQ>(
get_logger("C" + std::to_string(i) + "» "), LogLevel::trace
));
auto& c = clients.back();
c->log_level(LogLevel::trace);
c->add_category("personal", Access{AuthLevel::basic});
c->add_command("personal", "detail", [&,i](Message& m) {
auto l = catch_lock();
@ -265,7 +252,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
},
personal_detail);
}
std::this_thread::sleep_for(50ms);
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size(); });
{
auto l = catch_lock();
REQUIRE( things_remembered == all_the_things.size() );
@ -273,7 +260,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
}
clients[0]->send(conns[0], "hey google.recall");
std::this_thread::sleep_for(50ms);
reply_sleep();
{
auto l = catch_lock();
REQUIRE( google_knows == personal_details );
@ -286,10 +273,10 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::debug // This test traces so much that it takes 2.5-3s of CPU time at trace level, so don't do that.
};
server.log_level(LogLevel::debug);
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_plain(listen);
std::atomic<int> send_attempts{0};
std::atomic<int> send_failures{0};
@ -331,7 +318,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
for (i = 0; i < 20; i++) {
if (send_attempts.load() >= 500)
break;
std::this_thread::sleep_for(25ms);
std::this_thread::sleep_for(10ms);
}
{
auto lock = catch_lock();
@ -354,20 +341,20 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
expected_attempts += 500;
if (i >= 4) {
std::this_thread::sleep_for(25ms);
if (send_failures.load() > 0)
break;
std::this_thread::sleep_for(25ms);
}
}
for (i = 0; i < 10; i++) {
for (i = 0; i < 20; i++) {
if (send_attempts.load() >= expected_attempts)
break;
std::this_thread::sleep_for(25ms);
std::this_thread::sleep_for(10ms);
}
{
auto lock = catch_lock();
REQUIRE( i <= 8 );
REQUIRE( i < 20 );
REQUIRE( send_attempts.load() == expected_attempts );
REQUIRE( send_failures.load() > 0 );
}

View File

@ -1,5 +1,4 @@
#include "common.h"
#include <future>
#include <lokimq/hex.h>
extern "C" {
#include <sodium.h>
@ -12,46 +11,42 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") {
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
server.log_level(LogLevel::trace);
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_curve(listen);
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
server.start();
LokiMQ client{get_logger("")};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
client.start();
auto pubkey = server.get_pubkey();
std::atomic<int> connected{0};
std::atomic<bool> got{false};
bool success = false;
auto server_conn = client.connect_remote(listen,
[&](auto conn) { connected = 1; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); },
[&](auto conn) { success = true; got = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; },
pubkey);
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( i <= 1 );
REQUIRE( connected.load() );
REQUIRE( got );
REQUIRE( success );
}
bool success = false;
success = false;
std::vector<std::string> parts;
client.request(server_conn, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( success );
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( success );
}
}
TEST_CASE("self-connection SN optimization", "[connect][self]") {
@ -63,10 +58,16 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
pubkey, privkey,
true,
[&](auto pk) { if (pk == pubkey) return "tcp://127.0.0.1:5544"; else return ""; },
get_logger("")
get_logger(""),
LogLevel::trace
};
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk) { REQUIRE(ip == "127.0.0.1"); return Allow{AuthLevel::none, pk == pubkey}; });
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk, auto sn) {
auto lock = catch_lock();
REQUIRE(ip == "127.0.0.1");
REQUIRE(sn == (pk == pubkey));
return AuthLevel::none;
});
sn.add_category("a", Access{AuthLevel::none});
bool invoked = false;
sn.add_command("a", "b", [&](const Message& m) {
@ -77,57 +78,54 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
REQUIRE(!m.data.empty());
REQUIRE(m.data[0] == "my data");
});
sn.log_level(LogLevel::trace);
sn.set_active_sns({{pubkey}});
sn.start();
std::this_thread::sleep_for(50ms);
sn.send(pubkey, "a.b", "my data");
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE(invoked);
reply_sleep();
{
auto lock = catch_lock();
REQUIRE(invoked);
}
}
TEST_CASE("plain-text connections", "[plaintext][connect]") {
std::string listen = "tcp://127.0.0.1:4455";
LokiMQ server{get_logger("")};
server.log_level(LogLevel::trace);
LokiMQ server{get_logger(""), LogLevel::trace};
server.add_category("public", Access{AuthLevel::none});
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
server.listen_plain(listen);
server.start();
LokiMQ client{get_logger("")};
client.log_level(LogLevel::trace);
LokiMQ client{get_logger(""), LogLevel::trace};
client.start();
std::atomic<int> connected{0};
std::atomic<bool> got{false};
bool success = false;
auto c = client.connect_remote(listen,
[&](auto conn) { connected = 1; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
[&](auto conn) { success = true; got = true; },
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
);
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( i <= 1 );
REQUIRE( connected.load() );
REQUIRE( got );
REQUIRE( success );
}
bool success = false;
success = false;
std::vector<std::string> parts;
client.request(c, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( success );
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( success );
}
}
TEST_CASE("SN disconnections", "[connect][disconnect]") {
@ -147,28 +145,146 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") {
lmq.push_back(std::make_unique<LokiMQ>(
pubkey[i], privkey[i], true,
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
get_logger("S" + std::to_string(i) + "» ")
get_logger("S" + std::to_string(i) + "» "),
LogLevel::trace
));
auto& server = *lmq.back();
server.log_level(LogLevel::debug);
server.listen_curve(conn[pubkey[i]], [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, true}; });
server.listen_curve(conn[pubkey[i]]);
server.add_category("sn", Access{AuthLevel::none, true})
.add_command("hi", [&](Message& m) { his++; });
server.set_active_sns({pubkey.begin(), pubkey.end()});
server.start();
}
std::this_thread::sleep_for(50ms);
lmq[0]->send(pubkey[1], "sn.hi");
lmq[0]->send(pubkey[2], "sn.hi");
std::this_thread::sleep_for(50ms);
lmq[2]->send(pubkey[0], "sn.hi");
lmq[2]->send(pubkey[1], "sn.hi");
lmq[1]->send(pubkey[0], "BYE");
std::this_thread::sleep_for(50ms);
lmq[0]->send(pubkey[2], "sn.hi");
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE(his == 5);
}
TEST_CASE("SN auth checks", "[sandwich][auth]") {
// When a remote connects, we check its authentication level; if at the time of connection it
// isn't recognized as a SN but tries to invoke a SN command it'll be told to disconnect; if it
// tries to send again it should reconnect and reauthenticate. This test is meant to test this
// pattern where the reconnection/reauthentication now authenticates it as a SN.
std::string listen = "tcp://127.0.0.1:4455";
std::string pubkey, privkey;
pubkey.resize(crypto_box_PUBLICKEYBYTES);
privkey.resize(crypto_box_SECRETKEYBYTES);
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
LokiMQ server{
pubkey, privkey,
true, // service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
std::atomic<bool> incoming_is_sn{false};
server.listen_curve(listen);
server.add_category("public", Access{AuthLevel::none})
.add_request_command("hello", [&](Message& m) { m.send_reply("hi"); })
.add_request_command("sudo", [&](Message& m) {
server.update_active_sns({{m.conn.pubkey()}}, {});
m.send_reply("making sandwiches");
})
.add_request_command("nosudo", [&](Message& m) {
// Send the reply *first* because if we do it the other way we'll have just removed
// ourselves from the list of SNs and thus would try to open an outbound connection
// to deliver it since it's still queued as a message to a SN.
m.send_reply("make them yourself");
server.update_active_sns({}, {{m.conn.pubkey()}});
});
server.add_category("sandwich", Access{AuthLevel::none, true})
.add_request_command("make", [&](Message& m) { m.send_reply("okay"); });
server.start();
LokiMQ client{
"", "", false,
[&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; },
get_logger(""), LogLevel::trace};
client.start();
std::atomic<bool> got{false};
bool success;
client.request(pubkey, "public.hello", [&](auto success_, auto) { success = success_; got = true; });
wait_for_conn(got);
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
}
got = false;
using dvec = std::vector<std::string>;
dvec data;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE_FALSE( success );
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
}
// Somebody set up us the bomb. Main sudo turn on.
got = false;
client.request(pubkey, "public.sudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"making sandwiches"}} );
}
got = false;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"okay"}} );
}
// Take off every 'SUDO', You [not] know what you doing
got = false;
client.request(pubkey, "public.nosudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE( success );
REQUIRE( data == dvec{{"make them yourself"}} );
}
got = false;
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
success = success_;
data = std::move(data_);
got = true;
});
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got );
REQUIRE_FALSE( success );
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
}
}

309
tests/test_failures.cpp Normal file
View File

@ -0,0 +1,309 @@
#include "common.h"
#include <lokimq/hex.h>
#include <map>
#include <set>
using namespace lokimq;
TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen);
server.start();
// Use a raw socket here because I want to see the raw commands coming on the wire
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
auto lock = catch_lock();
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "a.a" );
REQUIRE_FALSE( resp.more() );
}
TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen);
server.add_category("x", AuthLevel::none)
.add_request_command("r", [] (auto& m) { m.send_reply("a"); });
server.start();
// Use a raw socket here because I want to see the raw commands coming on the wire
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "x.r" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "foo" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", AuthLevel::basic)
.add_command("x", [] (auto& m) { m.send_back("a"); });
server.add_category("y", AuthLevel::admin)
.add_command("x", [] (auto& m) { m.send_back("b"); });
server.start();
zmq::context_t client_ctx;
std::array<zmq::socket_t, 3> clients;
// Client 0 should get none auth level, client 1 should get basic, client 2 should get admin
for (auto& client : clients) {
client = {client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
}
for (auto& c : clients)
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
clients[0].recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
clients[0].recv(resp);
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
for (int i : {1, 2}) {
clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( resp.to_string() == "a" );
REQUIRE_FALSE( resp.more() );
}
for (auto& c : clients)
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
for (int i : {0, 1}) {
clients[i].recv(resp);
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN" );
REQUIRE( resp.more() );
clients[i].recv(resp);
REQUIRE( resp.to_string() == "y.x" );
REQUIRE_FALSE( resp.more() );
}
clients[2].recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "b" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", Access{AuthLevel::none, false, true})
.add_command("x", [] (auto&) {})
.add_request_command("r", [] (auto& m) { m.send_reply(); })
;
server.start();
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}
}
TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
std::string listen = "tcp://127.0.0.1:4567";
LokiMQ server{
"", "", // generate ephemeral keys
false, // not a service node
[](auto) { return ""; },
get_logger(""),
LogLevel::trace
};
server.listen_plain(listen, [](auto, auto, auto) {
static int count = 0;
++count;
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
});
server.add_category("x", Access{AuthLevel::none, true, false})
.add_command("x", [] (auto&) {})
.add_request_command("r", [] (auto& m) { m.send_reply(); })
;
server.start();
zmq::context_t client_ctx;
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
client.connect(listen);
// Handshake: we send HI, they reply HELLO.
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
{
zmq::message_t hello;
client.recv(hello);
auto lock = catch_lock();
REQUIRE( hello.to_string() == "HELLO" );
REQUIRE_FALSE( hello.more() );
}
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
zmq::message_t resp;
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "x.x" );
REQUIRE_FALSE( resp.more() );
}
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
client.recv(resp);
{
auto lock = catch_lock();
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "REPLY" );
REQUIRE( resp.more() );
client.recv(resp);
REQUIRE( resp.to_string() == "xyz123" );
REQUIRE_FALSE( resp.more() );
}
}

View File

@ -1,5 +1,4 @@
#include "common.h"
#include <future>
#include <lokimq/hex.h>
using namespace lokimq;
@ -36,17 +35,11 @@ TEST_CASE("basic requests", "[requests]") {
[&](auto, auto) { failed = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
wait_for([&] { return connected || failed; });
{
auto lock = catch_lock();
REQUIRE( connected.load() );
REQUIRE( !failed.load() );
REQUIRE( i <= 1 );
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
}
@ -59,11 +52,13 @@ TEST_CASE("basic requests", "[requests]") {
data = std::move(data_);
});
std::this_thread::sleep_for(50ms);
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"123"}} );
reply_sleep();
{
auto lock = catch_lock();
REQUIRE( got_reply.load() );
REQUIRE( success );
REQUIRE( data == std::vector<std::string>{{"123"}} );
}
}
TEST_CASE("request from server to client", "[requests]") {
@ -161,15 +156,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
[&](auto, auto) { failed = true; },
server.get_pubkey());
int i;
for (i = 0; i < 5; i++) {
if (connected.load())
break;
std::this_thread::sleep_for(50ms);
}
REQUIRE( connected.load() );
REQUIRE( !failed.load() );
REQUIRE( i <= 1 );
wait_for([&] { return connected || failed; });
REQUIRE( connected );
REQUIRE_FALSE( failed );
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
std::atomic<bool> got_triggered{false};
@ -180,7 +170,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
success = ok;
data = std::move(data_);
},
lokimq::send_option::request_timeout{30ms}
lokimq::send_option::request_timeout{20ms}
);
std::atomic<bool> got_triggered2{false};
@ -192,10 +182,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
lokimq::send_option::request_timeout{100ms}
);
std::this_thread::sleep_for(50ms);
REQUIRE( got_triggered.load() );
std::this_thread::sleep_for(30ms);
REQUIRE( got_triggered );
REQUIRE_FALSE( got_triggered2 );
REQUIRE_FALSE( success );
REQUIRE( data.size() == 0 );
REQUIRE( data == std::vector<std::string>{{"TIMEOUT"}} );
REQUIRE_FALSE( got_triggered2.load() );
}