mirror of https://github.com/oxen-io/oxen-mq.git
Fix incoming ConnectionIDs not being storable
ConnectionIDs weren't comparing their routes, which meant that if external code stored one in a map or set *all* incoming connections on the same listener would be considered the same connection. This fixes it by considering route for equality/hashing, and strips route off internally where we need to map it to a socket.
This commit is contained in:
parent
f4f1506df0
commit
3a0508fdce
|
@ -17,11 +17,17 @@ bt_dict build_send(ConnectionID to, string_view cmd, T&&... opts);
|
||||||
/// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy
|
/// anywhere a ConnectionID is called for). For non-SN remote connections you need to keep a copy
|
||||||
/// of the ConnectionID returned by connect_remote().
|
/// of the ConnectionID returned by connect_remote().
|
||||||
struct ConnectionID {
|
struct ConnectionID {
|
||||||
|
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
|
||||||
|
// an actual connection.
|
||||||
|
ConnectionID() : ConnectionID(0) {}
|
||||||
|
// Construction from a service node pubkey
|
||||||
ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} {
|
ConnectionID(std::string pubkey_) : id{SN_ID}, pk{std::move(pubkey_)} {
|
||||||
if (pk.size() != 32)
|
if (pk.size() != 32)
|
||||||
throw std::runtime_error{"Invalid pubkey: expected 32 bytes"};
|
throw std::runtime_error{"Invalid pubkey: expected 32 bytes"};
|
||||||
}
|
}
|
||||||
|
// Construction from a service node pubkey
|
||||||
ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
ConnectionID(string_view pubkey_) : ConnectionID(std::string{pubkey_}) {}
|
||||||
|
|
||||||
ConnectionID(const ConnectionID&) = default;
|
ConnectionID(const ConnectionID&) = default;
|
||||||
ConnectionID(ConnectionID&&) = default;
|
ConnectionID(ConnectionID&&) = default;
|
||||||
ConnectionID& operator=(const ConnectionID&) = default;
|
ConnectionID& operator=(const ConnectionID&) = default;
|
||||||
|
@ -33,29 +39,30 @@ struct ConnectionID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both
|
// Two ConnectionIDs are equal if they are both SNs and have matching pubkeys, or they are both
|
||||||
// not SNs and have matching internal IDs. (Pubkeys do not have to match for non-SNs, and
|
// not SNs and have matching internal IDs and routes. (Pubkeys do not have to match for
|
||||||
// routes are not considered for equality at all).
|
// non-SNs).
|
||||||
bool operator==(const ConnectionID &o) const {
|
bool operator==(const ConnectionID &o) const {
|
||||||
if (id == SN_ID && o.id == SN_ID)
|
if (sn() && o.sn())
|
||||||
return pk == o.pk;
|
return pk == o.pk;
|
||||||
return id == o.id;
|
return id == o.id && route == o.route;
|
||||||
}
|
}
|
||||||
bool operator!=(const ConnectionID &o) const { return !(*this == o); }
|
bool operator!=(const ConnectionID &o) const { return !(*this == o); }
|
||||||
bool operator<(const ConnectionID &o) const {
|
bool operator<(const ConnectionID &o) const {
|
||||||
if (id == SN_ID && o.id == SN_ID)
|
if (sn() && o.sn())
|
||||||
return pk < o.pk;
|
return pk < o.pk;
|
||||||
return id < o.id;
|
return id < o.id || (id == o.id && route < o.route);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if this ConnectionID represents a SN connection
|
// Returns true if this ConnectionID represents a SN connection
|
||||||
bool sn() const { return id == SN_ID; }
|
bool sn() const { return id == SN_ID; }
|
||||||
|
|
||||||
// Returns this connection's pubkey, if any. (Note that it is possible to have a pubkey and not
|
// Returns this connection's pubkey, if any. (Note that all curve connections have pubkeys, not
|
||||||
// be a SN when connecting to secure remotes: having a non-empty pubkey does not imply that
|
// only SNs).
|
||||||
// `sn()` is true).
|
|
||||||
const std::string& pubkey() const { return pk; }
|
const std::string& pubkey() const { return pk; }
|
||||||
// Default construction; creates a ConnectionID with an invalid internal ID that will not match
|
|
||||||
// an actual connection.
|
// Returns a copy of the ConnectionID with the route set to empty.
|
||||||
ConnectionID() : ConnectionID(0) {}
|
ConnectionID unrouted() { return ConnectionID{id, pk, ""}; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ConnectionID(long long id) : id{id} {}
|
ConnectionID(long long id) : id{id} {}
|
||||||
ConnectionID(long long id, std::string pubkey, std::string route = "")
|
ConnectionID(long long id, std::string pubkey, std::string route = "")
|
||||||
|
@ -77,7 +84,7 @@ namespace std {
|
||||||
template <> struct hash<lokimq::ConnectionID> {
|
template <> struct hash<lokimq::ConnectionID> {
|
||||||
size_t operator()(const lokimq::ConnectionID &c) const {
|
size_t operator()(const lokimq::ConnectionID &c) const {
|
||||||
return c.sn() ? lokimq::already_hashed{}(c.pk) :
|
return c.sn() ? lokimq::already_hashed{}(c.pk) :
|
||||||
std::hash<long long>{}(c.id);
|
std::hash<long long>{}(c.id) + std::hash<std::string>{}(c.route);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace std
|
} // namespace std
|
||||||
|
|
|
@ -199,7 +199,7 @@ LokiMQ::LokiMQ(
|
||||||
sn_lookup{std::move(lookup)}, log_lvl{level}, logger{std::move(logger)}
|
sn_lookup{std::move(lookup)}, log_lvl{level}, logger{std::move(logger)}
|
||||||
{
|
{
|
||||||
|
|
||||||
LMQ_TRACE("Constructing listening LokiMQ, id=", object_id, ", this=", this);
|
LMQ_TRACE("Constructing LokiMQ, id=", object_id, ", this=", this);
|
||||||
|
|
||||||
if (pubkey.empty() != privkey.empty()) {
|
if (pubkey.empty() != privkey.empty()) {
|
||||||
throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
|
throw std::invalid_argument("LokiMQ construction failed: one (and only one) of pubkey/privkey is empty. Both must be specified, or both empty to generate a key.");
|
||||||
|
|
|
@ -326,11 +326,13 @@ private:
|
||||||
/// SN pubkey string.
|
/// SN pubkey string.
|
||||||
std::unordered_multimap<ConnectionID, peer_info> peers;
|
std::unordered_multimap<ConnectionID, peer_info> peers;
|
||||||
|
|
||||||
/// Maps connection indices (which can change) to ConnectionIDs (which are permanent).
|
/// Maps connection indices (which can change) to ConnectionID values (which are permanent).
|
||||||
|
/// This is primarily for outgoing sockets, but incoming sockets are here too (with empty-route
|
||||||
|
/// (and thus unroutable) ConnectionIDs).
|
||||||
std::vector<ConnectionID> conn_index_to_id;
|
std::vector<ConnectionID> conn_index_to_id;
|
||||||
|
|
||||||
/// Maps listening socket ConnectionIDs to connection index values (these don't have peers
|
/// Maps listening socket ConnectionIDs to connection index values (these don't have peers
|
||||||
/// entries)
|
/// entries). The keys here have empty routes (and thus aren't actually routable).
|
||||||
std::unordered_map<ConnectionID, size_t> incoming_conn_index;
|
std::unordered_map<ConnectionID, size_t> incoming_conn_index;
|
||||||
|
|
||||||
/// The next ConnectionID value we should use (for non-SN connections).
|
/// The next ConnectionID value we should use (for non-SN connections).
|
||||||
|
|
|
@ -115,7 +115,7 @@ void LokiMQ::proxy_send(bt_dict_consumer data) {
|
||||||
send_to = sock_route.first;
|
send_to = sock_route.first;
|
||||||
conn_id.route = std::move(sock_route.second);
|
conn_id.route = std::move(sock_route.second);
|
||||||
} else if (!conn_id.route.empty()) { // incoming non-SN connection
|
} else if (!conn_id.route.empty()) { // incoming non-SN connection
|
||||||
auto it = incoming_conn_index.find(conn_id);
|
auto it = incoming_conn_index.find(conn_id.unrouted());
|
||||||
if (it == incoming_conn_index.end()) {
|
if (it == incoming_conn_index.end()) {
|
||||||
LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
|
LMQ_LOG(warn, "Unable to send to ", conn_id, ": incoming listening socket not found");
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -128,6 +128,62 @@ TEST_CASE("plain-text connections", "[plaintext][connect]") {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("unique connection IDs", "[connect][id]") {
|
||||||
|
std::string listen = "tcp://127.0.0.1:4455";
|
||||||
|
LokiMQ server{get_logger("S» "), LogLevel::trace};
|
||||||
|
|
||||||
|
ConnectionID first, second;
|
||||||
|
server.add_category("x", Access{AuthLevel::none})
|
||||||
|
.add_request_command("x", [&](Message& m) { first = m.conn; m.send_reply("hi"); })
|
||||||
|
.add_request_command("y", [&](Message& m) { second = m.conn; m.send_reply("hi"); })
|
||||||
|
;
|
||||||
|
|
||||||
|
server.listen_plain(listen);
|
||||||
|
|
||||||
|
server.start();
|
||||||
|
|
||||||
|
LokiMQ client1{get_logger("C1» "), LogLevel::trace};
|
||||||
|
LokiMQ client2{get_logger("C2» "), LogLevel::trace};
|
||||||
|
client1.start();
|
||||||
|
client2.start();
|
||||||
|
|
||||||
|
std::atomic<bool> good1{false}, good2{false};
|
||||||
|
auto r1 = client1.connect_remote(listen,
|
||||||
|
[&](auto conn) { good1 = true; },
|
||||||
|
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||||
|
);
|
||||||
|
auto r2 = client2.connect_remote(listen,
|
||||||
|
[&](auto conn) { good2 = true; },
|
||||||
|
[&](auto conn, string_view reason) { auto lock = catch_lock(); INFO("connection failed: " << reason); }
|
||||||
|
);
|
||||||
|
|
||||||
|
wait_for_conn(good1);
|
||||||
|
wait_for_conn(good2);
|
||||||
|
{
|
||||||
|
auto lock = catch_lock();
|
||||||
|
REQUIRE( good1 );
|
||||||
|
REQUIRE( good2 );
|
||||||
|
REQUIRE( first == second );
|
||||||
|
REQUIRE_FALSE( first );
|
||||||
|
REQUIRE_FALSE( second );
|
||||||
|
}
|
||||||
|
|
||||||
|
good1 = false;
|
||||||
|
good2 = false;
|
||||||
|
client1.request(r1, "x.x", [&](auto success_, auto parts_) { good1 = true; });
|
||||||
|
client2.request(r2, "x.y", [&](auto success_, auto parts_) { good2 = true; });
|
||||||
|
reply_sleep();
|
||||||
|
|
||||||
|
{
|
||||||
|
auto lock = catch_lock();
|
||||||
|
REQUIRE( good1 );
|
||||||
|
REQUIRE( good2 );
|
||||||
|
REQUIRE_FALSE( first == second );
|
||||||
|
REQUIRE_FALSE( std::hash<ConnectionID>{}(first) == std::hash<ConnectionID>{}(second) );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
TEST_CASE("SN disconnections", "[connect][disconnect]") {
|
||||||
std::vector<std::unique_ptr<LokiMQ>> lmq;
|
std::vector<std::unique_ptr<LokiMQ>> lmq;
|
||||||
std::vector<std::string> pubkey, privkey;
|
std::vector<std::string> pubkey, privkey;
|
||||||
|
|
Loading…
Reference in New Issue