Simplify conn index handling (#41)

The existing code was overly complicated by trying to track indices in
the `connections` vector, which complication happening because things
get removed from `connections` requiring all the internal index values
to be updated.  So we ended up with a connection ID inside the
ConnectionID object, plus a map of those connection IDs to the
`connections` index, and need a map back from indices to ConnectionIDs.

Though this seems to work usually, I recently noticed an
oxen-storage-server sending oxend requests on the wrong connection and
so I suspect there is some rare edge cases here where a failed
connection index might not be updated properly.

This PR simplifies the whole thing by making getting rid of connection
ids entirely and keeping the connections in a map (with connection ids
that never change).  This might end up being a little less efficient
than the vector, but it's unlikely to matter and the added complexity
isn't worth it.
This commit is contained in:
Jason Rhinelander 2021-06-23 17:51:25 -03:00 committed by GitHub
parent 7ba81a7d50
commit ad04c53c0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 127 additions and 150 deletions

View File

@ -31,7 +31,7 @@ std::string zmtp_metadata(std::string_view key, std::string_view value) {
} }
bool OxenMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, bool OxenMQ::proxy_check_auth(int64_t conn_id, bool outgoing, const peer_info& peer,
zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) { zmq::message_t& cmd, const cat_call_t& cat_call, std::vector<zmq::message_t>& data) {
auto command = view(cmd); auto command = view(cmd);
std::string reply; std::string reply;
@ -72,7 +72,7 @@ bool OxenMQ::proxy_check_auth(size_t conn_index, bool outgoing, const peer_info&
} }
try { try {
send_message_parts(connections[conn_index], msgs); send_message_parts(connections.at(conn_id), msgs);
} catch (const zmq::error_t& err) { } catch (const zmq::error_t& err) {
/* can't send: possibly already disconnected. Ignore. */ /* can't send: possibly already disconnected. Ignore. */
LMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what()); LMQ_LOG(debug, "Couldn't send auth failure message ", reply, " to peer [", to_hex(peer.pubkey), "]/", peer_address(cmd), ": ", err.what());
@ -178,11 +178,11 @@ void OxenMQ::proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed)
auto range = peers.equal_range(c); auto range = peers.equal_range(c);
for (auto it = range.first; it != range.second; ) { for (auto it = range.first; it != range.second; ) {
bool outgoing = it->second.outgoing(); bool outgoing = it->second.outgoing();
size_t conn_index = it->second.conn_index; auto conn_id = it->second.conn_id;
it = peers.erase(it); it = peers.erase(it);
if (outgoing) { if (outgoing) {
LMQ_LOG(debug, "Closing outgoing connection to ", c); LMQ_LOG(debug, "Closing outgoing connection to ", c);
proxy_close_connection(conn_index, CLOSE_LINGER); proxy_close_connection(conn_id, CLOSE_LINGER);
} }
} }
} }

View File

