diff --git a/lokimq/auth.cpp b/lokimq/auth.cpp index a7c58b3..068da56 100644 --- a/lokimq/auth.cpp +++ b/lokimq/auth.cpp @@ -61,9 +61,9 @@ bool LokiMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& try { if (outgoing) - send_direct_message(connections[conn_index], std::move(reply), command, zmq::send_flags::dontwait); + send_direct_message(connections[conn_index], std::move(reply), command); else - send_routed_message(connections[conn_index], peer.route, std::move(reply), command, zmq::send_flags::dontwait); + send_routed_message(connections[conn_index], peer.route, std::move(reply), command); } catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ } return false; diff --git a/lokimq/connections.h b/lokimq/connections.h index dafc1a9..d1373e6 100644 --- a/lokimq/connections.h +++ b/lokimq/connections.h @@ -8,7 +8,7 @@ struct ConnectionID; namespace detail { template -bt_dict build_send(ConnectionID to, string_view cmd, const T&... opts); +bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts); } /// Opaque data structure representing a connection which supports ==, !=, < and std::hash. For @@ -67,7 +67,7 @@ private: friend class LokiMQ; friend struct std::hash; template - friend bt_dict detail::build_send(ConnectionID to, string_view cmd, const T&... opts); + friend bt_dict detail::build_send(ConnectionID to, string_view cmd, T&&... opts); friend std::ostream& operator<<(std::ostream& o, const ConnectionID& conn); }; diff --git a/lokimq/jobs.cpp b/lokimq/jobs.cpp index f15345a..981db43 100644 --- a/lokimq/jobs.cpp +++ b/lokimq/jobs.cpp @@ -94,9 +94,8 @@ void LokiMQ::_queue_timer_job(int timer_id) { void LokiMQ::add_timer(std::function job, std::chrono::milliseconds interval, bool squelch) { if (proxy_thread.joinable()) { - auto *jobptr = new std::function{std::move(job)}; detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{ - reinterpret_cast(jobptr), + detail::serialize_object(std::move(job)), interval.count(), squelch}})); } else { diff --git a/lokimq/lokimq-internal.h b/lokimq/lokimq-internal.h index b5c190c..28a598c 100644 --- a/lokimq/lokimq-internal.h +++ b/lokimq/lokimq-internal.h @@ -38,41 +38,41 @@ inline zmq::message_t create_message(string_view data) { } template -void send_message_parts(zmq::socket_t &sock, It begin, It end, const zmq::send_flags flags = zmq::send_flags::none) { +bool send_message_parts(zmq::socket_t &sock, It begin, It end) { while (begin != end) { zmq::message_t &msg = *begin++; - sock.send(msg, begin == end ? flags : flags | zmq::send_flags::sndmore); + if (!sock.send(msg, begin == end ? zmq::send_flags::dontwait : zmq::send_flags::dontwait | zmq::send_flags::sndmore)) + return false; } + return true; } template -void send_message_parts(zmq::socket_t &sock, Container &&c, zmq::send_flags flags = zmq::send_flags::none) { - send_message_parts(sock, c.begin(), c.end(), flags); +bool send_message_parts(zmq::socket_t &sock, Container &&c) { + return send_message_parts(sock, c.begin(), c.end()); } /// Sends a message with an initial route. `msg` and `data` can be empty: if `msg` is empty then /// the msg frame will be an empty message; if `data` is empty then the data frame will be omitted. /// `flags` is passed through to zmq: typically given `zmq::send_flags::dontwait` to throw rather /// than block if a message can't be queued. -inline void send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}, - zmq::send_flags flags = zmq::send_flags::none) { +inline bool send_routed_message(zmq::socket_t &socket, std::string route, std::string msg = {}, std::string data = {}) { assert(!route.empty()); std::array msgs{{create_message(std::move(route))}}; if (!msg.empty()) msgs[1] = create_message(std::move(msg)); if (!data.empty()) msgs[2] = create_message(std::move(data)); - send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end(), flags); + return send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end()); } // Sends some stuff to a socket directly. If dontwait is true then we throw instead of blocking if // the message cannot be accepted by zmq (i.e. because the outgoing buffer is full). -inline void send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}, - zmq::send_flags flags = zmq::send_flags::none) { +inline bool send_direct_message(zmq::socket_t &socket, std::string msg, std::string data = {}) { std::array msgs{{create_message(std::move(msg))}}; if (!data.empty()) msgs[1] = create_message(std::move(data)); - send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end(), flags); + return send_message_parts(socket, msgs.begin(), data.empty() ? std::prev(msgs.end()) : msgs.end()); } // Receive all the parts of a single message from the given socket. Returns true if a message was diff --git a/lokimq/lokimq.h b/lokimq/lokimq.h index a2ab17a..d35eca1 100644 --- a/lokimq/lokimq.h +++ b/lokimq/lokimq.h @@ -419,9 +419,10 @@ private: // either accepting curve connections, or not accepting curve). void setup_outgoing_socket(zmq::socket_t& socket, string_view remote_pubkey = {}); - /// Common connection implementation used by proxy_connect/proxy_send. Returns the socket - /// and, if a routing prefix is needed, the required prefix (or an empty string if not needed). - /// For an optional connect that fail, returns nullptr for the socket. + /// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and, + /// if a routing prefix is needed, the required prefix (or an empty string if not needed). For + /// an optional connect that fails (or some other connection failure), returns nullptr for the + /// socket. /// /// @param pubkey the pubkey to connect to /// @param connect_hint if we need a new connection and this is non-empty then we *may* use it @@ -1060,10 +1061,60 @@ struct request_timeout { explicit request_timeout(std::chrono::milliseconds time) : time{std::move(time)} {} }; +/// Specifies a callback to invoke if the message couldn't be queued for delivery. There are +/// generally two failure modes here: a full queue, and a send exception. This callback is invoked +/// for both; to only catch full queues see `queue_full` instead. +/// +/// A full queue means there are too many messages queued for delivery already that haven't been +/// delivered yet (i.e. because the remote is slow); this error is potentially recoverable if the +/// remote end wakes up and receives/acknoledges its messages. +/// +/// A send exception is not recoverable: it indicates some failure such as the remote having +/// disconnected or an internal send error. +/// +/// This callback can be used by a caller to log, attempt to resend, or take other appropriate +/// action. +/// +/// Note that this callback is *not* exhaustive for all possible send failures: there are failure +/// cases (such as when a message is queued but the connection fails before delivery) that do not +/// trigger this failure at all; rather this callback only signals an immediate queuing failure. +struct queue_failure { + using callback_t = std::function; + /// Callback; invoked with nullptr for a queue full failure, otherwise will be set to a copy of + /// the raised exception. + callback_t callback; +}; + +/// This is similar to queue_failure_callback, but is only invoked on a (potentially recoverable) +/// full queue failure. Send failures are simply dropped. +struct queue_full { + using callback_t = std::function; + callback_t callback; +}; + } namespace detail { +/// Takes an rvalue reference, moves it into a new instance then returns a uintptr_t value +/// containing the pointer to be serialized to pass (via lokimq queues) from one thread to another. +/// Must be matched with a deserializer_pointer on the other side to reconstitute the object and +/// destroy the intermediate pointer. +template uintptr_t serialize_object(T&& obj) { + auto* ptr = new T{std::forward(obj)}; + return reinterpret_cast(ptr); +} + +/// Takes a uintptr_t as produced by serialize_pointer and the type, converts the serialized value +/// back into a pointer, moves it into a new instance (to be returned) and destroys the +/// intermediate. +template T deserialize_object(uintptr_t ptrval) { + auto* ptr = reinterpret_cast(ptrval); + T ret{std::move(*ptr)}; + delete ptr; + return ret; +} + // Sends a control message to the given socket consisting of the command plus optional dict // data (only sent if the data is non-empty). void send_control(zmq::socket_t& sock, string_view cmd, std::string data = {}); @@ -1105,20 +1156,28 @@ inline void apply_send_option(bt_list&, bt_dict& control_data, const send_option control_data["request_timeout"] = timeout.time.count(); } +/// `queue_failure` specialization +inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queue_failure f) { + control_data["send_fail"] = serialize_object(std::move(f.callback)); +} +/// `queue_full` specialization +inline void apply_send_option(bt_list&, bt_dict& control_data, send_option::queue_full f) { + control_data["send_full_q"] = serialize_object(std::move(f.callback)); +} /// Extracts a pubkey, SN status, and auth level from a zmq message received on a *listening* /// socket. std::tuple extract_metadata(zmq::message_t& msg); template -bt_dict build_send(ConnectionID to, string_view cmd, const T&... opts) { +bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts) { bt_dict control_data; bt_list parts{{cmd}}; #ifdef __cpp_fold_expressions - (detail::apply_send_option(parts, control_data, opts),...); + (detail::apply_send_option(parts, control_data, std::forward(opts)),...); #else - (void) std::initializer_list{(detail::apply_send_option(parts, control_data, opts), 0)...}; + (void) std::initializer_list{(detail::apply_send_option(parts, control_data, std::forward(opts)), 0)...}; #endif if (to.sn()) @@ -1148,7 +1207,7 @@ void LokiMQ::request(ConnectionID to, string_view cmd, ReplyCallback callback, c const auto reply_tag = make_random_string(15); // 15 random bytes is lots and should keep us in most stl implementations' small string optimization bt_dict control_data = detail::build_send(std::move(to), cmd, reply_tag, opts...); control_data["request"] = true; - control_data["request_callback"] = reinterpret_cast(new ReplyCallback{std::move(callback)}); + control_data["request_callback"] = detail::serialize_object(std::move(callback)); control_data["request_tag"] = string_view{reply_tag}; detail::send_control(get_control_socket(), "SEND", bt_serialize(std::move(control_data))); } diff --git a/lokimq/proxy.cpp b/lokimq/proxy.cpp index aba52d7..9893a45 100644 --- a/lokimq/proxy.cpp +++ b/lokimq/proxy.cpp @@ -70,13 +70,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { if (!data.skip_until("request_callback")) throw std::runtime_error("Internal error: received request without request_callback"); - // The initiator gives up ownership of the callback to us (serializing it through a - // uintptr_t), so we take the pointer, move the value out of it, then destroy the pointer we - // were given. Further down, if we are able to send the request successfully, we set up the - // pending request. - auto* cbptr = reinterpret_cast(data.consume_integer()); - request_callback = std::move(*cbptr); - delete cbptr; + request_callback = detail::deserialize_object(data.consume_integer()); if (!data.skip_until("request_tag")) throw std::runtime_error("Internal error: received request without request_name"); @@ -88,11 +82,20 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { throw std::runtime_error("Internal error: Invalid proxy send command; send parts missing"); bt_list_consumer send = data.consume_list_consumer(); + send_option::queue_failure::callback_t callback_nosend; + if (data.skip_until("send_fail")) + callback_nosend = detail::deserialize_object(data.consume_integer()); + + send_option::queue_full::callback_t callback_noqueue; + if (data.skip_until("send_full_q")) + callback_noqueue = detail::deserialize_object(data.consume_integer()); + // Now figure out which socket to send to and do the actual sending. We can repeat this loop // multiple times, if we're sending to a SN, because it's possible that we have multiple // connections open to that SN (e.g. one out + one in) so if one fails we can clean up that // connection and try the next one. - bool retry = true, sent = false; + bool retry = true, sent = false, warned = false; + std::unique_ptr send_error; while (retry) { retry = false; zmq::socket_t *send_to; @@ -103,7 +106,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { LMQ_LOG(debug, "Not sending: send is optional and no connection to ", to_hex(conn_id.pk), " is currently established"); else - LMQ_LOG(error, "Unable to send to ", to_hex(conn_id.pk), ": no connection address found"); + LMQ_LOG(error, "Unable to send to ", to_hex(conn_id.pk), ": no valid connection address found"); break; } send_to = sock_route.first; @@ -126,8 +129,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { } try { - send_message_parts(*send_to, build_send_parts(send, conn_id.route)); - sent = true; + sent = send_message_parts(*send_to, build_send_parts(send, conn_id.route)); } catch (const zmq::error_t &e) { if (e.num() == EHOSTUNREACH && !conn_id.route.empty() /*= incoming conn*/) { @@ -158,6 +160,11 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { } if (!retry) { LMQ_LOG(warn, "Unable to send message to ", conn_id, ": ", e.what()); + warned = true; + if (callback_nosend) { + job([callback = std::move(callback_nosend), error = e] { callback(&error); }); + callback_nosend = nullptr; + } } } } @@ -171,6 +178,14 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { job([callback = std::move(request_callback)] { callback(false, {}); }); } } + if (!sent) { + if (callback_nosend) + job([callback = std::move(callback_nosend)] { callback(nullptr); }); + else if (callback_noqueue) + job(std::move(callback_noqueue)); + else if (!warned) + LMQ_LOG(warn, "Unable to send message to ", conn_id, ": sending would block"); + } } void LokiMQ::proxy_reply(bt_dict_consumer data) { @@ -456,7 +471,7 @@ bool LokiMQ::proxy_handle_builtin(size_t conn_index, std::vector route = view(parts[0]); cmd = view(parts[1]); } - LMQ_TRACE("Checking for builtins: ", cmd, " from ", peer_address(parts.back())); + LMQ_TRACE("Checking for builtins: '", cmd, "' from ", peer_address(parts.back())); if (cmd == "REPLY") { size_t tag_pos = (outgoing ? 1 : 2); diff --git a/lokimq/worker.cpp b/lokimq/worker.cpp index 916ca77..b6e3b15 100644 --- a/lokimq/worker.cpp +++ b/lokimq/worker.cpp @@ -234,9 +234,9 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector& par LMQ_LOG(warn, "Invalid command '", command, "' sent by remote [", to_hex(peer->pubkey), "]/", peer_address(parts.back())); try { if (outgoing) - send_direct_message(connections[conn_index], "UNKNOWNCOMMAND", command, zmq::send_flags::dontwait); + send_direct_message(connections[conn_index], "UNKNOWNCOMMAND"); else - send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND", command, zmq::send_flags::dontwait); + send_routed_message(connections[conn_index], peer->route, "UNKNOWNCOMMAND"); } catch (const zmq::error_t&) { /* can't send: possibly already disconnected. Ignore. */ } return; } diff --git a/tests/test_commands.cpp b/tests/test_commands.cpp index 9483a47..44527eb 100644 --- a/tests/test_commands.cpp +++ b/tests/test_commands.cpp @@ -279,3 +279,96 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") REQUIRE( google_knows == personal_details ); } } + +TEST_CASE("send failure callbacks", "[commands][queue_full]") { + std::string listen = "tcp://127.0.0.1:4567"; + LokiMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + get_logger("S» ") + }; + server.log_level(LogLevel::debug); + server.listen_plain(listen, [](auto /*ip*/, auto /*pk*/) { return Allow{AuthLevel::none, false}; }); + + std::atomic send_attempts{0}; + std::atomic send_failures{0}; + // ZMQ TCP sockets' HWM is complicated and OS dependent; sender and receiver (probably) each + // have 1000 message queues, but there is also the TCP queue to worry about which means we can + // have more queued before we fill up, so we send 4kiB of null with each message so that we + // don't get too much TCP queuing. + std::string junk(4096, '0'); + server.add_category("x", Access{AuthLevel::none}) + .add_command("x", [&](Message& m) { + for (int x = 0; x < 500; x++) { + ++send_attempts; + m.send_back("y.y", junk, send_option::queue_full{[&]() { ++send_failures; }}); + } + }); + + server.start(); + + // Use a raw socket here because I want to stall it by not reading from it at all, and that is + // hard with LokiMQ. + zmq::context_t client_ctx; + zmq::socket_t client{client_ctx, zmq::socket_type::dealer}; + client.connect(listen); + // Handshake: we send HI, they reply HELLO. + client.send(zmq::message_t{"HI", 2}, zmq::send_flags::none); + zmq::message_t hello; + client.recv(hello); + string_view hello_sv{hello.data(), hello.size()}; + { + auto lock = catch_lock(); + REQUIRE( hello_sv == "HELLO" ); + REQUIRE_FALSE( hello.more() ); + } + + // Tell the remote to queue up a batch of messages + client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none); + + int i; + for (i = 0; i < 20; i++) { + if (send_attempts.load() >= 500) + break; + std::this_thread::sleep_for(25ms); + } + { + auto lock = catch_lock(); + REQUIRE( i <= 4 ); // should be not too slow + // We have two buffers here: 1000 on the receiver, and 1000 on the client, which means we + // should be able to get 2000 out before we hit HWM. We should only have been sent 501 so + // far (the "HELLO" handshake + 500 "y.y" messages). + REQUIRE( send_attempts.load() == 500 ); + REQUIRE( send_failures.load() == 0 ); + } + + // Now we want to tell the server to send enough to fill the outgoing queue and start stalling. + // This is complicated as it depends on ZMQ internals *and* OS-level TCP buffers, so we really + // don't know precisely where this will start failing. + // + // In practice, I seem to reach HWM (for this test, with this amount of data being sent, on my + // Debian desktop) after 2499 messages (that is, queuing 2500 gives 1 failure). + int expected_attempts = 500; + for (int i = 0; i < 10; i++) { + client.send(zmq::message_t{"x.x", 3}, zmq::send_flags::none); + expected_attempts += 500; + if (i >= 4) { + std::this_thread::sleep_for(25ms); + if (send_failures.load() > 0) + break; + } + } + + for (i = 0; i < 10; i++) { + if (send_attempts.load() >= expected_attempts) + break; + std::this_thread::sleep_for(25ms); + } + { + auto lock = catch_lock(); + REQUIRE( i <= 8 ); + REQUIRE( send_attempts.load() == expected_attempts ); + REQUIRE( send_failures.load() > 0 ); + } +}