diff --git a/CMakeLists.txt b/CMakeLists.txt index 10667af..ba0be63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.7) set(CMAKE_OSX_DEPLOYMENT_TARGET 10.12 CACHE STRING "macOS deployment target (Apple clang only)") project(liboxenmq - VERSION 1.2.13 + VERSION 1.2.14 LANGUAGES CXX C) include(GNUInstallDirs) diff --git a/oxenmq/pubsub.h b/oxenmq/pubsub.h new file mode 100644 index 0000000..5669a3d --- /dev/null +++ b/oxenmq/pubsub.h @@ -0,0 +1,172 @@ +#pragma once + +#include "connections.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace oxenmq { + +using namespace std::chrono_literals; + +namespace detail { + struct no_data_t final {}; + inline constexpr no_data_t no_data{}; + + template + struct SubData { + std::chrono::steady_clock::time_point expiry; + UserData user_data; + explicit SubData(std::chrono::steady_clock::time_point _exp) + : expiry{_exp}, user_data{} {} + }; + + template <> + struct SubData { + std::chrono::steady_clock::time_point expiry; + }; +} + + +/** + * OMQ Subscription class. Handles pub/sub connections such that the user only needs to call + * methods to subscribe and publish. + * + * FIXME: do we want an unsubscribe, or is expiry / conn management sufficient? + * + * Type UserData can contain whatever information the user may need at publish time, for example if + * the subscription is for logs the subscriber can specify log levels or categories, and the + * publisher can choose to send or not based on those. The UserData type, if provided and non-void, + * must be default constructible, must be comparable with ==, and must be movable. + */ +template +class Subscription { + static constexpr bool have_user_data = !std::is_void_v; + using UserData_if_present = std::conditional_t; + using subdata_t = detail::SubData; + + std::unordered_map subs; + std::shared_mutex _mutex; + const std::string description; // description of the sub for logging + const std::chrono::milliseconds sub_duration; // extended by re-subscribe + +public: + + Subscription() = delete; + Subscription(std::string description, std::chrono::milliseconds sub_duration = 30min) + : description{std::move(description)}, sub_duration{sub_duration} {} + + // returns true if new sub, false if refresh sub. throws on error. `data` will be checked + // against the existing data: if there is existing data and it compares `==` to the given value, + // false is returned (and the existing data is not replaced). Otherwise the given data gets + // stored for this connection (replacing existing data, if present), and true is returned. + bool subscribe(const ConnectionID& conn, UserData_if_present data) { + std::unique_lock lock{_mutex}; + auto expiry = std::chrono::steady_clock::now() + sub_duration; + auto [value, added] = subs.emplace(conn, subdata_t{expiry}); + if (added) { + if constexpr (have_user_data) + value->second.user_data = std::move(data); + return true; + } + + value->second.expiry = expiry; + + if constexpr (have_user_data) { + // if user_data changed, consider it a new sub rather than refresh, and update + // user_data in the mapped value. + if (!(value->second.user_data == data)) { + value->second.user_data = std::move(data); + return true; + } + } + return false; + } + + // no-user-data version, only available for Subscription (== Subscription without a + // UserData type). + template = 0> + bool subscribe(const ConnectionID& conn) { + return subscribe(conn, detail::no_data); + } + + // unsubscribe a connection ID. return the user data, if a sub was present. + template = 0> + std::optional unsubscribe(const ConnectionID& conn) { + std::unique_lock lock{_mutex}; + + auto node = subs.extract(conn); + if (!node.empty()) + return node.mapped().user_data; + + return std::nullopt; + } + + // no-user-data version, only available for Subscription (== Subscription without a + // UserData type). + template = 0> + bool unsubscribe(const ConnectionID& conn) { + std::unique_lock lock{_mutex}; + auto node = subs.extract(conn); + return !node.empty(); // true if removed, false if wasn't present + } + + // force removal of expired subscriptions. removal will otherwise only happen on publish + void remove_expired() { + std::unique_lock lock{_mutex}; + auto now = std::chrono::steady_clock::now(); + for (auto itr = subs.begin(); itr != subs.end();) { + if (itr->second.expiry < now) + itr = subs.erase(itr); + else + itr++; + } + } + + // Func is any callable which takes: + // - (const ConnectionID&, const UserData&) for Subscription with non-void UserData + // - (const ConnectionID&) for Subscription. + template + void publish(Func&& func) { + std::vector to_remove; + { + std::shared_lock lock(_mutex); + if (subs.empty()) + return; + + auto now = std::chrono::steady_clock::now(); + + for (const auto& [conn, sub] : subs) { + if (sub.expiry < now) + to_remove.push_back(conn); + else if constexpr (have_user_data) + func(conn, sub.user_data); + else + func(conn); + } + } + + if (to_remove.empty()) + return; + + std::unique_lock lock{_mutex}; + auto now = std::chrono::steady_clock::now(); + for (auto& conn : to_remove) { + auto it = subs.find(conn); + if (it != subs.end() && it->second.expiry < now /* recheck: client might have resubscribed in between locks */) { + subs.erase(it); + } + } + } + +}; + + +} // namespace oxenmq + +// vim:sw=4:et diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fadf85a..20e2282 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,6 +9,7 @@ add_executable(tests test_commands.cpp test_failures.cpp test_inject.cpp + test_pubsub.cpp test_requests.cpp test_socket_limit.cpp test_tagged_threads.cpp diff --git a/tests/test_pubsub.cpp b/tests/test_pubsub.cpp new file mode 100644 index 0000000..75015be --- /dev/null +++ b/tests/test_pubsub.cpp @@ -0,0 +1,611 @@ +#include "common.h" +#include "oxenmq/pubsub.h" + +#include + +using namespace oxenmq; +using namespace std::chrono_literals; + +TEST_CASE("sub OK", "[pubsub]") { + std::string listen = random_localhost(); + OxenMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + }; + server.listen_curve(listen); + + Subscription<> greetings{"greetings"}; + + std::atomic is_new{false}; + server.add_category("public", Access{AuthLevel::none}); + server.add_request_command("public", "greetings", [&](Message& m) { + is_new = greetings.subscribe(m.conn); + m.send_reply("OK"); + }); + server.start(); + + OxenMQ client( + [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } + ); + + std::atomic reply_count{0}; + client.add_category("notify", Access{AuthLevel::none}); + client.add_command("notify", "greetings", [&](Message& m) { + const auto& data = m.data; + if (!data.size()) + { + std::cerr << "client received public.greetings with empty data\n"; + return; + } + if (data[0] == "hello") + reply_count++; + }); + + client.start(); + + std::atomic connected{false}, failed{false}; + std::string pubkey; + + auto c = client.connect_remote(address{listen, server.get_pubkey()}, + [&](auto conn) { pubkey = conn.pubkey(); connected = true; }, + [&](auto, auto) { failed = true; }); + + wait_for([&] { return connected || failed; }); + { + auto lock = catch_lock(); + REQUIRE( connected ); + REQUIRE_FALSE( failed ); + REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) ); + } + + std::atomic got_reply{false}; + bool success; + std::vector data; + client.request(c, "public.greetings", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( data == std::vector{{"OK"}} ); + } + + greetings.publish([&](auto& conn) { + server.send(conn, "notify.greetings", "hello"); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count == 1 ); + } + + greetings.publish([&](auto& conn) { + server.send(conn, "notify.greetings", "hello"); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count == 2 ); + } + +} + +TEST_CASE("user data", "[pubsub]") { + std::string listen = random_localhost(); + OxenMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + }; + server.listen_curve(listen); + + Subscription greetings{"greetings"}; + + std::atomic is_new{false}; + server.add_category("public", Access{AuthLevel::none}); + server.add_request_command("public", "greetings", [&](Message& m) { + is_new = greetings.subscribe(m.conn, std::string{m.data[0]}); + m.send_reply("OK"); + }); + server.start(); + + OxenMQ client( + [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } + ); + + std::string response{"foo"}; + std::atomic reply_count{0}; + std::atomic foo_count{0}; + client.add_category("notify", Access{AuthLevel::none}); + client.add_command("notify", "greetings", [&](Message& m) { + const auto& data = m.data; + if (!data.size()) + { + std::cerr << "client received public.greetings with empty data\n"; + return; + } + if (data[0] == response) + reply_count++; + if (data[0] == "foo") + foo_count++; + }); + + client.start(); + + std::atomic connected{false}, failed{false}; + std::string pubkey; + + auto c = client.connect_remote(address{listen, server.get_pubkey()}, + [&](auto conn) { pubkey = conn.pubkey(); connected = true; }, + [&](auto, auto) { failed = true; }); + + wait_for([&] { return connected || failed; }); + { + auto lock = catch_lock(); + REQUIRE( connected ); + REQUIRE_FALSE( failed ); + REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) ); + } + + std::atomic got_reply{false}; + std::atomic success; + std::vector data; + client.request(c, "public.greetings", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }, response); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( is_new ); + REQUIRE( data == std::vector{{"OK"}} ); + } + + got_reply = false; + success = false; + client.request(c, "public.greetings", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }, response); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE_FALSE( is_new ); + REQUIRE( data == std::vector{{"OK"}} ); + } + + greetings.publish([&](auto& conn, std::string user) { + server.send(conn, "notify.greetings", user); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count == 1 ); + REQUIRE( foo_count == 1 ); + } + + got_reply = false; + success = false; + response = "bar"; + client.request(c, "public.greetings", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }, response); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( is_new ); + REQUIRE( data == std::vector{{"OK"}} ); + } + + greetings.publish([&](auto& conn, std::string user) { + server.send(conn, "notify.greetings", user); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count == 2 ); + REQUIRE( foo_count == 1 ); + } + +} + +TEST_CASE("unsubscribe", "[pubsub]") { + std::string listen = random_localhost(); + OxenMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + }; + server.listen_curve(listen); + + Subscription<> greetings{"greetings"}; + + std::atomic was_subbed{false}; + server.add_category("public", Access{AuthLevel::none}); + server.add_request_command("public", "greetings", [&](Message& m) { + greetings.subscribe(m.conn); + m.send_reply("OK"); + }); + server.add_request_command("public", "goodbye", [&](Message& m) { + was_subbed = greetings.unsubscribe(m.conn); + m.send_reply("OK"); + }); + server.start(); + + OxenMQ client( + [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } + ); + + std::atomic reply_count{0}; + client.add_category("notify", Access{AuthLevel::none}); + client.add_command("notify", "greetings", [&](Message& m) { + const auto& data = m.data; + if (!data.size()) + { + std::cerr << "client received public.greetings with empty data\n"; + return; + } + if (data[0] == "hello") + reply_count++; + }); + + client.start(); + + std::atomic connected{false}, failed{false}; + std::string pubkey; + + auto c = client.connect_remote(address{listen, server.get_pubkey()}, + [&](auto conn) { pubkey = conn.pubkey(); connected = true; }, + [&](auto, auto) { failed = true; }); + + wait_for([&] { return connected || failed; }); + { + auto lock = catch_lock(); + REQUIRE( connected ); + REQUIRE_FALSE( failed ); + REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) ); + } + + std::atomic got_reply{false}; + std::atomic success; + std::vector data; + client.request(c, "public.greetings", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( data == std::vector{{"OK"}} ); + } + + greetings.publish([&](auto& conn) { + server.send(conn, "notify.greetings", "hello"); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count == 1 ); + } + + got_reply = false; + success = false; + client.request(c, "public.goodbye", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( data == std::vector{{"OK"}} ); + REQUIRE( was_subbed ); + } + + greetings.publish([&](auto& conn) { + server.send(conn, "notify.greetings", "hello"); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count == 1 ); + } + + got_reply = false; + success = false; + client.request(c, "public.goodbye", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( data == std::vector{{"OK"}} ); + REQUIRE( was_subbed == false); + } + +} + +TEST_CASE("expire", "[pubsub]") { + std::string listen = random_localhost(); + OxenMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + }; + server.listen_curve(listen); + + Subscription<> greetings{"greetings", 250ms}; + + std::atomic was_subbed{false}; + server.add_category("public", Access{AuthLevel::none}); + server.add_request_command("public", "greetings", [&](Message& m) { + greetings.subscribe(m.conn); + m.send_reply("OK"); + }); + server.add_request_command("public", "goodbye", [&](Message& m) { + was_subbed = greetings.unsubscribe(m.conn); + m.send_reply("OK"); + }); + server.start(); + + OxenMQ client( + [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } + ); + + std::atomic reply_count{0}; + client.add_category("notify", Access{AuthLevel::none}); + client.add_command("notify", "greetings", [&](Message& m) { + const auto& data = m.data; + if (!data.size()) + { + std::cerr << "client received public.greetings with empty data\n"; + return; + } + if (data[0] == "hello") + reply_count++; + }); + + client.start(); + + std::atomic connected{false}, failed{false}; + std::string pubkey; + + auto c = client.connect_remote(address{listen, server.get_pubkey()}, + [&](auto conn) { pubkey = conn.pubkey(); connected = true; }, + [&](auto, auto) { failed = true; }); + + wait_for([&] { return connected || failed; }); + { + auto lock = catch_lock(); + REQUIRE( connected ); + REQUIRE_FALSE( failed ); + REQUIRE( oxenc::to_hex(pubkey) == oxenc::to_hex(server.get_pubkey()) ); + } + + std::atomic got_reply{false}; + bool success; + std::vector data; + client.request(c, "public.greetings", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( data == std::vector{{"OK"}} ); + } + + // should be expired by now + std::this_thread::sleep_for(500ms); + + greetings.remove_expired(); + + got_reply = false; + success = false; + client.request(c, "public.goodbye", [&](bool ok, std::vector data_) { + got_reply = true; + success = ok; + data = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply.load() ); + REQUIRE( success ); + REQUIRE( data == std::vector{{"OK"}} ); + REQUIRE( was_subbed == false); + } + +} + +TEST_CASE("multiple subs", "[pubsub]") { + std::string listen = random_localhost(); + OxenMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + }; + server.listen_curve(listen); + + Subscription<> greetings{"greetings"}; + + std::atomic is_new{false}; + server.add_category("public", Access{AuthLevel::none}); + server.add_request_command("public", "greetings", [&](Message& m) { + is_new = greetings.subscribe(m.conn); + m.send_reply("OK"); + }); + server.start(); + +/* client 1 */ + std::atomic reply_count_c1{0}; + std::atomic connected_c1{false}, failed_c1{false}; + std::atomic got_reply_c1{false}; + bool success_c1; + std::vector data_c1; + std::string pubkey_c1; + OxenMQ client1( + [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } + ); + + client1.add_category("notify", Access{AuthLevel::none}); + client1.add_command("notify", "greetings", [&](Message& m) { + const auto& data = m.data; + if (!data.size()) + { + std::cerr << "client received public.greetings with empty data\n"; + return; + } + if (data[0] == "hello") + reply_count_c1++; + }); + + client1.start(); + + auto c1 = client1.connect_remote(address{listen, server.get_pubkey()}, + [&](auto conn) { pubkey_c1 = conn.pubkey(); connected_c1 = true; }, + [&](auto, auto) { failed_c1 = true; }); + + wait_for([&] { return connected_c1 || failed_c1; }); + { + auto lock = catch_lock(); + REQUIRE( connected_c1 ); + REQUIRE_FALSE( failed_c1 ); + REQUIRE( oxenc::to_hex(pubkey_c1) == oxenc::to_hex(server.get_pubkey()) ); + } + + client1.request(c1, "public.greetings", [&](bool ok, std::vector data_) { + got_reply_c1 = true; + success_c1 = ok; + data_c1 = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply_c1.load() ); + REQUIRE( success_c1 ); + REQUIRE( data_c1 == std::vector{{"OK"}} ); + } +/* end client 1 */ + +/* client 2 */ + std::atomic reply_count_c2{0}; + std::atomic connected_c2{false}, failed_c2{false}; + std::atomic got_reply_c2{false}; + bool success_c2; + std::vector data_c2; + std::string pubkey_c2; + OxenMQ client2( + [](LogLevel, const char* file, int line, std::string msg) { std::cerr << file << ":" << line << " --C-- " << msg << "\n"; } + ); + + client2.add_category("notify", Access{AuthLevel::none}); + client2.add_command("notify", "greetings", [&](Message& m) { + const auto& data = m.data; + if (!data.size()) + { + std::cerr << "client received public.greetings with empty data\n"; + return; + } + if (data[0] == "hello") + reply_count_c2++; + }); + + client2.start(); + + auto c2 = client2.connect_remote(address{listen, server.get_pubkey()}, + [&](auto conn) { pubkey_c2 = conn.pubkey(); connected_c2 = true; }, + [&](auto, auto) { failed_c2 = true; }); + + wait_for([&] { return connected_c2 || failed_c2; }); + { + auto lock = catch_lock(); + REQUIRE( connected_c2 ); + REQUIRE_FALSE( failed_c2 ); + REQUIRE( oxenc::to_hex(pubkey_c2) == oxenc::to_hex(server.get_pubkey()) ); + } + + client2.request(c2, "public.greetings", [&](bool ok, std::vector data_) { + got_reply_c2 = true; + success_c2 = ok; + data_c2 = std::move(data_); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( got_reply_c2.load() ); + REQUIRE( success_c2 ); + REQUIRE( data_c2 == std::vector{{"OK"}} ); + } +/* end client2 */ + + greetings.publish([&](auto& conn) { + server.send(conn, "notify.greetings", "hello"); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count_c1 == 1 ); + REQUIRE( reply_count_c2 == 1 ); + } + + greetings.publish([&](auto& conn) { + server.send(conn, "notify.greetings", "hello"); + }); + + reply_sleep(); + { + auto lock = catch_lock(); + REQUIRE( reply_count_c1 == 2 ); + REQUIRE( reply_count_c2 == 2 ); + } + +} + + +// vim:sw=4:et