@ -30,9 +30,9 @@ void OxenMQ::rebuild_pollitems() {
add_pollitem(pollitems, workers_socket); add_pollitem(pollitems, workers_socket);
add_pollitem(pollitems, zap_auth); add_pollitem(pollitems, zap_auth);
for (auto& s : connections) for (auto& [id, s] : connections)
add_pollitem(pollitems, s); add_pollitem(pollitems, s);
pollitems_stale = false; connections_updated = false;
} }
void OxenMQ::setup_external_socket(zmq::socket_t& socket) { void OxenMQ::setup_external_socket(zmq::socket_t& socket) {
@ -128,7 +128,7 @@ OxenMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint,
} }
peer->activity(); peer->activity();
} }
return {&connections[peer->conn_index], peer->route}; return {&connections[peer->conn_id], peer->route};
} else if (optional || incoming_only) { } else if (optional || incoming_only) {
LMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt"); LMQ_LOG(debug, "proxy asked for optional or incoming connection, but no appropriate connection exists so aborting connection attempt");
return {nullptr, ""s}; return {nullptr, ""s};
@ -166,18 +166,17 @@ OxenMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint,
LMQ_LOG(error, "Outgoing connection to ", addr, " failed: ", e.what()); LMQ_LOG(error, "Outgoing connection to ", addr, " failed: ", e.what());
return {nullptr, ""s}; return {nullptr, ""s};
} }
peer_info p{};
auto& p = peers.emplace(std::move(remote_cid), peer_info{})->second;
p.service_node = true; p.service_node = true;
p.pubkey = std::string{remote}; p.pubkey = std::string{remote};
p.conn_index = connections.size(); p.conn_id = next_conn_id++;
p.idle_expiry = keep_alive; p.idle_expiry = keep_alive;
p.activity(); p.activity();
conn_index_to_id.push_back(remote_cid); connections_updated = true;
peers.emplace(std::move(remote_cid), std::move(p)); auto it = connections.emplace_hint(connections.end(), p.conn_id, std::move(socket));
connections.push_back(std::move(socket));
pollitems_stale = true;
return {&connections.back(), ""s}; return {&it->second, ""s};
} }
std::pair<zmq::socket_t *, std::string> OxenMQ::proxy_connect_sn(bt_dict_consumer data) { std::pair<zmq::socket_t *, std::string> OxenMQ::proxy_connect_sn(bt_dict_consumer data) {
@ -205,39 +204,19 @@ std::pair<zmq::socket_t *, std::string> OxenMQ::proxy_connect_sn(bt_dict_consume
return proxy_connect_sn(remote_pk, hint, optional, incoming_only, outgoing_only, ephemeral_rid, keep_alive); return proxy_connect_sn(remote_pk, hint, optional, incoming_only, outgoing_only, ephemeral_rid, keep_alive);
} }
template <typename Container, typename AccessIndex>
void update_connection_indices(Container& c, size_t index, AccessIndex get_index) {
for (auto it = c.begin(); it != c.end(); ) {
size_t& i = get_index(*it);
if (index == i) {
it = c.erase(it);
continue;
}
if (i > index)
--i;
++it;
}
}
/// Closes outgoing connections and removes all references. Note that this will call `erase()` /// Closes outgoing connections and removes all references. Note that this will call `erase()`
/// which can invalidate iterators on the various connection containers - if you don't want that, /// which can invalidate iterators on the various connection containers - if you don't want that,
/// delete it first so that the container won't contain the element being deleted. /// delete it first so that the container won't contain the element being deleted.
void OxenMQ::proxy_close_connection(size_t index, std::chrono::milliseconds linger) { void OxenMQ::proxy_close_connection(int64_t id, std::chrono::milliseconds linger) {
connections[index].set(zmq::sockopt::linger, linger > 0ms ? (int) linger.count() : 0); auto it = connections.find(id);
pollitems_stale = true; if (it == connections.end()) {
connections.erase(connections.begin() + index); LMQ_LOG(warn, "internal error: connection to close (", id, ") doesn't exist!");
return;
LMQ_LOG(debug, "Closing conn index ", index); }
update_connection_indices(peers, index, LMQ_LOG(debug, "Closing conn ", id);
[](auto& p) -> size_t& { return p.second.conn_index; }); it->second.set(zmq::sockopt::linger, linger > 0ms ? (int) linger.count() : 0);
update_connection_indices(pending_connects, index, connections.erase(it);
[](auto& pc) -> size_t& { return std::get<size_t>(pc); }); connections_updated = true;
update_connection_indices(bind, index,
[](auto& b) -> size_t& { return b.index; });
update_connection_indices(incoming_conn_index, index,
[](auto& oci) -> size_t& { return oci.second; });
assert(index < conn_index_to_id.size());
conn_index_to_id.erase(conn_index_to_id.begin() + index);
} }
void OxenMQ::proxy_expire_idle_peers() { void OxenMQ::proxy_expire_idle_peers() {
@ -249,8 +228,8 @@ void OxenMQ::proxy_expire_idle_peers() {
LMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle time (", LMQ_LOG(debug, "Closing outgoing connection to ", it->first, ": idle time (",
std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(), "ms) reached connection timeout (", std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(), "ms) reached connection timeout (",
info.idle_expiry.count(), "ms)"); info.idle_expiry.count(), "ms)");
++it; // The below is going to delete our current element proxy_close_connection(info.conn_id, CLOSE_LINGER);
proxy_close_connection(info.conn_index, CLOSE_LINGER); it = peers.erase(it);
} else { } else {
LMQ_LOG(trace, "Not closing ", it->first, ": ", std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(), LMQ_LOG(trace, "Not closing ", it->first, ": ", std::chrono::duration_cast<std::chrono::milliseconds>(idle).count(),
"ms <= ", info.idle_expiry.count(), "ms"); "ms <= ", info.idle_expiry.count(), "ms");
@ -279,9 +258,10 @@ void OxenMQ::proxy_conn_cleanup() {
for (auto it = pending_connects.begin(); it != pending_connects.end(); ) { for (auto it = pending_connects.begin(); it != pending_connects.end(); ) {
auto& pc = *it; auto& pc = *it;
if (std::get<std::chrono::steady_clock::time_point>(pc) < now) { if (std::get<std::chrono::steady_clock::time_point>(pc) < now) {
job([cid = ConnectionID{std::get<long long>(pc)}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); }); auto id = std::get<int64_t>(pc);
job([cid = ConnectionID{id}, callback = std::move(std::get<ConnectFailure>(pc))] { callback(cid, "connection attempt timed out"); });
it = pending_connects.erase(it); // Don't let the below erase it (because it invalidates iterators) it = pending_connects.erase(it); // Don't let the below erase it (because it invalidates iterators)
proxy_close_connection(std::get<size_t>(pc), CLOSE_LINGER); proxy_close_connection(id, CLOSE_LINGER);
} else { } else {
++it; ++it;
} }
@ -337,8 +317,6 @@ void OxenMQ::proxy_connect_remote(bt_dict_consumer data) {
LMQ_LOG(debug, "Establishing remote connection to ", remote, remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + to_hex(remote_pubkey)); LMQ_LOG(debug, "Establishing remote connection to ", remote, remote_pubkey.empty() ? " (NULL auth)" : " via CURVE expecting pubkey " + to_hex(remote_pubkey));
assert(conn_index_to_id.size() == connections.size());
zmq::socket_t sock{context, zmq::socket_type::dealer}; zmq::socket_t sock{context, zmq::socket_type::dealer};
try { try {
setup_outgoing_socket(sock, remote_pubkey, ephemeral_rid); setup_outgoing_socket(sock, remote_pubkey, ephemeral_rid);
@ -350,23 +328,19 @@ void OxenMQ::proxy_connect_remote(bt_dict_consumer data) {
return; return;
} }
connections.push_back(std::move(sock)); auto &s = connections.emplace_hint(connections.end(), conn_id, std::move(sock))->second;
pollitems_stale = true; connections_updated = true;
LMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI"); LMQ_LOG(debug, "Opened new zmq socket to ", remote, ", conn_id ", conn_id, "; sending HI");
send_direct_message(connections.back(), "HI"); send_direct_message(s, "HI");
pending_connects.emplace_back(connections.size()-1, conn_id, std::chrono::steady_clock::now() + timeout, pending_connects.emplace_back(conn_id, std::chrono::steady_clock::now() + timeout,
std::move(on_connect), std::move(on_failure)); std::move(on_connect), std::move(on_failure));
peer_info peer; auto& peer = peers.emplace(ConnectionID{conn_id, remote_pubkey}, peer_info{})->second;
peer.pubkey = std::move(remote_pubkey); peer.pubkey = std::move(remote_pubkey);
peer.service_node = false; peer.service_node = false;
peer.auth_level = auth_level; peer.auth_level = auth_level;
peer.conn_index = connections.size() - 1; peer.conn_id = conn_id;
ConnectionID conn{conn_id, peer.pubkey};
conn_index_to_id.push_back(conn);
assert(connections.size() == conn_index_to_id.size());
peer.idle_expiry = 24h * 10 * 365; // "forever" peer.idle_expiry = 24h * 10 * 365; // "forever"
peer.activity(); peer.activity();
peers.emplace(std::move(conn), std::move(peer));
} }
void OxenMQ::proxy_disconnect(bt_dict_consumer data) { void OxenMQ::proxy_disconnect(bt_dict_consumer data) {
@ -392,7 +366,8 @@ void OxenMQ::proxy_disconnect(ConnectionID conn, std::chrono::milliseconds linge
auto& peer = it->second; auto& peer = it->second;
if (peer.outgoing()) { if (peer.outgoing()) {
LMQ_LOG(debug, "Closing outgoing connection to ", conn); LMQ_LOG(debug, "Closing outgoing connection to ", conn);
proxy_close_connection(peer.conn_index, linger); proxy_close_connection(peer.conn_id, linger);
peers.erase(it);
return; return;
} }
} }

View File

@ -69,12 +69,12 @@ struct ConnectionID {
ConnectionID unrouted() { return ConnectionID{id, pk, ""}; } ConnectionID unrouted() { return ConnectionID{id, pk, ""}; }
private: private:
ConnectionID(long long id) : id{id} {} ConnectionID(int64_t id) : id{id} {}
ConnectionID(long long id, std::string pubkey, std::string route = "") ConnectionID(int64_t id, std::string pubkey, std::string route = "")
: id{id}, pk{std::move(pubkey)}, route{std::move(route)} {} : id{id}, pk{std::move(pubkey)}, route{std::move(route)} {}
constexpr static long long SN_ID = -1; constexpr static int64_t SN_ID = -1;
long long id = 0; int64_t id = 0;
std::string pk; std::string pk;
std::string route; std::string route;
friend class OxenMQ; friend class OxenMQ;
@ -89,7 +89,7 @@ namespace std {
template <> struct hash<oxenmq::ConnectionID> { template <> struct hash<oxenmq::ConnectionID> {
size_t operator()(const oxenmq::ConnectionID &c) const { size_t operator()(const oxenmq::ConnectionID &c) const {
return c.sn() ? oxenmq::already_hashed{}(c.pk) : return c.sn() ? oxenmq::already_hashed{}(c.pk) :
std::hash<long long>{}(c.id) + std::hash<std::string>{}(c.route); std::hash<int64_t>{}(c.id) + std::hash<std::string>{}(c.route);
} }
}; };
} // namespace std } // namespace std

View File

@ -322,11 +322,11 @@ private:
struct bind_data { struct bind_data {
std::string address; std::string address;
bool curve; bool curve;
size_t index; int64_t conn_id;
AllowFunc allow; AllowFunc allow;
std::function<void(bool)> on_bind; std::function<void(bool)> on_bind;
bind_data(std::string addr, bool curve, AllowFunc allow, std::function<void(bool)> on_bind) bind_data(std::string addr, bool curve, AllowFunc allow, std::function<void(bool)> on_bind)
: address{std::move(addr)}, curve{curve}, index{0}, allow{std::move(allow)}, on_bind{std::move(on_bind)} {} : address{std::move(addr)}, curve{curve}, conn_id{0}, allow{std::move(allow)}, on_bind{std::move(on_bind)} {}
}; };
/// Addresses on which we are listening (or, before start(), on which we will listen). /// Addresses on which we are listening (or, before start(), on which we will listen).
@ -349,8 +349,8 @@ private:
/// specified during outgoing connections. /// specified during outgoing connections.
AuthLevel auth_level = AuthLevel::none; AuthLevel auth_level = AuthLevel::none;
/// The actual internal socket index through which this connection is established /// The socket id through which this connection is established
size_t conn_index; int64_t conn_id;
/// Will be set to a non-empty routing prefix *if* one is necessary on the connection. This /// Will be set to a non-empty routing prefix *if* one is necessary on the connection. This
/// is used only for SN peers (non-SN incoming connections don't have a peer_info record, /// is used only for SN peers (non-SN incoming connections don't have a peer_info record,
@ -378,23 +378,15 @@ private:
/// SN pubkey string. /// SN pubkey string.
std::unordered_multimap<ConnectionID, peer_info> peers; std::unordered_multimap<ConnectionID, peer_info> peers;
/// Maps connection indices (which can change) to ConnectionID values (which are permanent). /// The next ConnectionID value we should use (for outgoing, non-SN connections).
/// This is primarily for outgoing sockets, but incoming sockets are here too (with empty-route std::atomic<int64_t> next_conn_id{1};
/// (and thus unroutable) ConnectionIDs).
std::vector<ConnectionID> conn_index_to_id;
/// Maps listening socket ConnectionIDs to connection index values (these don't have peers
/// entries). The keys here have empty routes (and thus aren't actually routable).
std::unordered_map<ConnectionID, size_t> incoming_conn_index;
/// The next ConnectionID value we should use (for non-SN connections).
std::atomic<long long> next_conn_id{1};
/// Remotes we are still trying to connect to (via connect_remote(), not connect_sn()); when /// Remotes we are still trying to connect to (via connect_remote(), not connect_sn()); when
/// we pass handshaking we move them out of here and (if set) trigger the on_connect callback. /// we pass handshaking we move them out of here and (if set) trigger the on_connect callback.
/// Unlike regular node-to-node peers, these have an extra "HI"/"HELLO" sequence that we used /// Unlike regular node-to-node peers, these have an extra "HI"/"HELLO" sequence that we used
/// before we consider ourselves connected to the remote. /// before we consider ourselves connected to the remote.
std::list<std::tuple<size_t /*conn_index*/, long long /*conn_id*/, std::chrono::steady_clock::time_point, ConnectSuccess, ConnectFailure>> pending_connects; std::list<std::tuple<int64_t /*conn_id*/, std::chrono::steady_clock::time_point, ConnectSuccess, ConnectFailure>>
pending_connects;
/// Pending requests that have been sent out but not yet received a matching "REPLY". The value /// Pending requests that have been sent out but not yet received a matching "REPLY". The value
/// is the timeout timestamp. /// is the timeout timestamp.
@ -404,20 +396,18 @@ private:
/// different polling sockets the proxy handler polls: this always contains some internal /// different polling sockets the proxy handler polls: this always contains some internal
/// sockets for inter-thread communication followed by a pollitem for every connection (both /// sockets for inter-thread communication followed by a pollitem for every connection (both
/// incoming and outgoing) in `connections`. We rebuild this from `connections` whenever /// incoming and outgoing) in `connections`. We rebuild this from `connections` whenever
/// `pollitems_stale` is set to true. /// `connections_updated` is set to true.
std::vector<zmq::pollitem_t> pollitems; std::vector<zmq::pollitem_t> pollitems;
/// If set then rebuild pollitems before the next poll (set when establishing new connections or
/// closing existing ones).
bool pollitems_stale = true;
/// Rebuilds pollitems to include the internal sockets + all incoming/outgoing sockets. /// Rebuilds pollitems to include the internal sockets + all incoming/outgoing sockets.
void rebuild_pollitems(); void rebuild_pollitems();
/// The connections to/from remotes we currently have open, both listening and outgoing. Each /// The connections to/from remotes we currently have open, both listening and outgoing.
/// element [i] here corresponds to an the pollitem_t at pollitems[i+1+poll_internal_size]. std::map<int64_t, zmq::socket_t> connections;
/// (Ideally we'd use one structure, but zmq requires the pollitems be in contiguous storage).
std::vector<zmq::socket_t> connections; /// If set then it indicates a change in `connections` which means we need to rebuild pollitems
/// and stop using existing connections iterators.
bool connections_updated = true;
/// Socket we listen on to receive control messages in the proxy thread. Each thread has its own /// Socket we listen on to receive control messages in the proxy thread. Each thread has its own
/// internal "control" connection (returned by `get_control_socket()`) to this socket used to /// internal "control" connection (returned by `get_control_socket()`) to this socket used to
@ -477,17 +467,17 @@ private:
void proxy_schedule_reply_job(std::function<void()> f); void proxy_schedule_reply_job(std::function<void()> f);
/// Looks up a peers element given a connect index (for outgoing connections where we already /// Looks up a peers element given a connect id (for outgoing connections where we already
/// knew the pubkey and SN status) or an incoming zmq message (which has the pubkey and sn /// knew the pubkey and SN status) or an incoming zmq message (which has the pubkey and sn
/// status metadata set during initial connection authentication), creating a new peer element /// status metadata set during initial connection authentication), creating a new peer element
/// if required. /// if required.
decltype(peers)::iterator proxy_lookup_peer(int conn_index, zmq::message_t& msg); decltype(peers)::iterator proxy_lookup_peer(int64_t conn_id, zmq::message_t& msg);
/// Handles built-in primitive commands in the proxy thread for things like "BYE" that have to /// Handles built-in primitive commands in the proxy thread for things like "BYE" that have to
/// be done in the proxy thread anyway (if we forwarded to a worker the worker would just have /// be done in the proxy thread anyway (if we forwarded to a worker the worker would just have
/// to send an instruction back to the proxy to do it). Returns true if one was handled, false /// to send an instruction back to the proxy to do it). Returns true if one was handled, false
/// to continue with sending to a worker. /// to continue with sending to a worker.
bool proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts); bool proxy_handle_builtin(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts);
struct run_info; struct run_info;
/// Gets an idle worker's run_info and removes the worker from the idle worker list. If there /// Gets an idle worker's run_info and removes the worker from the idle worker list. If there
@ -502,7 +492,7 @@ private:
void proxy_run_worker(run_info& run); void proxy_run_worker(run_info& run);
/// Sets up a job for a worker then signals the worker (or starts a worker thread) /// Sets up a job for a worker then signals the worker (or starts a worker thread)
void proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& parts); void proxy_to_worker(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts);
/// proxy thread command handlers for commands sent from the outer object QUIT. This doesn't /// proxy thread command handlers for commands sent from the outer object QUIT. This doesn't
/// get called immediately on a QUIT command: the QUIT commands tells workers to quit, then this /// get called immediately on a QUIT command: the QUIT commands tells workers to quit, then this
@ -510,7 +500,7 @@ private:
void proxy_quit(); void proxy_quit();
/// proxy handler for binding to addresses given via listen_*(). /// proxy handler for binding to addresses given via listen_*().
bool proxy_bind(bind_data& bind, size_t index); bool proxy_bind(bind_data& bind, size_t bind_index);
// Common setup code for setting up an external (incoming or outgoing) socket. // Common setup code for setting up an external (incoming or outgoing) socket.
void setup_external_socket(zmq::socket_t& socket); void setup_external_socket(zmq::socket_t& socket);
@ -603,7 +593,7 @@ private:
void proxy_expire_idle_peers(); void proxy_expire_idle_peers();
/// Helper method to actually close a remote connection and update the stuff that needs updating. /// Helper method to actually close a remote connection and update the stuff that needs updating.
void proxy_close_connection(size_t removed, std::chrono::milliseconds linger); void proxy_close_connection(int64_t removed, std::chrono::milliseconds linger);
/// Closes an outgoing connection immediately, updates internal variables appropriately. /// Closes an outgoing connection immediately, updates internal variables appropriately.
/// Returns the next iterator (the original may or may not be removed from peers, depending on /// Returns the next iterator (the original may or may not be removed from peers, depending on
@ -638,7 +628,7 @@ private:
/// Checks a peer's authentication level. Returns true if allowed, warns and returns false if /// Checks a peer's authentication level. Returns true if allowed, warns and returns false if
/// not. /// not.
bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, bool proxy_check_auth(int64_t conn_id, bool outgoing, const peer_info& peer,
zmq::message_t& command, const cat_call_t& cat_call, std::vector<zmq::message_t>& data); zmq::message_t& command, const cat_call_t& cat_call, std::vector<zmq::message_t>& data);
struct injected_task { struct injected_task {

View File

@ -33,7 +33,7 @@ void OxenMQ::proxy_quit() {
} }
workers_socket.close(); workers_socket.close();
int linger = std::chrono::milliseconds{CLOSE_LINGER}.count(); int linger = std::chrono::milliseconds{CLOSE_LINGER}.count();
for (auto& s : connections) for (auto& [id, s] : connections)
s.set(zmq::sockopt::linger, linger); s.set(zmq::sockopt::linger, linger);
connections.clear(); connections.clear();
peers.clear(); peers.clear();
@ -129,12 +129,12 @@ void OxenMQ::proxy_send(bt_dict_consumer data) {
send_to = sock_route.first; send_to = sock_route.first;
conn_id.route = std::move(sock_route.second); conn_id.route = std::move(sock_route.second);
} else if (!conn_id.route.empty()) { // incoming non-SN connection } else if (!conn_id.route.empty()) { // incoming non-SN connection
auto it = incoming_conn_index.find(conn_id.unrouted()); auto it = connections.find(conn_id.id);
if (it == incoming_conn_index.end()) { if (it == connections.end()) {
LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found"); LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
break; break;
} }
send_to = &connections[it->second]; send_to = &it->second;
} else { } else {
auto pr = peers.equal_range(conn_id); auto pr = peers.equal_range(conn_id);
if (pr.first == peers.end()) { if (pr.first == peers.end()) {
@ -142,7 +142,12 @@ void OxenMQ::proxy_send(bt_dict_consumer data) {
break; break;
} }
auto& peer = pr.first->second; auto& peer = pr.first->second;
send_to = &connections[peer.conn_index]; auto it = connections.find(peer.conn_id);
if (it == connections.end()) {
LMQ_LOG(warn, "Unable to send: peer connection id ", conn_id, " is not (or is no longer) a valid outgoing connection");
break;
}
send_to = &it->second;
} }
try { try {
@ -241,12 +246,21 @@ void OxenMQ::proxy_reply(bt_dict_consumer data) {
// SNs there might be one incoming and one outgoing). // SNs there might be one incoming and one outgoing).
for (auto it = pr.first; it != pr.second; ) { for (auto it = pr.first; it != pr.second; ) {
try { try {
send_message_parts(connections[it->second.conn_index], build_send_parts(send, it->second.route)); send_message_parts(connections[it->second.conn_id], build_send_parts(send, it->second.route));
break; break;
} catch (const zmq::error_t &err) { } catch (const zmq::error_t &err) {
if (err.num() == EHOSTUNREACH) { if (err.num() == EHOSTUNREACH) {
LMQ_LOG(debug, "Unable to send reply to incoming non-SN request: remote is no longer connected; removing peer details"); if (it->second.outgoing()) {
it = peers.erase(it); LMQ_LOG(debug, "Unable to send reply to non-SN request on outgoing socket: "
"remote is no longer connected; closing connection");
proxy_close_connection(it->second.conn_id, CLOSE_LINGER);
it = peers.erase(it);
++it;
} else {
LMQ_LOG(debug, "Unable to send reply to non-SN request on incoming socket: "
"remote is no longer connected; removing peer details");
it = peers.erase(it);
}
} else { } else {
LMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what()); LMQ_LOG(warn, "Unable to send reply to incoming non-SN request: ", err.what());
++it; ++it;
@ -322,9 +336,9 @@ void OxenMQ::proxy_control_message(std::vector<zmq::message_t>& parts) {
std::string{cmd} + " (" + std::to_string(parts.size()) + ")"); std::string{cmd} + " (" + std::to_string(parts.size()) + ")");
} }
bool OxenMQ::proxy_bind(bind_data& b, size_t index) { bool OxenMQ::proxy_bind(bind_data& b, size_t bind_index) {
zmq::socket_t listener{context, zmq::socket_type::router}; zmq::socket_t listener{context, zmq::socket_type::router};
setup_incoming_socket(listener, b.curve, pubkey, privkey, index); setup_incoming_socket(listener, b.curve, pubkey, privkey, bind_index);
bool good = true; bool good = true;
try { try {
@ -343,13 +357,10 @@ bool OxenMQ::proxy_bind(bind_data& b, size_t index) {
LMQ_LOG(info, "OxenMQ listening on ", b.address); LMQ_LOG(info, "OxenMQ listening on ", b.address);
connections.push_back(std::move(listener)); b.conn_id = next_conn_id++;
auto conn_id = next_conn_id++; connections.emplace_hint(connections.end(), b.conn_id, std::move(listener));
conn_index_to_id.push_back(conn_id);
incoming_conn_index[conn_id] = connections.size() - 1;
b.index = connections.size() - 1;
pollitems_stale = true; connections_updated = true;
return true; return true;
} }
@ -422,7 +433,7 @@ void OxenMQ::proxy_loop() {
} }
#endif #endif
pollitems_stale = true; connections_updated = true;
// Also add an internal connection to self so that calling code can avoid needing to // Also add an internal connection to self so that calling code can avoid needing to
// special-case rare situations where we are supposed to talk to a quorum member that happens to // special-case rare situations where we are supposed to talk to a quorum member that happens to
@ -485,14 +496,14 @@ void OxenMQ::proxy_loop() {
poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())}; poll_timeout = std::chrono::milliseconds{zmq_timers_timeout(timers.get())};
} }
if (connections_updated)
rebuild_pollitems();
if (proxy_skip_one_poll) if (proxy_skip_one_poll)
proxy_skip_one_poll = false; proxy_skip_one_poll = false;
else { else {
LMQ_TRACE("polling for new messages"); LMQ_TRACE("polling for new messages");
if (pollitems_stale)
rebuild_pollitems();
// We poll the control socket and worker socket for any incoming messages. If we have // We poll the control socket and worker socket for any incoming messages. If we have
// available worker room then also poll incoming connections and outgoing connections // available worker room then also poll incoming connections and outgoing connections
// for messages to forward to a worker. Otherwise, we just look for a control message // for messages to forward to a worker. Otherwise, we just look for a control message
@ -528,36 +539,37 @@ void OxenMQ::proxy_loop() {
// We round-robin connections when pulling off pending messages one-by-one rather than // We round-robin connections when pulling off pending messages one-by-one rather than
// pulling off all messages from one connection before moving to the next; thus in cases of // pulling off all messages from one connection before moving to the next; thus in cases of
// contention we end up fairly distributing. // contention we end up fairly distributing.
const int num_sockets = connections.size(); std::vector<std::pair<const int64_t, zmq::socket_t>*> queue; // Used as a circular buffer
std::queue<int> queue_index; queue.reserve(connections.size() + 1);
for (int i = 0; i < num_sockets; i++) for (auto& id_sock : connections)
queue_index.push(i); queue.push_back(&id_sock);
queue.push_back(nullptr);
size_t end = queue.size() - 1;
for (parts.clear(); !queue_index.empty(); parts.clear()) { for (size_t pos = 0; pos != end; ++pos %= queue.size()) {
size_t i = queue_index.front(); parts.clear();
queue_index.pop(); auto& [id, sock] = *queue[pos];
auto& sock = connections[i];
if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait)) if (!recv_message_parts(sock, parts, zmq::recv_flags::dontwait))
continue; continue;
// We only pull this one message now but then requeue the socket so that after we check // We only pull this one message now but then requeue the socket so that after we check
// all other sockets we come back to this one to check again. // all other sockets we come back to this one to check again.
queue_index.push(i); queue[end] = queue[pos];
++end %= queue.size();
if (parts.empty()) { if (parts.empty()) {
LMQ_LOG(warn, "Ignoring empty (0-part) incoming message"); LMQ_LOG(warn, "Ignoring empty (0-part) incoming message");
continue; continue;
} }
if (!proxy_handle_builtin(i, parts)) if (!proxy_handle_builtin(id, sock, parts))
proxy_to_worker(i, parts); proxy_to_worker(id, sock, parts);
if (pollitems_stale) { if (connections_updated) {
// If our items became stale then we may have just closed a connection and so our // If connections got updated then our points are stale, to restart the proxy loop;
// queue index maybe also be stale, so restart the proxy loop (so that we rebuild // if there are still messages waiting we'll end up right back here.
// pollitems). LMQ_TRACE("connections became stale; short-circuiting incoming message loop");
LMQ_TRACE("pollitems became stale; short-circuiting incoming message loop");
break; break;
} }
} }
@ -572,9 +584,9 @@ static bool is_error_response(std::string_view cmd) {
// Return true if we recognized/handled the builtin command (even if we reject it for whatever // Return true if we recognized/handled the builtin command (even if we reject it for whatever
// reason) // reason)
bool OxenMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>& parts) { bool OxenMQ::proxy_handle_builtin(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
// Doubling as a bool and an offset: // Doubling as a bool and an offset:
size_t incoming = connections[conn_index].get(zmq::sockopt::type) == ZMQ_ROUTER; size_t incoming = sock.get(zmq::sockopt::type) == ZMQ_ROUTER;
std::string_view route, cmd; std::string_view route, cmd;
if (parts.size() < 1 + incoming) { if (parts.size() < 1 + incoming) {
@ -618,7 +630,7 @@ bool OxenMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
} }
LMQ_LOG(debug, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO"); LMQ_LOG(debug, "Incoming client from ", peer_address(parts.back()), " sent HI, replying with HELLO");
try { try {
send_routed_message(connections[conn_index], std::string{route}, "HELLO"); send_routed_message(sock, std::string{route}, "HELLO");
} catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); } } catch (const std::exception &e) { LMQ_LOG(warn, "Couldn't reply with HELLO: ", e.what()); }
return true; return true;
} else if (cmd == "HELLO") { } else if (cmd == "HELLO") {
@ -627,13 +639,13 @@ bool OxenMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
return true; return true;
} }
auto it = std::find_if(pending_connects.begin(), pending_connects.end(), auto it = std::find_if(pending_connects.begin(), pending_connects.end(),
[&](auto& pc) { return std::get<size_t>(pc) == conn_index; }); [&](auto& pc) { return std::get<int64_t>(pc) == conn_id; });
if (it == pending_connects.end()) { if (it == pending_connects.end()) {
LMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring"); LMQ_LOG(warn, "Got invalid 'HELLO' message on an already handshaked incoming connection; ignoring");
return true; return true;
} }
auto& pc = *it; auto& pc = *it;
auto pit = peers.find(std::get<long long>(pc)); auto pit = peers.find(std::get<int64_t>(pc));
if (pit == peers.end()) { if (pit == peers.end()) {
LMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring"); LMQ_LOG(warn, "Got invalid 'HELLO' message with invalid conn_id; ignoring");
return true; return true;
@ -641,7 +653,7 @@ bool OxenMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
LMQ_LOG(debug, "Got initial HELLO server response from ", peer_address(parts.back())); LMQ_LOG(debug, "Got initial HELLO server response from ", peer_address(parts.back()));
proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)), proxy_schedule_reply_job([on_success=std::move(std::get<ConnectSuccess>(pc)),
conn=conn_index_to_id[conn_index]] { conn=pit->first] {
on_success(conn); on_success(conn);
}); });
pending_connects.erase(it); pending_connects.erase(it);
@ -649,7 +661,7 @@ bool OxenMQ::proxy_handle_builtin(size_t conn_index, std::vector<zmq::message_t>
} else if (cmd == "BYE") { } else if (cmd == "BYE") {
if (!incoming) { if (!incoming) {
LMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back())); LMQ_LOG(debug, "BYE command received; disconnecting from ", peer_address(parts.back()));
proxy_close_connection(conn_index, 0s); proxy_close_connection(conn_id, 0s);
} else { } else {
LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring"); LMQ_LOG(warn, "Got invalid 'BYE' command on an incoming socket; ignoring");
} }

View File

@ -275,17 +275,17 @@ void OxenMQ::proxy_run_worker(run_info& run) {
send_routed_message(workers_socket, run.worker_routing_id, "RUN"); send_routed_message(workers_socket, run.worker_routing_id, "RUN");
} }
void OxenMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& parts) { void OxenMQ::proxy_to_worker(int64_t conn_id, zmq::socket_t& sock, std::vector<zmq::message_t>& parts) {
bool outgoing = connections[conn_index].get(zmq::sockopt::type) == ZMQ_DEALER; bool outgoing = sock.get(zmq::sockopt::type) == ZMQ_DEALER;
peer_info tmp_peer; peer_info tmp_peer;
tmp_peer.conn_index = conn_index; tmp_peer.conn_id = conn_id;
if (!outgoing) tmp_peer.route = parts[0].to_string(); if (!outgoing) tmp_peer.route = parts[0].to_string();
peer_info* peer = nullptr; peer_info* peer = nullptr;
if (outgoing) { if (outgoing) {
auto it = peers.find(conn_index_to_id[conn_index]); auto it = peers.find(conn_id);
if (it == peers.end()) { if (it == peers.end()) {
LMQ_LOG(warn, "Internal error: connection index ", conn_index, " not found"); LMQ_LOG(warn, "Internal error: connection id ", conn_id, " not found");
return; return;
} }
peer = &it->second; peer = &it->second;
@ -298,7 +298,7 @@ void OxenMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
// the same route, and if not, add one. // the same route, and if not, add one.
auto pr = peers.equal_range(tmp_peer.pubkey); auto pr = peers.equal_range(tmp_peer.pubkey);
for (auto it = pr.first; it != pr.second; ++it) { for (auto it = pr.first; it != pr.second; ++it) {
if (it->second.conn_index == tmp_peer.conn_index && it->second.route == tmp_peer.route) { if (it->second.conn_id == tmp_peer.conn_id && it->second.route == tmp_peer.route) {
peer = &it->second; peer = &it->second;
// Update the stored auth level just in case the peer reconnected // Update the stored auth level just in case the peer reconnected
peer->auth_level = tmp_peer.auth_level; peer->auth_level = tmp_peer.auth_level;
@ -330,7 +330,7 @@ void OxenMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
auto cat_call = get_command(command); auto cat_call = get_command(command);
// Check that command is valid, that we have permission, etc. // Check that command is valid, that we have permission, etc.
if (!proxy_check_auth(conn_index, outgoing, *peer, parts[command_part_index], cat_call, data_parts)) if (!proxy_check_auth(conn_id, outgoing, *peer, parts[command_part_index], cat_call, data_parts))
return; return;
auto& category = *cat_call.first; auto& category = *cat_call.first;
@ -345,7 +345,7 @@ void OxenMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
} }
LMQ_LOG(debug, "No available free workers, queuing ", command, " for later"); LMQ_LOG(debug, "No available free workers, queuing ", command, " for later");
ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey, std::move(tmp_peer.route)}; ConnectionID conn{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey, std::move(tmp_peer.route)};
pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second, pending_commands.emplace_back(category, std::move(command), std::move(data_parts), cat_call.second,
std::move(conn), std::move(access), peer_address(parts[command_part_index])); std::move(conn), std::move(access), peer_address(parts[command_part_index]));
category.queued++; category.queued++;
@ -359,7 +359,7 @@ void OxenMQ::proxy_to_worker(size_t conn_index, std::vector<zmq::message_t>& par
auto& run = get_idle_worker(); auto& run = get_idle_worker();
{ {
ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_index_to_id[conn_index].id, peer->pubkey}; ConnectionID c{peer->service_node ? ConnectionID::SN_ID : conn_id, peer->pubkey};
c.route = std::move(tmp_peer.route); c.route = std::move(tmp_peer.route);
if (outgoing || peer->service_node) if (outgoing || peer->service_node)
tmp_peer.route.clear(); tmp_peer.route.clear();