Add support for listening after startup

This commit adds support for listening on new ports after startup.  This
will make things easier in storage server, in particular, where we want
to delay listening on public ports until we have an established
connection and initial block status update from oxend.
This commit is contained in:
Jason Rhinelander 2021-06-23 10:51:08 -03:00
parent dccbd1e8cd
commit 5dd7c12219
6 changed files with 173 additions and 62 deletions

View File

@ -256,14 +256,14 @@ void OxenMQ::process_zap_requests() {
LMQ_LOG(error, "Bad ZAP authentication request: invalid auth domain '", auth_domain, "'");
status_code = "400";
status_text = "Unknown authentication domain: " + std::string{auth_domain};
} else if (bind[bind_id].second.curve
} else if (bind[bind_id].curve
? !(frames.size() == 7 && view(frames[5]) == "CURVE")
: !(frames.size() == 6 && view(frames[5]) == "NULL")) {
LMQ_LOG(error, "Bad ZAP authentication request: invalid ",
bind[bind_id].second.curve ? "CURVE" : "NULL", " authentication request");
bind[bind_id].curve ? "CURVE" : "NULL", " authentication request");
status_code = "500";
status_text = "Invalid authentication request mechanism";
} else if (bind[bind_id].second.curve && frames[6].size() != 32) {
} else if (bind[bind_id].curve && frames[6].size() != 32) {
LMQ_LOG(error, "Bad ZAP authentication request: invalid request pubkey");
status_code = "500";
status_text = "Invalid public key size for CURVE authentication";
@ -271,13 +271,13 @@ void OxenMQ::process_zap_requests() {
auto ip = view(frames[3]);
std::string_view pubkey;
bool sn = false;
if (bind[bind_id].second.curve) {
if (bind[bind_id].curve) {
pubkey = view(frames[6]);
sn = active_service_nodes.count(std::string{pubkey});
}
auto auth = bind[bind_id].second.allow(ip, pubkey, sn);
auto auth = bind[bind_id].allow(ip, pubkey, sn);
auto& user_id = response_vals[4];
if (bind[bind_id].second.curve) {
if (bind[bind_id].curve) {
user_id.reserve(64);
to_hex(pubkey.begin(), pubkey.end(), std::back_inserter(user_id));
}

View File

@ -67,6 +67,21 @@ void OxenMQ::setup_outgoing_socket(zmq::socket_t& socket, std::string_view remot
// else let ZMQ pick a random one
}
void OxenMQ::setup_incoming_socket(zmq::socket_t& listener, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index) {
setup_external_socket(listener);
listener.set(zmq::sockopt::zap_domain, bt_serialize(bind_index));
if (curve) {
listener.set(zmq::sockopt::curve_server, true);
listener.set(zmq::sockopt::curve_publickey, pubkey);
listener.set(zmq::sockopt::curve_secretkey, privkey);
}
listener.set(zmq::sockopt::router_handover, true);
listener.set(zmq::sockopt::router_mandatory, true);
}
// Deprecated versions:
ConnectionID OxenMQ::connect_remote(std::string_view remote, ConnectSuccess on_connect,
ConnectFailure on_failure, AuthLevel auth_level, std::chrono::milliseconds timeout) {
@ -218,7 +233,7 @@ void OxenMQ::proxy_close_connection(size_t index, std::chrono::milliseconds ling
update_connection_indices(pending_connects, index,
[](auto& pc) -> size_t& { return std::get<size_t>(pc); });
update_connection_indices(bind, index,
[](auto& b) -> size_t& { return b.second.index; });
[](auto& b) -> size_t& { return b.index; });
update_connection_indices(incoming_conn_index, index,
[](auto& oci) -> size_t& { return oci.second; });
assert(index < conn_index_to_id.size());

View File

@ -232,12 +232,6 @@ void OxenMQ::start() {
if (proxy_thread.joinable())
throw std::logic_error("Cannot call start() multiple times!");
// If we're not binding to anything then we don't listen, i.e. we can only establish outbound
// connections. Don't allow this if we are in service_node mode because, if we aren't
// listening, we are useless as a service node.
if (bind.empty() && local_service_node)
throw std::invalid_argument{"Cannot create a service node listener with no address(es) to bind"};
LMQ_LOG(info, "Initializing OxenMQ ", bind.empty() ? "remote-only" : "listener", " with pubkey ", to_hex(pubkey));
int zmq_socket_limit = context.get(zmq::ctxopt::socket_limit);
@ -267,20 +261,22 @@ void OxenMQ::start() {
LMQ_LOG(debug, "Proxy thread is ready");
}
void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection) {
// TODO: there's no particular reason we can't start listening after starting up; just needs to
// be implemented. (But if we can start we'll probably also want to be able to stop, so it's
// more than just binding that needs implementing).
check_not_started(proxy_thread, "start listening");
bind.emplace_back(std::move(bind_addr), bind_data{true, std::move(allow_connection)});
void OxenMQ::listen_curve(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; };
bind_data d{std::move(bind_addr), true, std::move(allow_connection), std::move(on_bind)};
if (proxy_thread.joinable())
detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d))));
else
bind.push_back(std::move(d));
}
void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection) {
// TODO: As above.
check_not_started(proxy_thread, "start listening");
bind.emplace_back(std::move(bind_addr), bind_data{false, std::move(allow_connection)});
void OxenMQ::listen_plain(std::string bind_addr, AllowFunc allow_connection, std::function<void(bool)> on_bind) {
if (!allow_connection) allow_connection = [](auto, auto, auto) { return AuthLevel::none; };
bind_data d{std::move(bind_addr), false, std::move(allow_connection), std::move(on_bind)};
if (proxy_thread.joinable())
detail::send_control(get_control_socket(), "BIND", bt_serialize(detail::serialize_object(std::move(d))));
else
bind.push_back(std::move(d));
}

View File

@ -320,15 +320,17 @@ private:
zmq::socket_t zap_auth{context, zmq::socket_type::rep};
struct bind_data {
std::string address;
bool curve;
size_t index;
AllowFunc allow;
bind_data(bool curve, AllowFunc allow)
: curve{curve}, index{0}, allow{std::move(allow)} {}
std::function<void(bool)> on_bind;
bind_data(std::string addr, bool curve, AllowFunc allow, std::function<void(bool)> on_bind)
: address{std::move(addr)}, curve{curve}, index{0}, allow{std::move(allow)}, on_bind{std::move(on_bind)} {}
};
/// Addresses on which we are listening (or, before start(), on which we will listen).
std::vector<std::pair<std::string, bind_data>> bind;
std::vector<bind_data> bind;
/// Info about a peer's established connection with us. Note that "established" means both
/// connected and authenticated. Note that we only store peer info data for SN connections (in
@ -507,6 +509,9 @@ private:
/// gets called after all works have done so.
void proxy_quit();
/// proxy handler for binding to addresses given via listen_*().
bool proxy_bind(bind_data& bind, size_t index);
// Common setup code for setting up an external (incoming or outgoing) socket.
void setup_external_socket(zmq::socket_t& socket);
@ -516,6 +521,9 @@ private:
// either accepting curve connections, or not accepting curve).
void setup_outgoing_socket(zmq::socket_t& socket, std::string_view remote_pubkey, bool use_ephemeral_routing_id);
/// Sets the various properties on an listening socket prior to binding.
void setup_incoming_socket(zmq::socket_t& socket, bool curve, std::string_view pubkey, std::string_view privkey, size_t bind_index);
/// Common connection implementation used by proxy_connect/proxy_send. Returns the socket and,
/// if a routing prefix is needed, the required prefix (or an empty string if not needed). For
/// an optional connect that fails (or some other connection failure), returns nullptr for the
@ -971,14 +979,26 @@ public:
* will be encrypted. `allow_connection` is invoked for any incoming connections on this
* address to determine the incoming remote's access and authentication level.
*
* If called before `start()` then the given bind address is mandatory and start() will throw if
* the bind fails. If called after `start()` then the bind may fail (in which case the callback
* will be used to notify of the failure).
*
* @param bind address - can be any string zmq supports; typically a tcp IP/port combination
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
* so, the authentication level it receives. If omitted (or null) the default returns
* AuthLevel::none access for all connections.
*
* @param on_bind function to call when the port has been successfully opened or failed to
* open. For addresses set up before .start() this will be called during `start()` itself; for
* post-start listens this will be called from the proxy thread when it opens the new port.
* Note that this function must is called directly from the proxy thread and so should be fast
* and non-blocking.
*/
void listen_curve(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
void listen_curve(std::string bind,
AllowFunc allow_connection = nullptr,
std::function<void(bool success)> on_bind = nullptr);
/** Start listening on the given bind address in unauthenticated plain text mode. Incoming
* connections can come from anywhere. `allow_connection` is invoked for any incoming
@ -989,10 +1009,14 @@ public:
* such as: "tcp://\*:4567" or "tcp://1.2.3.4:5678".
*
* @param allow_connection function to call to determine whether to allow the connection and, if
* so, the authentication level it receives. If omitted the default returns AuthLevel::none
* access.
* so, the authentication level it receives. If omitted (or null) the default returns
* AuthLevel::none access for all connections.
*
* @param on_result called after binding with the result; see `listen_curve` for details.
*/
void listen_plain(std::string bind, AllowFunc allow_connection = [](auto, auto, auto) { return AuthLevel::none; });
void listen_plain(std::string bind,
AllowFunc allow_connection = nullptr,
std::function<void(bool success)> on_bind = nullptr);
/**
* Try to initiate a connection to the given SN in anticipation of needing a connection in the
@ -1446,8 +1470,8 @@ namespace connect_option {
/// Typically use: `connect_options::ephemeral_routing_id{}` or `connect_options::ephemeral_routing_id{false}`.
struct ephemeral_routing_id {
bool use_ephemeral_routing_id = true;
// Constructor; default construction gives you pubkey routing, but the bool parameter can be
// specified as false to explicitly disable the pubkey routing flag.
// Constructor; default construction gives you ephemeral routing id, but the bool parameter can
// be specified as false to use pubkey routing flag.
explicit ephemeral_routing_id(bool use = true) : use_ephemeral_routing_id{use} {}
};

View File

@ -293,6 +293,11 @@ void OxenMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
return proxy_timer(data);
} else if (cmd == "TIMER_DEL") {
return proxy_timer_del(bt_deserialize<int>(data));
} else if (cmd == "BIND") {
auto b = detail::deserialize_object<bind_data>(bt_deserialize<uintptr_t>(data));
if (proxy_bind(b, bind.size()))
bind.push_back(std::move(b));
return;
}
} else if (parts.size() == 2) {
if (cmd == "START") {
@ -317,6 +322,38 @@ void OxenMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
std::string{cmd} + " (" + std::to_string(parts.size()) + ")");
}
bool OxenMQ::proxy_bind(bind_data& b, size_t index) {
zmq::socket_t listener{context, zmq::socket_type::router};
setup_incoming_socket(listener, b.curve, pubkey, privkey, index);
bool good = true;
try {
listener.bind(b.address);
} catch (const zmq::error_t&) {
good = false;
}
if (b.on_bind) {
b.on_bind(good);
b.on_bind = nullptr;
}
if (!good) {
LMQ_LOG(warn, "OxenMQ failed to listen on ", b.address);
return false;
}
LMQ_LOG(info, "OxenMQ listening on ", b.address);
connections.push_back(std::move(listener));
auto conn_id = next_conn_id++;
conn_index_to_id.push_back(conn_id);
incoming_conn_index[conn_id] = connections.size() - 1;
b.index = connections.size() - 1;
pollitems_stale = true;
return true;
}
void OxenMQ::proxy_loop() {
#if defined(__linux__) || defined(__sun) || defined(__MINGW32__)
@ -364,27 +401,10 @@ void OxenMQ::proxy_loop() {
#endif
for (size_t i = 0; i < bind.size(); i++) {
auto& b = bind[i].second;
zmq::socket_t listener{context, zmq::socket_type::router};
setup_external_socket(listener);
listener.set(zmq::sockopt::zap_domain, bt_serialize(i));
if (b.curve) {
listener.set(zmq::sockopt::curve_server, true);
listener.set(zmq::sockopt::curve_publickey, pubkey);
listener.set(zmq::sockopt::curve_secretkey, privkey);
if (!proxy_bind(bind[i], i)) {
LMQ_LOG(warn, "OxenMQ failed to listen on ", bind[i].address);
throw zmq::error_t{};
}
listener.set(zmq::sockopt::router_handover, true);
listener.set(zmq::sockopt::router_mandatory, true);
listener.bind(bind[i].first);
LMQ_LOG(info, "OxenMQ listening on ", bind[i].first);
connections.push_back(std::move(listener));
auto conn_id = next_conn_id++;
conn_index_to_id.push_back(conn_id);
incoming_conn_index[conn_id] = connections.size() - 1;
b.index = connections.size() - 1;
}
#ifndef _WIN32
@ -393,13 +413,11 @@ void OxenMQ::proxy_loop() {
// set socket gid / uid if it is provided
if (SOCKET_GID != -1 or SOCKET_UID != -1) {
for(size_t i = 0; i < bind.size(); i++) {
const address addr(bind[i].first);
if(addr.ipc()) {
if(chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1) {
for (auto& b : bind) {
const address addr(b.address);
if (addr.ipc())
if (chown(addr.socket.c_str(), SOCKET_UID, SOCKET_GID) == -1)
throw std::runtime_error("cannot set group on " + addr.socket + ": " + strerror(errno));
}
}
}
}
#endif

View File

@ -129,6 +129,64 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") {
}
}
TEST_CASE("post-start listening", "[connect][listen]") {
OxenMQ server{get_logger(""), LogLevel::trace};
server.add_category("x", AuthLevel::none)
.add_request_command("y", [&](Message& m) { m.send_reply("hi", m.data[0]); });
server.start();
std::atomic<int> listens;
auto listen_curve = random_localhost();
server.listen_curve(listen_curve, nullptr, [&](bool success) { if (success) listens++; });
auto listen_plain = random_localhost();
server.listen_plain(listen_plain, nullptr, [&](bool success) { if (success) listens += 10; });
wait_for([&] { return listens.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( listens == 11 );
}
// This should fail since we're already listening on it:
server.listen_curve(listen_plain, nullptr, [&](bool success) { if (!success) listens++; });
wait_for([&] { return listens.load() >= 12; });
{
auto lock = catch_lock();
REQUIRE( listens == 12 );
}
OxenMQ client{get_logger("C1» "), LogLevel::trace};
client.start();
std::atomic<int> conns = 0;
auto c1 = client.connect_remote(address{listen_curve, server.get_pubkey()},
[&](auto) { conns++; },
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
auto c2 = client.connect_remote(address{listen_plain},
[&](auto) { conns += 10; },
[&](auto, auto why) { auto lock = catch_lock(); UNSCOPED_INFO("connection failed: " << why); });
wait_for([&] { return conns.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( conns == 11 );
}
std::atomic<int> replies = 0;
std::string reply1, reply2;
client.request(c1, "x.y", [&](auto success, auto parts) { replies++; for (auto& p : parts) reply1 += p; }, " world");
client.request(c2, "x.y", [&](auto success, auto parts) { replies += 10; for (auto& p : parts) reply2 += p; }, " cat");
wait_for([&] { return replies.load() >= 11; });
{
auto lock = catch_lock();
REQUIRE( replies == 11 );
REQUIRE( reply1 == "hi world" );
REQUIRE( reply2 == "hi cat" );
}
}
TEST_CASE("unique connection IDs", "[connect][id]") {
std::string listen = random_localhost();
OxenMQ server{get_logger(""), LogLevel::trace};