diff --git a/lokimq/connections.h b/lokimq/connections.h index dff3631..34e79cf 100644 --- a/lokimq/connections.h +++ b/lokimq/connections.h @@ -17,11 +17,17 @@ bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts); /// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy /// of the ConnectionID returned by connect_remote(). struct ConnectionID { + // Default construction; creates a ConnectionID with an invalid internal ID that will not match + // an actual connection. + ConnectionID() : ConnectionID(0) {} + // Construction from a service node pubkey ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} { if (pk.size() != 32) throw std::runtime_error{"Invalid pubkey: expected 32 bytes"}; } + // Construction from a service node pubkey ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {} + ConnectionID(const ConnectionID&) = default; ConnectionID(ConnectionID&&) = default; ConnectionID& operator=(const ConnectionID&) = default; @@ -33,29 +39,30 @@ struct ConnectionID { } // Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both - // not SNs and have matching internal IDs. (Pubkeys do not have to match for non-SNs, and - // routes are not considered for equality at all). + // not SNs and have matching internal IDs and routes. (Pubkeys do not have to match for + // non-SNs). bool operator==(const ConnectionID &o) const { - if (id == SN_ID && o.id == SN_ID) + if (sn() && o.sn()) return pk == o.pk; - return id == o.id; + return id == o.id && route == o.route; } bool operator!=(const ConnectionID &o) const { return !(*this == o); } bool operator<(const ConnectionID &o) const { - if (id == SN_ID && o.id == SN_ID) + if (sn() && o.sn()) return pk < o.pk; - return id < o.id; + return id < o.id || (id == o.id && route < o.route); } + // Returns true if this ConnectionID represents a SN connection bool sn() const { return id == SN_ID; } - // Returns this connection's pubkey, if any. (Note that it is possible to have a pubkey and not - // be a SN when connecting to secure remotes: having a non-empty pubkey does not imply that - // `sn()` is true). + // Returns this connection's pubkey, if any. (Note that all curve connections have pubkeys, not + // only SNs). const std::string& pubkey() const { return pk; } - // Default construction; creates a ConnectionID with an invalid internal ID that will not match - // an actual connection. - ConnectionID() : ConnectionID(0) {} + + // Returns a copy of the ConnectionID with the route set to empty. + ConnectionID unrouted() { return ConnectionID{id, pk, ""}; } + private: ConnectionID(long long id) : id{id} {} ConnectionID(long long id, std::string pubkey, std::string route = "") @@ -77,7 +84,7 @@ namespace std { template <> struct hash { size_t operator()(const lokimq::ConnectionID &c) const { return c.sn() ? lokimq::already_hashed{}(c.pk) : - std::hash{}(c.id); + std::hash{}(c.id) + std::hash{}(c.route); } }; } // namespace std diff --git a/lokimq/lokimq.cpp b/lokimq/lokimq.cpp index eef35be..eebb61d 100644 --- a/lokimq/lokimq.cpp +++ b/lokimq/lokimq.cpp @@ -199,7 +199,7 @@ LokiMQ::LokiMQ( sn_lookup{std::move(lookup)}, log_lvl{level}, logger{std::move(logger)} { - LMQ_TRACE("Constructing listening LokiMQ, id=", object_id, ", this=", this); + LMQ_TRACE("Constructing LokiMQ, id=", object_id, ", this=", this); if (pubkey.empty() != privkey.empty()) { throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key."); diff --git a/lokimq/lokimq.h b/lokimq/lokimq.h index 8b161b4..425d2ca 100644 --- a/lokimq/lokimq.h +++ b/lokimq/lokimq.h @@ -326,11 +326,13 @@ private: /// SN pubkey string. std::unordered_multimap peers; - /// Maps connection indices (which can change) to ConnectionIDs (which are permanent). + /// Maps connection indices (which can change) to ConnectionID values (which are permanent). + /// This is primarily for outgoing sockets, but incoming sockets are here too (with empty-route + /// (and thus unroutable) ConnectionIDs). std::vector conn_index_to_id; /// Maps listening socket ConnectionIDs to connection index values (these don't have peers - /// entries) + /// entries). The keys here have empty routes (and thus aren't actually routable). std::unordered_map incoming_conn_index; /// The next ConnectionID value we should use (for non-SN connections). diff --git a/lokimq/proxy.cpp b/lokimq/proxy.cpp index 7219353..58fae15 100644 --- a/lokimq/proxy.cpp +++ b/lokimq/proxy.cpp @@ -115,7 +115,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) { send_to = sock_route.first; conn_id.route = std::move(sock_route.second); } else if (!conn_id.route.empty()) { // incoming non-SN connection - auto it = incoming_conn_index.find(conn_id); + auto it = incoming_conn_index.find(conn_id.unrouted()); if (it == incoming_conn_index.end()) { LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found"); break; diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index 7335a71..12edde9 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -128,6 +128,62 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") { } } +TEST_CASE("unique connection IDs", "[connect][id]") { + std::string listen = "tcp://127.0.0.1:4455"; + LokiMQ server{get_logger("S» "), LogLevel::trace}; + + ConnectionID first, second; + server.add_category("x", Access{AuthLevel::none}) + .add_request_command("x", [&](Message& m) { first = m.conn; m.send_reply("hi"); }) + .add_request_command("y", [&](Message& m) { second = m.conn; m.send_reply("hi"); }) + ; + + server.listen_plain(listen); + + server.start(); + + LokiMQ client1{get_logger("C1» "), LogLevel::trace}; + LokiMQ client2{get_logger("C2» "), LogLevel::trace}; + client1.start(); + client2.start(); + + std::atomic good1{false}, good2{false}; + auto r1 = client1.connect_remote(listen, + [&](auto conn) { good1 = true; }, + [&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); } + ); + auto r2 = client2.connect_remote(listen, + [&](auto conn) { good2 = true; }, + [&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); } + ); + + wait_for_conn(good1); + wait_for_conn(good2); + { + auto lock = catch_lock(); + REQUIRE( good1 ); + REQUIRE( good2 ); + REQUIRE( first == second ); + REQUIRE_FALSE( first ); + REQUIRE_FALSE( second ); + } + + good1 = false; + good2 = false; + client1.request(r1, "x.x", [&](auto success_, auto parts_) { good1 = true; }); + client2.request(r2, "x.y", [&](auto success_, auto parts_) { good2 = true; }); + reply_sleep(); + + { + auto lock = catch_lock(); + REQUIRE( good1 ); + REQUIRE( good2 ); + REQUIRE_FALSE( first == second ); + REQUIRE_FALSE( std::hash{}(first) == std::hash{}(second) ); + } +} + + TEST_CASE("SN disconnections", "[connect][disconnect]") { std::vector> lmq; std::vector pubkey, privkey;