diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f93265..611ed6e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,8 +5,8 @@ project(liblokimq CXX) include(GNUInstallDirs) set(LOKIMQ_VERSION_MAJOR 1) -set(LOKIMQ_VERSION_MINOR 0) -set(LOKIMQ_VERSION_PATCH 5) +set(LOKIMQ_VERSION_MINOR 1) +set(LOKIMQ_VERSION_PATCH 0) set(LOKIMQ_VERSION "${LOKIMQ_VERSION_MAJOR}.${LOKIMQ_VERSION_MINOR}.${LOKIMQ_VERSION_PATCH}") message(STATUS "lokimq v${LOKIMQ_VERSION}") diff --git a/lokimq/auth.cpp b/lokimq/auth.cpp index f4c2bed..1dba1c8 100644 --- a/lokimq/auth.cpp +++ b/lokimq/auth.cpp @@ -30,45 +30,166 @@ std::string zmtp_metadata(string_view key, string_view value) { bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, - const std::string& command, const category& cat, zmq::message_t& msg) { + zmq::message_t& cmd, const cat_call_t& cat_call, std::vector& data) { + auto command = view(cmd); std::string reply; - if (peer.auth_level < cat.access.auth) { - LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg), - ": peer auth level ", peer.auth_level, " < ", cat.access.auth); + + if (!cat_call.first) { + LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer.pubkey), "]/", peer_address(cmd)); + reply = "UNKNOWNCOMMAND"; + } else if (peer.auth_level < cat_call.first->access.auth) { + LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), + ": peer auth level ", peer.auth_level, " < ", cat_call.first->access.auth); reply = "FORBIDDEN"; - } - else if (cat.access.local_sn && !local_service_node) { - LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg), + } else if (cat_call.first->access.local_sn && !local_service_node) { + LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": that command is only available when this LokiMQ is running in service node mode"); reply = "NOT_A_SERVICE_NODE"; - } - else if (cat.access.remote_sn && !peer.service_node) { - LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(msg), + } else if (cat_call.first->access.remote_sn && !peer.service_node) { + LMQ_LOG(warn, "Access denied to ", command, " for peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": remote is not recognized as a service node"); - // Disconnect: we don't think the remote is a SN, but it issued a command only SNs should be - // issuing. Drop the connection; if the remote has something important to relay it will - // reconnect, at which point we will reassess the SN status on the new incoming connection. - if (outgoing) { - proxy_disconnect(peer.service_node ? ConnectionID{peer.pubkey} : conn_index_to_id[conn_index], 1s); - return false; - } - else - reply = "BYE"; + reply = "FORBIDDEN_SN"; + } else if (cat_call.second->second /*is_request*/ && data.empty()) { + LMQ_LOG(warn, "Received an invalid request for '", command, "' with no reply tag from remote [", + to_hex(peer.pubkey), "]/", peer_address(cmd)); + reply = "NO_REPLY_TAG"; + } else { + return true; } - if (reply.empty()) - return true; + std::vector msgs; + msgs.reserve(4); + if (!outgoing) + msgs.push_back(create_message(peer.route)); + msgs.push_back(create_message(reply)); + if (cat_call.second && cat_call.second->second /*request command*/ && !data.empty()) { + msgs.push_back(create_message("REPLY"_sv)); + msgs.push_back(create_message(view(data.front()))); // reply tag + } else { + msgs.push_back(create_message(view(cmd))); + } try { - if (outgoing) - send_direct_message(connections[conn_index], std::move(reply), command); - else - send_routed_message(connections[conn_index], peer.route, std::move(reply), command); - } catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ } + send_message_parts(connections[conn_index], msgs); + } catch (const zmq::error_t& err) { + /* can't send: possibly already disconnected. Ignore. */ + LMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what()); + } return false; } +void LokiMQ::set_active_sns(pubkey_set pubkeys) { + if (proxy_thread.joinable()) { + auto data = bt_serialize(detail::serialize_object(std::move(pubkeys))); + detail::send_control(get_control_socket(), "SET_SNS", data); + } else { + proxy_set_active_sns(std::move(pubkeys)); + } +} +void LokiMQ::proxy_set_active_sns(string_view data) { + proxy_set_active_sns(detail::deserialize_object(bt_deserialize(data))); +} +void LokiMQ::proxy_set_active_sns(pubkey_set pubkeys) { + pubkey_set added, removed; + for (auto it = pubkeys.begin(); it != pubkeys.end(); ) { + auto& pk = *it; + if (pk.size() != 32) { + LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to set_active_sns"); + it = pubkeys.erase(it); + continue; + } + if (!active_service_nodes.count(pk)) + added.insert(std::move(pk)); + ++it; + } + if (added.empty() && active_service_nodes.size() == pubkeys.size()) { + LMQ_LOG(debug, "set_active_sns(): new set of SNs is unchanged, skipping update"); + return; + } + for (const auto& pk : active_service_nodes) { + if (!pubkeys.count(pk)) + removed.insert(pk); + if (active_service_nodes.size() + added.size() - removed.size() == pubkeys.size()) + break; + } + proxy_update_active_sns_clean(std::move(added), std::move(removed)); +} + +void LokiMQ::update_active_sns(pubkey_set added, pubkey_set removed) { + LMQ_LOG(info, "uh, ", added.size()); + if (proxy_thread.joinable()) { + std::array data; + data[0] = detail::serialize_object(std::move(added)); + data[1] = detail::serialize_object(std::move(removed)); + detail::send_control(get_control_socket(), "UPDATE_SNS", bt_serialize(data)); + } else { + proxy_update_active_sns(std::move(added), std::move(removed)); + } +} +void LokiMQ::proxy_update_active_sns(bt_list_consumer data) { + auto added = detail::deserialize_object(data.consume_integer()); + auto remed = detail::deserialize_object(data.consume_integer()); + proxy_update_active_sns(std::move(added), std::move(remed)); +} +void LokiMQ::proxy_update_active_sns(pubkey_set added, pubkey_set removed) { + // We take a caller-provided set of added/removed then filter out any junk (bad pks, conflicting + // values, pubkeys that already(added) or do not(removed) exist), then pass the purified lists + // to the _clean version. + + LMQ_LOG(info, "uh, ", added.size(), ", ", removed.size()); + for (auto it = removed.begin(); it != removed.end(); ) { + const auto& pk = *it; + if (pk.size() != 32) { + LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (removed)"); + it = removed.erase(it); + } else if (!active_service_nodes.count(pk) || added.count(pk) /* added wins if in both */) { + it = removed.erase(it); + } else { + ++it; + } + } + + for (auto it = added.begin(); it != added.end(); ) { + const auto& pk = *it; + if (pk.size() != 32) { + LMQ_LOG(warn, "Invalid private key of length ", pk.size(), " (", to_hex(pk), ") passed to update_active_sns (added)"); + it = added.erase(it); + } else if (active_service_nodes.count(pk)) { + it = added.erase(it); + } else { + ++it; + } + } + + proxy_update_active_sns_clean(std::move(added), std::move(removed)); +} + +void LokiMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed) { + LMQ_LOG(debug, "Updating SN auth status with +", added.size(), "/-", removed.size(), " pubkeys"); + + // For anything we remove we want close the connection to the SN (if outgoing), and remove the + // stored peer_info (incoming or outgoing). + for (const auto& pk : removed) { + ConnectionID c{pk}; + active_service_nodes.erase(pk); + auto range = peers.equal_range(c); + for (auto it = range.first; it != range.second; ) { + bool outgoing = it->second.outgoing(); + size_t conn_index = it->second.conn_index; + it = peers.erase(it); + if (outgoing) { + LMQ_LOG(debug, "Closing outgoing connection to ", c); + proxy_close_connection(conn_index, CLOSE_LINGER); + } + } + } + + // For pubkeys we add there's nothing special to be done beyond adding them to the pubkey set + for (auto& pk : added) + active_service_nodes.insert(std::move(pk)); +} + void LokiMQ::process_zap_requests() { for (std::vector frames; recv_message_parts(zap_auth, frames, zmq::recv_flags::dontwait); frames.clear()) { #ifndef NDEBUG @@ -147,31 +268,32 @@ void LokiMQ::process_zap_requests() { } else { auto ip = view(frames[3]); string_view pubkey; - if (bind[bind_id].second.curve) + bool sn = false; + if (bind[bind_id].second.curve) { pubkey = view(frames[6]); - auto result = bind[bind_id].second.allow(ip, pubkey); - bool sn = result.remote_sn; + sn = active_service_nodes.count(std::string{pubkey}); + } + auto auth = bind[bind_id].second.allow(ip, pubkey, sn); auto& user_id = response_vals[4]; if (bind[bind_id].second.curve) { user_id.reserve(64); to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id)); } - if (result.auth <= AuthLevel::denied || result.auth > AuthLevel::admin) { + if (auth <= AuthLevel::denied || auth > AuthLevel::admin) { LMQ_LOG(info, "Access denied for incoming ", view(frames[5]), (sn ? " service node" : " client"), " connection from ", !user_id.empty() ? user_id + " at " : ""s, ip, - " with initial auth level ", result.auth); + " with initial auth level ", auth); status_code = "400"; status_text = "Access denied"; user_id.clear(); } else { LMQ_LOG(debug, "Accepted incoming ", view(frames[5]), (sn ? " service node" : " client"), - " connection with authentication level ", result.auth, + " connection with authentication level ", auth, " from ", !user_id.empty() ? user_id + " at " : ""s, ip); auto& metadata = response_vals[5]; - metadata += zmtp_metadata("X-SN", result.remote_sn ? "1" : "0"); - metadata += zmtp_metadata("X-AuthLevel", to_string(result.auth)); + metadata += zmtp_metadata("X-AuthLevel", to_string(auth)); status_code = "200"; status_text = ""; diff --git a/lokimq/auth.h b/lokimq/auth.h index f5ee196..c47a5b8 100644 --- a/lokimq/auth.h +++ b/lokimq/auth.h @@ -21,13 +21,12 @@ struct Access { bool remote_sn = false; /// If true the category requires that the local node is a SN bool local_sn = false; -}; -/// Return type of the AllowFunc: this determines whether we allow the connection at all, and if so, -/// sets the initial authentication level and tells LokiMQ whether the other end is an active SN. -struct Allow { - AuthLevel auth = AuthLevel::none; - bool remote_sn = false; + /// Constructor. Intentionally allows implicit conversion from an AuthLevel so that an + /// AuthLevel can be passed anywhere an Access is required (the resulting Access will have both + /// remote and local sn set to false). + Access(AuthLevel auth, bool remote_sn = false, bool local_sn = false) + : auth{auth}, remote_sn{remote_sn}, local_sn{local_sn} {} }; } diff --git a/lokimq/connections.cpp b/lokimq/connections.cpp index c3a9751..2a7f112 100644 --- a/lokimq/connections.cpp +++ b/lokimq/connections.cpp @@ -53,6 +53,47 @@ void LokiMQ::setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pub // else let ZMQ pick a random one } +ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) { + if (!proxy_thread.joinable()) + throw std::logic_error("Cannot call connect_sn() before calling `start()`"); + + detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}})); + + return pubkey; +} + +ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, + string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) { + if (!proxy_thread.joinable()) + throw std::logic_error("Cannot call connect_remote() before calling `start()`"); + + if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */)) + throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string"); + + auto id = next_conn_id++; + LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id, + pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]"); + detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize({ + {"auth_level", static_cast>(auth_level)}, + {"conn_id", id}, + {"connect", detail::serialize_object(std::move(on_connect))}, + {"failure", detail::serialize_object(std::move(on_failure))}, + {"pubkey", pubkey}, + {"remote", remote}, + {"timeout", timeout.count()}, + })); + + return id; +} + +void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) { + detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize({ + {"conn_id", id.id}, + {"linger_ms", linger.count()}, + {"pubkey", id.pk}, + })); +} + std::pair LokiMQ::proxy_connect_sn(string_view remote, string_view connect_hint, bool optional, bool incoming_only, bool outgoing_only, std::chrono::milliseconds keep_alive) { ConnectionID remote_cid{remote}; @@ -166,9 +207,9 @@ void update_connection_indices(Container& c, size_t index, AccessIndex get_index } } -/// Closes outgoing connections and removes all references. Note that this will invalidate -/// iterators on the various connection containers - if you don't want that, delete it first so that -/// the container won't contain the element being deleted. +/// Closes outgoing connections and removes all references. Note that this will call `erase()` +/// which can invalidate iterators on the various connection containers - if you don't want that, +/// delete it first so that the container won't contain the element being deleted. void LokiMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) { connections[index].setsockopt(ZMQ_LINGER, linger > 0ms ? linger.count() : 0); pollitems_stale = true; @@ -197,6 +238,7 @@ void LokiMQ::proxy_expire_idle_peers() { continue; } LMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle timeout reached"); + ++it; // The below is going to delete our current element proxy_close_connection(info.conn_index, CLOSE_LINGER); } else { ++it; @@ -238,7 +280,7 @@ void LokiMQ::proxy_conn_cleanup() { auto& callback = it->second; if (callback.first < now) { LMQ_LOG(debug, "pending request ", to_hex(it->first), " expired, invoking callback with failure status and removing"); - job([callback = std::move(callback.second)] { callback(false, {}); }); + job([callback = std::move(callback.second)] { callback(false, {{"TIMEOUT"s}}); }); it = pending_requests.erase(it); } else { ++it; @@ -262,14 +304,10 @@ void LokiMQ::proxy_connect_remote(bt_dict_consumer data) { if (data.skip_until("conn_id")) conn_id = data.consume_integer(); if (data.skip_until("connect")) { - auto* ptr = reinterpret_cast(data.consume_integer()); - on_connect = std::move(*ptr); - delete ptr; + on_connect = detail::deserialize_object(data.consume_integer()); } if (data.skip_until("failure")) { - auto* ptr = reinterpret_cast(data.consume_integer()); - on_failure = std::move(*ptr); - delete ptr; + on_failure = detail::deserialize_object(data.consume_integer()); } if (data.skip_until("pubkey")) { remote_pubkey = data.consume_string(); diff --git a/lokimq/connections.h b/lokimq/connections.h index d1373e6..680094b 100644 --- a/lokimq/connections.h +++ b/lokimq/connections.h @@ -1,5 +1,6 @@ #pragma once #include "string_view.h" +#include namespace lokimq { @@ -71,11 +72,24 @@ private: friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn); }; +/// Simple hash implementation for a string that is *already* a hash-like value (such as a pubkey). +/// Falls back to std::hash if given a string smaller than a size_t. +struct already_hashed { + size_t operator()(const std::string& s) const { + if (s.size() < sizeof(size_t)) + return std::hash{}(s); + size_t hash; + std::memcpy(&hash, &s[0], sizeof(hash)); + return hash; + } +}; + + } // namespace lokimq namespace std { template <> struct hash { size_t operator()(const lokimq::ConnectionID &c) const { - return c.sn() ? std::hash{}(c.pk) : + return c.sn() ? lokimq::already_hashed{}(c.pk) : std::hash{}(c.id); } }; diff --git a/lokimq/lokimq-internal.h b/lokimq/lokimq-internal.h index 28a598c..a054ea7 100644 --- a/lokimq/lokimq-internal.h +++ b/lokimq/lokimq-internal.h @@ -27,7 +27,7 @@ extern "C" inline void message_buffer_destroy(void*, void* hint) { } /// Creates a message without needing to reallocate the provided string data -inline zmq::message_t create_message(std::string &&data) { +inline zmq::message_t create_message(std::string&& data) { auto *buffer = new std::string(std::move(data)); return zmq::message_t{&(*buffer)[0], buffer->size(), message_buffer_destroy, buffer}; }; @@ -38,7 +38,7 @@ inline zmq::message_t create_message(string_view data) { } template -bool send_message_parts(zmq::socket_t &sock, It begin, It end) { +bool send_message_parts(zmq::socket_t& sock, It begin, It end) { while (begin != end) { zmq::message_t &msg = *begin++; if (!sock.send(msg, begin == end ? zmq::send_flags::dontwait : zmq::send_flags::dontwait | zmq::send_flags::sndmore)) @@ -48,7 +48,7 @@ bool send_message_parts(zmq::socket_t &sock, It begin, It end) { } template -bool send_message_parts(zmq::socket_t &sock, Container &&c) { +bool send_message_parts(zmq::socket_t& sock, Container&& c) { return send_message_parts(sock, c.begin(), c.end()); } @@ -56,7 +56,7 @@ bool send_message_parts(zmq::socket_t &sock, Container &&c) { /// the msg frame will be an empty message; if `data` is empty then the data frame will be omitted. /// `flags` is passed through to zmq: typically given `zmq::send_flags::dontwait` to throw rather /// than block if a message can't be queued. -inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}) { +inline bool send_routed_message(zmq::socket_t& socket, std::string route, std::string msg = {}, std::string data = {}) { assert(!route.empty()); std::array msgs{{create_message(std::move(route))}}; if (!msg.empty()) @@ -68,7 +68,7 @@ inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::s // Sends some stuff to a socket directly. If dontwait is true then we throw instead of blocking if // the message cannot be accepted by zmq (i.e. because the outgoing buffer is full). -inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}) { +inline bool send_direct_message(zmq::socket_t& socket, std::string msg, std::string data = {}) { std::array msgs{{create_message(std::move(msg))}}; if (!data.empty()) msgs[1] = create_message(std::move(data)); @@ -77,7 +77,7 @@ inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::str // Receive all the parts of a single message from the given socket. Returns true if a message was // received, false if called with flags=zmq::recv_flags::dontwait and no message was available. -inline bool recv_message_parts(zmq::socket_t &sock, std::vector& parts, const zmq::recv_flags flags = zmq::recv_flags::none) { +inline bool recv_message_parts(zmq::socket_t& sock, std::vector& parts, const zmq::recv_flags flags = zmq::recv_flags::none) { do { zmq::message_t msg; if (!sock.recv(msg, flags)) diff --git a/lokimq/lokimq.cpp b/lokimq/lokimq.cpp index 101d6d2..930f7c1 100644 --- a/lokimq/lokimq.cpp +++ b/lokimq/lokimq.cpp @@ -26,11 +26,6 @@ std::vector as_strings(const MessageContainer& msgs) { return result; } -void check_started(const std::thread& proxy_thread, const std::string &verb) { - if (!proxy_thread.joinable()) - throw std::logic_error("Cannot " + verb + " before calling `start()`"); -} - void check_not_started(const std::thread& proxy_thread, const std::string &verb) { if (proxy_thread.joinable()) throw std::logic_error("Cannot " + verb + " after calling `start()`"); @@ -54,29 +49,20 @@ void send_control(zmq::socket_t& sock, string_view cmd, std::string data) { } } -/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening* -/// socket. -std::tuple extract_metadata(zmq::message_t& msg) { - auto result = std::make_tuple(""s, false, AuthLevel::none); +/// Extracts a pubkey and and auth level from a zmq message received on a *listening* socket. +std::pair extract_metadata(zmq::message_t& msg) { + auto result = std::make_pair(""s, AuthLevel::none); try { string_view pubkey_hex{msg.gets("User-Id")}; if (pubkey_hex.size() != 64) throw std::logic_error("bad user-id"); assert(is_hex(pubkey_hex.begin(), pubkey_hex.end())); - auto& pubkey = std::get(result); - pubkey.resize(32, 0); - from_hex(pubkey_hex.begin(), pubkey_hex.end(), pubkey.begin()); + result.first.resize(32, 0); + from_hex(pubkey_hex.begin(), pubkey_hex.end(), result.first.begin()); } catch (...) {} try { - string_view is_sn{msg.gets("X-SN")}; - if (is_sn.size() == 1 && is_sn[0] == '1') - std::get(result) = true; - } catch (...) {} - - try { - string_view auth_level{msg.gets("X-AuthLevel")}; - std::get(result) = auth_from_string(auth_level); + result.second = auth_from_string(msg.gets("X-AuthLevel")); } catch (...) {} return result; @@ -385,46 +371,6 @@ LokiMQ::~LokiMQ() { LMQ_LOG(info, "LokiMQ proxy thread has stopped"); } -ConnectionID LokiMQ::connect_sn(string_view pubkey, std::chrono::milliseconds keep_alive, string_view hint) { - check_started(proxy_thread, "connect"); - - detail::send_control(get_control_socket(), "CONNECT_SN", bt_serialize({{"pubkey",pubkey}, {"keep_alive",keep_alive.count()}, {"hint",hint}})); - - return pubkey; -} - -ConnectionID LokiMQ::connect_remote(string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, - string_view pubkey, AuthLevel auth_level, std::chrono::milliseconds timeout) { - if (!proxy_thread.joinable()) - LMQ_LOG(warn, "connect_remote() called before start(); this won't take effect until start() is called"); - - if (remote.size() < 7 || !(remote.substr(0, 6) == "tcp://" || remote.substr(0, 6) == "ipc://" /* unix domain sockets */)) - throw std::runtime_error("Invalid connect_remote: remote address '" + std::string{remote} + "' is not a valid or supported zmq connect string"); - - auto id = next_conn_id++; - LMQ_TRACE("telling proxy to connect to ", remote, ", id ", id, - pubkey.empty() ? "using NULL auth" : ", using CURVE with remote pubkey [" + to_hex(pubkey) + "]"); - detail::send_control(get_control_socket(), "CONNECT_REMOTE", bt_serialize({ - {"auth_level", static_cast>(auth_level)}, - {"conn_id", id}, - {"connect", reinterpret_cast(new ConnectSuccess{std::move(on_connect)})}, - {"failure", reinterpret_cast(new ConnectFailure{std::move(on_failure)})}, - {"pubkey", pubkey}, - {"remote", remote}, - {"timeout", timeout.count()}, - })); - - return id; -} - -void LokiMQ::disconnect(ConnectionID id, std::chrono::milliseconds linger) { - detail::send_control(get_control_socket(), "DISCONNECT", bt_serialize({ - {"conn_id", id.id}, - {"linger_ms", linger.count()}, - {"pubkey", id.pk}, - })); -} - std::ostream &operator<<(std::ostream &os, LogLevel lvl) { os << (lvl == LogLevel::trace ? "trace" : lvl == LogLevel::debug ? "debug" : diff --git a/lokimq/lokimq.h b/lokimq/lokimq.h index 08ff3fb..2890246 100644 --- a/lokimq/lokimq.h +++ b/lokimq/lokimq.h @@ -86,6 +86,14 @@ static constexpr size_t MAX_COMMAND_LENGTH = 200; class CatHelper; +/// std::unordered_set specialization for specifying pubkeys (used, in particular, by +/// LokiMQ::set_active_sns and LokiMQ::update_active_sns); this is a std::string unordered_set that +/// also uses a specialized trivial hash function that uses part of the value itself (i.e. the +/// pubkey) directly as a hash value. (This is nice and fast for uniformly distributed values like +/// pubkeys and a terrible hash choice for anything else). +using pubkey_set = std::unordered_set; + + /** * Class that handles LokiMQ listeners, connections, proxying, and workers. An application * typically has just one instance of this class. @@ -134,16 +142,18 @@ private: public: /// Callback type invoked to determine whether the given new incoming connection is allowed to - /// connect to us and to set its initial authentication level. + /// connect to us and to set its authentication level. /// /// @param ip - the ip address of the incoming connection /// @param pubkey - the x25519 pubkey of the connecting client (32 byte string). Note that this /// will only be non-empty for incoming connections on `listen_curve` sockets; `listen_plain` /// sockets do not have a pubkey. + /// @param service_node - will be true if the `pubkey` is in the set of known active service + /// nodes. /// /// @returns an `AuthLevel` enum value indicating the default auth level for the incoming /// connection, or AuthLevel::denied if the connection should be refused. - using AllowFunc = std::function; + using AllowFunc = std::function; /// Callback that is invoked when we need to send a "strong" message to a SN that we aren't /// already connected to and need to establish a connection. This callback returns the ZMQ @@ -244,7 +254,9 @@ private: std::vector> bind; /// Info about a peer's established connection with us. Note that "established" means both - /// connected and authenticated. + /// connected and authenticated. Note that we only store peer info data for SN connections (in + /// or out), and outgoing non-SN connections. Incoming non-SN connections are handled on the + /// fly. struct peer_info { /// Pubkey of the remote, if this connection is a curve25519 connection; empty otherwise. std::string pubkey; @@ -521,15 +533,26 @@ private: /// done). std::unordered_map command_aliases; + using cat_call_t = std::pair*>; /// Retrieve category and callback from a command name, including alias mapping. Warns on /// invalid commands and returns nullptrs. The command name will be updated in place if it is /// aliased to another command. - std::pair*> get_command(std::string& command); + cat_call_t get_command(std::string& command); /// Checks a peer's authentication level. Returns true if allowed, warns and returns false if /// not. bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, - const std::string& command, const category& cat, zmq::message_t& msg); + zmq::message_t& command, const cat_call_t& cat_call, std::vector& data); + + /// Set of active service nodes. + pubkey_set active_service_nodes; + + /// Resets or updates the stored set of active SN pubkeys + void proxy_set_active_sns(string_view data); + void proxy_set_active_sns(pubkey_set pubkeys); + void proxy_update_active_sns(bt_list_consumer data); + void proxy_update_active_sns(pubkey_set added, pubkey_set removed); + void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed); /// Details for a pending command; such a command already has authenticated access and is just /// waiting for a thread to become available to handle it. @@ -775,8 +798,15 @@ public: * Finish starting up: binds to the bind locations given in the constructor and launches the * proxy thread to handle message dispatching between remote nodes and worker threads. * - * You will need to call `add_category` and `add_command` to register commands before calling - * `start()`; once start() is called commands cannot be changed. + * Things you want to do before calling this: + * - Use `add_category`/`add_command` to set up any commands remote connections can invoke. + * - If any commands require SN authentication, specify a list of currently active service node + * pubkeys via `set_active_sns()` (and make sure this gets updated when things change by + * another `set_active_sns()` or a `update_active_sns()` call). It *is* possible to make the + * initial call after calling `start()`, but that creates a window during which incoming + * remote SN connections will be erroneously treated as non-SN connections. + * - If this LMQ instance should accept incoming connections, set up any listening ports via + * `listen_curve()` and/or `listen_plain()`. */ void start(); @@ -789,10 +819,10 @@ public: * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". * * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns non-service node, - * AuthLevel::none access. + * so, the authentication level it receives. If omitted the default returns AuthLevel::none + * access. */ - void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; }); + void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); /** Start listening on the given bind address in unauthenticated plain text mode. Incoming * connections can come from anywhere. `allow_connection` is invoked for any incoming @@ -803,10 +833,10 @@ public: * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". * * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns non-service node, - * AuthLevel::none access. + * so, the authentication level it receives. If omitted the default returns AuthLevel::none + * access. */ - void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto) { return Allow{AuthLevel::none, false}; }); + void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); /** * Try to initiate a connection to the given SN in anticipation of needing a connection in the @@ -939,9 +969,23 @@ public: * @param to - the pubkey string or ConnectionID to send this request to * @param cmd - the command name * @param callback - the callback to invoke when we get a reply. Called with a true value and - * the data strings when a reply is received, or false and an empty vector of data parts if we - * get no reply in the timeout interval. + * the data strings when a reply is received, or false with error string(s) indicating the + * failure reason upon failure or timeout. * @param opts - anything else (i.e. strings, send_options) is forwarded to send(). + * + * Possible error data values: + * - ["TIMEOUT"] - we got no reply within the timeout window + * - ["UNKNOWNCOMMAND"] - the remote did not recognize the given request command + * - ["NO_REPLY_TAG"] - the invoked command is a request command but no reply tag was included + * - ["FORBIDDEN"] - the command requires an authorization level (e.g. Basic or Admin) that we + * do not have. + * - ["FORBIDDEN_SN"] - the command requires service node authentication, but the remote did not + * recognize us as a service node. You *may* want to retry the request a limited number of + * times (but do not retry indefinitely as that can be an infinite loop!) because this is + * typically also followed by a disconnection; a retried message would reconnect and + * reauthenticate which *may* result in picking up the SN authentication. + * - ["NOT_A_SERVICE_NODE"] - this command is only invokable on service nodes, and the remote is + * not running as a service node. */ template void request(ConnectionID to, string_view cmd, ReplyCallback callback, const T&... opts); @@ -951,6 +995,41 @@ public: const std::string& get_pubkey() const { return pubkey; } const std::string& get_privkey() const { return privkey; } + /** Updates (or initially sets) LokiMQ's list of service node pubkeys with the given list. + * + * This has two main effects: + * + * - All commands processed after the update will have SN status determined by the new list. + * - All outgoing connections to service nodes that are no longer on the list will be closed. + * This includes both explicit connections (established by `connect_sn()`) and implicit ones + * (established by sending to a SN that wasn't connected). + * + * As this update is potentially quite heavy it is recommended that this be called only when + * necessary--i.e. when the list has changed (or potentially changed), but *not* on a short + * periodic timer. + * + * This method may (and should!) be called before start() to load an initial set of SNs. + * + * Once a full list has been set, updates on changes can either call this again with the new + * list, or use the more efficient update_active_sns() call if incremental results are + * available. + */ + void set_active_sns(pubkey_set pubkeys); + + /** Updates the list of active pubkeys by adding or removing the given pubkeys from the existing + * list. This is more efficient when the incremental information is already available; if it + * isn't, simply call set_active_sns with a new list to have LokiMQ figure out what was added or + * removed. + * + * \param added new pubkeys that were added since the last set_active_sns or update_active_sns + * call. + * + * \param removed pubkeys that were removed from active SN status since the last call. If a + * pubkey is in both `added` and `removed` for some reason then its presence in `removed` will + * be ignored. + */ + void update_active_sns(pubkey_set added, pubkey_set removed); + /** * Batches a set of jobs to be executed by workers, optionally followed by a completion function. * @@ -1121,7 +1200,9 @@ namespace detail { /// containing the pointer to be serialized to pass (via lokimq queues) from one thread to another. /// Must be matched with a deserializer_pointer on the other side to reconstitute the object and /// destroy the intermediate pointer. -template uintptr_t serialize_object(T&& obj) { +template +uintptr_t serialize_object(T&& obj) { + static_assert(std::is_rvalue_reference::value, "serialize_object must be given an rvalue reference"); auto* ptr = new T{std::forward(obj)}; return reinterpret_cast(ptr); } @@ -1192,9 +1273,8 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queu control_data["send_full_q"] = serialize_object(std::move(f.callback)); } -/// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening* -/// socket. -std::tuple extract_metadata(zmq::message_t& msg); +/// Extracts a pubkey and auth level from a zmq message received on a *listening* socket. +std::pair extract_metadata(zmq::message_t& msg); template bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) { diff --git a/lokimq/proxy.cpp b/lokimq/proxy.cpp index a627a53..bd9c0f6 100644 --- a/lokimq/proxy.cpp +++ b/lokimq/proxy.cpp @@ -178,7 +178,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { std::chrono::steady_clock::now() + request_timeout, std::move(request_callback) }}); } else { LMQ_LOG(debug, "Could not send request, scheduling request callback failure"); - job([callback = std::move(request_callback)] { callback(false, {}); }); + job([callback = std::move(request_callback)] { callback(false, {{"TIMEOUT"s}}); }); } } if (!sent) { @@ -237,31 +237,36 @@ void LokiMQ::proxy_reply(bt_dict_consumer data) { } void LokiMQ::proxy_control_message(std::vector& parts) { + // We throw an uncaught exception here because we only generate control messages internally in + // lokimq code: if one of these condition fail it's a lokimq bug. if (parts.size() < 2) - throw std::logic_error("Expected 2-3 message parts for a proxy control message"); + throw std::logic_error("LokiMQ bug: Expected 2-3 message parts for a proxy control message"); auto route = view(parts[0]), cmd = view(parts[1]); LMQ_TRACE("control message: ", cmd); if (parts.size() == 3) { LMQ_TRACE("...: ", parts[2]); + auto data = view(parts[2]); if (cmd == "SEND") { LMQ_TRACE("proxying message"); - return proxy_send(view(parts[2])); + return proxy_send(data); } else if (cmd == "REPLY") { LMQ_TRACE("proxying reply to non-SN incoming message"); - return proxy_reply(view(parts[2])); + return proxy_reply(data); } else if (cmd == "BATCH") { LMQ_TRACE("proxy batch jobs"); - auto ptrval = bt_deserialize(view(parts[2])); + auto ptrval = bt_deserialize(data); return proxy_batch(reinterpret_cast(ptrval)); + } else if (cmd == "UPDATE_SNS") { + return proxy_update_active_sns(data); } else if (cmd == "CONNECT_SN") { - proxy_connect_sn(view(parts[2])); + proxy_connect_sn(data); return; } else if (cmd == "CONNECT_REMOTE") { - return proxy_connect_remote(view(parts[2])); + return proxy_connect_remote(data); } else if (cmd == "DISCONNECT") { - return proxy_disconnect(view(parts[2])); + return proxy_disconnect(data); } else if (cmd == "TIMER") { - return proxy_timer(view(parts[2])); + return proxy_timer(data); } } else if (parts.size() == 2) { if (cmd == "START") { @@ -279,8 +284,8 @@ void LokiMQ::proxy_control_message(std::vector& parts) { return; } } - throw std::runtime_error("Proxy received invalid control command: " + std::string{cmd} + - " (" + std::to_string(parts.size()) + ")"); + throw std::runtime_error("LokiMQ bug: Proxy received invalid control command: " + + std::string{cmd} + " (" + std::to_string(parts.size()) + ")"); } void LokiMQ::proxy_loop() { @@ -455,26 +460,31 @@ void LokiMQ::proxy_loop() { } } +static bool is_error_response(string_view cmd) { + return cmd == "FORBIDDEN" || cmd == "FORBIDDEN_SN" || cmd == "NOT_A_SERVICE_NODE" || cmd == "UNKNOWNCOMMAND" || cmd == "NO_REPLY_TAG"; +} + // Return true if we recognized/handled the builtin command (even if we reject it for whatever // reason) bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector& parts) { - bool outgoing = connections[conn_index].getsockopt(ZMQ_TYPE) == ZMQ_DEALER; + // Doubling as a bool and an offset: + size_t incoming = connections[conn_index].getsockopt(ZMQ_TYPE) == ZMQ_ROUTER; string_view route, cmd; - if (parts.size() < (outgoing ? 1 : 2)) { + if (parts.size() < 1 + incoming) { LMQ_LOG(warn, "Received empty message; ignoring"); return true; } - if (outgoing) { - cmd = view(parts[0]); - } else { + if (incoming) { route = view(parts[0]); cmd = view(parts[1]); + } else { + cmd = view(parts[0]); } LMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back())); if (cmd == "REPLY") { - size_t tag_pos = (outgoing ? 1 : 2); + size_t tag_pos = 1 + incoming; if (parts.size() <= tag_pos) { LMQ_LOG(warn, "Received REPLY without a reply tag; ignoring"); return true; @@ -482,7 +492,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector std::string reply_tag{view(parts[tag_pos])}; auto it = pending_requests.find(reply_tag); if (it != pending_requests.end()) { - LMQ_LOG(debug, "Received REPLY for pending command", to_hex(reply_tag), "; scheduling callback"); + LMQ_LOG(debug, "Received REPLY for pending command ", to_hex(reply_tag), "; scheduling callback"); std::vector data; data.reserve(parts.size() - (tag_pos + 1)); for (auto it = parts.begin() + (tag_pos + 1); it != parts.end(); ++it) @@ -496,7 +506,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector } return true; } else if (cmd == "HI") { - if (outgoing) { + if (!incoming) { LMQ_LOG(warn, "Got invalid 'HI' message on an outgoing connection; ignoring"); return true; } @@ -506,7 +516,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector } catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); } return true; } else if (cmd == "HELLO") { - if (!outgoing) { + if (incoming) { LMQ_LOG(warn, "Got invalid 'HELLO' message on an incoming connection; ignoring"); return true; } @@ -531,22 +541,50 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector pending_connects.erase(it); return true; } else if (cmd == "BYE") { - if (outgoing) { - std::string pk; - bool sn; - AuthLevel a; - std::tie(pk, sn, a) = detail::extract_metadata(parts.back()); - ConnectionID conn = sn ? ConnectionID{std::move(pk)} : conn_index_to_id[conn_index]; - LMQ_LOG(debug, "BYE command received; disconnecting from ", conn); - proxy_disconnect(conn, 1s); + if (!incoming) { + LMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back())); + proxy_close_connection(conn_index, 0s); } else { LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring"); } return true; } - else if (cmd == "FORBIDDEN" || cmd == "NOT_A_SERVICE_NODE") { - return true; // FIXME - ignore these? Log? + else if (is_error_response(cmd)) { + // These messages (FORBIDDEN, UNKNOWNCOMMAND, etc.) are sent in response to us trying to + // invoke something that doesn't exist or we don't have permission to access. These have + // two forms (the latter is only sent by remotes running 1.0.6+). + // - ["XXX", "whatever.command"] + // - ["XXX", "REPLY", replytag] + // (ignoring the routing prefix on incoming commands). + // For the former, we log; for the latter we trigger the reply callback with a failure + + if (parts.size() == (2 + incoming) && is_error_response(view(parts[1 + incoming]))) { + // Something like ["UNKNOWNCOMMAND", "FORBIDDEN_SN"] which can happen because the remote + // is running an older version that didn't understand the FORBIDDEN_SN (or whatever) + // error reply that we sent them. We just ignore it because anything else could trigger + // an infinite cycle. + LMQ_LOG(debug, "Received [", cmd, ",", view(parts[1 + incoming]), "]; remote is probably an older lokimq. Ignoring."); + return true; + } + + if (parts.size() == (3 + incoming) && view(parts[1 + incoming]) == "REPLY") { + std::string reply_tag{view(parts[2 + incoming])}; + auto it = pending_requests.find(reply_tag); + if (it != pending_requests.end()) { + LMQ_LOG(debug, "Received ", cmd, " REPLY for pending command ", to_hex(reply_tag), "; scheduling failure callback"); + proxy_schedule_reply_job([callback=std::move(it->second.second), cmd=std::string{cmd}] { + callback(false, {{std::move(cmd)}}); + }); + pending_requests.erase(it); + } else { + LMQ_LOG(warn, "Received REPLY with unknown or already handled reply tag (", to_hex(reply_tag), "); ignoring"); + } + } else { + LMQ_LOG(warn, "Received ", cmd, ':', (parts.size() > 1 + incoming ? view(parts[1 + incoming]) : "(unknown command)"_sv), + " from ", peer_address(parts.back())); + } + return true; } return false; } diff --git a/lokimq/worker.cpp b/lokimq/worker.cpp index b6e3b15..3e91fa8 100644 --- a/lokimq/worker.cpp +++ b/lokimq/worker.cpp @@ -201,22 +201,24 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector& par } peer = &it->second; } else { - std::tie(tmp_peer.pubkey, tmp_peer.service_node, tmp_peer.auth_level) = detail::extract_metadata(parts.back()); + std::tie(tmp_peer.pubkey, tmp_peer.auth_level) = detail::extract_metadata(parts.back()); + tmp_peer.service_node = tmp_peer.pubkey.size() == 32 && active_service_nodes.count(tmp_peer.pubkey); + if (tmp_peer.service_node) { // It's a service node so we should have a peer_info entry; see if we can find one with // the same route, and if not, add one. auto pr = peers.equal_range(tmp_peer.pubkey); for (auto it = pr.first; it != pr.second; ++it) { - if (it->second.route == tmp_peer.route) { + if (it->second.conn_index == tmp_peer.conn_index && it->second.route == tmp_peer.route) { peer = &it->second; - // Upgrade permissions in case we have something higher on the socket - peer->service_node |= tmp_peer.service_node; - if (tmp_peer.auth_level > peer->auth_level) - peer->auth_level = tmp_peer.auth_level; + // Update the stored auth level just in case the peer reconnected + peer->auth_level = tmp_peer.auth_level; break; } } if (!peer) { + // We don't have a record: this is either a new SN connection or a new message on a + // connection that recently gained SN status. peer = &peers.emplace(ConnectionID{tmp_peer.pubkey}, std::move(tmp_peer))->second; } } else { @@ -228,23 +230,6 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector& par size_t command_part_index = outgoing ? 0 : 1; std::string command = parts[command_part_index].to_string(); - auto cat_call = get_command(command); - - if (!cat_call.first) { - LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer->pubkey), "]/", peer_address(parts.back())); - try { - if (outgoing) - send_direct_message(connections[conn_index], "UNKNOWNCOMMAND"); - else - send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND"); - } catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ } - return; - } - - auto& category = *cat_call.first; - - if (!proxy_check_auth(conn_index, outgoing, *peer, command, category, parts.back())) - return; // Steal any data message parts size_t data_part_index = command_part_index + 1; @@ -253,6 +238,14 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector& par for (auto it = parts.begin() + data_part_index; it != parts.end(); ++it) data_parts.push_back(std::move(*it)); + auto cat_call = get_command(command); + + // Check that command is valid, that we have permission, etc. + if (!proxy_check_auth(conn_index, outgoing, *peer, parts[command_part_index], cat_call, data_parts)) + return; + + auto& category = *cat_call.first; + if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) { // No free reserved or general spots, try to queue it for later if (category.max_queue >= 0 && category.queued >= category.max_queue) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 92dccf5..36997a7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,6 +6,7 @@ set(LMQ_TEST_SRC test_batch.cpp test_connect.cpp test_commands.cpp + test_failures.cpp test_requests.cpp test_string_view.cpp ) diff --git a/tests/common.h b/tests/common.h index 84dc1be..adef968 100644 --- a/tests/common.h +++ b/tests/common.h @@ -6,6 +6,25 @@ using namespace lokimq; static auto startup = std::chrono::steady_clock::now(); +/// Waits up to 100ms for something to happen. +template +inline void wait_for(Func f) { + for (int i = 0; i < 10; i++) { + if (f()) + break; + std::this_thread::sleep_for(10ms); + } +} + +/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough +/// time for an initial connection + request. +inline void wait_for_conn(std::atomic &c) { + wait_for([&c] { return c.load(); }); +} + +/// Waits enough time for us to receive a reply from a localhost remote. +inline void reply_sleep() { std::this_thread::sleep_for(10ms); } + // Catch2 macros aren't thread safe, so guard with a mutex inline std::unique_lock catch_lock() { static std::mutex mutex; diff --git a/tests/test_commands.cpp b/tests/test_commands.cpp index 44527eb..e185cdc 100644 --- a/tests/test_commands.cpp +++ b/tests/test_commands.cpp @@ -1,5 +1,4 @@ #include "common.h" -#include #include #include #include @@ -12,10 +11,10 @@ TEST_CASE("basic commands", "[commands]") { "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, - get_logger("S» ") + get_logger("S» "), + LogLevel::trace }; - server.log_level(LogLevel::trace); - server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; }); + server.listen_curve(listen); std::atomic hellos{0}, his{0}; @@ -32,41 +31,33 @@ TEST_CASE("basic commands", "[commands]") { server.start(); - LokiMQ client{ - get_logger("C» ") - }; - client.log_level(LogLevel::trace); + LokiMQ client{get_logger("C» "), LogLevel::trace}; client.add_category("public", Access{AuthLevel::none}); client.add_command("public", "hi", [&](auto&) { his++; }); client.start(); - std::atomic connected{false}, failed{false}; + std::atomic got{false}; + bool success = false, failed = false; std::string pubkey; auto c = client.connect_remote(listen, - [&](auto conn) { pubkey = conn.pubkey(); connected = true; }, - [&](auto conn, string_view) { failed = true; }, + [&](auto conn) { pubkey = conn.pubkey(); success = true; got = true; }, + [&](auto conn, string_view) { failed = true; got = true; }, server.get_pubkey()); - int i; - for (i = 0; i < 5; i++) { - if (connected.load()) - break; - std::this_thread::sleep_for(50ms); - } + wait_for_conn(got); { auto lock = catch_lock(); - REQUIRE( connected.load() ); - REQUIRE( i <= 1 ); // should be fast - REQUIRE( !failed.load() ); + REQUIRE( got ); + REQUIRE( success ); + REQUIRE_FALSE( failed ); REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) ); } client.send(c, "public.hello"); client.send(c, "public.client.pubkey"); - - std::this_thread::sleep_for(50ms); + reply_sleep(); { auto lock = catch_lock(); REQUIRE( hellos == 1 ); @@ -77,7 +68,7 @@ TEST_CASE("basic commands", "[commands]") { for (int i = 0; i < 50; i++) client.send(c, "public.hello"); - std::this_thread::sleep_for(100ms); + wait_for([&] { return his == 26; }); { auto lock = catch_lock(); REQUIRE( hellos == 51 ); @@ -91,10 +82,10 @@ TEST_CASE("outgoing auth level", "[commands][auth]") { "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, - get_logger("S» ") + get_logger("S» "), + LogLevel::trace }; - server.log_level(LogLevel::trace); - server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; }); + server.listen_curve(listen); std::atomic hellos{0}; @@ -103,10 +94,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") { server.start(); - LokiMQ client{ - get_logger("C» ") - }; - client.log_level(LogLevel::trace); + LokiMQ client{get_logger("C» "), LogLevel::trace}; std::atomic public_hi{0}, basic_hi{0}, admin_hi{0}; client.add_category("public", Access{AuthLevel::none}); @@ -124,8 +112,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") { auto admin_c = client.connect_remote(listen, [](...) {}, [](...) {}, server.get_pubkey(), AuthLevel::admin); client.send(public_c, "public.reflect", "public.hi"); - std::this_thread::sleep_for(50ms); - + wait_for([&] { return public_hi == 1; }); { auto lock = catch_lock(); REQUIRE( public_hi == 1 ); @@ -138,7 +125,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") { client.send(admin_c, "public.reflect", "admin.hi"); client.send(basic_c, "public.reflect", "basic.hi"); - std::this_thread::sleep_for(50ms); + wait_for([&] { return public_hi == 2; }); { auto lock = catch_lock(); REQUIRE( admin_hi == 3 ); @@ -160,7 +147,7 @@ TEST_CASE("outgoing auth level", "[commands][auth]") { client.send(admin_c, "public.reflect", "basic.hi"); client.send(admin_c, "public.reflect", "public.hi"); - std::this_thread::sleep_for(50ms); + wait_for([&] { return public_hi == 3; }); auto lock = catch_lock(); REQUIRE( admin_hi == 1 ); REQUIRE( basic_hi == 2 ); @@ -176,10 +163,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, - get_logger("S» ") + get_logger("S» "), + LogLevel::trace }; - server.log_level(LogLevel::trace); - server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; }); + server.listen_curve(listen); std::vector> subscribers; ConnectionID backdoor; @@ -221,8 +208,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") auto nsa_c = nsa.connect_remote(listen, connect_success, connect_failure, server.get_pubkey(), AuthLevel::admin); nsa.send(nsa_c, "hey google.install backdoor"); - std::this_thread::sleep_for(50ms); - + wait_for([&] { auto lock = catch_lock(); return (bool) backdoor; }); { auto l = catch_lock(); REQUIRE( backdoor ); @@ -243,9 +229,10 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") std::map> google_knows; int things_remembered{0}; for (int i = 0; i < 5; i++) { - clients.push_back(std::make_unique(get_logger("C" + std::to_string(i) + "» "))); + clients.push_back(std::make_unique( + get_logger("C" + std::to_string(i) + "» "), LogLevel::trace + )); auto& c = clients.back(); - c->log_level(LogLevel::trace); c->add_category("personal", Access{AuthLevel::basic}); c->add_command("personal", "detail", [&,i](Message& m) { auto l = catch_lock(); @@ -265,7 +252,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") }, personal_detail); } - std::this_thread::sleep_for(50ms); + wait_for([&] { auto lock = catch_lock(); return things_remembered == all_the_things.size(); }); { auto l = catch_lock(); REQUIRE( things_remembered == all_the_things.size() ); @@ -273,7 +260,7 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") } clients[0]->send(conns[0], "hey google.recall"); - std::this_thread::sleep_for(50ms); + reply_sleep(); { auto l = catch_lock(); REQUIRE( google_knows == personal_details ); @@ -286,10 +273,10 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") { "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, - get_logger("S» ") + get_logger("S» "), + LogLevel::debug // This test traces so much that it takes 2.5-3s of CPU time at trace level, so don't do that. }; - server.log_level(LogLevel::debug); - server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; }); + server.listen_plain(listen); std::atomic send_attempts{0}; std::atomic send_failures{0}; @@ -331,7 +318,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") { for (i = 0; i < 20; i++) { if (send_attempts.load() >= 500) break; - std::this_thread::sleep_for(25ms); + std::this_thread::sleep_for(10ms); } { auto lock = catch_lock(); @@ -354,20 +341,20 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") { client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none); expected_attempts += 500; if (i >= 4) { - std::this_thread::sleep_for(25ms); if (send_failures.load() > 0) break; + std::this_thread::sleep_for(25ms); } } - for (i = 0; i < 10; i++) { + for (i = 0; i < 20; i++) { if (send_attempts.load() >= expected_attempts) break; - std::this_thread::sleep_for(25ms); + std::this_thread::sleep_for(10ms); } { auto lock = catch_lock(); - REQUIRE( i <= 8 ); + REQUIRE( i < 20 ); REQUIRE( send_attempts.load() == expected_attempts ); REQUIRE( send_failures.load() > 0 ); } diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index 6c1b5c9..b0c5fdf 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -1,5 +1,4 @@ #include "common.h" -#include #include extern "C" { #include @@ -12,46 +11,42 @@ TEST_CASE("connections with curve authentication", "[curve][connect]") { "", "", // generate ephemeral keys false, // not a service node [](auto) { return ""; }, - get_logger("S» ") + get_logger("S» "), + LogLevel::trace }; - server.log_level(LogLevel::trace); - server.listen_curve(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; }); + server.listen_curve(listen); server.add_category("public", Access{AuthLevel::none}); server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); }); server.start(); - LokiMQ client{get_logger("C» ")}; - client.log_level(LogLevel::trace); + LokiMQ client{get_logger("C» "), LogLevel::trace}; client.start(); auto pubkey = server.get_pubkey(); - std::atomic connected{0}; + std::atomic got{false}; + bool success = false; auto server_conn = client.connect_remote(listen, - [&](auto conn) { connected = 1; }, - [&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }, + [&](auto conn) { success = true; got = true; }, + [&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; }, pubkey); - int i; - for (i = 0; i < 5; i++) { - if (connected.load()) - break; - std::this_thread::sleep_for(50ms); - } + wait_for_conn(got); { auto lock = catch_lock(); - REQUIRE( i <= 1 ); - REQUIRE( connected.load() ); + REQUIRE( got ); + REQUIRE( success ); } - bool success = false; + success = false; std::vector parts; client.request(server_conn, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; }); - std::this_thread::sleep_for(50ms); - auto lock = catch_lock(); - REQUIRE( success ); - + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( success ); + } } TEST_CASE("self-connection SN optimization", "[connect][self]") { @@ -63,10 +58,16 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") { pubkey, privkey, true, [&](auto pk) { if (pk == pubkey) return "tcp://127.0.0.1:5544"; else return ""; }, - get_logger("S» ") + get_logger("S» "), + LogLevel::trace }; - sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk) { REQUIRE(ip == "127.0.0.1"); return Allow{AuthLevel::none, pk == pubkey}; }); + sn.listen_curve("tcp://127.0.0.1:5544", [&](auto ip, auto pk, auto sn) { + auto lock = catch_lock(); + REQUIRE(ip == "127.0.0.1"); + REQUIRE(sn == (pk == pubkey)); + return AuthLevel::none; + }); sn.add_category("a", Access{AuthLevel::none}); bool invoked = false; sn.add_command("a", "b", [&](const Message& m) { @@ -77,57 +78,54 @@ TEST_CASE("self-connection SN optimization", "[connect][self]") { REQUIRE(!m.data.empty()); REQUIRE(m.data[0] == "my data"); }); - sn.log_level(LogLevel::trace); + sn.set_active_sns({{pubkey}}); sn.start(); - std::this_thread::sleep_for(50ms); sn.send(pubkey, "a.b", "my data"); - std::this_thread::sleep_for(50ms); - auto lock = catch_lock(); - REQUIRE(invoked); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE(invoked); + } } TEST_CASE("plain-text connections", "[plaintext][connect]") { std::string listen = "tcp://127.0.0.1:4455"; - LokiMQ server{get_logger("S» ")}; - server.log_level(LogLevel::trace); + LokiMQ server{get_logger("S» "), LogLevel::trace}; server.add_category("public", Access{AuthLevel::none}); server.add_request_command("public", "hello", [&](Message& m) { m.send_reply("hi"); }); - server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; }); + server.listen_plain(listen); server.start(); - LokiMQ client{get_logger("C» ")}; - client.log_level(LogLevel::trace); + LokiMQ client{get_logger("C» "), LogLevel::trace}; client.start(); - std::atomic connected{0}; + std::atomic got{false}; + bool success = false; auto c = client.connect_remote(listen, - [&](auto conn) { connected = 1; }, - [&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); } + [&](auto conn) { success = true; got = true; }, + [&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); got = true; } ); - int i; - for (i = 0; i < 5; i++) { - if (connected.load()) - break; - std::this_thread::sleep_for(50ms); - } + wait_for_conn(got); { auto lock = catch_lock(); - REQUIRE( i <= 1 ); - REQUIRE( connected.load() ); + REQUIRE( got ); + REQUIRE( success ); } - bool success = false; + success = false; std::vector parts; client.request(c, "public.hello", [&](auto success_, auto parts_) { success = success_; parts = parts_; }); - std::this_thread::sleep_for(50ms); - auto lock = catch_lock(); - REQUIRE( success ); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( success ); + } } TEST_CASE("SN disconnections", "[connect][disconnect]") { @@ -147,28 +145,146 @@ TEST_CASE("SN disconnections", "[connect][disconnect]") { lmq.push_back(std::make_unique( pubkey[i], privkey[i], true, [conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; }, - get_logger("S" + std::to_string(i) + "» ") + get_logger("S" + std::to_string(i) + "» "), + LogLevel::trace )); auto& server = *lmq.back(); - server.log_level(LogLevel::debug); - server.listen_curve(conn[pubkey[i]], [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, true}; }); + server.listen_curve(conn[pubkey[i]]); server.add_category("sn", Access{AuthLevel::none, true}) .add_command("hi", [&](Message& m) { his++; }); + server.set_active_sns({pubkey.begin(), pubkey.end()}); server.start(); } - std::this_thread::sleep_for(50ms); lmq[0]->send(pubkey[1], "sn.hi"); lmq[0]->send(pubkey[2], "sn.hi"); - std::this_thread::sleep_for(50ms); lmq[2]->send(pubkey[0], "sn.hi"); lmq[2]->send(pubkey[1], "sn.hi"); lmq[1]->send(pubkey[0], "BYE"); - std::this_thread::sleep_for(50ms); lmq[0]->send(pubkey[2], "sn.hi"); std::this_thread::sleep_for(50ms); auto lock = catch_lock(); REQUIRE(his == 5); } + +TEST_CASE("SN auth checks", "[sandwich][auth]") { + // When a remote connects, we check its authentication level; if at the time of connection it + // isn't recognized as a SN but tries to invoke a SN command it'll be told to disconnect; if it + // tries to send again it should reconnect and reauthenticate. This test is meant to test this + // pattern where the reconnection/reauthentication now authenticates it as a SN. + std::string listen = "tcp://127.0.0.1:4455"; + std::string pubkey, privkey; + pubkey.resize(crypto_box_PUBLICKEYBYTES); + privkey.resize(crypto_box_SECRETKEYBYTES); + crypto_box_keypair(reinterpret_cast(&pubkey[0]), reinterpret_cast(&privkey[0])); + LokiMQ server{ + pubkey, privkey, + true, // service node + [](auto) { return ""; }, + get_logger("A» "), + LogLevel::trace + }; + + std::atomic incoming_is_sn{false}; + server.listen_curve(listen); + server.add_category("public", Access{AuthLevel::none}) + .add_request_command("hello", [&](Message& m) { m.send_reply("hi"); }) + .add_request_command("sudo", [&](Message& m) { + server.update_active_sns({{m.conn.pubkey()}}, {}); + m.send_reply("making sandwiches"); + }) + .add_request_command("nosudo", [&](Message& m) { + // Send the reply *first* because if we do it the other way we'll have just removed + // ourselves from the list of SNs and thus would try to open an outbound connection + // to deliver it since it's still queued as a message to a SN. + m.send_reply("make them yourself"); + server.update_active_sns({}, {{m.conn.pubkey()}}); + }); + server.add_category("sandwich", Access{AuthLevel::none, true}) + .add_request_command("make", [&](Message& m) { m.send_reply("okay"); }); + server.start(); + + LokiMQ client{ + "", "", false, + [&](auto remote_pk) { if (remote_pk == pubkey) return listen; return ""s; }, + get_logger("B» "), LogLevel::trace}; + client.start(); + + std::atomic got{false}; + bool success; + client.request(pubkey, "public.hello", [&](auto success_, auto) { success = success_; got = true; }); + wait_for_conn(got); + { + auto lock = catch_lock(); + REQUIRE( got ); + REQUIRE( success ); + } + + got = false; + using dvec = std::vector; + dvec data; + client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) { + success = success_; + data = std::move(data_); + got = true; + }); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got ); + REQUIRE_FALSE( success ); + REQUIRE( data == dvec{{"FORBIDDEN_SN"}} ); + } + + // Somebody set up us the bomb. Main sudo turn on. + got = false; + client.request(pubkey, "public.sudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; }); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got ); + REQUIRE( success ); + REQUIRE( data == dvec{{"making sandwiches"}} ); + } + + got = false; + client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) { + success = success_; + data = std::move(data_); + got = true; + }); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got ); + REQUIRE( success ); + REQUIRE( data == dvec{{"okay"}} ); + } + + // Take off every 'SUDO', You [not] know what you doing + got = false; + client.request(pubkey, "public.nosudo", [&](auto success_, auto data_) { success = success_; data = data_; got = true; }); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got ); + REQUIRE( success ); + REQUIRE( data == dvec{{"make them yourself"}} ); + } + + got = false; + client.request(pubkey, "sandwich.make", [&](auto success_, auto data_) { + success = success_; + data = std::move(data_); + got = true; + }); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got ); + REQUIRE_FALSE( success ); + REQUIRE( data == dvec{{"FORBIDDEN_SN"}} ); + } +} diff --git a/tests/test_failures.cpp b/tests/test_failures.cpp new file mode 100644 index 0000000..9aeb7c4 --- /dev/null +++ b/tests/test_failures.cpp @@ -0,0 +1,309 @@ +#include "common.h" +#include +#include +#include + +using namespace lokimq; + +TEST_CASE("failure responses - UNKNOWNCOMMAND", "[failure][UNKNOWNCOMMAND]") { + std::string listen = "tcp://127.0.0.1:4567"; + LokiMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + get_logger("S» "), + LogLevel::trace + }; + server.listen_plain(listen); + server.start(); + + // Use a raw socket here because I want to see the raw commands coming on the wire + zmq::context_t client_ctx; + zmq::socket_t client{client_ctx, zmq::socket_type::dealer}; + client.connect(listen); + // Handshake: we send HI, they reply HELLO. + client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none); + { + zmq::message_t hello; + client.recv(hello); + + auto lock = catch_lock(); + REQUIRE( hello.to_string() == "HELLO" ); + REQUIRE_FALSE( hello.more() ); + } + + client.send(zmq::message_t{"a.a", 3}, zmq::send_flags::none); + zmq::message_t resp; + client.recv(resp); + + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "UNKNOWNCOMMAND" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "a.a" ); + REQUIRE_FALSE( resp.more() ); +} + +TEST_CASE("failure responses - NO_REPLY_TAG", "[failure][NO_REPLY_TAG]") { + std::string listen = "tcp://127.0.0.1:4567"; + LokiMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + get_logger("S» "), + LogLevel::trace + }; + server.listen_plain(listen); + server.add_category("x", AuthLevel::none) + .add_request_command("r", [] (auto& m) { m.send_reply("a"); }); + server.start(); + + // Use a raw socket here because I want to see the raw commands coming on the wire + zmq::context_t client_ctx; + zmq::socket_t client{client_ctx, zmq::socket_type::dealer}; + client.connect(listen); + // Handshake: we send HI, they reply HELLO. + client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none); + { + zmq::message_t hello; + client.recv(hello); + + auto lock = catch_lock(); + REQUIRE( hello.to_string() == "HELLO" ); + REQUIRE_FALSE( hello.more() ); + } + + client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::none); + zmq::message_t resp; + client.recv(resp); + + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "NO_REPLY_TAG" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "x.r" ); + REQUIRE_FALSE( resp.more() ); + } + + client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore); + client.send(zmq::message_t{"foo", 3}, zmq::send_flags::none); + client.recv(resp); + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "REPLY" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "foo" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "a" ); + REQUIRE_FALSE( resp.more() ); + } +} + +TEST_CASE("failure responses - FORBIDDEN", "[failure][FORBIDDEN]") { + std::string listen = "tcp://127.0.0.1:4567"; + LokiMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + get_logger("S» "), + LogLevel::trace + }; + server.listen_plain(listen, [](auto, auto, auto) { + static int count = 0; + ++count; + return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin; + }); + server.add_category("x", AuthLevel::basic) + .add_command("x", [] (auto& m) { m.send_back("a"); }); + server.add_category("y", AuthLevel::admin) + .add_command("x", [] (auto& m) { m.send_back("b"); }); + server.start(); + + zmq::context_t client_ctx; + std::array clients; + // Client 0 should get none auth level, client 1 should get basic, client 2 should get admin + for (auto& client : clients) { + client = {client_ctx, zmq::socket_type::dealer}; + client.connect(listen); + // Handshake: we send HI, they reply HELLO. + client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none); + { + zmq::message_t hello; + client.recv(hello); + + auto lock = catch_lock(); + REQUIRE( hello.to_string() == "HELLO" ); + REQUIRE_FALSE( hello.more() ); + } + } + + for (auto& c : clients) + c.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none); + + zmq::message_t resp; + clients[0].recv(resp); + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "FORBIDDEN" ); + REQUIRE( resp.more() ); + clients[0].recv(resp); + REQUIRE( resp.to_string() == "x.x" ); + REQUIRE_FALSE( resp.more() ); + } + for (int i : {1, 2}) { + clients[i].recv(resp); + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "a" ); + REQUIRE_FALSE( resp.more() ); + } + + for (auto& c : clients) + c.send(zmq::message_t{"y.x", 3}, zmq::send_flags::none); + + for (int i : {0, 1}) { + clients[i].recv(resp); + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "FORBIDDEN" ); + REQUIRE( resp.more() ); + clients[i].recv(resp); + REQUIRE( resp.to_string() == "y.x" ); + REQUIRE_FALSE( resp.more() ); + } + clients[2].recv(resp); + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "b" ); + REQUIRE_FALSE( resp.more() ); + } +} + +TEST_CASE("failure responses - NOT_A_SERVICE_NODE", "[failure][NOT_A_SERVICE_NODE]") { + std::string listen = "tcp://127.0.0.1:4567"; + LokiMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + get_logger("S» "), + LogLevel::trace + }; + server.listen_plain(listen, [](auto, auto, auto) { + static int count = 0; + ++count; + return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin; + }); + server.add_category("x", Access{AuthLevel::none, false, true}) + .add_command("x", [] (auto&) {}) + .add_request_command("r", [] (auto& m) { m.send_reply(); }) + ; + server.start(); + + zmq::context_t client_ctx; + zmq::socket_t client{client_ctx, zmq::socket_type::dealer}; + client.connect(listen); + // Handshake: we send HI, they reply HELLO. + client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none); + { + zmq::message_t hello; + client.recv(hello); + + auto lock = catch_lock(); + REQUIRE( hello.to_string() == "HELLO" ); + REQUIRE_FALSE( hello.more() ); + } + + client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none); + + zmq::message_t resp; + client.recv(resp); + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "x.x" ); + REQUIRE_FALSE( resp.more() ); + } + + client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore); + client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag + + client.recv(resp); + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "NOT_A_SERVICE_NODE" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "REPLY" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "xyz123" ); + REQUIRE_FALSE( resp.more() ); + } +} + +TEST_CASE("failure responses - FORBIDDEN_SN", "[failure][FORBIDDEN_SN]") { + std::string listen = "tcp://127.0.0.1:4567"; + LokiMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + get_logger("S» "), + LogLevel::trace + }; + server.listen_plain(listen, [](auto, auto, auto) { + static int count = 0; + ++count; + return count == 1 ? AuthLevel::none : count == 2 ? AuthLevel::basic : AuthLevel::admin; + }); + server.add_category("x", Access{AuthLevel::none, true, false}) + .add_command("x", [] (auto&) {}) + .add_request_command("r", [] (auto& m) { m.send_reply(); }) + ; + server.start(); + + zmq::context_t client_ctx; + zmq::socket_t client{client_ctx, zmq::socket_type::dealer}; + client.connect(listen); + // Handshake: we send HI, they reply HELLO. + client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none); + { + zmq::message_t hello; + client.recv(hello); + + auto lock = catch_lock(); + REQUIRE( hello.to_string() == "HELLO" ); + REQUIRE_FALSE( hello.more() ); + } + + client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none); + + zmq::message_t resp; + client.recv(resp); + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "FORBIDDEN_SN" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "x.x" ); + REQUIRE_FALSE( resp.more() ); + } + + client.send(zmq::message_t{"x.r", 3}, zmq::send_flags::sndmore); + client.send(zmq::message_t{"xyz123", 6}, zmq::send_flags::none); // reply tag + + client.recv(resp); + { + auto lock = catch_lock(); + REQUIRE( resp.to_string() == "FORBIDDEN_SN" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "REPLY" ); + REQUIRE( resp.more() ); + client.recv(resp); + REQUIRE( resp.to_string() == "xyz123" ); + REQUIRE_FALSE( resp.more() ); + } +} diff --git a/tests/test_requests.cpp b/tests/test_requests.cpp index 53a4c68..46b2dae 100644 --- a/tests/test_requests.cpp +++ b/tests/test_requests.cpp @@ -1,5 +1,4 @@ #include "common.h" -#include #include using namespace lokimq; @@ -36,17 +35,11 @@ TEST_CASE("basic requests", "[requests]") { [&](auto, auto) { failed = true; }, server.get_pubkey()); - int i; - for (i = 0; i < 5; i++) { - if (connected.load()) - break; - std::this_thread::sleep_for(50ms); - } + wait_for([&] { return connected || failed; }); { auto lock = catch_lock(); - REQUIRE( connected.load() ); - REQUIRE( !failed.load() ); - REQUIRE( i <= 1 ); + REQUIRE( connected ); + REQUIRE_FALSE( failed ); REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) ); } @@ -59,11 +52,13 @@ TEST_CASE("basic requests", "[requests]") { data = std::move(data_); }); - std::this_thread::sleep_for(50ms); - auto lock = catch_lock(); - REQUIRE( got_reply.load() ); - REQUIRE( success ); - REQUIRE( data == std::vector{{"123"}} ); + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( data == std::vector{{"123"}} ); + } } TEST_CASE("request from server to client", "[requests]") { @@ -161,15 +156,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") { [&](auto, auto) { failed = true; }, server.get_pubkey()); - int i; - for (i = 0; i < 5; i++) { - if (connected.load()) - break; - std::this_thread::sleep_for(50ms); - } - REQUIRE( connected.load() ); - REQUIRE( !failed.load() ); - REQUIRE( i <= 1 ); + wait_for([&] { return connected || failed; }); + + REQUIRE( connected ); + REQUIRE_FALSE( failed ); REQUIRE( to_hex(pubkey) == to_hex(server.get_pubkey()) ); std::atomic got_triggered{false}; @@ -180,7 +170,7 @@ TEST_CASE("request timeouts", "[requests][timeout]") { success = ok; data = std::move(data_); }, - lokimq::send_option::request_timeout{30ms} + lokimq::send_option::request_timeout{20ms} ); std::atomic got_triggered2{false}; @@ -192,10 +182,10 @@ TEST_CASE("request timeouts", "[requests][timeout]") { lokimq::send_option::request_timeout{100ms} ); - std::this_thread::sleep_for(50ms); - REQUIRE( got_triggered.load() ); + std::this_thread::sleep_for(30ms); + REQUIRE( got_triggered ); + REQUIRE_FALSE( got_triggered2 ); REQUIRE_FALSE( success ); - REQUIRE( data.size() == 0 ); + REQUIRE( data == std::vector{{"TIMEOUT"}} ); - REQUIRE_FALSE( got_triggered2.load() ); }