Merge pull request #40 from jagerman/dynamic-listen

Add support for listening on a new port after start()
This commit is contained in:
Jason Rhinelander 2021-06-23 14:06:37 -03:00 committed by GitHub
commit 7ba81a7d50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 205 additions and 82 deletions

View File

@ -256,14 +256,14 @@ void OxenMQ::process_zap_requests() {
LMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
status_code = "400";
status_text = "Unknown authentication domain: " + std::string{auth_domain};
} else if (bind[bind_id].second.curve
} else if (bind[bind_id].curve
? !(frames.size() == 7 && view(frames[5]) == "CURVE")
: !(frames.size() == 6 && view(frames[5]) == "NULL")) {
LMQ_LOG(error, "Bad ZAP authentication request: invalid ",
bind[bind_id].second.curve ? "CURVE" : "NULL", " authentication request");
bind[bind_id].curve ? "CURVE" : "NULL", " authentication request");
status_code = "500";
status_text = "Invalid authentication request mechanism";
} else if (bind[bind_id].second.curve && frames[6].size() != 32) {
} else if (bind[bind_id].curve && frames[6].size() != 32) {
LMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
status_code = "500";
status_text = "Invalid public key size for CURVE authentication";
@ -271,13 +271,13 @@ void OxenMQ::process_zap_requests() {
auto ip = view(frames[3]);
std::string_view pubkey;
bool sn = false;
if (bind[bind_id].second.curve) {
if (bind[bind_id].curve) {
pubkey = view(frames[6]);
sn = active_service_nodes.count(std::string{pubkey});
}
auto auth = bind[bind_id].second.allow(ip, pubkey, sn);
auto auth = bind[bind_id].allow(ip, pubkey, sn);
auto& user_id = response_vals[4];
if (bind[bind_id].second.curve) {
if (bind[bind_id].curve) {
user_id.reserve(64);
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
}

View File

@ -67,6 +67,21 @@ void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remot
// else let ZMQ pick a random one
}
void OxenMQ::setup_incoming_socket(zmq::socket_t& listener, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index) {
setup_external_socket(listener);
listener.set(zmq::sockopt::zap_domain, bt_serialize(bind_index));
if (curve) {
listener.set(zmq::sockopt::curve_server, true);
listener.set(zmq::sockopt::curve_publickey, pubkey);
listener.set(zmq::sockopt::curve_secretkey, privkey);
}
listener.set(zmq::sockopt::router_handover, true);
listener.set(zmq::sockopt::router_mandatory, true);
}
// Deprecated versions:
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) {
@ -218,7 +233,7 @@ void OxenMQ::proxy_close_connection(size_t index, std::chrono::milliseconds ling
update_connection_indices(pending_connects, index,
[](auto& pc) -> size_t& { return std::get<size_t>(pc); });
update_connection_indices(bind, index,
[](auto& b) -> size_t& { return b.second.index; });
[](auto& b) -> size_t& { return b.index; });
update_connection_indices(incoming_conn_index, index,
[](auto& oci) -> size_t& { return oci.second; });
assert(index < conn_index_to_id.size());

View File

@ -232,12 +232,6 @@ void OxenMQ::start() {
if (proxy_thread.joinable())
throw std::logic_error("Cannot call start() multiple times!");
// If we're not binding to anything then we don't listen, i.e. we can only establish outbound
// connections. Don't allow this if we are in service_node mode because, if we aren't
// listening, we are useless as a service node.
if (bind.empty() && local_service_node)
throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"};
LMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey));
int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit);
@ -267,20 +261,22 @@ void OxenMQ::start() {
LMQ_LOG(debug, "Proxy thread is ready");
}
void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) {
// TODO: there's no particular reason we can't start listening after starting up; just needs to
// be implemented. (But if we can start we'll probably also want to be able to stop, so it's
// more than just binding that needs implementing).
check_not_started(proxy_thread, "start listening");
bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)});
void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; };
bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)};
if (proxy_thread.joinable())
detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d))));
else
bind.push_back(std::move(d));
}
void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) {
// TODO: As above.
check_not_started(proxy_thread, "start listening");
bind.emplace_back(std::move(bind_addr), bind_data{false, std::move(allow_connection)});
void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; };
bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)};
if (proxy_thread.joinable())
detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d))));
else
bind.push_back(std::move(d));
}

View File

