mirror of
https://github.com/oxen-io/oxen-mq.git
synced 2023-12-13 21:00:31 +01:00
Merge remote-tracking branch 'origin/master' into ubuntu/bionic
This commit is contained in:
commit
1662279a20
|
@ -5,8 +5,8 @@ project(liblokimq CXX)
|
|||
include(GNUInstallDirs)
|
||||
|
||||
set(LOKIMQ_VERSION_MAJOR 1)
|
||||
set(LOKIMQ_VERSION_MINOR 0)
|
||||
set(LOKIMQ_VERSION_PATCH 5)
|
||||
set(LOKIMQ_VERSION_MINOR 1)
|
||||
set(LOKIMQ_VERSION_PATCH 0)
|
||||
set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}")
|
||||
message(STATUS "lokimq v${LOKIMQ_VERSION}")
|
||||
|
||||
|
|
190
lokimq/auth.cpp
190
lokimq/auth.cpp
|
@ -30,45 +30,166 @@ std::string zmtp_metadata(string_view key, string_view value) {
|
|||
|
||||
|
||||
bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
|
||||
const std::string& command, const category& cat, zmq::message_t& msg) {
|
||||
zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) {
|
||||
auto command = view(cmd);
|
||||
std::string reply;
|
||||
if (peer.auth_level < cat.access.auth) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
|
||||
": peer auth level ", peer.auth_level, " < ", cat.access.auth);
|
||||
|
||||
if (!cat_call.first) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
reply = "UNKNOWNCOMMAND";
|
||||
} else if (peer.auth_level < cat_call.first->access.auth) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": peer auth level ", peer.auth_level, " < ", cat_call.first->access.auth);
|
||||
reply = "FORBIDDEN";
|
||||
}
|
||||
else if (cat.access.local_sn && !local_service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
|
||||
} else if (cat_call.first->access.local_sn && !local_service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": that command is only available when this LokiMQ is running in service node mode");
|
||||
reply = "NOT_A_SERVICE_NODE";
|
||||
}
|
||||
else if (cat.access.remote_sn && !peer.service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg),
|
||||
} else if (cat_call.first->access.remote_sn && !peer.service_node) {
|
||||
LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd),
|
||||
": remote is not recognized as a service node");
|
||||
// Disconnect: we don't think the remote is a SN, but it issued a command only SNs should be
|
||||
// issuing. Drop the connection; if the remote has something important to relay it will
|
||||
// reconnect, at which point we will reassess the SN status on the new incoming connection.
|
||||
if (outgoing) {
|
||||
proxy_disconnect(peer.service_node ? ConnectionID{peer.pubkey} : conn_index_to_id[conn_index], 1s);
|
||||
return false;
|
||||
}
|
||||
else
|
||||
reply = "BYE";
|
||||
reply = "FORBIDDEN_SN";
|
||||
} else if (cat_call.second->second /*is_request*/ && data.empty()) {
|
||||
LMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [",
|
||||
to_hex(peer.pubkey), "]/", peer_address(cmd));
|
||||
reply = "NO_REPLY_TAG";
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (reply.empty())
|
||||
return true;
|
||||
std::vector<zmq::message_t> msgs;
|
||||
msgs.reserve(4);
|
||||
if (!outgoing)
|
||||
msgs.push_back(create_message(peer.route));
|
||||
msgs.push_back(create_message(reply));
|
||||
if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) {
|
||||
msgs.push_back(create_message("REPLY"_sv));
|
||||
msgs.push_back(create_message(view(data.front()))); // reply tag
|
||||
} else {
|
||||
msgs.push_back(create_message(view(cmd)));
|
||||
}
|
||||
|
||||
try {
|
||||
if (outgoing)
|
||||
send_direct_message(connections[conn_index], std::move(reply), command);
|
||||
else
|
||||
send_routed_message(connections[conn_index], peer.route, std::move(reply), command);
|
||||
} catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ }
|
||||
send_message_parts(connections[conn_index], msgs);
|
||||
} catch (const zmq::error_t& err) {
|
||||
/* can't send: possibly already disconnected. Ignore. */
|
||||
LMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void LokiMQ::set_active_sns(pubkey_set pubkeys) {
|
||||
if (proxy_thread.joinable()) {
|
||||
auto data = bt_serialize(detail::serialize_object(std::move(pubkeys)));
|
||||
detail::send_control(get_control_socket(), "SET_SNS", data);
|
||||
} else {
|
||||
proxy_set_active_sns(std::move(pubkeys));
|
||||
}
|
||||
}
|
||||
void LokiMQ::proxy_set_active_sns(string_view data) {
|
||||
proxy_set_active_sns(detail::deserialize_object<pubkey_set>(bt_deserialize<uintptr_t>(data)));
|
||||
}
|
||||
void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) {
|
||||
pubkey_set added, removed;
|
||||
for (auto it = pubkeys.begin(); it != pubkeys.end(); ) {
|
||||
auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to set_active_sns");
|
||||
it = pubkeys.erase(it);
|
||||
continue;
|
||||
}
|
||||
if (!active_service_nodes.count(pk))
|
||||
added.insert(std::move(pk));
|
||||
++it;
|
||||
}
|
||||
if (added.empty() && active_service_nodes.size() == pubkeys.size()) {
|
||||
LMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update");
|
||||
return;
|
||||
}
|
||||
for (const auto& pk : active_service_nodes) {
|
||||
if (!pubkeys.count(pk))
|
||||
removed.insert(pk);
|
||||
if (active_service_nodes.size() + added.size() - removed.size() == pubkeys.size())
|
||||
break;
|
||||
}
|
||||
proxy_update_active_sns_clean(std::move(added), std::move(removed));
|
||||
}
|
||||
|
||||
void LokiMQ::update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
LMQ_LOG(info, "uh, ", added.size());
|
||||
if (proxy_thread.joinable()) {
|
||||
std::array<uintptr_t, 2> data;
|
||||
data[0] = detail::serialize_object(std::move(added));
|
||||
data[1] = detail::serialize_object(std::move(removed));
|
||||
detail::send_control(get_control_socket(), "UPDATE_SNS", bt_serialize(data));
|
||||
} else {
|
||||
proxy_update_active_sns(std::move(added), std::move(removed));
|
||||
}
|
||||
}
|
||||
void LokiMQ::proxy_update_active_sns(bt_list_consumer data) {
|
||||
auto added = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
|
||||
auto remed = detail::deserialize_object<pubkey_set>(data.consume_integer<uintptr_t>());
|
||||
proxy_update_active_sns(std::move(added), std::move(remed));
|
||||
}
|
||||
void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) {
|
||||
// We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting
|
||||
// values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists
|
||||
// to the _clean version.
|
||||
|
||||
LMQ_LOG(info, "uh, ", added.size(), ", ", removed.size());
|
||||
for (auto it = removed.begin(); it != removed.end(); ) {
|
||||
const auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (removed)");
|
||||
it = removed.erase(it);
|
||||
} else if (!active_service_nodes.count(pk) || added.count(pk) /* added wins if in both */) {
|
||||
it = removed.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it = added.begin(); it != added.end(); ) {
|
||||
const auto& pk = *it;
|
||||
if (pk.size() != 32) {
|
||||
LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (added)");
|
||||
it = added.erase(it);
|
||||
} else if (active_service_nodes.count(pk)) {
|
||||
it = added.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
proxy_update_active_sns_clean(std::move(added), std::move(removed));
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) {
|
||||
LMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys");
|
||||
|
||||
// For anything we remove we want close the connection to the SN (if outgoing), and remove the
|
||||
// stored peer_info (incoming or outgoing).
|
||||
for (const auto& pk : removed) {
|
||||
ConnectionID c{pk};
|
||||
active_service_nodes.erase(pk);
|
||||
auto range = peers.equal_range(c);
|
||||
for (auto it = range.first; it != range.second; ) {
|
||||
bool outgoing = it->second.outgoing();
|
||||
size_t conn_index = it->second.conn_index;
|
||||
it = peers.erase(it);
|
||||
if (outgoing) {
|
||||
LMQ_LOG(debug, "Closing outgoing connection to ", c);
|
||||
proxy_close_connection(conn_index, CLOSE_LINGER);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For pubkeys we add there's nothing special to be done beyond adding them to the pubkey set
|
||||
for (auto& pk : added)
|
||||
active_service_nodes.insert(std::move(pk));
|
||||
}
|
||||
|
||||
void LokiMQ::process_zap_requests() {
|
||||
for (std::vector<zmq::message_t> frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) {
|
||||
#ifndef NDEBUG
|
||||
|
@ -147,31 +268,32 @@ void LokiMQ::process_zap_requests() {
|
|||
} else {
|
||||
auto ip = view(frames[3]);
|
||||
string_view pubkey;
|
||||
if (bind[bind_id].second.curve)
|
||||
bool sn = false;
|
||||
if (bind[bind_id].second.curve) {
|
||||
pubkey = view(frames[6]);
|
||||
auto result = bind[bind_id].second.allow(ip, pubkey);
|
||||
bool sn = result.remote_sn;
|
||||
sn = active_service_nodes.count(std::string{pubkey});
|
||||
}
|
||||
auto auth = bind[bind_id].second.allow(ip, pubkey, sn);
|
||||
auto& user_id = response_vals[4];
|
||||
if (bind[bind_id].second.curve) {
|
||||
user_id.reserve(64);
|
||||
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
|
||||
}
|
||||
|
||||
if (result.auth <= AuthLevel::denied || result.auth > AuthLevel::admin) {
|
||||
if (auth <= AuthLevel::denied || auth > AuthLevel::admin) {
|
||||
LMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection from ", !user_id.empty() ? user_id + " at " : ""s, ip,
|
||||
" with initial auth level ", result.auth);
|
||||
" with initial auth level ", auth);
|
||||
status_code = "400";
|
||||
status_text = "Access denied";
|
||||
user_id.clear();
|
||||
} else {
|
||||
LMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"),
|
||||
" connection with authentication level ", result.auth,
|
||||
" connection with authentication level ", auth,
|
||||
" from ", !user_id.empty() ? user_id + " at " : ""s, ip);
|
||||
|
||||
auto& metadata = response_vals[5];
|
||||
metadata += zmtp_metadata("X-SN", result.remote_sn ? "1" : "0");
|
||||
metadata += zmtp_metadata("X-AuthLevel", to_string(result.auth));
|
||||
metadata += zmtp_metadata("X-AuthLevel", to_string(auth));
|
||||
|
||||
status_code = "200";
|
||||
status_text = "";
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
#pragma once
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace lokimq {
|
||||
|
||||
|
@ -21,13 +24,32 @@ struct Access {
|
|||
bool remote_sn = false;
|
||||
/// If true the category requires that the local node is a SN
|
||||
bool local_sn = false;
|
||||
|
||||
/// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an
|
||||
/// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both
|
||||
/// remote and local sn set to false).
|
||||
Access(AuthLevel auth, bool remote_sn = false, bool local_sn = false)
|
||||
: auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {}
|
||||
};
|
||||
|
||||
/// Return type of the AllowFunc: this determines whether we allow the connection at all, and if so,
|
||||
/// sets the initial authentication level and tells LokiMQ whether the other end is an active SN.
|
||||
struct Allow {
|
||||
AuthLevel auth = AuthLevel::none;
|
||||
bool remote_sn = false;
|
||||
/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey).
|
||||
/// Falls back to std::hash<std::string> if given a string smaller than a size_t.
|
||||
struct already_hashed {
|
||||
size_t operator()(const std::string& s) const {
|
||||
if (s.size() < sizeof(size_t))
|
||||
return std::hash<std::string>{}(s);
|
||||
size_t hash;
|
||||
std::memcpy(&hash, &s[0], sizeof(hash));
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
/// std::unordered_set specialization for specifying pubkeys (used, in particular, by
|
||||
/// LokiMQ::set_active_sns and LokiMQ::update_active_sns); this is a std::string unordered_set that
|
||||
/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the
|
||||
/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like
|
||||
/// pubkeys and a terrible hash choice for anything else).
|
||||
using pubkey_set = std::unordered_set<std::string, already_hashed>;
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -53,6 +53,47 @@ void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pub
|
|||
// else let ZMQ pick a random one
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call connect_sn() before calling `start()`");
|
||||
|
||||
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
|
||||
|
||||
return pubkey;
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot call connect_remote() before calling `start()`");
|
||||
|
||||
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
|
||||
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
|
||||
|
||||
auto id = next_conn_id++;
|
||||
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
|
||||
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
|
||||
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
|
||||
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
|
||||
{"conn_id", id},
|
||||
{"connect", detail::serialize_object(std::move(on_connect))},
|
||||
{"failure", detail::serialize_object(std::move(on_failure))},
|
||||
{"pubkey", pubkey},
|
||||
{"remote", remote},
|
||||
{"timeout", timeout.count()},
|
||||
}));
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
||||
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
|
||||
{"conn_id", id.id},
|
||||
{"linger_ms", linger.count()},
|
||||
{"pubkey", id.pk},
|
||||
}));
|
||||
}
|
||||
|
||||
std::pair<zmq::socket_t *, std::string>
|
||||
LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) {
|
||||
ConnectionID remote_cid{remote};
|
||||
|
@ -166,9 +207,9 @@ void update_connection_indices(Container& c, size_t index, AccessIndex get_index
|
|||
}
|
||||
}
|
||||
|
||||
/// Closes outgoing connections and removes all references. Note that this will invalidate
|
||||
/// iterators on the various connection containers - if you don't want that, delete it first so that
|
||||
/// the container won't contain the element being deleted.
|
||||
/// Closes outgoing connections and removes all references. Note that this will call `erase()`
|
||||
/// which can invalidate iterators on the various connection containers - if you don't want that,
|
||||
/// delete it first so that the container won't contain the element being deleted.
|
||||
void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) {
|
||||
connections[index].setsockopt<int>(ZMQ_LINGER, linger > 0ms ? linger.count() : 0);
|
||||
pollitems_stale = true;
|
||||
|
@ -197,6 +238,7 @@ void LokiMQ::proxy_expire_idle_peers() {
|
|||
continue;
|
||||
}
|
||||
LMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle timeout reached");
|
||||
++it; // The below is going to delete our current element
|
||||
proxy_close_connection(info.conn_index, CLOSE_LINGER);
|
||||
} else {
|
||||
++it;
|
||||
|
@ -238,7 +280,7 @@ void LokiMQ::proxy_conn_cleanup() {
|
|||
auto& callback = it->second;
|
||||
if (callback.first < now) {
|
||||
LMQ_LOG(debug, "pending request ", to_hex(it->first), " expired, invoking callback with failure status and removing");
|
||||
job([callback = std::move(callback.second)] { callback(false, {}); });
|
||||
job([callback = std::move(callback.second)] { callback(false, {{"TIMEOUT"s}}); });
|
||||
it = pending_requests.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
|
@ -262,14 +304,10 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) {
|
|||
if (data.skip_until("conn_id"))
|
||||
conn_id = data.consume_integer<long long>();
|
||||
if (data.skip_until("connect")) {
|
||||
auto* ptr = reinterpret_cast<ConnectSuccess*>(data.consume_integer<uintptr_t>());
|
||||
on_connect = std::move(*ptr);
|
||||
delete ptr;
|
||||
on_connect = detail::deserialize_object<ConnectSuccess>(data.consume_integer<uintptr_t>());
|
||||
}
|
||||
if (data.skip_until("failure")) {
|
||||
auto* ptr = reinterpret_cast<ConnectFailure*>(data.consume_integer<uintptr_t>());
|
||||
on_failure = std::move(*ptr);
|
||||
delete ptr;
|
||||
on_failure = detail::deserialize_object<ConnectFailure>(data.consume_integer<uintptr_t>());
|
||||
}
|
||||
if (data.skip_until("pubkey")) {
|
||||
remote_pubkey = data.consume_string();
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#pragma once
|
||||
#include "auth.h"
|
||||
#include "string_view.h"
|
||||
|
||||
namespace lokimq {
|
||||
|
@ -75,7 +76,7 @@ private:
|
|||
namespace std {
|
||||
template <> struct hash<lokimq::ConnectionID> {
|
||||
size_t operator()(const lokimq::ConnectionID &c) const {
|
||||
return c.sn() ? std::hash<std::string>{}(c.pk) :
|
||||
return c.sn() ? lokimq::already_hashed{}(c.pk) :
|
||||
std::hash<long long>{}(c.id);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -27,7 +27,7 @@ extern "C" inline void message_buffer_destroy(void*, void* hint) {
|
|||
}
|
||||
|
||||
/// Creates a message without needing to reallocate the provided string data
|
||||
inline zmq::message_t create_message(std::string &&data) {
|
||||
inline zmq::message_t create_message(std::string&& data) {
|
||||
auto *buffer = new std::string(std::move(data));
|
||||
return zmq::message_t{&(*buffer)[0], buffer->size(), message_buffer_destroy, buffer};
|
||||
};
|
||||
|
@ -38,7 +38,7 @@ inline zmq::message_t create_message(string_view data) {
|
|||
}
|
||||
|
||||
template <typename It>
|
||||
bool send_message_parts(zmq::socket_t &sock, It begin, It end) {
|
||||
bool send_message_parts(zmq::socket_t& sock, It begin, It end) {
|
||||
while (begin != end) {
|
||||
zmq::message_t &msg = *begin++;
|
||||
if (!sock.send(msg, begin == end ? zmq::send_flags::dontwait : zmq::send_flags::dontwait | zmq::send_flags::sndmore))
|
||||
|
@ -48,7 +48,7 @@ bool send_message_parts(zmq::socket_t &sock, It begin, It end) {
|
|||
}
|
||||
|
||||
template <typename Container>
|
||||
bool send_message_parts(zmq::socket_t &sock, Container &&c) {
|
||||
bool send_message_parts(zmq::socket_t& sock, Container&& c) {
|
||||
return send_message_parts(sock, c.begin(), c.end());
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ bool send_message_parts(zmq::socket_t &sock, Container &&c) {
|
|||
/// the msg frame will be an empty message; if `data` is empty then the data frame will be omitted.
|
||||
/// `flags` is passed through to zmq: typically given `zmq::send_flags::dontwait` to throw rather
|
||||
/// than block if a message can't be queued.
|
||||
inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}) {
|
||||
inline bool send_routed_message(zmq::socket_t& socket, std::string route, std::string msg = {}, std::string data = {}) {
|
||||
assert(!route.empty());
|
||||
std::array<zmq::message_t, 3> msgs{{create_message(std::move(route))}};
|
||||
if (!msg.empty())
|
||||
|
@ -68,7 +68,7 @@ inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::s
|
|||
|
||||
// Sends some stuff to a socket directly. If dontwait is true then we throw instead of blocking if
|
||||
// the message cannot be accepted by zmq (i.e. because the outgoing buffer is full).
|
||||
inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}) {
|
||||
inline bool send_direct_message(zmq::socket_t& socket, std::string msg, std::string data = {}) {
|
||||
std::array<zmq::message_t, 2> msgs{{create_message(std::move(msg))}};
|
||||
if (!data.empty())
|
||||
msgs[1] = create_message(std::move(data));
|
||||
|
@ -77,7 +77,7 @@ inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::str
|
|||
|
||||
// Receive all the parts of a single message from the given socket. Returns true if a message was
|
||||
// received, false if called with flags=zmq::recv_flags::dontwait and no message was available.
|
||||
inline bool recv_message_parts(zmq::socket_t &sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
|
||||
inline bool recv_message_parts(zmq::socket_t& sock, std::vector<zmq::message_t>& parts, const zmq::recv_flags flags = zmq::recv_flags::none) {
|
||||
do {
|
||||
zmq::message_t msg;
|
||||
if (!sock.recv(msg, flags))
|
||||
|
|
|
@ -26,11 +26,6 @@ std::vector<std::string> as_strings(const MessageContainer& msgs) {
|
|||
return result;
|
||||
}
|
||||
|
||||
void check_started(const std::thread& proxy_thread, const std::string &verb) {
|
||||
if (!proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot " + verb + " before calling `start()`");
|
||||
}
|
||||
|
||||
void check_not_started(const std::thread& proxy_thread, const std::string &verb) {
|
||||
if (proxy_thread.joinable())
|
||||
throw std::logic_error("Cannot " + verb + " after calling `start()`");
|
||||
|
@ -54,29 +49,20 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
|
||||
/// socket.
|
||||
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
||||
auto result = std::make_tuple(""s, false, AuthLevel::none);
|
||||
/// Extracts a pubkey and and auth level from a zmq message received on a *listening* socket.
|
||||
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg) {
|
||||
auto result = std::make_pair(""s, AuthLevel::none);
|
||||
try {
|
||||
string_view pubkey_hex{msg.gets("User-Id")};
|
||||
if (pubkey_hex.size() != 64)
|
||||
throw std::logic_error("bad user-id");
|
||||
assert(is_hex(pubkey_hex.begin(), pubkey_hex.end()));
|
||||
auto& pubkey = std::get<std::string>(result);
|
||||
pubkey.resize(32, 0);
|
||||
from_hex(pubkey_hex.begin(), pubkey_hex.end(), pubkey.begin());
|
||||
result.first.resize(32, 0);
|
||||
from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin());
|
||||
} catch (...) {}
|
||||
|
||||
try {
|
||||
string_view is_sn{msg.gets("X-SN")};
|
||||
if (is_sn.size() == 1 && is_sn[0] == '1')
|
||||
std::get<bool>(result) = true;
|
||||
} catch (...) {}
|
||||
|
||||
try {
|
||||
string_view auth_level{msg.gets("X-AuthLevel")};
|
||||
std::get<AuthLevel>(result) = auth_from_string(auth_level);
|
||||
result.second = auth_from_string(msg.gets("X-AuthLevel"));
|
||||
} catch (...) {}
|
||||
|
||||
return result;
|
||||
|
@ -385,46 +371,6 @@ LokiMQ::~LokiMQ() {
|
|||
LMQ_LOG(info, "LokiMQ proxy thread has stopped");
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) {
|
||||
check_started(proxy_thread, "connect");
|
||||
|
||||
detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize<bt_dict>({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}}));
|
||||
|
||||
return pubkey;
|
||||
}
|
||||
|
||||
ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure,
|
||||
string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) {
|
||||
if (!proxy_thread.joinable())
|
||||
LMQ_LOG(warn, "connect_remote() called before start(); this won't take effect until start() is called");
|
||||
|
||||
if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */))
|
||||
throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string");
|
||||
|
||||
auto id = next_conn_id++;
|
||||
LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id,
|
||||
pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]");
|
||||
detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize<bt_dict>({
|
||||
{"auth_level", static_cast<std::underlying_type_t<AuthLevel>>(auth_level)},
|
||||
{"conn_id", id},
|
||||
{"connect", reinterpret_cast<uintptr_t>(new ConnectSuccess{std::move(on_connect)})},
|
||||
{"failure", reinterpret_cast<uintptr_t>(new ConnectFailure{std::move(on_failure)})},
|
||||
{"pubkey", pubkey},
|
||||
{"remote", remote},
|
||||
{"timeout", timeout.count()},
|
||||
}));
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) {
|
||||
detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize<bt_dict>({
|
||||
{"conn_id", id.id},
|
||||
{"linger_ms", linger.count()},
|
||||
{"pubkey", id.pk},
|
||||
}));
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, LogLevel lvl) {
|
||||
os << (lvl == LogLevel::trace ? "trace" :
|
||||
lvl == LogLevel::debug ? "debug" :
|
||||
|
|
110
lokimq/lokimq.h
110
lokimq/lokimq.h
|
@ -134,16 +134,18 @@ private:
|
|||
public:
|
||||
|
||||
/// Callback type invoked to determine whether the given new incoming connection is allowed to
|
||||
/// connect to us and to set its initial authentication level.
|
||||
/// connect to us and to set its authentication level.
|
||||
///
|
||||
/// @param ip - the ip address of the incoming connection
|
||||
/// @param pubkey - the x25519 pubkey of the connecting client (32 byte string). Note that this
|
||||
/// will only be non-empty for incoming connections on `listen_curve` sockets; `listen_plain`
|
||||
/// sockets do not have a pubkey.
|
||||
/// @param service_node - will be true if the `pubkey` is in the set of known active service
|
||||
/// nodes.
|
||||
///
|
||||
/// @returns an `AuthLevel` enum value indicating the default auth level for the incoming
|
||||
/// connection, or AuthLevel::denied if the connection should be refused.
|
||||
using AllowFunc = std::function<Allow(string_view ip, string_view pubkey)>;
|
||||
using AllowFunc = std::function<AuthLevel(string_view ip, string_view pubkey, bool service_node)>;
|
||||
|
||||
/// Callback that is invoked when we need to send a "strong" message to a SN that we aren't
|
||||
/// already connected to and need to establish a connection. This callback returns the ZMQ
|
||||
|
@ -244,7 +246,9 @@ private:
|
|||
std::vector<std::pair<std::string, bind_data>> bind;
|
||||
|
||||
/// Info about a peer's established connection with us. Note that "established" means both
|
||||
/// connected and authenticated.
|
||||
/// connected and authenticated. Note that we only store peer info data for SN connections (in
|
||||
/// or out), and outgoing non-SN connections. Incoming non-SN connections are handled on the
|
||||
/// fly.
|
||||
struct peer_info {
|
||||
/// Pubkey of the remote, if this connection is a curve25519 connection; empty otherwise.
|
||||
std::string pubkey;
|
||||
|
@ -521,15 +525,26 @@ private:
|
|||
/// done).
|
||||
std::unordered_map<std::string, std::string> command_aliases;
|
||||
|
||||
using cat_call_t = std::pair<category*, const std::pair<CommandCallback, bool>*>;
|
||||
/// Retrieve category and callback from a command name, including alias mapping. Warns on
|
||||
/// invalid commands and returns nullptrs. The command name will be updated in place if it is
|
||||
/// aliased to another command.
|
||||
std::pair<category*, const std::pair<CommandCallback, bool>*> get_command(std::string& command);
|
||||
cat_call_t get_command(std::string& command);
|
||||
|
||||
/// Checks a peer's authentication level. Returns true if allowed, warns and returns false if
|
||||
/// not.
|
||||
bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer,
|
||||
const std::string& command, const category& cat, zmq::message_t& msg);
|
||||
zmq::message_t& command, const cat_call_t& cat_call, std::vector<zmq::message_t>& data);
|
||||
|
||||
/// Set of active service nodes.
|
||||
pubkey_set active_service_nodes;
|
||||
|
||||
/// Resets or updates the stored set of active SN pubkeys
|
||||
void proxy_set_active_sns(string_view data);
|
||||
void proxy_set_active_sns(pubkey_set pubkeys);
|
||||
void proxy_update_active_sns(bt_list_consumer data);
|
||||
void proxy_update_active_sns(pubkey_set added, pubkey_set removed);
|
||||
void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed);
|
||||
|
||||
/// Details for a pending command; such a command already has authenticated access and is just
|
||||
/// waiting for a thread to become available to handle it.
|
||||
|
@ -775,8 +790,15 @@ public:
|
|||
* Finish starting up: binds to the bind locations given in the constructor and launches the
|
||||
* proxy thread to handle message dispatching between remote nodes and worker threads.
|
||||
*
|
||||
* You will need to call `add_category` and `add_command` to register commands before calling
|
||||
* `start()`; once start() is called commands cannot be changed.
|
||||
* Things you want to do before calling this:
|
||||
* - Use `add_category`/`add_command` to set up any commands remote connections can invoke.
|
||||
* - If any commands require SN authentication, specify a list of currently active service node
|
||||
* pubkeys via `set_active_sns()` (and make sure this gets updated when things change by
|
||||
* another `set_active_sns()` or a `update_active_sns()` call). It *is* possible to make the
|
||||
* initial call after calling `start()`, but that creates a window during which incoming
|
||||
* remote SN connections will be erroneously treated as non-SN connections.
|
||||
* - If this LMQ instance should accept incoming connections, set up any listening ports via
|
||||
* `listen_curve()` and/or `listen_plain()`.
|
||||
*/
|
||||
void start();
|
||||
|
||||
|
@ -789,10 +811,10 @@ public:
|
|||
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
|
||||
*
|
||||
* @param allow_connection function to call to determine whether to allow the connection and, if
|
||||
* so, the authentication level it receives. If omitted the default returns non-service node,
|
||||
* AuthLevel::none access.
|
||||
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
|
||||
* access.
|
||||
*/
|
||||
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; });
|
||||
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
|
||||
|
||||
/** Start listening on the given bind address in unauthenticated plain text mode. Incoming
|
||||
* connections can come from anywhere. `allow_connection` is invoked for any incoming
|
||||
|
@ -803,10 +825,10 @@ public:
|
|||
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
|
||||
*
|
||||
* @param allow_connection function to call to determine whether to allow the connection and, if
|
||||
* so, the authentication level it receives. If omitted the default returns non-service node,
|
||||
* AuthLevel::none access.
|
||||
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
|
||||
* access.
|
||||
*/
|
||||
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; });
|
||||
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
|
||||
|
||||
/**
|
||||
* Try to initiate a connection to the given SN in anticipation of needing a connection in the
|
||||
|
@ -939,9 +961,23 @@ public:
|
|||
* @param to - the pubkey string or ConnectionID to send this request to
|
||||
* @param cmd - the command name
|
||||
* @param callback - the callback to invoke when we get a reply. Called with a true value and
|
||||
* the data strings when a reply is received, or false and an empty vector of data parts if we
|
||||
* get no reply in the timeout interval.
|
||||
* the data strings when a reply is received, or false with error string(s) indicating the
|
||||
* failure reason upon failure or timeout.
|
||||
* @param opts - anything else (i.e. strings, send_options) is forwarded to send().
|
||||
*
|
||||
* Possible error data values:
|
||||
* - ["TIMEOUT"] - we got no reply within the timeout window
|
||||
* - ["UNKNOWNCOMMAND"] - the remote did not recognize the given request command
|
||||
* - ["NO_REPLY_TAG"] - the invoked command is a request command but no reply tag was included
|
||||
* - ["FORBIDDEN"] - the command requires an authorization level (e.g. Basic or Admin) that we
|
||||
* do not have.
|
||||
* - ["FORBIDDEN_SN"] - the command requires service node authentication, but the remote did not
|
||||
* recognize us as a service node. You *may* want to retry the request a limited number of
|
||||
* times (but do not retry indefinitely as that can be an infinite loop!) because this is
|
||||
* typically also followed by a disconnection; a retried message would reconnect and
|
||||
* reauthenticate which *may* result in picking up the SN authentication.
|
||||
* - ["NOT_A_SERVICE_NODE"] - this command is only invokable on service nodes, and the remote is
|
||||
* not running as a service node.
|
||||
*/
|
||||
template <typename... T>
|
||||
void request(ConnectionID to, string_view cmd, ReplyCallback callback, const T&... opts);
|
||||
|
@ -951,6 +987,41 @@ public:
|
|||
const std::string& get_pubkey() const { return pubkey; }
|
||||
const std::string& get_privkey() const { return privkey; }
|
||||
|
||||
/** Updates (or initially sets) LokiMQ's list of service node pubkeys with the given list.
|
||||
*
|
||||
* This has two main effects:
|
||||
*
|
||||
* - All commands processed after the update will have SN status determined by the new list.
|
||||
* - All outgoing connections to service nodes that are no longer on the list will be closed.
|
||||
* This includes both explicit connections (established by `connect_sn()`) and implicit ones
|
||||
* (established by sending to a SN that wasn't connected).
|
||||
*
|
||||
* As this update is potentially quite heavy it is recommended that this be called only when
|
||||
* necessary--i.e. when the list has changed (or potentially changed), but *not* on a short
|
||||
* periodic timer.
|
||||
*
|
||||
* This method may (and should!) be called before start() to load an initial set of SNs.
|
||||
*
|
||||
* Once a full list has been set, updates on changes can either call this again with the new
|
||||
* list, or use the more efficient update_active_sns() call if incremental results are
|
||||
* available.
|
||||
*/
|
||||
void set_active_sns(pubkey_set pubkeys);
|
||||
|
||||
/** Updates the list of active pubkeys by adding or removing the given pubkeys from the existing
|
||||
* list. This is more efficient when the incremental information is already available; if it
|
||||
* isn't, simply call set_active_sns with a new list to have LokiMQ figure out what was added or
|
||||
* removed.
|
||||
*
|
||||
* \param added new pubkeys that were added since the last set_active_sns or update_active_sns
|
||||
* call.
|
||||
*
|
||||
* \param removed pubkeys that were removed from active SN status since the last call. If a
|
||||
* pubkey is in both `added` and `removed` for some reason then its presence in `removed` will
|
||||
* be ignored.
|
||||
*/
|
||||
void update_active_sns(pubkey_set added, pubkey_set removed);
|
||||
|
||||
/**
|
||||
* Batches a set of jobs to be executed by workers, optionally followed by a completion function.
|
||||
*
|
||||
|
@ -1121,7 +1192,9 @@ namespace detail {
|
|||
/// containing the pointer to be serialized to pass (via lokimq queues) from one thread to another.
|
||||
/// Must be matched with a deserializer_pointer on the other side to reconstitute the object and
|
||||
/// destroy the intermediate pointer.
|
||||
template <typename T> uintptr_t serialize_object(T&& obj) {
|
||||
template <typename T>
|
||||
uintptr_t serialize_object(T&& obj) {
|
||||
static_assert(std::is_rvalue_reference<decltype(obj)>::value, "serialize_object must be given an rvalue reference");
|
||||
auto* ptr = new T{std::forward<T>(obj)};
|
||||
return reinterpret_cast<uintptr_t>(ptr);
|
||||
}
|
||||
|
@ -1192,9 +1265,8 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queu
|
|||
control_data["send_full_q"] = serialize_object(std::move(f.callback));
|
||||
}
|
||||
|
||||
/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening*
|
||||
/// socket.
|
||||
std::tuple<std::string, bool, AuthLevel> extract_metadata(zmq::message_t& msg);
|
||||
/// Extracts a pubkey and auth level from a zmq message received on a *listening* socket.
|
||||
std::pair<std::string, AuthLevel> extract_metadata(zmq::message_t& msg);
|
||||
|
||||
template <typename... T>
|
||||
bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) {
|
||||
|
|
100
lokimq/proxy.cpp
100
lokimq/proxy.cpp
|
@ -178,7 +178,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) {
|
|||
std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }});
|
||||
} else {
|
||||
LMQ_LOG(debug, "Could not send request, scheduling request callback failure");
|
||||
job([callback = std::move(request_callback)] { callback(false, {}); });
|
||||
job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); });
|
||||
}
|
||||
}
|
||||
if (!sent) {
|
||||
|
@ -237,31 +237,38 @@ void LokiMQ::proxy_reply(bt_dict_consumer data) {
|
|||
}
|
||||
|
||||
void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
|
||||
// We throw an uncaught exception here because we only generate control messages internally in
|
||||
// lokimq code: if one of these condition fail it's a lokimq bug.
|
||||
if (parts.size() < 2)
|
||||
throw std::logic_error("Expected 2-3 message parts for a proxy control message");
|
||||
throw std::logic_error("LokiMQ bug: Expected 2-3 message parts for a proxy control message");
|
||||
auto route = view(parts[0]), cmd = view(parts[1]);
|
||||
LMQ_TRACE("control message: ", cmd);
|
||||
if (parts.size() == 3) {
|
||||
LMQ_TRACE("...: ", parts[2]);
|
||||
auto data = view(parts[2]);
|
||||
if (cmd == "SEND") {
|
||||
LMQ_TRACE("proxying message");
|
||||
return proxy_send(view(parts[2]));
|
||||
return proxy_send(data);
|
||||
} else if (cmd == "REPLY") {
|
||||
LMQ_TRACE("proxying reply to non-SN incoming message");
|
||||
return proxy_reply(view(parts[2]));
|
||||
return proxy_reply(data);
|
||||
} else if (cmd == "BATCH") {
|
||||
LMQ_TRACE("proxy batch jobs");
|
||||
auto ptrval = bt_deserialize<uintptr_t>(view(parts[2]));
|
||||
auto ptrval = bt_deserialize<uintptr_t>(data);
|
||||
return proxy_batch(reinterpret_cast<detail::Batch*>(ptrval));
|
||||
} else if (cmd == "SET_SNS") {
|
||||
return proxy_set_active_sns(data);
|
||||
} else if (cmd == "UPDATE_SNS") {
|
||||
return proxy_update_active_sns(data);
|
||||
} else if (cmd == "CONNECT_SN") {
|
||||
proxy_connect_sn(view(parts[2]));
|
||||
proxy_connect_sn(data);
|
||||
return;
|
||||
} else if (cmd == "CONNECT_REMOTE") {
|
||||
return proxy_connect_remote(view(parts[2]));
|
||||
return proxy_connect_remote(data);
|
||||
} else if (cmd == "DISCONNECT") {
|
||||
return proxy_disconnect(view(parts[2]));
|
||||
return proxy_disconnect(data);
|
||||
} else if (cmd == "TIMER") {
|
||||
return proxy_timer(view(parts[2]));
|
||||
return proxy_timer(data);
|
||||
}
|
||||
} else if (parts.size() == 2) {
|
||||
if (cmd == "START") {
|
||||
|
@ -279,8 +286,8 @@ void LokiMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
|
|||
return;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Proxy received invalid control command: " + std::string{cmd} +
|
||||
" (" + std::to_string(parts.size()) + ")");
|
||||
throw std::runtime_error("LokiMQ bug: Proxy received invalid control command: " +
|
||||
std::string{cmd} + " (" + std::to_string(parts.size()) + ")");
|
||||
}
|
||||
|
||||
void LokiMQ::proxy_loop() {
|
||||
|
@ -455,26 +462,31 @@ void LokiMQ::proxy_loop() {
|
|||
}
|
||||
}
|
||||
|
||||
static bool is_error_response(string_view cmd) {
|
||||
return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG";
|
||||
}
|
||||
|
||||
// Return true if we recognized/handled the builtin command (even if we reject it for whatever
|
||||
// reason)
|
||||
bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts) {
|
||||
bool outgoing = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_DEALER;
|
||||
// Doubling as a bool and an offset:
|
||||
size_t incoming = connections[conn_index].getsockopt<int>(ZMQ_TYPE) == ZMQ_ROUTER;
|
||||
|
||||
string_view route, cmd;
|
||||
if (parts.size() < (outgoing ? 1 : 2)) {
|
||||
if (parts.size() < 1 + incoming) {
|
||||
LMQ_LOG(warn, "Received empty message; ignoring");
|
||||
return true;
|
||||
}
|
||||
if (outgoing) {
|
||||
cmd = view(parts[0]);
|
||||
} else {
|
||||
if (incoming) {
|
||||
route = view(parts[0]);
|
||||
cmd = view(parts[1]);
|
||||
} else {
|
||||
cmd = view(parts[0]);
|
||||
}
|
||||
LMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back()));
|
||||
|
||||
if (cmd == "REPLY") {
|
||||
size_t tag_pos = (outgoing ? 1 : 2);
|
||||
size_t tag_pos = 1 + incoming;
|
||||
if (parts.size() <= tag_pos) {
|
||||
LMQ_LOG(warn, "Received REPLY without a reply tag; ignoring");
|
||||
return true;
|
||||
|
@ -482,7 +494,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
|
|||
std::string reply_tag{view(parts[tag_pos])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
LMQ_LOG(debug, "Received REPLY for pending command", to_hex(reply_tag), "; scheduling callback");
|
||||
LMQ_LOG(debug, "Received REPLY for pending command ", to_hex(reply_tag), "; scheduling callback");
|
||||
std::vector<std::string> data;
|
||||
data.reserve(parts.size() - (tag_pos + 1));
|
||||
for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it)
|
||||
|
@ -496,7 +508,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
|
|||
}
|
||||
return true;
|
||||
} else if (cmd == "HI") {
|
||||
if (outgoing) {
|
||||
if (!incoming) {
|
||||
LMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
|
@ -506,7 +518,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
|
|||
} catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
|
||||
return true;
|
||||
} else if (cmd == "HELLO") {
|
||||
if (!outgoing) {
|
||||
if (incoming) {
|
||||
LMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring");
|
||||
return true;
|
||||
}
|
||||
|
@ -531,22 +543,50 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
|
|||
pending_connects.erase(it);
|
||||
return true;
|
||||
} else if (cmd == "BYE") {
|
||||
if (outgoing) {
|
||||
std::string pk;
|
||||
bool sn;
|
||||
AuthLevel a;
|
||||
std::tie(pk, sn, a) = detail::extract_metadata(parts.back());
|
||||
ConnectionID conn = sn ? ConnectionID{std::move(pk)} : conn_index_to_id[conn_index];
|
||||
LMQ_LOG(debug, "BYE command received; disconnecting from ", conn);
|
||||
proxy_disconnect(conn, 1s);
|
||||
if (!incoming) {
|
||||
LMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
|
||||
proxy_close_connection(conn_index, 0s);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
else if (cmd == "FORBIDDEN" || cmd == "NOT_A_SERVICE_NODE") {
|
||||
return true; // FIXME - ignore these? Log?
|
||||
else if (is_error_response(cmd)) {
|
||||
// These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to
|
||||
// invoke something that doesn't exist or we don't have permission to access. These have
|
||||
// two forms (the latter is only sent by remotes running 1.0.6+).
|
||||
// - ["XXX", "whatever.command"]
|
||||
// - ["XXX", "REPLY", replytag]
|
||||
// (ignoring the routing prefix on incoming commands).
|
||||
// For the former, we log; for the latter we trigger the reply callback with a failure
|
||||
|
||||
if (parts.size() == (2 + incoming) && is_error_response(view(parts[1 + incoming]))) {
|
||||
// Something like ["UNKNOWNCOMMAND", "FORBIDDEN_SN"] which can happen because the remote
|
||||
// is running an older version that didn't understand the FORBIDDEN_SN (or whatever)
|
||||
// error reply that we sent them. We just ignore it because anything else could trigger
|
||||
// an infinite cycle.
|
||||
LMQ_LOG(debug, "Received [", cmd, ",", view(parts[1 + incoming]), "]; remote is probably an older lokimq. Ignoring.");
|
||||
return true;
|
||||
}
|
||||
|
||||
if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") {
|
||||
std::string reply_tag{view(parts[2 + incoming])};
|
||||
auto it = pending_requests.find(reply_tag);
|
||||
if (it != pending_requests.end()) {
|
||||
LMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", to_hex(reply_tag), "; scheduling failure callback");
|
||||
proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] {
|
||||
callback(false, {{std::move(cmd)}});
|
||||
});
|
||||
pending_requests.erase(it);
|
||||
} else {
|
||||
LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring");
|
||||
}
|
||||
} else {
|
||||
LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"_sv),
|
||||
" from ", peer_address(parts.back()));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -201,22 +201,24 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
|
|||
}
|
||||
peer = &it->second;
|
||||
} else {
|
||||
std::tie(tmp_peer.pubkey, tmp_peer.service_node, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
|
||||
std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back());
|
||||
tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey);
|
||||
|
||||
if (tmp_peer.service_node) {
|
||||
// It's a service node so we should have a peer_info entry; see if we can find one with
|
||||
// the same route, and if not, add one.
|
||||
auto pr = peers.equal_range(tmp_peer.pubkey);
|
||||
for (auto it = pr.first; it != pr.second; ++it) {
|
||||
if (it->second.route == tmp_peer.route) {
|
||||
if (it->second.conn_index == tmp_peer.conn_index && it->second.route == tmp_peer.route) {
|
||||
peer = &it->second;
|
||||
// Upgrade permissions in case we have something higher on the socket
|
||||
peer->service_node |= tmp_peer.service_node;
|
||||
if (tmp_peer.auth_level > peer->auth_level)
|
||||
peer->auth_level = tmp_peer.auth_level;
|
||||
// Update the stored auth level just in case the peer reconnected
|
||||
peer->auth_level = tmp_peer.auth_level;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!peer) {
|
||||
// We don't have a record: this is either a new SN connection or a new message on a
|
||||
// connection that recently gained SN status.
|
||||
peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second;
|
||||
}
|
||||
} else {
|
||||
|
@ -228,23 +230,6 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
|
|||
|
||||
size_t command_part_index = outgoing ? 0 : 1;
|
||||
std::string command = parts[command_part_index].to_string();
|
||||
auto cat_call = get_command(command);
|
||||
|
||||
if (!cat_call.first) {
|
||||
LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer->pubkey), "]/", peer_address(parts.back()));
|
||||
try {
|
||||
if (outgoing)
|
||||
send_direct_message(connections[conn_index], "UNKNOWNCOMMAND");
|
||||
else
|
||||
send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND");
|
||||
} catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ }
|
||||
return;
|
||||
}
|
||||
|
||||
auto& category = *cat_call.first;
|
||||
|
||||
if (!proxy_check_auth(conn_index, outgoing, *peer, command, category, parts.back()))
|
||||
return;
|
||||
|
||||
// Steal any data message parts
|
||||
size_t data_part_index = command_part_index + 1;
|
||||
|
@ -253,6 +238,14 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
|
|||
for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it)
|
||||
data_parts.push_back(std::move(*it));
|
||||
|
||||
auto cat_call = get_command(command);
|
||||
|
||||
// Check that command is valid, that we have permission, etc.
|
||||
if (!proxy_check_auth(conn_index, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
|
||||
return;
|
||||
|
||||
auto& category = *cat_call.first;
|
||||
|
||||
if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) {
|
||||
// No free reserved or general spots, try to queue it for later
|
||||
if (category.max_queue >= 0 && category.queued >= category.max_queue) {
|
||||
|
|
|
@ -6,6 +6,7 @@ set(LMQ_TEST_SRC
|
|||
test_batch.cpp
|
||||
test_connect.cpp
|
||||
test_commands.cpp
|
||||
test_failures.cpp
|
||||
test_requests.cpp
|
||||
test_string_view.cpp
|
||||
)
|
||||
|
|
|
@ -6,6 +6,25 @@ using namespace lokimq;
|
|||
|
||||
static auto startup = std::chrono::steady_clock::now();
|
||||
|
||||
/// Waits up to 100ms for something to happen.
|
||||
template <typename Func>
|
||||
inline void wait_for(Func f) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
if (f())
|
||||
break;
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
}
|
||||
|
||||
/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough
|
||||
/// time for an initial connection + request.
|
||||
inline void wait_for_conn(std::atomic<bool> &c) {
|
||||
wait_for([&c] { return c.load(); });
|
||||
}
|
||||
|
||||
/// Waits enough time for us to receive a reply from a localhost remote.
|
||||
inline void reply_sleep() { std::this_thread::sleep_for(10ms); }
|
||||
|
||||
// Catch2 macros aren't thread safe, so guard with a mutex
|
||||
inline std::unique_lock<std::mutex> catch_lock() {
|
||||
static std::mutex mutex;
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "common.h"
|
||||
#include <future>
|
||||
#include <lokimq/hex.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
@ -12,10 +11,10 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos{0}, his{0};
|
||||
|
||||
|
@ -32,41 +31,33 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{
|
||||
get_logger("C» ")
|
||||
};
|
||||
client.log_level(LogLevel::trace);
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.add_category("public", Access{AuthLevel::none});
|
||||
client.add_command("public", "hi", [&](auto&) { his++; });
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> connected{false}, failed{false};
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false, failed = false;
|
||||
std::string pubkey;
|
||||
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { pubkey = conn.pubkey(); connected = true; },
|
||||
[&](auto conn, string_view) { failed = true; },
|
||||
[&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; },
|
||||
[&](auto conn, string_view) { failed = true; got = true; },
|
||||
server.get_pubkey());
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( i <= 1 ); // should be fast
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
client.send(c, "public.hello");
|
||||
client.send(c, "public.client.pubkey");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 1 );
|
||||
|
@ -77,7 +68,7 @@ TEST_CASE("basic commands", "[commands]") {
|
|||
for (int i = 0; i < 50; i++)
|
||||
client.send(c, "public.hello");
|
||||
|
||||
std::this_thread::sleep_for(100ms);
|
||||
wait_for([&] { return his == 26; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hellos == 51 );
|
||||
|
@ -91,10 +82,10 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::atomic<int> hellos{0};
|
||||
|
||||
|
@ -103,10 +94,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{
|
||||
get_logger("C» ")
|
||||
};
|
||||
client.log_level(LogLevel::trace);
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
std::atomic<int> public_hi{0}, basic_hi{0}, admin_hi{0};
|
||||
client.add_category("public", Access{AuthLevel::none});
|
||||
|
@ -124,8 +112,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin);
|
||||
|
||||
client.send(public_c, "public.reflect", "public.hi");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
wait_for([&] { return public_hi == 1; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( public_hi == 1 );
|
||||
|
@ -138,7 +125,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
client.send(admin_c, "public.reflect", "admin.hi");
|
||||
client.send(basic_c, "public.reflect", "basic.hi");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
wait_for([&] { return public_hi == 2; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( admin_hi == 3 );
|
||||
|
@ -160,7 +147,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") {
|
|||
client.send(admin_c, "public.reflect", "basic.hi");
|
||||
client.send(admin_c, "public.reflect", "public.hi");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
wait_for([&] { return public_hi == 3; });
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( admin_hi == 1 );
|
||||
REQUIRE( basic_hi == 2 );
|
||||
|
@ -176,10 +163,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
|
||||
std::vector<std::pair<ConnectionID, std::string>> subscribers;
|
||||
ConnectionID backdoor;
|
||||
|
@ -221,8 +208,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin);
|
||||
nsa.send(nsa_c, "hey google.install backdoor");
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; });
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( backdoor );
|
||||
|
@ -243,9 +229,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
std::map<int, std::set<std::string>> google_knows;
|
||||
int things_remembered{0};
|
||||
for (int i = 0; i < 5; i++) {
|
||||
clients.push_back(std::make_unique<LokiMQ>(get_logger("C" + std::to_string(i) + "» ")));
|
||||
clients.push_back(std::make_unique<LokiMQ>(
|
||||
get_logger("C" + std::to_string(i) + "» "), LogLevel::trace
|
||||
));
|
||||
auto& c = clients.back();
|
||||
c->log_level(LogLevel::trace);
|
||||
c->add_category("personal", Access{AuthLevel::basic});
|
||||
c->add_command("personal", "detail", [&,i](Message& m) {
|
||||
auto l = catch_lock();
|
||||
|
@ -265,7 +252,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
},
|
||||
personal_detail);
|
||||
}
|
||||
std::this_thread::sleep_for(50ms);
|
||||
wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size(); });
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( things_remembered == all_the_things.size() );
|
||||
|
@ -273,7 +260,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
|
|||
}
|
||||
|
||||
clients[0]->send(conns[0], "hey google.recall");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
reply_sleep();
|
||||
{
|
||||
auto l = catch_lock();
|
||||
REQUIRE( google_knows == personal_details );
|
||||
|
@ -286,10 +273,10 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
|||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::debug // This test traces so much that it takes 2.5-3s of CPU time at trace level, so don't do that.
|
||||
};
|
||||
server.log_level(LogLevel::debug);
|
||||
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_plain(listen);
|
||||
|
||||
std::atomic<int> send_attempts{0};
|
||||
std::atomic<int> send_failures{0};
|
||||
|
@ -331,7 +318,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
|||
for (i = 0; i < 20; i++) {
|
||||
if (send_attempts.load() >= 500)
|
||||
break;
|
||||
std::this_thread::sleep_for(25ms);
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
|
@ -354,20 +341,20 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
|
|||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
expected_attempts += 500;
|
||||
if (i >= 4) {
|
||||
std::this_thread::sleep_for(25ms);
|
||||
if (send_failures.load() > 0)
|
||||
break;
|
||||
std::this_thread::sleep_for(25ms);
|
||||
}
|
||||
}
|
||||
|
||||
for (i = 0; i < 10; i++) {
|
||||
for (i = 0; i < 20; i++) {
|
||||
if (send_attempts.load() >= expected_attempts)
|
||||
break;
|
||||
std::this_thread::sleep_for(25ms);
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( i <= 8 );
|
||||
REQUIRE( i < 20 );
|
||||
REQUIRE( send_attempts.load() == expected_attempts );
|
||||
REQUIRE( send_failures.load() > 0 );
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "common.h"
|
||||
#include <future>
|
||||
#include <lokimq/hex.h>
|
||||
extern "C" {
|
||||
#include <sodium.h>
|
||||
|
@ -12,46 +11,42 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") {
|
|||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.log_level(LogLevel::trace);
|
||||
|
||||
server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_curve(listen);
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» ")};
|
||||
client.log_level(LogLevel::trace);
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.start();
|
||||
|
||||
auto pubkey = server.get_pubkey();
|
||||
std::atomic<int> connected{0};
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
auto server_conn = client.connect_remote(listen,
|
||||
[&](auto conn) { connected = 1; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); },
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; },
|
||||
pubkey);
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
success = false;
|
||||
std::vector<std::string> parts;
|
||||
client.request(server_conn, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
||||
|
@ -63,10 +58,16 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
|||
pubkey, privkey,
|
||||
true,
|
||||
[&](auto pk) { if (pk == pubkey) return "tcp://127.0.0.1:5544"; else return ""; },
|
||||
get_logger("S» ")
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
|
||||
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk) { REQUIRE(ip == "127.0.0.1"); return Allow{AuthLevel::none, pk == pubkey}; });
|
||||
sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk, auto sn) {
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(ip == "127.0.0.1");
|
||||
REQUIRE(sn == (pk == pubkey));
|
||||
return AuthLevel::none;
|
||||
});
|
||||
sn.add_category("a", Access{AuthLevel::none});
|
||||
bool invoked = false;
|
||||
sn.add_command("a", "b", [&](const Message& m) {
|
||||
|
@ -77,57 +78,54 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") {
|
|||
REQUIRE(!m.data.empty());
|
||||
REQUIRE(m.data[0] == "my data");
|
||||
});
|
||||
sn.log_level(LogLevel::trace);
|
||||
sn.set_active_sns({{pubkey}});
|
||||
|
||||
sn.start();
|
||||
std::this_thread::sleep_for(50ms);
|
||||
sn.send(pubkey, "a.b", "my data");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(invoked);
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(invoked);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("plain-text connections", "[plaintext][connect]") {
|
||||
std::string listen = "tcp://127.0.0.1:4455";
|
||||
LokiMQ server{get_logger("S» ")};
|
||||
server.log_level(LogLevel::trace);
|
||||
LokiMQ server{get_logger("S» "), LogLevel::trace};
|
||||
|
||||
server.add_category("public", Access{AuthLevel::none});
|
||||
server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); });
|
||||
|
||||
server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; });
|
||||
server.listen_plain(listen);
|
||||
|
||||
server.start();
|
||||
|
||||
LokiMQ client{get_logger("C» ")};
|
||||
client.log_level(LogLevel::trace);
|
||||
LokiMQ client{get_logger("C» "), LogLevel::trace};
|
||||
|
||||
client.start();
|
||||
|
||||
std::atomic<int> connected{0};
|
||||
std::atomic<bool> got{false};
|
||||
bool success = false;
|
||||
auto c = client.connect_remote(listen,
|
||||
[&](auto conn) { connected = 1; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||
[&](auto conn) { success = true; got = true; },
|
||||
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }
|
||||
);
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
success = false;
|
||||
std::vector<std::string> parts;
|
||||
client.request(c, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; });
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( success );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
||||
|
@ -147,28 +145,146 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
|||
lmq.push_back(std::make_unique<LokiMQ>(
|
||||
pubkey[i], privkey[i], true,
|
||||
[conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; },
|
||||
get_logger("S" + std::to_string(i) + "» ")
|
||||
get_logger("S" + std::to_string(i) + "» "),
|
||||
LogLevel::trace
|
||||
));
|
||||
auto& server = *lmq.back();
|
||||
server.log_level(LogLevel::debug);
|
||||
|
||||
server.listen_curve(conn[pubkey[i]], [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, true}; });
|
||||
server.listen_curve(conn[pubkey[i]]);
|
||||
server.add_category("sn", Access{AuthLevel::none, true})
|
||||
.add_command("hi", [&](Message& m) { his++; });
|
||||
server.set_active_sns({pubkey.begin(), pubkey.end()});
|
||||
server.start();
|
||||
}
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
lmq[0]->send(pubkey[1], "sn.hi");
|
||||
lmq[0]->send(pubkey[2], "sn.hi");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
lmq[2]->send(pubkey[0], "sn.hi");
|
||||
lmq[2]->send(pubkey[1], "sn.hi");
|
||||
lmq[1]->send(pubkey[0], "BYE");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
lmq[0]->send(pubkey[2], "sn.hi");
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE(his == 5);
|
||||
}
|
||||
|
||||
TEST_CASE("SN auth checks", "[sandwich][auth]") {
|
||||
// When a remote connects, we check its authentication level; if at the time of connection it
|
||||
// isn't recognized as a SN but tries to invoke a SN command it'll be told to disconnect; if it
|
||||
// tries to send again it should reconnect and reauthenticate. This test is meant to test this
|
||||
// pattern where the reconnection/reauthentication now authenticates it as a SN.
|
||||
std::string listen = "tcp://127.0.0.1:4455";
|
||||
std::string pubkey, privkey;
|
||||
pubkey.resize(crypto_box_PUBLICKEYBYTES);
|
||||
privkey.resize(crypto_box_SECRETKEYBYTES);
|
||||
crypto_box_keypair(reinterpret_cast<unsigned char*>(&pubkey[0]), reinterpret_cast<unsigned char*>(&privkey[0]));
|
||||
LokiMQ server{
|
||||
pubkey, privkey,
|
||||
true, // service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("A» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
|
||||
std::atomic<bool> incoming_is_sn{false};
|
||||
server.listen_curve(listen);
|
||||
server.add_category("public", Access{AuthLevel::none})
|
||||
.add_request_command("hello", [&](Message& m) { m.send_reply("hi"); })
|
||||
.add_request_command("sudo", [&](Message& m) {
|
||||
server.update_active_sns({{m.conn.pubkey()}}, {});
|
||||
m.send_reply("making sandwiches");
|
||||
})
|
||||
.add_request_command("nosudo", [&](Message& m) {
|
||||
// Send the reply *first* because if we do it the other way we'll have just removed
|
||||
// ourselves from the list of SNs and thus would try to open an outbound connection
|
||||
// to deliver it since it's still queued as a message to a SN.
|
||||
m.send_reply("make them yourself");
|
||||
server.update_active_sns({}, {{m.conn.pubkey()}});
|
||||
});
|
||||
server.add_category("sandwich", Access{AuthLevel::none, true})
|
||||
.add_request_command("make", [&](Message& m) { m.send_reply("okay"); });
|
||||
server.start();
|
||||
|
||||
LokiMQ client{
|
||||
"", "", false,
|
||||
[&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; },
|
||||
get_logger("B» "), LogLevel::trace};
|
||||
client.start();
|
||||
|
||||
std::atomic<bool> got{false};
|
||||
bool success;
|
||||
client.request(pubkey, "public.hello", [&](auto success_, auto) { success = success_; got = true; });
|
||||
wait_for_conn(got);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
}
|
||||
|
||||
got = false;
|
||||
using dvec = std::vector<std::string>;
|
||||
dvec data;
|
||||
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
|
||||
success = success_;
|
||||
data = std::move(data_);
|
||||
got = true;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE_FALSE( success );
|
||||
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
|
||||
}
|
||||
|
||||
// Somebody set up us the bomb. Main sudo turn on.
|
||||
got = false;
|
||||
client.request(pubkey, "public.sudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == dvec{{"making sandwiches"}} );
|
||||
}
|
||||
|
||||
got = false;
|
||||
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
|
||||
success = success_;
|
||||
data = std::move(data_);
|
||||
got = true;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == dvec{{"okay"}} );
|
||||
}
|
||||
|
||||
// Take off every 'SUDO', You [not] know what you doing
|
||||
got = false;
|
||||
client.request(pubkey, "public.nosudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; });
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == dvec{{"make them yourself"}} );
|
||||
}
|
||||
|
||||
got = false;
|
||||
client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) {
|
||||
success = success_;
|
||||
data = std::move(data_);
|
||||
got = true;
|
||||
});
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got );
|
||||
REQUIRE_FALSE( success );
|
||||
REQUIRE( data == dvec{{"FORBIDDEN_SN"}} );
|
||||
}
|
||||
}
|
||||
|
|
309
tests/test_failures.cpp
Normal file
309
tests/test_failures.cpp
Normal file
|
@ -0,0 +1,309 @@
|
|||
#include "common.h"
|
||||
#include <lokimq/hex.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
using namespace lokimq;
|
||||
|
||||
TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen);
|
||||
server.start();
|
||||
|
||||
// Use a raw socket here because I want to see the raw commands coming on the wire
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "a.a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen);
|
||||
server.add_category("x", AuthLevel::none)
|
||||
.add_request_command("r", [] (auto& m) { m.send_reply("a"); });
|
||||
server.start();
|
||||
|
||||
// Use a raw socket here because I want to see the raw commands coming on the wire
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none);
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "NO_REPLY_TAG" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "x.r" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none);
|
||||
client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "foo" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen, [](auto, auto, auto) {
|
||||
static int count = 0;
|
||||
++count;
|
||||
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
|
||||
});
|
||||
server.add_category("x", AuthLevel::basic)
|
||||
.add_command("x", [] (auto& m) { m.send_back("a"); });
|
||||
server.add_category("y", AuthLevel::admin)
|
||||
.add_command("x", [] (auto& m) { m.send_back("b"); });
|
||||
server.start();
|
||||
|
||||
zmq::context_t client_ctx;
|
||||
std::array<zmq::socket_t, 3> clients;
|
||||
// Client 0 should get none auth level, client 1 should get basic, client 2 should get admin
|
||||
for (auto& client : clients) {
|
||||
client = {client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& c : clients)
|
||||
c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
clients[0].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
clients[0].recv(resp);
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
for (int i : {1, 2}) {
|
||||
clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "a" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
for (auto& c : clients)
|
||||
c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none);
|
||||
|
||||
for (int i : {0, 1}) {
|
||||
clients[i].recv(resp);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN" );
|
||||
REQUIRE( resp.more() );
|
||||
clients[i].recv(resp);
|
||||
REQUIRE( resp.to_string() == "y.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
clients[2].recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "b" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen, [](auto, auto, auto) {
|
||||
static int count = 0;
|
||||
++count;
|
||||
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
|
||||
});
|
||||
server.add_category("x", Access{AuthLevel::none, false, true})
|
||||
.add_command("x", [] (auto&) {})
|
||||
.add_request_command("r", [] (auto& m) { m.send_reply(); })
|
||||
;
|
||||
server.start();
|
||||
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") {
|
||||
std::string listen = "tcp://127.0.0.1:4567";
|
||||
LokiMQ server{
|
||||
"", "", // generate ephemeral keys
|
||||
false, // not a service node
|
||||
[](auto) { return ""; },
|
||||
get_logger("S» "),
|
||||
LogLevel::trace
|
||||
};
|
||||
server.listen_plain(listen, [](auto, auto, auto) {
|
||||
static int count = 0;
|
||||
++count;
|
||||
return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin;
|
||||
});
|
||||
server.add_category("x", Access{AuthLevel::none, true, false})
|
||||
.add_command("x", [] (auto&) {})
|
||||
.add_request_command("r", [] (auto& m) { m.send_reply(); })
|
||||
;
|
||||
server.start();
|
||||
|
||||
zmq::context_t client_ctx;
|
||||
zmq::socket_t client{client_ctx, zmq::socket_type::dealer};
|
||||
client.connect(listen);
|
||||
// Handshake: we send HI, they reply HELLO.
|
||||
client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none);
|
||||
{
|
||||
zmq::message_t hello;
|
||||
client.recv(hello);
|
||||
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( hello.to_string() == "HELLO" );
|
||||
REQUIRE_FALSE( hello.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none);
|
||||
|
||||
zmq::message_t resp;
|
||||
client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "x.x" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
|
||||
client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore);
|
||||
client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag
|
||||
|
||||
client.recv(resp);
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( resp.to_string() == "FORBIDDEN_SN" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "REPLY" );
|
||||
REQUIRE( resp.more() );
|
||||
client.recv(resp);
|
||||
REQUIRE( resp.to_string() == "xyz123" );
|
||||
REQUIRE_FALSE( resp.more() );
|
||||
}
|
||||
}
|
|
@ -1,5 +1,4 @@
|
|||
#include "common.h"
|
||||
#include <future>
|
||||
#include <lokimq/hex.h>
|
||||
|
||||
using namespace lokimq;
|
||||
|
@ -36,17 +35,11 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
wait_for([&] { return connected || failed; });
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( i <= 1 );
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
}
|
||||
|
||||
|
@ -59,11 +52,13 @@ TEST_CASE("basic requests", "[requests]") {
|
|||
data = std::move(data_);
|
||||
});
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"123"}} );
|
||||
reply_sleep();
|
||||
{
|
||||
auto lock = catch_lock();
|
||||
REQUIRE( got_reply.load() );
|
||||
REQUIRE( success );
|
||||
REQUIRE( data == std::vector<std::string>{{"123"}} );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("request from server to client", "[requests]") {
|
||||
|
@ -161,15 +156,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
[&](auto, auto) { failed = true; },
|
||||
server.get_pubkey());
|
||||
|
||||
int i;
|
||||
for (i = 0; i < 5; i++) {
|
||||
if (connected.load())
|
||||
break;
|
||||
std::this_thread::sleep_for(50ms);
|
||||
}
|
||||
REQUIRE( connected.load() );
|
||||
REQUIRE( !failed.load() );
|
||||
REQUIRE( i <= 1 );
|
||||
wait_for([&] { return connected || failed; });
|
||||
|
||||
REQUIRE( connected );
|
||||
REQUIRE_FALSE( failed );
|
||||
REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) );
|
||||
|
||||
std::atomic<bool> got_triggered{false};
|
||||
|
@ -180,7 +170,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
success = ok;
|
||||
data = std::move(data_);
|
||||
},
|
||||
lokimq::send_option::request_timeout{30ms}
|
||||
lokimq::send_option::request_timeout{20ms}
|
||||
);
|
||||
|
||||
std::atomic<bool> got_triggered2{false};
|
||||
|
@ -192,10 +182,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") {
|
|||
lokimq::send_option::request_timeout{100ms}
|
||||
);
|
||||
|
||||
std::this_thread::sleep_for(50ms);
|
||||
REQUIRE( got_triggered.load() );
|
||||
std::this_thread::sleep_for(30ms);
|
||||
REQUIRE( got_triggered );
|
||||
REQUIRE_FALSE( got_triggered2 );
|
||||
REQUIRE_FALSE( success );
|
||||
REQUIRE( data.size() == 0 );
|
||||
REQUIRE( data == std::vector<std::string>{{"TIMEOUT"}} );
|
||||
|
||||
REQUIRE_FALSE( got_triggered2.load() );
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue