diff --git a/lokimq/lokimq.cpp b/lokimq/lokimq.cpp index fd9aa08..2fbae74 100644 --- a/lokimq/lokimq.cpp +++ b/lokimq/lokimq.cpp @@ -349,30 +349,45 @@ void LokiMQ::set_general_threads(int threads) { LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, ConnectionID conn_, Access access_, std::string remote_, std::vector data_parts_, const std::pair* callback_) { - is_batch_job = false; - is_reply_job = false; - is_tagged_thread_job = false; + reset(); cat = cat_; command = std::move(command_); conn = std::move(conn_); access = std::move(access_); remote = std::move(remote_); data_parts = std::move(data_parts_); - callback = callback_; + to_run = callback_; + return *this; +} + +LokiMQ::run_info& LokiMQ::run_info::load(category* cat_, std::string command_, std::string remote_, std::function callback) { + reset(); + is_injected = true; + cat = cat_; + command = std::move(command_); + conn = {}; + access = {}; + remote = std::move(remote_); + to_run = std::move(callback); return *this; } LokiMQ::run_info& LokiMQ::run_info::load(pending_command&& pending) { + if (auto *f = std::get_if>(&pending.callback)) + return load(&pending.cat, std::move(pending.command), std::move(pending.remote), std::move(*f)); + + assert(pending.callback.index() == 0); return load(&pending.cat, std::move(pending.command), std::move(pending.conn), std::move(pending.access), - std::move(pending.remote), std::move(pending.data_parts), pending.callback); + std::move(pending.remote), std::move(pending.data_parts), std::get<0>(pending.callback)); } LokiMQ::run_info& LokiMQ::run_info::load(batch_job&& bj, bool reply_job, int tagged_thread) { + reset(); is_batch_job = true; is_reply_job = reply_job; is_tagged_thread_job = tagged_thread > 0; batch_jobno = bj.second; - batch = bj.first; + to_run = bj.first; return *this; } diff --git a/lokimq/lokimq.h b/lokimq/lokimq.h index 05f0499..ac76b4d 100644 --- a/lokimq/lokimq.h +++ b/lokimq/lokimq.h @@ -596,6 +596,18 @@ private: bool proxy_check_auth(size_t conn_index, bool outgoing, const peer_info& peer, zmq::message_t& command, const cat_call_t& cat_call, std::vector& data); + struct injected_task { + category& cat; + std::string command; + std::string remote; + std::function callback; + }; + + /// Injects a external callback to be handled by a worker; this is the proxy side of + /// inject_task(). + void proxy_inject_task(injected_task task); + + /// Set of active service nodes. pubkey_set active_service_nodes; @@ -607,20 +619,30 @@ private: void proxy_update_active_sns_clean(pubkey_set added, pubkey_set removed); /// Details for a pending command; such a command already has authenticated access and is just - /// waiting for a thread to become available to handle it. + /// waiting for a thread to become available to handle it. This also gets used (via the + /// `callback` variant) for injected external jobs to be able to integrate some external + /// interface with the lokimq job queue. struct pending_command { category& cat; std::string command; std::vector data_parts; - const std::pair* callback; + std::variant< + const std::pair*, // Normal command callback + std::function // Injected external callback + > callback; ConnectionID conn; Access access; std::string remote; + // Normal ctor for an actual lmq command being processed pending_command(category& cat, std::string command, std::vector data_parts, const std::pair* callback, ConnectionID conn, Access access, std::string remote) : cat{cat}, command{std::move(command)}, data_parts{std::move(data_parts)}, callback{callback}, conn{std::move(conn)}, access{std::move(access)}, remote{std::move(remote)} {} + + // Ctor for an injected external command. + pending_command(category& cat, std::string command, std::function callback, std::string remote) + : cat{cat}, command{std::move(command)}, callback{std::move(callback)}, remote{std::move(remote)} {} }; std::list pending_commands; @@ -635,9 +657,15 @@ private: bool is_batch_job = false; bool is_reply_job = false; bool is_tagged_thread_job = false; + bool is_injected = false; + + // resets the job type bools, above. + void reset() { is_batch_job = is_reply_job = is_tagged_thread_job = is_injected = false; } // If is_batch_job is false then these will be set appropriate (if is_batch_job is true then - // these shouldn't be accessed and likely contain stale data). + // these shouldn't be accessed and likely contain stale data). Note that if the command is + // an external, injected command then conn, access, conn_route, and data_parts will be + // empty/default constructed. category *cat; std::string command; ConnectionID conn; // The connection (or SN pubkey) to reply on/to. @@ -649,10 +677,13 @@ private: // If is_batch_job true then these are set (if is_batch_job false then don't access these!): int batch_jobno; // >= 0 for a job, -1 for the completion job - union { - const std::pair* callback; // set if !is_batch_job - detail::Batch* batch; // set if is_batch_job - }; + // The callback or batch job to run. The first of these is for regular tasks, the second + // for batch jobs, the third for injected external tasks. + std::variant< + const std::pair*, + detail::Batch*, + std::function + > to_run; // These belong to the proxy thread and must not be accessed by a worker: std::thread worker_thread; @@ -663,8 +694,12 @@ private: run_info& load(category* cat, std::string command, ConnectionID conn, Access access, std::string remote, std::vector data_parts, const std::pair* callback); + /// Loads the run info with an injected external command + run_info& load(category* cat, std::string command, std::string remote, std::function callback); + /// Loads the run info with a stored pending command run_info& load(pending_command&& pending); + /// Loads the run info with a batch job run_info& load(batch_job&& bj, bool reply_job = false, int tagged_thread = 0); }; @@ -1091,6 +1126,32 @@ public: template void request(ConnectionID to, std::string_view cmd, ReplyCallback callback, const T&... opts); + /** Injects an external task into the lokimq command queue. This is used to allow connecting + * non-LokiMQ requests into the LokiMQ thread pool as if they were ordinary requests, to be + * scheduled as commands of an individual category. For example, you might support rpc requests + * via LokiMQ as `rpc.some_command` and *also* accept them over HTTP. Using `inject_task()` + * allows you to handle processing the request in the same thread pool with the same priority as + * `rpc.*` commands. + * + * @param category - the category name that should handle the request for the purposes of + * scheduling the job. The category must have been added using add_category(). The category + * can be an actual category with added commands, in which case the injected tasks are queued + * along with LMQ requests for that category, or can have no commands to set up a distinct + * category for the injected jobs. + * + * @param command - a command name; this is mainly used for debugging and does not need to + * actually exist (and, in fact, is often less confusing if it does not). It is recommended for + * clarity purposes to use something that doesn't look like a typical command, for example + * "(http)". + * + * @param remote - some free-form identifier of the remote connection. For example, this could + * be a remote IP address. Can be blank if there is nothing suitable. + * + * @param callback - the function to call from a worker thread when the injected task is + * processed. Takes no arguments. + */ + void inject_task(const std::string& category, std::string command, std::string remote, std::function callback); + /// The key pair this LokiMQ was created with; if empty keys were given during construction then /// this returns the generated keys. const std::string& get_pubkey() const { return pubkey; } diff --git a/lokimq/proxy.cpp b/lokimq/proxy.cpp index 9312d6a..8c9ecb3 100644 --- a/lokimq/proxy.cpp +++ b/lokimq/proxy.cpp @@ -263,6 +263,9 @@ void LokiMQ::proxy_control_message(std::vector& parts) { LMQ_TRACE("proxy batch jobs"); auto ptrval = bt_deserialize(data); return proxy_batch(reinterpret_cast(ptrval)); + } else if (cmd == "INJECT") { + LMQ_TRACE("proxy inject"); + return proxy_inject_task(detail::deserialize_object(bt_deserialize(data))); } else if (cmd == "SET_SNS") { return proxy_set_active_sns(data); } else if (cmd == "UPDATE_SNS") { @@ -684,5 +687,4 @@ void LokiMQ::proxy_process_queue() { } } - } diff --git a/lokimq/worker.cpp b/lokimq/worker.cpp index 86386d4..68e860e 100644 --- a/lokimq/worker.cpp +++ b/lokimq/worker.cpp @@ -92,13 +92,19 @@ void LokiMQ::worker_thread(unsigned int index, std::optional tagged try { if (run.is_batch_job) { + auto* batch = std::get(run.to_run); if (run.batch_jobno >= 0) { - LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, "#", run.batch_jobno); - run.batch->run_job(run.batch_jobno); + LMQ_TRACE("worker thread ", worker_id, " running batch ", batch, "#", run.batch_jobno); + batch->run_job(run.batch_jobno); } else if (run.batch_jobno == -1) { - LMQ_TRACE("worker thread ", worker_id, " running batch ", run.batch, " completion"); - run.batch->job_completion(); + LMQ_TRACE("worker thread ", worker_id, " running batch ", batch, " completion"); + batch->job_completion(); } + } else if (run.is_injected) { + auto& func = std::get>(run.to_run); + LMQ_TRACE("worker thread ", worker_id, " invoking injected command ", run.command); + func(); + func = nullptr; } else { message.conn = run.conn; message.access = run.access; @@ -107,7 +113,8 @@ void LokiMQ::worker_thread(unsigned int index, std::optional tagged LMQ_TRACE("Got incoming command from ", message.remote, "/", message.conn, message.conn.route.empty() ? " (outgoing)" : " (incoming)"); - if (run.callback->second /*is_request*/) { + auto& [callback, is_request] = *std::get*>(run.to_run); + if (is_request) { message.reply_tag = {run.data_parts[0].data(), run.data_parts[0].size()}; for (auto it = run.data_parts.begin() + 1; it != run.data_parts.end(); ++it) message.data.emplace_back(it->data(), it->size()); @@ -117,7 +124,7 @@ void LokiMQ::worker_thread(unsigned int index, std::optional tagged } LMQ_TRACE("worker thread ", worker_id, " invoking ", run.command, " callback with ", message.data.size(), " message parts"); - run.callback->first(message); + callback(message); } } catch (const bt_deserialize_invalid& e) { @@ -194,16 +201,17 @@ void LokiMQ::proxy_worker_message(std::vector& parts) { active--; } bool clear_job = false; + auto* batch = std::get(run.to_run); if (run.batch_jobno == -1) { // Returned from the completion function clear_job = true; } else { - auto [state, thread] = run.batch->job_finished(); + auto [state, thread] = batch->job_finished(); if (state == detail::BatchState::complete) { if (thread == -1) { // run directly in proxy LMQ_TRACE("Completion job running directly in proxy"); try { - run.batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD + batch->job_completion(); // RUN DIRECTLY IN PROXY THREAD } catch (const std::exception &e) { // Raise these to error levels: the caller really shouldn't be doing // anything non-trivial in an in-proxy completion function! @@ -219,7 +227,7 @@ void LokiMQ::proxy_worker_message(std::vector& parts) { : run.is_reply_job ? reply_jobs : batch_jobs; - jobs.emplace(run.batch, -1); + jobs.emplace(batch, -1); } } else if (state == detail::BatchState::done) { // No completion job @@ -229,9 +237,9 @@ void LokiMQ::proxy_worker_message(std::vector& parts) { } if (clear_job) { - batches.erase(run.batch); - delete run.batch; - run.batch = nullptr; + batches.erase(batch); + delete batch; + run.to_run = static_cast(nullptr); } } else { assert(run.cat->active_threads > 0); @@ -360,5 +368,38 @@ void LokiMQ::proxy_to_worker(size_t conn_index, std::vector& par category.active_threads++; } +void LokiMQ::inject_task(const std::string& category, std::string command, std::string remote, std::function callback) { + if (!callback) return; + auto it = categories.find(category); + if (it == categories.end()) + throw std::out_of_range{"Invalid category `" + category + "': category does not exist"}; + detail::send_control(get_control_socket(), "INJECT", bt_serialize(detail::serialize_object( + injected_task{it->second, std::move(command), std::move(remote), std::move(callback)}))); +} + +void LokiMQ::proxy_inject_task(injected_task task) { + auto& category = task.cat; + if (category.active_threads >= category.reserved_threads && active_workers() >= general_workers) { + // No free worker slot, queue for later + if (category.max_queue >= 0 && category.queued >= category.max_queue) { + LMQ_LOG(warn, "No space to queue injected task ", task.command, "; already have ", category.queued, + "commands queued in that category (max ", category.max_queue, "); dropping task"); + return; + } + LMQ_LOG(debug, "No available free workers for injected task ", task.command, "; queuing for later"); + pending_commands.emplace_back(category, std::move(task.command), std::move(task.callback), std::move(task.remote)); + category.queued++; + return; + } + + auto& run = get_idle_worker(); + LMQ_TRACE("Forwarding incoming injected task ", task.command, " from ", task.remote, " to worker ", run.worker_routing_id); + run.load(&category, std::move(task.command), std::move(task.remote), std::move(task.callback)); + + proxy_run_worker(run); + category.active_threads++; +} + + } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 15fda07..2b377ce 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,7 @@ set(LMQ_TEST_SRC test_commands.cpp test_encoding.cpp test_failures.cpp + test_inject.cpp test_requests.cpp test_tagged_threads.cpp ) diff --git a/tests/test_inject.cpp b/tests/test_inject.cpp new file mode 100644 index 0000000..168e02f --- /dev/null +++ b/tests/test_inject.cpp @@ -0,0 +1,87 @@ +#include "common.h" + +using namespace lokimq; + +TEST_CASE("injected external commands", "[injected]") { + std::string listen = "tcp://127.0.0.1:4567"; + LokiMQ server{ + "", "", // generate ephemeral keys + false, // not a service node + [](auto) { return ""; }, + get_logger("S» "), + LogLevel::trace + }; + server.set_general_threads(1); + server.listen_curve(listen); + + std::atomic hellos = 0; + std::atomic done = false; + server.add_category("public", AuthLevel::none, 3); + server.add_command("public", "hello", [&](Message& m) { + hellos++; + while (!done) std::this_thread::sleep_for(10ms); + }); + + server.start(); + + LokiMQ client{get_logger("C» "), LogLevel::trace}; + client.start(); + + std::atomic got{false}; + bool success = false; + + auto c = client.connect_remote(listen, + [&](auto conn) { success = true; got = true; }, + [&](auto conn, std::string_view) { got = true; }, + server.get_pubkey()); + + wait_for_conn(got); + { + auto lock = catch_lock(); + REQUIRE( got ); + REQUIRE( success ); + } + + // First make sure that basic message respects the 3 thread limit + client.send(c, "public.hello"); + client.send(c, "public.hello"); + client.send(c, "public.hello"); + client.send(c, "public.hello"); + wait_for([&] { return hellos >= 3; }); + std::this_thread::sleep_for(20ms); + { + auto lock = catch_lock(); + REQUIRE( hellos == 3 ); + } + done = true; + wait_for([&] { return hellos >= 4; }); + { + auto lock = catch_lock(); + REQUIRE( hellos == 4 ); + } + + // Now try injecting external commands + done = false; + hellos = 0; + client.send(c, "public.hello"); + wait_for([&] { return hellos >= 1; }); + server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); }); + wait_for([&] { return hellos >= 11; }); + client.send(c, "public.hello"); + wait_for([&] { return hellos >= 12; }); + server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); }); + server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); }); + server.inject_task("public", "(injected)", "localhost", [&] { hellos += 10; while (!done) std::this_thread::sleep_for(10ms); }); + wait_for([&] { return hellos >= 12; }); + std::this_thread::sleep_for(20ms); + { + auto lock = catch_lock(); + REQUIRE( hellos == 12 ); + } + done = true; + wait_for([&] { return hellos >= 42; }); + { + auto lock = catch_lock(); + REQUIRE( hellos == 42 ); + } +}