@ -320,15 +320,17 @@ private:
zmq::socket_t zap_auth{context, zmq::socket_type::rep};
struct bind_data {
std::string address;
bool curve;
size_t index;
AllowFunc allow;
bind_data(bool curve, AllowFunc allow)
: curve{curve}, index{0}, allow{std::move(allow)} {}
std::function<void(bool)> on_bind;
bind_data(std::string addr, bool curve, AllowFunc allow, std::function<void(bool)> on_bind)
: address{std::move(addr)}, curve{curve}, index{0}, allow{std::move(allow)}, on_bind{std::move(on_bind)} {}
};
/// Addresses on which we are listening (or, before start(), on which we will listen).
std::vector<std::pair<std::string, bind_data>> bind;
std::vector<bind_data> bind;
/// Info about a peer's established connection with us. Note that "established" means both
/// connected and authenticated. Note that we only store peer info data for SN connections (in
@ -507,6 +509,9 @@ private:
/// gets called after all works have done so.
void proxy_quit();
/// proxy handler for binding to addresses given via listen_*().
bool proxy_bind(bind_data& bind, size_t index);
// Common setup code for setting up an external (incoming or outgoing) socket.
void setup_external_socket(zmq::socket_t& socket);
@ -516,6 +521,9 @@ private:
// either accepting curve connections, or not accepting curve).
void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey, bool use_ephemeral_routing_id);
/// Sets the various properties on an listening socket prior to binding.
void setup_incoming_socket(zmq::socket_t& socket, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index);
/// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and,
/// if a routing prefix is needed, the required prefix (or an empty string if not needed). For
/// an optional connect that fails (or some other connection failure), returns nullptr for the
@ -971,14 +979,26 @@ public:
* will be encrypted. `allow_connection` is invoked for any incoming connections on this
* address to determine the incoming remote's access and authentication level.
*
* If called before `start()` then the given bind address is mandatory and start() will throw if
* the bind fails. If called after `start()` then the bind may fail (in which case the callback
* will be used to notify of the failure).
*
* @param bind address - can be any string zmq supports; typically a tcp IP/port combination
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
* so, the authentication level it receives. If omitted (or null) the default returns
* AuthLevel::none access for all connections.
*
* @param on_bind function to call when the port has been successfully opened or failed to
* open. For addresses set up before .start() this will be called during `start()` itself; for
* post-start listens this will be called from the proxy thread when it opens the new port.
* Note that this function must is called directly from the proxy thread and so should be fast
* and non-blocking.
*/
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
void listen_curve(std::string bind,
AllowFunc allow_connection = nullptr,
std::function<void(bool success)> on_bind = nullptr);
/** Start listening on the given bind address in unauthenticated plain text mode. Incoming
* connections can come from anywhere. `allow_connection` is invoked for any incoming
@ -989,10 +1009,14 @@ public:
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
* so, the authentication level it receives. If omitted (or null) the default returns
* AuthLevel::none access for all connections.
*
* @param on_result called after binding with the result; see `listen_curve` for details.
*/
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
void listen_plain(std::string bind,
AllowFunc allow_connection = nullptr,
std::function<void(bool success)> on_bind = nullptr);
/**
* Try to initiate a connection to the given SN in anticipation of needing a connection in the
@ -1446,8 +1470,8 @@ namespace connect_option {
/// Typically use: `connect_options::ephemeral_routing_id{}` or `connect_options::ephemeral_routing_id{false}`.
struct ephemeral_routing_id {
bool use_ephemeral_routing_id = true;
// Constructor; default construction gives you pubkey routing, but the bool parameter can be
// specified as false to explicitly disable the pubkey routing flag.
// Constructor; default construction gives you ephemeral routing id, but the bool parameter can
// be specified as false to use pubkey routing flag.
explicit ephemeral_routing_id(bool use = true) : use_ephemeral_routing_id{use} {}
};

View File

@ -293,6 +293,11 @@ void OxenMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
return proxy_timer(data);
} else if (cmd == "TIMER_DEL") {
return proxy_timer_del(bt_deserialize<int>(data));
} else if (cmd == "BIND") {
auto b = detail::deserialize_object<bind_data>(bt_deserialize<uintptr_t>(data));
if (proxy_bind(b, bind.size()))
bind.push_back(std::move(b));
return;
}
} else if (parts.size() == 2) {
if (cmd == "START") {
@ -317,6 +322,38 @@ void OxenMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
std::string{cmd} + " (" + std::to_string(parts.size()) + ")");
}
bool OxenMQ::proxy_bind(bind_data& b, size_t index) {
zmq::socket_t listener{context, zmq::socket_type::router};
setup_incoming_socket(listener, b.curve, pubkey, privkey, index);
bool good = true;
try {
listener.bind(b.address);
} catch (const zmq::error_t&) {
good = false;
}
if (b.on_bind) {
b.on_bind(good);
b.on_bind = nullptr;
}
if (!good) {
LMQ_LOG(warn, "OxenMQ failed to listen on ", b.address);
return false;
}
LMQ_LOG(info, "OxenMQ listening on ", b.address);
connections.push_back(std::move(listener));
auto conn_id = next_conn_id++;
conn_index_to_id.push_back(conn_id);
incoming_conn_index[conn_id] = connections.size() - 1;
b.index = connections.size() - 1;
pollitems_stale = true;
return true;
}
void OxenMQ::proxy_loop() {
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
@ -364,27 +401,10 @@ void OxenMQ::proxy_loop() {
#endif
for (size_t i = 0; i < bind.size(); i++) {
auto& b = bind[i].second;
zmq::socket_t listener{context, zmq::socket_type::router};
setup_external_socket(listener);
listener.set(zmq::sockopt::zap_domain, bt_serialize(i));
if (b.curve) {
listener.set(zmq::sockopt::curve_server, true);
listener.set(zmq::sockopt::curve_publickey, pubkey);
listener.set(zmq::sockopt::curve_secretkey, privkey);
if (!proxy_bind(bind[i], i)) {
LMQ_LOG(warn, "OxenMQ failed to listen on ", bind[i].address);
throw zmq::error_t{};
}
listener.set(zmq::sockopt::router_handover, true);
listener.set(zmq::sockopt::router_mandatory, true);
listener.bind(bind[i].first);
LMQ_LOG(info, "OxenMQ listening on ", bind[i].first);
connections.push_back(std::move(listener));
auto conn_id = next_conn_id++;
conn_index_to_id.push_back(conn_id);
incoming_conn_index[conn_id] = connections.size() - 1;
b.index = connections.size() - 1;
}
#ifndef _WIN32
@ -393,13 +413,11 @@ void OxenMQ::proxy_loop() {
// set socket gid / uid if it is provided
if (SOCKET_GID != -1 or SOCKET_UID != -1) {
for(size_t i = 0; i < bind.size(); i++) {
const address addr(bind[i].first);
if(addr.ipc()) {
if(chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) {
for (auto& b : bind) {
const address addr(b.address);
if (addr.ipc())
if (chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1)
throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno));
}
}
}
}
#endif

View File

@ -1,6 +1,7 @@
#pragma once
#include "oxenmq/oxenmq.h"
#include <catch2/catch.hpp>
#include <chrono>
using namespace oxenmq;
@ -16,14 +17,23 @@ inline std::string random_localhost() {
}
/// Waits up to 100ms for something to happen.
// Catch2 macros aren't thread safe, so guard with a mutex
inline std::unique_lock<std::mutex> catch_lock() {
static std::mutex mutex;
return std::unique_lock<std::mutex>{mutex};
}
/// Waits up to 200ms for something to happen.
template <typename Func>
inline void wait_for(Func f) {
for (int i = 0; i < 10; i++) {
auto start = std::chrono::steady_clock::now();
for (int i = 0; i < 20; i++) {
if (f())
break;
std::this_thread::sleep_for(10ms);
}
auto lock = catch_lock();
UNSCOPED_INFO("done waiting after " << (std::chrono::steady_clock::now() - start).count() << "ns");
}
/// Waits on an atomic bool for up to 100ms for an initial connection, which is more than enough
@ -35,12 +45,6 @@ inline void wait_for_conn(std::atomic<bool> &c) {
/// Waits enough time for us to receive a reply from a localhost remote.
inline void reply_sleep() { std::this_thread::sleep_for(10ms); }
// Catch2 macros aren't thread safe, so guard with a mutex
inline std::unique_lock<std::mutex> catch_lock() {
static std::mutex mutex;
return std::unique_lock<std::mutex>{mutex};
}
inline OxenMQ::Logger get_logger(std::string prefix = "") {
std::string me = "tests/common.h";
std::string strip = __FILE__;

View File

@ -173,18 +173,26 @@ TEST_CASE("deferred replies on incoming connections", "[commands][hey google]")
server.add_category("hey google", Access{AuthLevel::none});
server.add_request_command("hey google", "remember", [&](Message& m) {
auto l = catch_lock();
subscribers.emplace_back(m.conn, std::string{m.data[0]});
bool bd;
{
auto l = catch_lock();
subscribers.emplace_back(m.conn, std::string{m.data[0]});
bd = (bool) backdoor;
}
m.send_reply("Okay, I'll remember that.");
if (backdoor)
if (bd)
m.oxenmq.send(backdoor, "backdoor.data", m.data[0]);
});
server.add_command("hey google", "recall", [&](Message& m) {
auto l = catch_lock();
for (auto& s : subscribers) {
server.send(s.first, "personal.detail", s.second);
decltype(subscribers) subs;
{
auto l = catch_lock();
subs = subscribers;
}
for (auto& s : subs)
server.send(s.first, "personal.detail", s.second);
});
server.add_command("hey google", "install backdoor", [&](Message& m) {
auto l = catch_lock();
@ -363,7 +371,7 @@ TEST_CASE("send failure callbacks", "[commands][queue_full]") {
}
}
TEST_CASE("data parts", "[send][data_parts]") {
TEST_CASE("data parts", "[commands][send][data_parts]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
@ -446,7 +454,7 @@ TEST_CASE("data parts", "[send][data_parts]") {
}
}
TEST_CASE("deferred replies", "[send][deferred]") {
TEST_CASE("deferred replies", "[commands][send][deferred]") {
std::string listen = random_localhost();
OxenMQ server{
"", "", // generate ephemeral keys
@ -461,9 +469,9 @@ TEST_CASE("deferred replies", "[send][deferred]") {
server.add_request_command("public", "echo", [&](Message& m) {
std::string msg = m.data.empty() ? ""s : std::string{m.data.front()};
std::thread t{[send=m.send_later(), msg=std::move(msg)] {
{ auto lock = catch_lock(); INFO("sleeping"); }
{ auto lock = catch_lock(); UNSCOPED_INFO("sleeping"); }
std::this_thread::sleep_for(50ms);
{ auto lock = catch_lock(); INFO("sending"); }
{ auto lock = catch_lock(); UNSCOPED_INFO("sending"); }
send.reply(msg);
}};
t.detach();
@ -472,8 +480,8 @@ TEST_CASE("deferred replies", "[send][deferred]") {
server.start();
OxenMQ client(
[](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; }
);
get_logger(""),
LogLevel::trace);
//client.log_level(LogLevel::trace);
client.start();

View File

@ -129,6 +129,64 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") {
}
}
TEST_CASE("post-start listening", "[connect][listen]") {
OxenMQ server{get_logger(""), LogLevel::trace};
server.add_category("x", AuthLevel::none)
.add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); });
server.start();
std::atomic<int> listens = 0;
auto listen_curve = random_localhost();
server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; });
auto listen_plain = random_localhost();
server.listen_plain(listen_plain, nullptr, [&](bool success) { if (success) listens += 10; });
wait_for([&] { return listens.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( listens == 11 );
}
// This should fail since we're already listening on it:
server.listen_curve(listen_plain, nullptr, [&](bool success) { if (!success) listens++; });
wait_for([&] { return listens.load() >= 12; });
{
auto lock = catch_lock();
REQUIRE( listens == 12 );
}
OxenMQ client{get_logger("C1» "), LogLevel::trace};
client.start();
std::atomic<int> conns = 0;
auto c1 = client.connect_remote(address{listen_curve, server.get_pubkey()},
[&](auto) { conns++; },
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
auto c2 = client.connect_remote(address{listen_plain},
[&](auto) { conns += 10; },
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
wait_for([&] { return conns.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( conns == 11 );
}
std::atomic<int> replies = 0;
std::string reply1, reply2;
client.request(c1, "x.y", [&](auto success, auto parts) { replies++; for (auto& p : parts) reply1 += p; }, " world");
client.request(c2, "x.y", [&](auto success, auto parts) { replies += 10; for (auto& p : parts) reply2 += p; }, " cat");
wait_for([&] { return replies.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( replies == 11 );
REQUIRE( reply1 == "hi world" );
REQUIRE( reply2 == "hi cat" );
}
}
TEST_CASE("unique connection IDs", "[connect][id]") {
std::string listen = random_localhost();
OxenMQ server{get_logger(""), LogLevel::trace};