From b0edf998f89832bc0bfcda1df18daf3e6bcfa783 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Wed, 27 Sep 2023 16:06:08 -0300 Subject: [PATCH] Fix python binding segfault This fixes a segfault in the reply handling code, where the destructor of a `py::function` wasn't necessarily happening with the GIL held, which could make Python segfault. The solution that was here to address this didn't work because it was relying on std::optionals, but it is possible (and apparently sometimes happens) that the std::function this lambda gets stuffed into gets moved (or maybe copied?), which then results in *two* lambda destructions: we were correctly dealing with the one that eventually gets called by clearing things properly when it gets called, but the temporary destructor also fires, and that is the one that broke. This changes it to instead leak bare pointers into the lambda and then recapture them inside when we get called; since we are guaranteed to be called exactly once, this recaptures them without losing them but doesn't incur destruction of a py::function deep in oxenmq (outside of GIL scope). --- setup.py | 2 +- src/oxenmq.cpp | 41 ++++++++++++++++++++++------------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/setup.py b/setup.py index d6f63d6..d64c8ae 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup # Available at setup time due to pyproject.toml from pybind11.setup_helpers import Pybind11Extension, build_ext -__version__ = "1.0.4" +__version__ = "1.0.5" # Note: # Sort input source files if you glob sources to ensure bit-for-bit diff --git a/src/oxenmq.cpp b/src/oxenmq.cpp index c10c413..b691ea9 100644 --- a/src/oxenmq.cpp +++ b/src/oxenmq.cpp @@ -757,12 +757,12 @@ the background).)") } bool request = kwargs.contains("request") && kwargs["request"].cast(); - std::optional on_reply, on_reply_failure; + std::unique_ptr on_reply, on_reply_failure; if (request) { if (kwargs.contains("on_reply")) - on_reply = kwargs["on_reply"].cast(); + on_reply = std::make_unique(kwargs["on_reply"].cast()); if (kwargs.contains("on_reply_failure")) - on_reply_failure = kwargs["on_reply_failure"].cast(); + on_reply_failure = std::make_unique(kwargs["on_reply_failure"].cast()); } else if (kwargs.contains("on_reply") || kwargs.contains("on_reply_failure")) { throw std::logic_error{"Error: send(...) on_reply=/on_reply_failure= option " "requires request=True (perhaps you meant to use `.request(...)` instead?)"}; @@ -795,29 +795,32 @@ the background).)") hint, optional, incoming, outgoing, keep_alive, request_timeout, std::move(qfail), std::move(qfull)); } else { - auto reply_cb = [on_reply = std::move(on_reply), on_fail = std::move(on_reply_failure)] - (bool success, std::vector data) mutable { + auto reply_cb = [reply_rawptr = on_reply.release(), fail_rawptr = on_reply_failure.release()] + (bool success, std::vector data) { // The gil here makes things tricky: the function invocation itself is // already gil protected, but the *destruction* of the lambda isn't, and // that breaks things because the destruction frees a python reference to // the callback. However oxenmq invokes this callback exactly once so we - // can deal with it by stealing the captures out of the lambda to force - // destruction here, with the gil held. + // can deal with it by leaking raw pointers into the lambda captures then + // reclaiming them the one and only time we are called. py::gil_scoped_acquire gil; - auto reply = std::move(on_reply); - auto fail = std::move(on_fail); + std::unique_ptr reply{reply_rawptr}; + std::unique_ptr fail{fail_rawptr}; - if (success ? !reply : !fail) - return; - py::list l; if (success) { - for (const auto& part : data) - l.append(py::memoryview::from_memory(part.data(), part.size())); - (*reply)(l); - } else if (on_fail) { - for (const auto& part : data) - l.append(py::bytes(part.data(), part.size())); - (*fail)(l); + if (reply) { + py::list l; + for (const auto& part : data) + l.append(py::memoryview::from_memory(part.data(), part.size())); + (*reply)(l); + } + } else { + if (fail) { + py::list l; + for (const auto& part : data) + l.append(py::bytes(part.data(), part.size())); + (*fail)(l); + } } };