From 5dd7c122196df98eb5284d3f7befc284c6068fdb Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Wed, 23 Jun 2021 10:51:08 -0300 Subject: [PATCH 1/3] Add support for listening after startup This commit adds support for listening on new ports after startup. This will make things easier in storage server, in particular, where we want to delay listening on public ports until we have an established connection and initial block status update from oxend. --- oxenmq/auth.cpp | 12 ++++---- oxenmq/connections.cpp | 17 +++++++++- oxenmq/oxenmq.cpp | 32 +++++++++---------- oxenmq/oxenmq.h | 46 ++++++++++++++++++++------- oxenmq/proxy.cpp | 70 ++++++++++++++++++++++++++---------------- tests/test_connect.cpp | 58 ++++++++++++++++++++++++++++++++++ 6 files changed, 173 insertions(+), 62 deletions(-) diff --git a/oxenmq/auth.cpp b/oxenmq/auth.cpp index dfe1f6c..2be7f48 100644 --- a/oxenmq/auth.cpp +++ b/oxenmq/auth.cpp @@ -256,14 +256,14 @@ void OxenMQ::process_zap_requests() { LMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'"); status_code = "400"; status_text = "Unknown authentication domain: " + std::string{auth_domain}; - } else if (bind[bind_id].second.curve + } else if (bind[bind_id].curve ? !(frames.size() == 7 && view(frames[5]) == "CURVE") : !(frames.size() == 6 && view(frames[5]) == "NULL")) { LMQ_LOG(error, "Bad ZAP authentication request: invalid ", - bind[bind_id].second.curve ? "CURVE" : "NULL", " authentication request"); + bind[bind_id].curve ? "CURVE" : "NULL", " authentication request"); status_code = "500"; status_text = "Invalid authentication request mechanism"; - } else if (bind[bind_id].second.curve && frames[6].size() != 32) { + } else if (bind[bind_id].curve && frames[6].size() != 32) { LMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey"); status_code = "500"; status_text = "Invalid public key size for CURVE authentication"; @@ -271,13 +271,13 @@ void OxenMQ::process_zap_requests() { auto ip = view(frames[3]); std::string_view pubkey; bool sn = false; - if (bind[bind_id].second.curve) { + if (bind[bind_id].curve) { pubkey = view(frames[6]); sn = active_service_nodes.count(std::string{pubkey}); } - auto auth = bind[bind_id].second.allow(ip, pubkey, sn); + auto auth = bind[bind_id].allow(ip, pubkey, sn); auto& user_id = response_vals[4]; - if (bind[bind_id].second.curve) { + if (bind[bind_id].curve) { user_id.reserve(64); to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id)); } diff --git a/oxenmq/connections.cpp b/oxenmq/connections.cpp index e87d71c..a84b665 100644 --- a/oxenmq/connections.cpp +++ b/oxenmq/connections.cpp @@ -67,6 +67,21 @@ void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remot // else let ZMQ pick a random one } + +void OxenMQ::setup_incoming_socket(zmq::socket_t& listener, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index) { + + setup_external_socket(listener); + + listener.set(zmq::sockopt::zap_domain, bt_serialize(bind_index)); + if (curve) { + listener.set(zmq::sockopt::curve_server, true); + listener.set(zmq::sockopt::curve_publickey, pubkey); + listener.set(zmq::sockopt::curve_secretkey, privkey); + } + listener.set(zmq::sockopt::router_handover, true); + listener.set(zmq::sockopt::router_mandatory, true); +} + // Deprecated versions: ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) { @@ -218,7 +233,7 @@ void OxenMQ::proxy_close_connection(size_t index, std::chrono::milliseconds ling update_connection_indices(pending_connects, index, [](auto& pc) -> size_t& { return std::get(pc); }); update_connection_indices(bind, index, - [](auto& b) -> size_t& { return b.second.index; }); + [](auto& b) -> size_t& { return b.index; }); update_connection_indices(incoming_conn_index, index, [](auto& oci) -> size_t& { return oci.second; }); assert(index < conn_index_to_id.size()); diff --git a/oxenmq/oxenmq.cpp b/oxenmq/oxenmq.cpp index 7747c40..c30095b 100644 --- a/oxenmq/oxenmq.cpp +++ b/oxenmq/oxenmq.cpp @@ -232,12 +232,6 @@ void OxenMQ::start() { if (proxy_thread.joinable()) throw std::logic_error("Cannot call start() multiple times!"); - // If we're not binding to anything then we don't listen, i.e. we can only establish outbound - // connections. Don't allow this if we are in service_node mode because, if we aren't - // listening, we are useless as a service node. - if (bind.empty() && local_service_node) - throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"}; - LMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey)); int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit); @@ -267,20 +261,22 @@ void OxenMQ::start() { LMQ_LOG(debug, "Proxy thread is ready"); } -void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) { - // TODO: there's no particular reason we can't start listening after starting up; just needs to - // be implemented. (But if we can start we'll probably also want to be able to stop, so it's - // more than just binding that needs implementing). - check_not_started(proxy_thread, "start listening"); - - bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)}); +void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function on_bind) { + if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; }; + bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)}; + if (proxy_thread.joinable()) + detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d)))); + else + bind.push_back(std::move(d)); } -void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) { - // TODO: As above. - check_not_started(proxy_thread, "start listening"); - - bind.emplace_back(std::move(bind_addr), bind_data{false, std::move(allow_connection)}); +void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function on_bind) { + if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; }; + bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)}; + if (proxy_thread.joinable()) + detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d)))); + else + bind.push_back(std::move(d)); } diff --git a/oxenmq/oxenmq.h b/oxenmq/oxenmq.h index d23b7b7..2db5f17 100644 --- a/oxenmq/oxenmq.h +++ b/oxenmq/oxenmq.h @@ -320,15 +320,17 @@ private: zmq::socket_t zap_auth{context, zmq::socket_type::rep}; struct bind_data { + std::string address; bool curve; size_t index; AllowFunc allow; - bind_data(bool curve, AllowFunc allow) - : curve{curve}, index{0}, allow{std::move(allow)} {} + std::function on_bind; + bind_data(std::string addr, bool curve, AllowFunc allow, std::function on_bind) + : address{std::move(addr)}, curve{curve}, index{0}, allow{std::move(allow)}, on_bind{std::move(on_bind)} {} }; /// Addresses on which we are listening (or, before start(), on which we will listen). - std::vector> bind; + std::vector bind; /// Info about a peer's established connection with us. Note that "established" means both /// connected and authenticated. Note that we only store peer info data for SN connections (in @@ -507,6 +509,9 @@ private: /// gets called after all works have done so. void proxy_quit(); + /// proxy handler for binding to addresses given via listen_*(). + bool proxy_bind(bind_data& bind, size_t index); + // Common setup code for setting up an external (incoming or outgoing) socket. void setup_external_socket(zmq::socket_t& socket); @@ -516,6 +521,9 @@ private: // either accepting curve connections, or not accepting curve). void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey, bool use_ephemeral_routing_id); + /// Sets the various properties on an listening socket prior to binding. + void setup_incoming_socket(zmq::socket_t& socket, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index); + /// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and, /// if a routing prefix is needed, the required prefix (or an empty string if not needed). For /// an optional connect that fails (or some other connection failure), returns nullptr for the @@ -971,14 +979,26 @@ public: * will be encrypted. `allow_connection` is invoked for any incoming connections on this * address to determine the incoming remote's access and authentication level. * + * If called before `start()` then the given bind address is mandatory and start() will throw if + * the bind fails. If called after `start()` then the bind may fail (in which case the callback + * will be used to notify of the failure). + * * @param bind address - can be any string zmq supports; typically a tcp IP/port combination * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". * * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns AuthLevel::none - * access. + * so, the authentication level it receives. If omitted (or null) the default returns + * AuthLevel::none access for all connections. + * + * @param on_bind function to call when the port has been successfully opened or failed to + * open. For addresses set up before .start() this will be called during `start()` itself; for + * post-start listens this will be called from the proxy thread when it opens the new port. + * Note that this function must is called directly from the proxy thread and so should be fast + * and non-blocking. */ - void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); + void listen_curve(std::string bind, + AllowFunc allow_connection = nullptr, + std::function on_bind = nullptr); /** Start listening on the given bind address in unauthenticated plain text mode. Incoming * connections can come from anywhere. `allow_connection` is invoked for any incoming @@ -989,10 +1009,14 @@ public: * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". * * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns AuthLevel::none - * access. + * so, the authentication level it receives. If omitted (or null) the default returns + * AuthLevel::none access for all connections. + * + * @param on_result called after binding with the result; see `listen_curve` for details. */ - void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); + void listen_plain(std::string bind, + AllowFunc allow_connection = nullptr, + std::function on_bind = nullptr); /** * Try to initiate a connection to the given SN in anticipation of needing a connection in the @@ -1446,8 +1470,8 @@ namespace connect_option { /// Typically use: `connect_options::ephemeral_routing_id{}` or `connect_options::ephemeral_routing_id{false}`. struct ephemeral_routing_id { bool use_ephemeral_routing_id = true; - // Constructor; default construction gives you pubkey routing, but the bool parameter can be - // specified as false to explicitly disable the pubkey routing flag. + // Constructor; default construction gives you ephemeral routing id, but the bool parameter can + // be specified as false to use pubkey routing flag. explicit ephemeral_routing_id(bool use = true) : use_ephemeral_routing_id{use} {} }; diff --git a/oxenmq/proxy.cpp b/oxenmq/proxy.cpp index 1778599..6208aab 100644 --- a/oxenmq/proxy.cpp +++ b/oxenmq/proxy.cpp @@ -293,6 +293,11 @@ void OxenMQ::proxy_control_message(std::vector& parts) { return proxy_timer(data); } else if (cmd == "TIMER_DEL") { return proxy_timer_del(bt_deserialize(data)); + } else if (cmd == "BIND") { + auto b = detail::deserialize_object(bt_deserialize(data)); + if (proxy_bind(b, bind.size())) + bind.push_back(std::move(b)); + return; } } else if (parts.size() == 2) { if (cmd == "START") { @@ -317,6 +322,38 @@ void OxenMQ::proxy_control_message(std::vector& parts) { std::string{cmd} + " (" + std::to_string(parts.size()) + ")"); } +bool OxenMQ::proxy_bind(bind_data& b, size_t index) { + zmq::socket_t listener{context, zmq::socket_type::router}; + setup_incoming_socket(listener, b.curve, pubkey, privkey, index); + + bool good = true; + try { + listener.bind(b.address); + } catch (const zmq::error_t&) { + good = false; + } + if (b.on_bind) { + b.on_bind(good); + b.on_bind = nullptr; + } + if (!good) { + LMQ_LOG(warn, "OxenMQ failed to listen on ", b.address); + return false; + } + + LMQ_LOG(info, "OxenMQ listening on ", b.address); + + connections.push_back(std::move(listener)); + auto conn_id = next_conn_id++; + conn_index_to_id.push_back(conn_id); + incoming_conn_index[conn_id] = connections.size() - 1; + b.index = connections.size() - 1; + + pollitems_stale = true; + + return true; +} + void OxenMQ::proxy_loop() { #if defined(__linux__) || defined(__sun) || defined(__MINGW32__) @@ -364,27 +401,10 @@ void OxenMQ::proxy_loop() { #endif for (size_t i = 0; i < bind.size(); i++) { - auto& b = bind[i].second; - zmq::socket_t listener{context, zmq::socket_type::router}; - - setup_external_socket(listener); - listener.set(zmq::sockopt::zap_domain, bt_serialize(i)); - if (b.curve) { - listener.set(zmq::sockopt::curve_server, true); - listener.set(zmq::sockopt::curve_publickey, pubkey); - listener.set(zmq::sockopt::curve_secretkey, privkey); + if (!proxy_bind(bind[i], i)) { + LMQ_LOG(warn, "OxenMQ failed to listen on ", bind[i].address); + throw zmq::error_t{}; } - listener.set(zmq::sockopt::router_handover, true); - listener.set(zmq::sockopt::router_mandatory, true); - - listener.bind(bind[i].first); - LMQ_LOG(info, "OxenMQ listening on ", bind[i].first); - - connections.push_back(std::move(listener)); - auto conn_id = next_conn_id++; - conn_index_to_id.push_back(conn_id); - incoming_conn_index[conn_id] = connections.size() - 1; - b.index = connections.size() - 1; } #ifndef _WIN32 @@ -393,13 +413,11 @@ void OxenMQ::proxy_loop() { // set socket gid / uid if it is provided if (SOCKET_GID != -1 or SOCKET_UID != -1) { - for(size_t i = 0; i < bind.size(); i++) { - const address addr(bind[i].first); - if(addr.ipc()) { - if(chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) { + for (auto& b : bind) { + const address addr(b.address); + if (addr.ipc()) + if (chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno)); - } - } } } #endif diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index 2c7f10e..f4c2d07 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -129,6 +129,64 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") { } } +TEST_CASE("post-start listening", "[connect][listen]") { + OxenMQ server{get_logger("S» "), LogLevel::trace}; + server.add_category("x", AuthLevel::none) + .add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); }); + server.start(); + std::atomic listens; + auto listen_curve = random_localhost(); + server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; }); + auto listen_plain = random_localhost(); + server.listen_plain(listen_plain, nullptr, [&](bool success) { if (success) listens += 10; }); + + wait_for([&] { return listens.load() >= 11; }); + { + auto lock = catch_lock(); + REQUIRE( listens == 11 ); + } + + // This should fail since we're already listening on it: + server.listen_curve(listen_plain, nullptr, [&](bool success) { if (!success) listens++; }); + + wait_for([&] { return listens.load() >= 12; }); + { + auto lock = catch_lock(); + REQUIRE( listens == 12 ); + } + + + OxenMQ client{get_logger("C1» "), LogLevel::trace}; + client.start(); + std::atomic conns = 0; + auto c1 = client.connect_remote(address{listen_curve, server.get_pubkey()}, + [&](auto) { conns++; }, + [&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); }); + auto c2 = client.connect_remote(address{listen_plain}, + [&](auto) { conns += 10; }, + [&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); }); + + + wait_for([&] { return conns.load() >= 11; }); + { + auto lock = catch_lock(); + REQUIRE( conns == 11 ); + } + + std::atomic replies = 0; + std::string reply1, reply2; + client.request(c1, "x.y", [&](auto success, auto parts) { replies++; for (auto& p : parts) reply1 += p; }, " world"); + client.request(c2, "x.y", [&](auto success, auto parts) { replies += 10; for (auto& p : parts) reply2 += p; }, " cat"); + + wait_for([&] { return replies.load() >= 11; }); + { + auto lock = catch_lock(); + REQUIRE( replies == 11 ); + REQUIRE( reply1 == "hi world" ); + REQUIRE( reply2 == "hi cat" ); + } +} + TEST_CASE("unique connection IDs", "[connect][id]") { std::string listen = random_localhost(); OxenMQ server{get_logger("S» "), LogLevel::trace}; From a0642a894e63b4bc28305ec3d9f7fa58887c3312 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Wed, 23 Jun 2021 10:55:47 -0300 Subject: [PATCH 2/3] Miscellaneous small test suite fixes/improvements - Allow up to 200ms (instead of 100ms) for the things we are waiting on to become available, to prevent occasional spurious failures. - Add unscoped info for how long we waited. - Avoid calling into oxenmq with the catch lock held in the "hey google" tests (because this will deadlock if the oxenmq call invokes any logging). - Replace an old std::cerr logger with the updated catch2 logger. --- tests/common.h | 20 ++++++++++++-------- tests/test_commands.cpp | 32 ++++++++++++++++++++------------ 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/tests/common.h b/tests/common.h index 6dda8f6..2798675 100644 --- a/tests/common.h +++ b/tests/common.h @@ -1,6 +1,7 @@ #pragma once #include "oxenmq/oxenmq.h" #include +#include using namespace oxenmq; @@ -16,14 +17,23 @@ inline std::string random_localhost() { } -/// Waits up to 100ms for something to happen. +// Catch2 macros aren't thread safe, so guard with a mutex +inline std::unique_lock catch_lock() { + static std::mutex mutex; + return std::unique_lock{mutex}; +} + +/// Waits up to 200ms for something to happen. template inline void wait_for(Func f) { - for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i < 20; i++) { if (f()) break; std::this_thread::sleep_for(10ms); } + auto lock = catch_lock(); + UNSCOPED_INFO("done waiting after " << (std::chrono::steady_clock::now() - start).count() << "ns"); } /// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough @@ -35,12 +45,6 @@ inline void wait_for_conn(std::atomic &c) { /// Waits enough time for us to receive a reply from a localhost remote. inline void reply_sleep() { std::this_thread::sleep_for(10ms); } -// Catch2 macros aren't thread safe, so guard with a mutex -inline std::unique_lock catch_lock() { - static std::mutex mutex; - return std::unique_lock{mutex}; -} - inline OxenMQ::Logger get_logger(std::string prefix = "") { std::string me = "tests/common.h"; std::string strip = __FILE__; diff --git a/tests/test_commands.cpp b/tests/test_commands.cpp index b0c4e24..606fed3 100644 --- a/tests/test_commands.cpp +++ b/tests/test_commands.cpp @@ -173,18 +173,26 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") server.add_category("hey google", Access{AuthLevel::none}); server.add_request_command("hey google", "remember", [&](Message& m) { - auto l = catch_lock(); - subscribers.emplace_back(m.conn, std::string{m.data[0]}); + bool bd; + { + auto l = catch_lock(); + subscribers.emplace_back(m.conn, std::string{m.data[0]}); + bd = (bool) backdoor; + } m.send_reply("Okay, I'll remember that."); - if (backdoor) + if (bd) m.oxenmq.send(backdoor, "backdoor.data", m.data[0]); }); server.add_command("hey google", "recall", [&](Message& m) { - auto l = catch_lock(); - for (auto& s : subscribers) { - server.send(s.first, "personal.detail", s.second); + decltype(subscribers) subs; + { + auto l = catch_lock(); + subs = subscribers; } + + for (auto& s : subs) + server.send(s.first, "personal.detail", s.second); }); server.add_command("hey google", "install backdoor", [&](Message& m) { auto l = catch_lock(); @@ -363,7 +371,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") { } } -TEST_CASE("data parts", "[send][data_parts]") { +TEST_CASE("data parts", "[commands][send][data_parts]") { std::string listen = random_localhost(); OxenMQ server{ "", "", // generate ephemeral keys @@ -446,7 +454,7 @@ TEST_CASE("data parts", "[send][data_parts]") { } } -TEST_CASE("deferred replies", "[send][deferred]") { +TEST_CASE("deferred replies", "[commands][send][deferred]") { std::string listen = random_localhost(); OxenMQ server{ "", "", // generate ephemeral keys @@ -461,9 +469,9 @@ TEST_CASE("deferred replies", "[send][deferred]") { server.add_request_command("public", "echo", [&](Message& m) { std::string msg = m.data.empty() ? ""s : std::string{m.data.front()}; std::thread t{[send=m.send_later(), msg=std::move(msg)] { - { auto lock = catch_lock(); INFO("sleeping"); } + { auto lock = catch_lock(); UNSCOPED_INFO("sleeping"); } std::this_thread::sleep_for(50ms); - { auto lock = catch_lock(); INFO("sending"); } + { auto lock = catch_lock(); UNSCOPED_INFO("sending"); } send.reply(msg); }}; t.detach(); @@ -472,8 +480,8 @@ TEST_CASE("deferred replies", "[send][deferred]") { server.start(); OxenMQ client( - [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } - ); + get_logger("C» "), + LogLevel::trace); //client.log_level(LogLevel::trace); client.start(); From 45db87f712eea8cfe875c76ae33eff5dc048338b Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Wed, 23 Jun 2021 13:48:41 -0300 Subject: [PATCH 3/3] Fix uninitialized value in post-start listen test --- tests/test_connect.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index f4c2d07..723f184 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -134,7 +134,7 @@ TEST_CASE("post-start listening", "[connect][listen]") { server.add_category("x", AuthLevel::none) .add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); }); server.start(); - std::atomic listens; + std::atomic listens = 0; auto listen_curve = random_localhost(); server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; }); auto listen_plain = random_localhost();