diff --git a/oxenmq/auth.cpp b/oxenmq/auth.cpp index dfe1f6c..2be7f48 100644 --- a/oxenmq/auth.cpp +++ b/oxenmq/auth.cpp @@ -256,14 +256,14 @@ void OxenMQ::process_zap_requests() { LMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'"); status_code = "400"; status_text = "Unknown authentication domain: " + std::string{auth_domain}; - } else if (bind[bind_id].second.curve + } else if (bind[bind_id].curve ? !(frames.size() == 7 && view(frames[5]) == "CURVE") : !(frames.size() == 6 && view(frames[5]) == "NULL")) { LMQ_LOG(error, "Bad ZAP authentication request: invalid ", - bind[bind_id].second.curve ? "CURVE" : "NULL", " authentication request"); + bind[bind_id].curve ? "CURVE" : "NULL", " authentication request"); status_code = "500"; status_text = "Invalid authentication request mechanism"; - } else if (bind[bind_id].second.curve && frames[6].size() != 32) { + } else if (bind[bind_id].curve && frames[6].size() != 32) { LMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey"); status_code = "500"; status_text = "Invalid public key size for CURVE authentication"; @@ -271,13 +271,13 @@ void OxenMQ::process_zap_requests() { auto ip = view(frames[3]); std::string_view pubkey; bool sn = false; - if (bind[bind_id].second.curve) { + if (bind[bind_id].curve) { pubkey = view(frames[6]); sn = active_service_nodes.count(std::string{pubkey}); } - auto auth = bind[bind_id].second.allow(ip, pubkey, sn); + auto auth = bind[bind_id].allow(ip, pubkey, sn); auto& user_id = response_vals[4]; - if (bind[bind_id].second.curve) { + if (bind[bind_id].curve) { user_id.reserve(64); to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id)); } diff --git a/oxenmq/connections.cpp b/oxenmq/connections.cpp index e87d71c..a84b665 100644 --- a/oxenmq/connections.cpp +++ b/oxenmq/connections.cpp @@ -67,6 +67,21 @@ void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remot // else let ZMQ pick a random one } + +void OxenMQ::setup_incoming_socket(zmq::socket_t& listener, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index) { + + setup_external_socket(listener); + + listener.set(zmq::sockopt::zap_domain, bt_serialize(bind_index)); + if (curve) { + listener.set(zmq::sockopt::curve_server, true); + listener.set(zmq::sockopt::curve_publickey, pubkey); + listener.set(zmq::sockopt::curve_secretkey, privkey); + } + listener.set(zmq::sockopt::router_handover, true); + listener.set(zmq::sockopt::router_mandatory, true); +} + // Deprecated versions: ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect, ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) { @@ -218,7 +233,7 @@ void OxenMQ::proxy_close_connection(size_t index, std::chrono::milliseconds ling update_connection_indices(pending_connects, index, [](auto& pc) -> size_t& { return std::get(pc); }); update_connection_indices(bind, index, - [](auto& b) -> size_t& { return b.second.index; }); + [](auto& b) -> size_t& { return b.index; }); update_connection_indices(incoming_conn_index, index, [](auto& oci) -> size_t& { return oci.second; }); assert(index < conn_index_to_id.size()); diff --git a/oxenmq/oxenmq.cpp b/oxenmq/oxenmq.cpp index 7747c40..c30095b 100644 --- a/oxenmq/oxenmq.cpp +++ b/oxenmq/oxenmq.cpp @@ -232,12 +232,6 @@ void OxenMQ::start() { if (proxy_thread.joinable()) throw std::logic_error("Cannot call start() multiple times!"); - // If we're not binding to anything then we don't listen, i.e. we can only establish outbound - // connections. Don't allow this if we are in service_node mode because, if we aren't - // listening, we are useless as a service node. - if (bind.empty() && local_service_node) - throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"}; - LMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey)); int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit); @@ -267,20 +261,22 @@ void OxenMQ::start() { LMQ_LOG(debug, "Proxy thread is ready"); } -void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) { - // TODO: there's no particular reason we can't start listening after starting up; just needs to - // be implemented. (But if we can start we'll probably also want to be able to stop, so it's - // more than just binding that needs implementing). - check_not_started(proxy_thread, "start listening"); - - bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)}); +void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function on_bind) { + if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; }; + bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)}; + if (proxy_thread.joinable()) + detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d)))); + else + bind.push_back(std::move(d)); } -void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) { - // TODO: As above. - check_not_started(proxy_thread, "start listening"); - - bind.emplace_back(std::move(bind_addr), bind_data{false, std::move(allow_connection)}); +void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function on_bind) { + if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; }; + bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)}; + if (proxy_thread.joinable()) + detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d)))); + else + bind.push_back(std::move(d)); } diff --git a/oxenmq/oxenmq.h b/oxenmq/oxenmq.h index d23b7b7..2db5f17 100644 --- a/oxenmq/oxenmq.h +++ b/oxenmq/oxenmq.h @@ -320,15 +320,17 @@ private: zmq::socket_t zap_auth{context, zmq::socket_type::rep}; struct bind_data { + std::string address; bool curve; size_t index; AllowFunc allow; - bind_data(bool curve, AllowFunc allow) - : curve{curve}, index{0}, allow{std::move(allow)} {} + std::function on_bind; + bind_data(std::string addr, bool curve, AllowFunc allow, std::function on_bind) + : address{std::move(addr)}, curve{curve}, index{0}, allow{std::move(allow)}, on_bind{std::move(on_bind)} {} }; /// Addresses on which we are listening (or, before start(), on which we will listen). - std::vector> bind; + std::vector bind; /// Info about a peer's established connection with us. Note that "established" means both /// connected and authenticated. Note that we only store peer info data for SN connections (in @@ -507,6 +509,9 @@ private: /// gets called after all works have done so. void proxy_quit(); + /// proxy handler for binding to addresses given via listen_*(). + bool proxy_bind(bind_data& bind, size_t index); + // Common setup code for setting up an external (incoming or outgoing) socket. void setup_external_socket(zmq::socket_t& socket); @@ -516,6 +521,9 @@ private: // either accepting curve connections, or not accepting curve). void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey, bool use_ephemeral_routing_id); + /// Sets the various properties on an listening socket prior to binding. + void setup_incoming_socket(zmq::socket_t& socket, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index); + /// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and, /// if a routing prefix is needed, the required prefix (or an empty string if not needed). For /// an optional connect that fails (or some other connection failure), returns nullptr for the @@ -971,14 +979,26 @@ public: * will be encrypted. `allow_connection` is invoked for any incoming connections on this * address to determine the incoming remote's access and authentication level. * + * If called before `start()` then the given bind address is mandatory and start() will throw if + * the bind fails. If called after `start()` then the bind may fail (in which case the callback + * will be used to notify of the failure). + * * @param bind address - can be any string zmq supports; typically a tcp IP/port combination * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". * * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns AuthLevel::none - * access. + * so, the authentication level it receives. If omitted (or null) the default returns + * AuthLevel::none access for all connections. + * + * @param on_bind function to call when the port has been successfully opened or failed to + * open. For addresses set up before .start() this will be called during `start()` itself; for + * post-start listens this will be called from the proxy thread when it opens the new port. + * Note that this function must is called directly from the proxy thread and so should be fast + * and non-blocking. */ - void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); + void listen_curve(std::string bind, + AllowFunc allow_connection = nullptr, + std::function on_bind = nullptr); /** Start listening on the given bind address in unauthenticated plain text mode. Incoming * connections can come from anywhere. `allow_connection` is invoked for any incoming @@ -989,10 +1009,14 @@ public: * such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678". * * @param allow_connection function to call to determine whether to allow the connection and, if - * so, the authentication level it receives. If omitted the default returns AuthLevel::none - * access. + * so, the authentication level it receives. If omitted (or null) the default returns + * AuthLevel::none access for all connections. + * + * @param on_result called after binding with the result; see `listen_curve` for details. */ - void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; }); + void listen_plain(std::string bind, + AllowFunc allow_connection = nullptr, + std::function on_bind = nullptr); /** * Try to initiate a connection to the given SN in anticipation of needing a connection in the @@ -1446,8 +1470,8 @@ namespace connect_option { /// Typically use: `connect_options::ephemeral_routing_id{}` or `connect_options::ephemeral_routing_id{false}`. struct ephemeral_routing_id { bool use_ephemeral_routing_id = true; - // Constructor; default construction gives you pubkey routing, but the bool parameter can be - // specified as false to explicitly disable the pubkey routing flag. + // Constructor; default construction gives you ephemeral routing id, but the bool parameter can + // be specified as false to use pubkey routing flag. explicit ephemeral_routing_id(bool use = true) : use_ephemeral_routing_id{use} {} }; diff --git a/oxenmq/proxy.cpp b/oxenmq/proxy.cpp index 1778599..6208aab 100644 --- a/oxenmq/proxy.cpp +++ b/oxenmq/proxy.cpp @@ -293,6 +293,11 @@ void OxenMQ::proxy_control_message(std::vector& parts) { return proxy_timer(data); } else if (cmd == "TIMER_DEL") { return proxy_timer_del(bt_deserialize(data)); + } else if (cmd == "BIND") { + auto b = detail::deserialize_object(bt_deserialize(data)); + if (proxy_bind(b, bind.size())) + bind.push_back(std::move(b)); + return; } } else if (parts.size() == 2) { if (cmd == "START") { @@ -317,6 +322,38 @@ void OxenMQ::proxy_control_message(std::vector& parts) { std::string{cmd} + " (" + std::to_string(parts.size()) + ")"); } +bool OxenMQ::proxy_bind(bind_data& b, size_t index) { + zmq::socket_t listener{context, zmq::socket_type::router}; + setup_incoming_socket(listener, b.curve, pubkey, privkey, index); + + bool good = true; + try { + listener.bind(b.address); + } catch (const zmq::error_t&) { + good = false; + } + if (b.on_bind) { + b.on_bind(good); + b.on_bind = nullptr; + } + if (!good) { + LMQ_LOG(warn, "OxenMQ failed to listen on ", b.address); + return false; + } + + LMQ_LOG(info, "OxenMQ listening on ", b.address); + + connections.push_back(std::move(listener)); + auto conn_id = next_conn_id++; + conn_index_to_id.push_back(conn_id); + incoming_conn_index[conn_id] = connections.size() - 1; + b.index = connections.size() - 1; + + pollitems_stale = true; + + return true; +} + void OxenMQ::proxy_loop() { #if defined(__linux__) || defined(__sun) || defined(__MINGW32__) @@ -364,27 +401,10 @@ void OxenMQ::proxy_loop() { #endif for (size_t i = 0; i < bind.size(); i++) { - auto& b = bind[i].second; - zmq::socket_t listener{context, zmq::socket_type::router}; - - setup_external_socket(listener); - listener.set(zmq::sockopt::zap_domain, bt_serialize(i)); - if (b.curve) { - listener.set(zmq::sockopt::curve_server, true); - listener.set(zmq::sockopt::curve_publickey, pubkey); - listener.set(zmq::sockopt::curve_secretkey, privkey); + if (!proxy_bind(bind[i], i)) { + LMQ_LOG(warn, "OxenMQ failed to listen on ", bind[i].address); + throw zmq::error_t{}; } - listener.set(zmq::sockopt::router_handover, true); - listener.set(zmq::sockopt::router_mandatory, true); - - listener.bind(bind[i].first); - LMQ_LOG(info, "OxenMQ listening on ", bind[i].first); - - connections.push_back(std::move(listener)); - auto conn_id = next_conn_id++; - conn_index_to_id.push_back(conn_id); - incoming_conn_index[conn_id] = connections.size() - 1; - b.index = connections.size() - 1; } #ifndef _WIN32 @@ -393,13 +413,11 @@ void OxenMQ::proxy_loop() { // set socket gid / uid if it is provided if (SOCKET_GID != -1 or SOCKET_UID != -1) { - for(size_t i = 0; i < bind.size(); i++) { - const address addr(bind[i].first); - if(addr.ipc()) { - if(chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) { + for (auto& b : bind) { + const address addr(b.address); + if (addr.ipc()) + if (chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno)); - } - } } } #endif diff --git a/tests/common.h b/tests/common.h index 6dda8f6..2798675 100644 --- a/tests/common.h +++ b/tests/common.h @@ -1,6 +1,7 @@ #pragma once #include "oxenmq/oxenmq.h" #include +#include using namespace oxenmq; @@ -16,14 +17,23 @@ inline std::string random_localhost() { } -/// Waits up to 100ms for something to happen. +// Catch2 macros aren't thread safe, so guard with a mutex +inline std::unique_lock catch_lock() { + static std::mutex mutex; + return std::unique_lock{mutex}; +} + +/// Waits up to 200ms for something to happen. template inline void wait_for(Func f) { - for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i < 20; i++) { if (f()) break; std::this_thread::sleep_for(10ms); } + auto lock = catch_lock(); + UNSCOPED_INFO("done waiting after " << (std::chrono::steady_clock::now() - start).count() << "ns"); } /// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough @@ -35,12 +45,6 @@ inline void wait_for_conn(std::atomic &c) { /// Waits enough time for us to receive a reply from a localhost remote. inline void reply_sleep() { std::this_thread::sleep_for(10ms); } -// Catch2 macros aren't thread safe, so guard with a mutex -inline std::unique_lock catch_lock() { - static std::mutex mutex; - return std::unique_lock{mutex}; -} - inline OxenMQ::Logger get_logger(std::string prefix = "") { std::string me = "tests/common.h"; std::string strip = __FILE__; diff --git a/tests/test_commands.cpp b/tests/test_commands.cpp index b0c4e24..606fed3 100644 --- a/tests/test_commands.cpp +++ b/tests/test_commands.cpp @@ -173,18 +173,26 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]") server.add_category("hey google", Access{AuthLevel::none}); server.add_request_command("hey google", "remember", [&](Message& m) { - auto l = catch_lock(); - subscribers.emplace_back(m.conn, std::string{m.data[0]}); + bool bd; + { + auto l = catch_lock(); + subscribers.emplace_back(m.conn, std::string{m.data[0]}); + bd = (bool) backdoor; + } m.send_reply("Okay, I'll remember that."); - if (backdoor) + if (bd) m.oxenmq.send(backdoor, "backdoor.data", m.data[0]); }); server.add_command("hey google", "recall", [&](Message& m) { - auto l = catch_lock(); - for (auto& s : subscribers) { - server.send(s.first, "personal.detail", s.second); + decltype(subscribers) subs; + { + auto l = catch_lock(); + subs = subscribers; } + + for (auto& s : subs) + server.send(s.first, "personal.detail", s.second); }); server.add_command("hey google", "install backdoor", [&](Message& m) { auto l = catch_lock(); @@ -363,7 +371,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") { } } -TEST_CASE("data parts", "[send][data_parts]") { +TEST_CASE("data parts", "[commands][send][data_parts]") { std::string listen = random_localhost(); OxenMQ server{ "", "", // generate ephemeral keys @@ -446,7 +454,7 @@ TEST_CASE("data parts", "[send][data_parts]") { } } -TEST_CASE("deferred replies", "[send][deferred]") { +TEST_CASE("deferred replies", "[commands][send][deferred]") { std::string listen = random_localhost(); OxenMQ server{ "", "", // generate ephemeral keys @@ -461,9 +469,9 @@ TEST_CASE("deferred replies", "[send][deferred]") { server.add_request_command("public", "echo", [&](Message& m) { std::string msg = m.data.empty() ? ""s : std::string{m.data.front()}; std::thread t{[send=m.send_later(), msg=std::move(msg)] { - { auto lock = catch_lock(); INFO("sleeping"); } + { auto lock = catch_lock(); UNSCOPED_INFO("sleeping"); } std::this_thread::sleep_for(50ms); - { auto lock = catch_lock(); INFO("sending"); } + { auto lock = catch_lock(); UNSCOPED_INFO("sending"); } send.reply(msg); }}; t.detach(); @@ -472,8 +480,8 @@ TEST_CASE("deferred replies", "[send][deferred]") { server.start(); OxenMQ client( - [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } - ); + get_logger("C» "), + LogLevel::trace); //client.log_level(LogLevel::trace); client.start(); diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index 2c7f10e..723f184 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -129,6 +129,64 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") { } } +TEST_CASE("post-start listening", "[connect][listen]") { + OxenMQ server{get_logger("S» "), LogLevel::trace}; + server.add_category("x", AuthLevel::none) + .add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); }); + server.start(); + std::atomic listens = 0; + auto listen_curve = random_localhost(); + server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; }); + auto listen_plain = random_localhost(); + server.listen_plain(listen_plain, nullptr, [&](bool success) { if (success) listens += 10; }); + + wait_for([&] { return listens.load() >= 11; }); + { + auto lock = catch_lock(); + REQUIRE( listens == 11 ); + } + + // This should fail since we're already listening on it: + server.listen_curve(listen_plain, nullptr, [&](bool success) { if (!success) listens++; }); + + wait_for([&] { return listens.load() >= 12; }); + { + auto lock = catch_lock(); + REQUIRE( listens == 12 ); + } + + + OxenMQ client{get_logger("C1» "), LogLevel::trace}; + client.start(); + std::atomic conns = 0; + auto c1 = client.connect_remote(address{listen_curve, server.get_pubkey()}, + [&](auto) { conns++; }, + [&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); }); + auto c2 = client.connect_remote(address{listen_plain}, + [&](auto) { conns += 10; }, + [&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); }); + + + wait_for([&] { return conns.load() >= 11; }); + { + auto lock = catch_lock(); + REQUIRE( conns == 11 ); + } + + std::atomic replies = 0; + std::string reply1, reply2; + client.request(c1, "x.y", [&](auto success, auto parts) { replies++; for (auto& p : parts) reply1 += p; }, " world"); + client.request(c2, "x.y", [&](auto success, auto parts) { replies += 10; for (auto& p : parts) reply2 += p; }, " cat"); + + wait_for([&] { return replies.load() >= 11; }); + { + auto lock = catch_lock(); + REQUIRE( replies == 11 ); + REQUIRE( reply1 == "hi world" ); + REQUIRE( reply2 == "hi cat" ); + } +} + TEST_CASE("unique connection IDs", "[connect][id]") { std::string listen = random_localhost(); OxenMQ server{get_logger("S» "), LogLevel::trace};