diff --git a/oxenmq/connections.cpp b/oxenmq/connections.cpp index 26aed93..57ec6fa 100644 --- a/oxenmq/connections.cpp +++ b/oxenmq/connections.cpp @@ -174,6 +174,7 @@ OxenMQ::proxy_connect_sn(std::string_view remote, std::string_view connect_hint, p.idle_expiry = keep_alive; p.activity(); connections_updated = true; + outgoing_sn_conns.emplace_hint(outgoing_sn_conns.end(), p.conn_id, ConnectionID{remote}); auto it = connections.emplace_hint(connections.end(), p.conn_id, std::move(socket)); return {&it->second, ""s}; @@ -217,6 +218,8 @@ void OxenMQ::proxy_close_connection(int64_t id, std::chrono::milliseconds linger it->second.set(zmq::sockopt::linger, linger > 0ms ? (int) linger.count() : 0); connections.erase(it); connections_updated = true; + + outgoing_sn_conns.erase(id); } void OxenMQ::proxy_expire_idle_peers() { diff --git a/oxenmq/oxenmq.h b/oxenmq/oxenmq.h index bc3536f..6b49b35 100644 --- a/oxenmq/oxenmq.h +++ b/oxenmq/oxenmq.h @@ -378,6 +378,11 @@ private: /// SN pubkey string. std::unordered_multimap peers; + /// For outgoing connections to service nodes `peers` contains the service node connection id, + /// but we sometimes need to be able to get the peer info from a numeric connection id (for + /// example, for incoming messages on a connection we made); this map lets us do that. + std::map outgoing_sn_conns; + /// The next ConnectionID value we should use (for outgoing, non-SN connections). std::atomic next_conn_id{1}; @@ -1359,6 +1364,10 @@ struct data_parts_impl { template ()), std::string_view>>> data_parts_impl data_parts(InputIt begin, InputIt end) { return {std::move(begin), std::move(end)}; } +/// Shortcut for send_option::data_parts(container.begin(), container.end()) +template +auto data_parts(const Container& c) { return data_parts(c.begin(), c.end()); } + /// Specifies a connection hint when passed in to send(). If there is no current connection to the /// peer then the hint is used to save a call to the SNRemoteAddress to get the connection location. /// (Note that there is no guarantee that the given hint will be used or that a SNRemoteAddress call diff --git a/oxenmq/proxy.cpp b/oxenmq/proxy.cpp index db564e1..9252483 100644 --- a/oxenmq/proxy.cpp +++ b/oxenmq/proxy.cpp @@ -563,8 +563,10 @@ void OxenMQ::proxy_loop() { continue; } - if (!proxy_handle_builtin(id, sock, parts)) + if (!proxy_handle_builtin(id, sock, parts)) { + LMQ_LOG(warn, "proxying to worker from connection ", id); proxy_to_worker(id, sock, parts); + } if (connections_updated) { // If connections got updated then our points are stale, to restart the proxy loop; diff --git a/oxenmq/worker.cpp b/oxenmq/worker.cpp index acd8e99..ae5046f 100644 --- a/oxenmq/worker.cpp +++ b/oxenmq/worker.cpp @@ -283,7 +283,11 @@ void OxenMQ::proxy_to_worker(int64_t conn_id, zmq::socket_t& sock, std::vectorsecond) + : peers.find(conn_id); + if (it == peers.end()) { LMQ_LOG(warn, "Internal error: connection id ", conn_id, " not found"); return; diff --git a/tests/test_connect.cpp b/tests/test_connect.cpp index 723f184..0b4b846 100644 --- a/tests/test_connect.cpp +++ b/tests/test_connect.cpp @@ -446,3 +446,61 @@ TEST_CASE("SN single worker test", "[connect][worker]") { } } + +TEST_CASE("SN backchatter", "[connect][sn]") { + // When we have a SN connection A -> B and then B sends a message to A on that existing + // connection, A should see it as coming from B. + std::vector> omq; + std::vector pubkey, privkey; + std::unordered_map conn; + REQUIRE(sodium_init() != -1); + for (int i = 0; i < 2; i++) { + pubkey.emplace_back(); + privkey.emplace_back(); + pubkey[i].resize(crypto_box_PUBLICKEYBYTES); + privkey[i].resize(crypto_box_SECRETKEYBYTES); + crypto_box_keypair(reinterpret_cast(&pubkey[i][0]), reinterpret_cast(&privkey[i][0])); + conn.emplace(pubkey[i], random_localhost()); + } + + for (int i = 0; i < pubkey.size(); i++) { + omq.push_back(std::make_unique( + pubkey[i], privkey[i], true, + [conn](auto pk) { auto it = conn.find((std::string) pk); if (it != conn.end()) return it->second; return ""s; }, + get_logger("S" + std::to_string(i) + "ยป "), + LogLevel::trace + )); + auto& server = *omq.back(); + + server.listen_curve(conn[pubkey[i]]); + server.set_active_sns({pubkey.begin(), pubkey.end()}); + } + std::string f; + omq[0]->add_category("a", Access{AuthLevel::none, true}) + .add_command("a", [&](Message& m) { + m.oxenmq.send(m.conn, "b.b", "abc"); + //m.send_back("b.b", "abc"); + }) + .add_command("z", [&](Message& m) { + auto lock = catch_lock(); + f = m.data[0]; + }); + omq[1]->add_category("b", Access{AuthLevel::none, true}) + .add_command("b", [&](Message& m) { + { + auto lock = catch_lock(); + UNSCOPED_INFO("b.b from conn " << m.conn); + } + m.send_back("a.z", m.data[0]); + }); + + for (auto& server : omq) + server->start(); + + auto c = omq[1]->connect_sn(pubkey[0]); + omq[1]->send(c, "a.a"); + std::this_thread::sleep_for(50ms); + + auto lock = catch_lock(); + REQUIRE(f == "abc"); +}