diff --git a/oxenmq/jobs.cpp b/oxenmq/jobs.cpp index deae049..2409823 100644 --- a/oxenmq/jobs.cpp +++ b/oxenmq/jobs.cpp @@ -54,27 +54,29 @@ void OxenMQ::proxy_run_batch_jobs(std::queue& jobs, const int reserve // Called either within the proxy thread, or before the proxy thread has been created; actually adds // the timer. If the timer object hasn't been set up yet it gets set up here. -void OxenMQ::proxy_timer(std::function job, std::chrono::milliseconds interval, bool squelch, int thread) { +void OxenMQ::proxy_timer(int id, std::function job, std::chrono::milliseconds interval, bool squelch, int thread) { if (!timers) timers.reset(zmq_timers_new()); - int timer_id = zmq_timers_add(timers.get(), + int zmq_timer_id = zmq_timers_add(timers.get(), interval.count(), [](int timer_id, void* self) { static_cast(self)->_queue_timer_job(timer_id); }, this); - if (timer_id == -1) + if (zmq_timer_id == -1) throw zmq::error_t{}; - timer_jobs[timer_id] = { std::move(job), squelch, false, thread }; + timer_jobs[zmq_timer_id] = { std::move(job), squelch, false, thread }; + timer_zmq_id[id] = zmq_timer_id; } void OxenMQ::proxy_timer(bt_list_consumer timer_data) { + auto timer_id = timer_data.consume_integer(); std::unique_ptr> func{reinterpret_cast*>(timer_data.consume_integer())}; auto interval = std::chrono::milliseconds{timer_data.consume_integer()}; auto squelch = timer_data.consume_integer(); auto thread = timer_data.consume_integer(); if (!timer_data.is_finished()) throw std::runtime_error("Internal error: proxied timer request contains unexpected data"); - proxy_timer(std::move(*func), interval, squelch, thread); + proxy_timer(timer_id, std::move(*func), interval, squelch, thread); } void OxenMQ::_queue_timer_job(int timer_id) { @@ -118,16 +120,37 @@ void OxenMQ::_queue_timer_job(int timer_id) { queue.emplace(static_cast(b), 0); } -void OxenMQ::add_timer(std::function job, std::chrono::milliseconds interval, bool squelch, std::optional thread) { +TimerID OxenMQ::add_timer(std::function job, std::chrono::milliseconds interval, bool squelch, std::optional thread) { + int id = next_timer_id++; int th_id = thread ? thread->_id : 0; if (proxy_thread.joinable()) { detail::send_control(get_control_socket(), "TIMER", bt_serialize(bt_list{{ + id, detail::serialize_object(std::move(job)), interval.count(), squelch, th_id}})); } else { - proxy_timer(std::move(job), interval, squelch, th_id); + proxy_timer(id, std::move(job), interval, squelch, th_id); + } + return TimerID{id}; +} + +void OxenMQ::proxy_timer_del(int id) { + if (!timers) + return; + auto it = timer_zmq_id.find(id); + if (it == timer_zmq_id.end()) + return; + zmq_timers_cancel(timers.get(), it->second); + timer_zmq_id.erase(it); +} + +void OxenMQ::cancel_timer(TimerID timer_id) { + if (proxy_thread.joinable()) { + detail::send_control(get_control_socket(), "TIMER_DEL", bt_serialize(timer_id._id)); + } else { + proxy_timer_del(timer_id._id); } } diff --git a/oxenmq/oxenmq.h b/oxenmq/oxenmq.h index 1bde562..3dba390 100644 --- a/oxenmq/oxenmq.h +++ b/oxenmq/oxenmq.h @@ -104,6 +104,16 @@ private: template friend class Batch; }; +/// Opaque handler for a timer constructed by add_timer(...). Not directly constructible, but is +/// safe (and cheap) to copy. The only real use of this is to pass it in to cancel_timer() to +/// cancel a timer. +struct TimerID { +private: + int _id; + explicit constexpr TimerID(int id) : _id{id} {} + friend class OxenMQ; +}; + /** * Class that handles OxenMQ listeners, connections, proxying, and workers. An application * typically has just one instance of this class. @@ -415,8 +425,13 @@ private: /// Timers. TODO: once cppzmq adds an interface around the zmq C timers API then switch to it. struct TimersDeleter { void operator()(void* timers); }; struct timer_data { std::function function; bool squelch; bool running; int thread; }; - std::unordered_map timer_jobs; + std::unordered_map timer_jobs; // keys are zmq timer ids std::unique_ptr timers; + // The next internal timer id (returned opaquely via TimerID return from add_timer) + std::atomic next_timer_id = 1; + // Maps our internal timer id values (returned by add_timer) to zmq timer ids; used for + // delete_timer(). + std::unordered_map timer_zmq_id; public: // This needs to be public because we have to be able to call it from a plain C function. // Nothing external may call it! @@ -556,13 +571,16 @@ private: /// take over and queue batch jobs. void proxy_batch(detail::Batch* batch); - /// TIMER command. Called with a serialized list containing: function pointer to assume - /// ownership of, an interval count (in ms), and whether or not jobs should be squelched (see - /// `add_timer()`). + /// TIMER command. Called with a serialized list containing: our local timer_id, function + /// pointer to assume ownership of, an interval count (in ms), and whether or not jobs should be + /// squelched (see `add_timer()`). void proxy_timer(bt_list_consumer timer_data); /// Same, but deserialized - void proxy_timer(std::function job, std::chrono::milliseconds interval, bool squelch, int thread); + void proxy_timer(int timer_id, std::function job, std::chrono::milliseconds interval, bool squelch, int thread); + + /// TIMER_DEL command. Called with a timer_id to delete an active timer. + void proxy_timer_del(int timer_id); /// ZAP (https://rfc.zeromq.org/spec:27/ZAP/) authentication handler; this does non-blocking /// processing of any waiting authentication requests for new incoming connections. @@ -1239,9 +1257,22 @@ public: * (so that, under heavy load or long jobs, there can be more than one of the same job scheduled * or running at a time) then specify `squelch` as `false`. * + * The returned value can be kept and later passed into `cancel_timer()` if you want to be able + * to cancel a timer. + * * \param thread specifies a thread (added with add_tagged_thread()) on which this timer must run. */ - void add_timer(std::function job, std::chrono::milliseconds interval, bool squelch = true, std::optional = std::nullopt); + TimerID add_timer(std::function job, std::chrono::milliseconds interval, bool squelch = true, std::optional = std::nullopt); + + /** + * Cancels a running timer. Note that an existing timer job (or multiple, if the timer disabled + * squelch) that have already been scheduled may still be executed after cancel_timer is called. + * + * It is safe (though does nothing) to call this more than once with the same TimerID value. + * + * \param timer a TimerID value as returned by add_timer. + */ + void cancel_timer(TimerID timer); }; /// Helper class that slightly simplifies adding commands to a category. diff --git a/oxenmq/proxy.cpp b/oxenmq/proxy.cpp index ca6dcb2..1778599 100644 --- a/oxenmq/proxy.cpp +++ b/oxenmq/proxy.cpp @@ -291,6 +291,8 @@ void OxenMQ::proxy_control_message(std::vector& parts) { return proxy_disconnect(data); } else if (cmd == "TIMER") { return proxy_timer(data); + } else if (cmd == "TIMER_DEL") { + return proxy_timer_del(bt_deserialize(data)); } } else if (parts.size() == 2) { if (cmd == "START") { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 96d5710..83f5b8a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -13,6 +13,7 @@ set(LMQ_TEST_SRC test_inject.cpp test_requests.cpp test_tagged_threads.cpp + test_timer.cpp ) add_executable(tests ${LMQ_TEST_SRC}) diff --git a/tests/test_timer.cpp b/tests/test_timer.cpp new file mode 100644 index 0000000..5011670 --- /dev/null +++ b/tests/test_timer.cpp @@ -0,0 +1,101 @@ +#include "oxenmq/oxenmq.h" +#include "common.h" +#include +#include + +TEST_CASE("timer test", "[timer][basic]") { + oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace}; + + omq.set_general_threads(1); + omq.set_batch_threads(1); + + std::atomic ticks = 0; + auto timer = omq.add_timer([&] { ticks++; }, 5ms); + omq.start(); + auto start = std::chrono::steady_clock::now(); + wait_for([&] { return ticks.load() > 3; }); + { + auto lock = catch_lock(); + REQUIRE( ticks.load() > 3 ); + REQUIRE( std::chrono::steady_clock::now() - start < 40ms ); + } +} + +TEST_CASE("timer squelch", "[timer][squelch]") { + oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace}; + + omq.set_general_threads(3); + omq.set_batch_threads(3); + + std::atomic first = true; + std::atomic done = false; + std::atomic ticks = 0; + + // Set up a timer with squelch on; the job shouldn't get rescheduled until the first call + // finishes, by which point we set `done` and so should get exactly 1 tick. + auto timer = omq.add_timer([&] { + if (first.exchange(false)) { + std::this_thread::sleep_for(30ms); + ticks++; + done = true; + } else if (!done) { + ticks++; + } + }, 5ms, true /* squelch */); + omq.start(); + + wait_for([&] { return done.load(); }); + { + auto lock = catch_lock(); + REQUIRE( done.load() ); + REQUIRE( ticks.load() == 1 ); + } + + // Start another timer with squelch *off*; the subsequent jobs should get scheduled even while + // the first one blocks + std::atomic first2 = true; + std::atomic done2 = false; + std::atomic ticks2 = 0; + auto timer2 = omq.add_timer([&] { + if (first2.exchange(false)) { + std::this_thread::sleep_for(30ms); + done2 = true; + } else if (!done2) { + ticks2++; + } + }, 5ms, false /* squelch */); + + wait_for([&] { return done2.load(); }); + { + auto lock = catch_lock(); + REQUIRE( ticks2.load() > 2 ); + REQUIRE( done2.load() ); + } +} + +TEST_CASE("timer cancel", "[timer][cancel]") { + oxenmq::OxenMQ omq{get_logger(""), LogLevel::trace}; + + omq.set_general_threads(1); + omq.set_batch_threads(1); + + std::atomic ticks = 0; + + // We set up *and cancel* this timer before omq starts, so it should never fire + auto notimer = omq.add_timer([&] { ticks += 1000; }, 5ms); + omq.cancel_timer(notimer); + + TimerID timer = omq.add_timer([&] { + if (++ticks == 3) + omq.cancel_timer(timer); + }, 5ms); + + omq.start(); + + wait_for([&] { return ticks.load() >= 3; }); + { + auto lock = catch_lock(); + REQUIRE( ticks.load() == 3 ); + } +} +