diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 1ed9ff343..7bb02d3f8 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -12,7 +12,7 @@ runs: - name: Set up Python uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: - python-version: "3.8" + python-version: "3.12" - name: Update System (Linux) shell: bash diff --git a/pyproject.toml b/pyproject.toml index 5e1c0b1ac..228edc212 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build" [project] name = "scaler" description = "Scaler Distribution Framework" -requires-python = ">=3.8" +requires-python = ">=3.10" readme = { file = "README.md", content-type = "text/markdown" } license = { text = "Apache 2.0" } authors = [{ name = "Citi", email = "opensource@citi.com" }] diff --git a/scaler/io/async_object_storage_connector.py b/scaler/io/async_object_storage_connector.py index 940201cce..32eec5095 100644 --- a/scaler/io/async_object_storage_connector.py +++ b/scaler/io/async_object_storage_connector.py @@ -7,6 +7,7 @@ from typing import Dict, Optional, Tuple from scaler.io.mixins import AsyncObjectStorageConnector +from scaler.io.ymq.ymq import * from scaler.protocol.capnp._python import _object_storage # noqa from scaler.protocol.python.object_storage import ObjectRequestHeader, ObjectResponseHeader, to_capnp_object_id from scaler.utility.exceptions import ObjectStorageException @@ -22,19 +23,18 @@ def __init__(self): self._connected_event = asyncio.Event() - self._reader: Optional[asyncio.StreamReader] = None - self._writer: Optional[asyncio.StreamWriter] = None - self._next_request_id = 0 self._pending_get_requests: Dict[ObjectID, asyncio.Future] = {} - self._identity: bytes = f"{os.getpid()}|{socket.gethostname().split('.')[0]}|{uuid.uuid4()}".encode() + self._lock = asyncio.Lock() + self._identity: str = f"{os.getpid()}|{socket.gethostname().split('.')[0]}|{uuid.uuid4()}" + self._io_context: IOContext = IOContext() + self._io_socket = self._io_context.createIOSocket_sync(self._identity, IOSocketType.Connector) def __del__(self): if not self.is_connected(): return - - self._writer.close() + self._io_socket = None async def connect(self, host: str, port: int): self._host = host @@ -42,20 +42,7 @@ async def connect(self, host: str, port: int): if self.is_connected(): raise ObjectStorageException("connector is already connected.") - - self._reader, self._writer = await asyncio.open_connection(self._host, self._port) - await self.__read_framed_message() - self.__write_framed(self._identity) - - try: - await self._writer.drain() - except ConnectionResetError: - self.__raise_connection_failure() - - # Makes sure the socket is TCP_NODELAY. It seems to be the case by default, but that's not specified in the - # asyncio's documentation and might change in the future. - self._writer.get_extra_info("socket").setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - + await self._io_socket.connect(self.address) self._connected_event.set() async def wait_until_connected(self): @@ -67,23 +54,10 @@ def is_connected(self) -> bool: async def destroy(self): if not self.is_connected(): return - - if not self._writer.is_closing: - self._writer.close() - - await self._writer.wait_closed() - - @property - def reader(self) -> Optional[asyncio.StreamReader]: - return self._reader - - @property - def writer(self) -> Optional[asyncio.StreamWriter]: - return self._writer + self._io_socket = None @property def address(self) -> str: - self.__ensure_is_connected() return f"tcp://{self._host}:{self._port}" async def routine(self): @@ -136,12 +110,9 @@ async def duplicate_object_id(self, object_id: ObjectID, new_object_id: ObjectID ) def __ensure_is_connected(self): - if self._writer is None: + if self._io_socket is None: raise ObjectStorageException("connector is not connected.") - if self._writer.is_closing(): - raise ObjectStorageException("connection is closed.") - async def __send_request( self, object_id: ObjectID, @@ -150,7 +121,6 @@ async def __send_request( payload: Optional[bytes], ): self.__ensure_is_connected() - assert self._writer is not None request_id = self._next_request_id self._next_request_id += 1 @@ -158,67 +128,59 @@ async def __send_request( header = ObjectRequestHeader.new_msg(object_id, payload_length, request_id, request_type) - self.__write_request_header(header) + try: + async with self._lock: + await self.__write_request_header(header) - if payload is not None: - self.__write_request_payload(payload) + if payload is not None: + await self.__write_request_payload(payload) - try: - await self._writer.drain() - except ConnectionResetError: + except YMQException: + self._io_socket = None self.__raise_connection_failure() - def __write_request_header(self, header: ObjectRequestHeader): - assert self._writer is not None - self.__write_framed(header.get_message().to_bytes()) + async def __write_request_header(self, header: ObjectRequestHeader): + assert self._io_socket is not None + await self._io_socket.send(Message(address=None, payload=header.get_message().to_bytes())) - def __write_request_payload(self, payload: bytes): - assert self._writer is not None - self.__write_framed(payload) + async def __write_request_payload(self, payload: bytes): + assert self._io_socket is not None + await self._io_socket.send(Message(address=None, payload=payload)) async def __receive_response(self) -> Optional[Tuple[ObjectResponseHeader, bytes]]: - assert self._reader is not None - - if self._writer.is_closing(): + if self._io_socket is None: return None try: header = await self.__read_response_header() payload = await self.__read_response_payload(header) - except asyncio.IncompleteReadError: + except YMQException: + self._io_socket = None self.__raise_connection_failure() return header, payload async def __read_response_header(self) -> ObjectResponseHeader: - assert self._reader is not None + assert self._io_socket is not None - header_data = await self.__read_framed_message() + msg = await self._io_socket.recv() + header_data = msg.payload.data assert len(header_data) == ObjectResponseHeader.MESSAGE_LENGTH with _object_storage.ObjectResponseHeader.from_bytes(header_data) as header_message: return ObjectResponseHeader(header_message) async def __read_response_payload(self, header: ObjectResponseHeader) -> bytes: - assert self._reader is not None + assert self._io_socket is not None + # assert self._reader is not None if header.payload_length > 0: - res = await self.__read_framed_message() - assert len(res) == header.payload_length - return res + res = await self._io_socket.recv() + assert len(res.payload) == header.payload_length + return res.payload.data else: return b"" - async def __read_framed_message(self) -> bytes: - length_bytes = await self._reader.readexactly(8) - (payload_length,) = struct.unpack(" 0 else bytes() - - def __write_framed(self, payload: bytes): - self._writer.write(struct.pack(" str: @@ -114,7 +109,7 @@ def duplicate_object_id(self, object_id: ObjectID, new_object_id: ObjectID) -> N self.__ensure_empty_payload(response_payload) def __ensure_is_connected(self): - if self._socket is None: + if self._io_socket is None: raise ObjectStorageException("connector is closed.") def __ensure_response_type( @@ -135,7 +130,7 @@ def __send_request( payload: Optional[bytes] = None, ): self.__ensure_is_connected() - assert self._socket is not None + assert self._io_socket is not None request_id = self._next_request_id self._next_request_id += 1 @@ -145,102 +140,46 @@ def __send_request( header_bytes = header.get_message().to_bytes() if payload is not None: - self.__send_buffers( - [struct.pack(" None: - if len(buffers) < 1: - return - - total_size = sum(len(buffer) for buffer in buffers) - - # If the message is small enough, first try to send it at once with sendmsg(). This would ensure the message can - # be transmitted within a single TCP segment. - if total_size < MAX_CHUNK_SIZE: - sent = self._socket.sendmsg(buffers) - - if sent <= 0: - self.__raise_connection_failure() - - remaining_buffers = collections.deque(buffers) - while sent > len(remaining_buffers[0]): - removed_buffer = remaining_buffers.popleft() - sent -= len(removed_buffer) - - if sent > 0: - # Truncate the first partially sent buffer - remaining_buffers[0] = memoryview(remaining_buffers[0])[sent:] - - buffers = list(remaining_buffers) - - # Send the remaining buffers sequentially - for buffer in buffers: - self.__send_buffer(buffer) - - def __send_buffer(self, buffer: bytes) -> None: - buffer_view = memoryview(buffer) + self._io_socket.send_sync(Message(address=None, payload=header_bytes)) - total_sent = 0 - while total_sent < len(buffer): - sent = self._socket.send(buffer_view[total_sent : MAX_CHUNK_SIZE + total_sent]) + def __receive_response(self): + assert self._io_socket is not None - if sent <= 0: - self.__raise_connection_failure() - - total_sent += sent - - def __receive_response(self) -> Tuple[ObjectResponseHeader, bytearray]: - assert self._socket is not None - - header = self.__read_response_header() - payload = self.__read_response_payload(header) - - return header, payload + try: + header = self.__read_response_header() + payload = self.__read_response_payload(header) + return header, payload + except YMQException: + self.__raise_connection_failure() def __read_response_header(self) -> ObjectResponseHeader: - assert self._socket is not None + assert self._io_socket is not None - header_bytearray = self.__read_framed_message() + header_bytes = self._io_socket.recv_sync().payload.data + if header_bytes is None: + self.__raise_connection_failure() # pycapnp does not like to read from a bytearray object. This look like an not-yet-resolved issue. # That's is annoying because it leads to an unnecessary copy of the header's buffer. # See https://github.com/capnproto/pycapnp/issues/153 - header_bytes = bytes(header_bytearray) + # header_bytes = bytes(header_bytearray) with _object_storage.ObjectResponseHeader.from_bytes(header_bytes) as header_message: return ObjectResponseHeader(header_message) def __read_response_payload(self, header: ObjectResponseHeader) -> bytearray: if header.payload_length > 0: - res = self.__read_framed_message() + res = self._io_socket.recv_sync().payload.data + if res is None: + self.__raise_connection_failure() assert len(res) == header.payload_length - return res + return bytearray(res) else: return bytearray() - def __read_exactly(self, length: int) -> bytearray: - buffer = bytearray(length) - - total_received = 0 - while total_received < length: - chunk_size = min(MAX_CHUNK_SIZE, length - total_received) - received = self._socket.recv_into(memoryview(buffer)[total_received:], chunk_size) - - if received <= 0: - self.__raise_connection_failure() - - total_received += received - - return buffer - - def __read_framed_message(self) -> bytearray: - length_bytes = self.__read_exactly(8) - (payload_length,) = struct.unpack(" 0 else bytearray() - @staticmethod def __raise_connection_failure(): raise ObjectStorageException("connection failure to object storage server.") diff --git a/scaler/io/ymq/CMakeLists.txt b/scaler/io/ymq/CMakeLists.txt index 63b898929..5d6cc2d1c 100644 --- a/scaler/io/ymq/CMakeLists.txt +++ b/scaler/io/ymq/CMakeLists.txt @@ -61,13 +61,12 @@ if(LINUX) find_package(Python3 COMPONENTS Development.Module REQUIRED) add_library(py_ymq SHARED - pymod_ymq/async.h pymod_ymq/bytes.h pymod_ymq/exception.h + pymod_ymq/gil.h pymod_ymq/message.h pymod_ymq/io_context.h pymod_ymq/io_socket.h - pymod_ymq/utils.h pymod_ymq/ymq.h pymod_ymq/ymq.cpp ) @@ -81,7 +80,7 @@ if(LINUX) set_target_properties(py_ymq PROPERTIES PREFIX "" - OUTPUT_NAME "ymq" + OUTPUT_NAME "_ymq" LINKER_LANGUAGE CXX ) diff --git a/scaler/io/ymq/ymq.pyi b/scaler/io/ymq/_ymq.pyi similarity index 68% rename from scaler/io/ymq/ymq.pyi rename to scaler/io/ymq/_ymq.pyi index 03229bca9..f27e9b45e 100644 --- a/scaler/io/ymq/ymq.pyi +++ b/scaler/io/ymq/_ymq.pyi @@ -1,9 +1,8 @@ # NOTE: NOT IMPLEMENTATION, TYPE INFORMATION ONLY # This file contains type stubs for the Ymq Python C Extension module import sys -from collections.abc import Awaitable from enum import IntEnum -from typing import SupportsBytes +from typing import Callable, Optional, SupportsBytes, Union if sys.version_info >= (3, 12): from collections.abc import Buffer @@ -39,46 +38,33 @@ class IOSocketType(IntEnum): Unicast = 3 Multicast = 4 -class IOContext: +class BaseIOContext: num_threads: int def __init__(self, num_threads: int = 1) -> None: ... def __repr__(self) -> str: ... - def createIOSocket(self, /, identity: str, socket_type: IOSocketType) -> Awaitable[IOSocket]: + def createIOSocket( + self, callback: Callable[[Union[BaseIOSocket, Exception]], None], identity: str, socket_type: IOSocketType + ) -> None: """Create an io socket with an identity and socket type""" - def createIOSocket_sync(self, /, identity: str, socket_type: IOSocketType) -> IOSocket: - """Create an io socket with an identity and socket type synchronously""" - -class IOSocket: +class BaseIOSocket: identity: str socket_type: IOSocketType def __repr__(self) -> str: ... - async def send(self, message: Message) -> None: + def send(self, callback: Callable[[Optional[Exception]], None], message: Message) -> None: """Send a message to one of the socket's peers""" - async def recv(self) -> Message: + def recv(self, callback: Callable[[Union[Message, Exception]], None]) -> None: """Receive a message from one of the socket's peers""" - async def bind(self, address: str) -> None: + def bind(self, callback: Callable[[Optional[Exception]], None], address: str) -> None: """Bind the socket to an address and listen for incoming connections""" - async def connect(self, address: str) -> None: + def connect(self, callback: Callable[[Optional[Exception]], None], address: str) -> None: """Connect to a remote socket""" - def send_sync(self, message: Message) -> None: - """Send a message to one of the socket's peers synchronously""" - - def recv_sync(self) -> Message: - """Receive a message from one of the socket's peers synchronously""" - - def bind_sync(self, address: str) -> None: - """Bind the socket to an address and listen for incoming connections synchronously""" - - def connect_sync(self, address: str) -> None: - """Connect to a remote socket synchronously""" - class ErrorCode(IntEnum): Uninit = 0 InvalidPortFormat = 1 @@ -108,6 +94,3 @@ class YMQException(Exception): def __init__(self, /, code: ErrorCode, message: str) -> None: ... def __repr__(self) -> str: ... def __str__(self) -> str: ... - -class YMQInterruptedException(YMQException): - def __init__(self) -> None: ... diff --git a/scaler/io/ymq/pymod_ymq/async.h b/scaler/io/ymq/pymod_ymq/async.h deleted file mode 100644 index 602097eb9..000000000 --- a/scaler/io/ymq/pymod_ymq/async.h +++ /dev/null @@ -1,97 +0,0 @@ -#pragma once - -// Python -#include "scaler/io/ymq/pymod_ymq/python.h" - -// C++ -#include - -// First-party -#include "scaler/io/ymq/pymod_ymq/ymq.h" - -// wraps an async callback that accepts a Python asyncio future -static PyObject* async_wrapper(PyObject* self, const std::function&& callback) -{ - auto state = YMQStateFromSelf(self); - if (!state) - return nullptr; - - OwnedPyObject loop = PyObject_CallMethod(*state->asyncioModule, "get_event_loop", nullptr); - if (!loop) { - PyErr_SetString(PyExc_RuntimeError, "Failed to get event loop"); - return nullptr; - } - - OwnedPyObject future = PyObject_CallMethod(*loop, "create_future", nullptr); - if (!future) { - PyErr_SetString(PyExc_RuntimeError, "Failed to create future"); - return nullptr; - } - - // create the awaitable before calling the callback - // this ensures that we create a new strong reference to the future before the callback decrefs it - auto awaitable = PyObject_CallFunction(*state->PyAwaitableType, "O", *future); - - // async - // we transfer ownership of the future to the callback - // TODO: investigate having the callback take an OwnedPyObject, and just std::move() - callback(state, future.take()); - - return awaitable; -} - -struct Awaitable { - PyObject_HEAD; - OwnedPyObject<> future; -}; - -extern "C" { - -static int Awaitable_init(Awaitable* self, PyObject* args, PyObject* kwds) -{ - PyObject* future = nullptr; - if (!PyArg_ParseTuple(args, "O", &future)) - return -1; - - new (&self->future) OwnedPyObject<>(); - self->future = OwnedPyObject<>::fromBorrowed(future); - - return 0; -} - -static PyObject* Awaitable_await(Awaitable* self) -{ - // Easy: coroutines are just iterators and we don't need anything fancy - // so we can just return the future's iterator! - return PyObject_GetIter(*self->future); -} - -static void Awaitable_dealloc(Awaitable* self) -{ - try { - self->future.~OwnedPyObject(); - } catch (...) { - PyErr_SetString(PyExc_RuntimeError, "Failed to deallocate Awaitable"); - PyErr_WriteUnraisable((PyObject*)self); - } - - auto* tp = Py_TYPE(self); - tp->tp_free(self); - Py_DECREF(tp); -} -} - -static PyType_Slot Awaitable_slots[] = { - {Py_tp_init, (void*)Awaitable_init}, - {Py_tp_dealloc, (void*)Awaitable_dealloc}, - {Py_am_await, (void*)Awaitable_await}, - {0, nullptr}, -}; - -static PyType_Spec Awaitable_spec { - .name = "ymq.Awaitable", - .basicsize = sizeof(Awaitable), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, - .slots = Awaitable_slots, -}; diff --git a/scaler/io/ymq/pymod_ymq/bytes.h b/scaler/io/ymq/pymod_ymq/bytes.h index 0941c02c9..40743d7b6 100644 --- a/scaler/io/ymq/pymod_ymq/bytes.h +++ b/scaler/io/ymq/pymod_ymq/bytes.h @@ -6,8 +6,6 @@ // First-party #include "scaler/io/ymq/bytes.h" -using namespace scaler::ymq; - struct PyBytesYMQ { PyObject_HEAD; Bytes bytes; @@ -19,9 +17,8 @@ static int PyBytesYMQ_init(PyBytesYMQ* self, PyObject* args, PyObject* kwds) { Py_buffer view {.buf = nullptr}; const char* keywords[] = {"bytes", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|y*", (char**)keywords, &view)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|y*", (char**)keywords, &view)) return -1; // Error parsing arguments - } if (!view.buf) { // If no bytes were provided, initialize with an empty Bytes object @@ -94,11 +91,6 @@ static PyGetSetDef PyBytesYMQ_properties[] = { {nullptr, nullptr, nullptr, nullptr, nullptr}, // Sentinel }; -static PyBufferProcs PyBytesYMQBufferProcs = { - .bf_getbuffer = (getbufferproc)PyBytesYMQ_getbuffer, - .bf_releasebuffer = (releasebufferproc)PyBytesYMQ_releasebuffer, -}; - static PyType_Slot PyBytesYMQ_slots[] = { {Py_tp_init, (void*)PyBytesYMQ_init}, {Py_tp_dealloc, (void*)PyBytesYMQ_dealloc}, @@ -113,7 +105,7 @@ static PyType_Slot PyBytesYMQ_slots[] = { }; static PyType_Spec PyBytesYMQ_spec = { - .name = "ymq.Bytes", + .name = "_ymq.Bytes", .basicsize = sizeof(PyBytesYMQ), .itemsize = 0, .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, diff --git a/scaler/io/ymq/pymod_ymq/exception.h b/scaler/io/ymq/pymod_ymq/exception.h index 2369d3b45..fb63862a0 100644 --- a/scaler/io/ymq/pymod_ymq/exception.h +++ b/scaler/io/ymq/pymod_ymq/exception.h @@ -3,13 +3,12 @@ // Python #include "scaler/io/ymq/pymod_ymq/python.h" -// C++ -#include - // First-party -#include "scaler/io/ymq/pymod_ymq/utils.h" +#include "scaler/io/ymq/error.h" #include "scaler/io/ymq/pymod_ymq/ymq.h" +using namespace scaler::ymq; + // the order of the members in the exception args tuple const Py_ssize_t YMQException_errorCodeIndex = 0; const Py_ssize_t YMQException_messageIndex = 1; @@ -81,9 +80,9 @@ static PyType_Slot YMQException_slots[] = { }; static PyType_Spec YMQException_spec = { - "ymq.YMQException", sizeof(YMQException), 0, Py_TPFLAGS_DEFAULT, YMQException_slots}; + "_ymq.YMQException", sizeof(YMQException), 0, Py_TPFLAGS_DEFAULT, YMQException_slots}; -OwnedPyObject<> YMQException_argtupleFromCoreError(YMQState* state, const Error* error) +inline OwnedPyObject<> YMQException_argtupleFromCoreError(YMQState* state, const Error* error) { OwnedPyObject code = PyLong_FromLong(static_cast(error->_errorCode)); @@ -103,7 +102,7 @@ OwnedPyObject<> YMQException_argtupleFromCoreError(YMQState* state, const Error* return PyTuple_Pack(2, *pyCode, *message); } -void YMQException_setFromCoreError(YMQState* state, const Error* error) +inline void YMQException_setFromCoreError(YMQState* state, const Error* error) { auto tuple = YMQException_argtupleFromCoreError(state, error); if (!tuple) @@ -112,7 +111,7 @@ void YMQException_setFromCoreError(YMQState* state, const Error* error) PyErr_SetObject(*state->PyExceptionType, *tuple); } -PyObject* YMQException_createFromCoreError(YMQState* state, const Error* error) +inline PyObject* YMQException_createFromCoreError(YMQState* state, const Error* error) { auto tuple = YMQException_argtupleFromCoreError(state, error); if (!tuple) diff --git a/scaler/io/ymq/pymod_ymq/gil.h b/scaler/io/ymq/pymod_ymq/gil.h new file mode 100644 index 000000000..c28590e8d --- /dev/null +++ b/scaler/io/ymq/pymod_ymq/gil.h @@ -0,0 +1,15 @@ +#include "scaler/io/ymq/pymod_ymq/python.h" + +class AcquireGIL { +public: + AcquireGIL() : _state(PyGILState_Ensure()) {} + ~AcquireGIL() { PyGILState_Release(_state); } + + AcquireGIL(const AcquireGIL&) = delete; + AcquireGIL& operator=(const AcquireGIL&) = delete; + AcquireGIL(AcquireGIL&&) = delete; + AcquireGIL& operator=(AcquireGIL&&) = delete; + +private: + PyGILState_STATE _state; +}; diff --git a/scaler/io/ymq/pymod_ymq/io_context.h b/scaler/io/ymq/pymod_ymq/io_context.h index deb63003e..c3c116749 100644 --- a/scaler/io/ymq/pymod_ymq/io_context.h +++ b/scaler/io/ymq/pymod_ymq/io_context.h @@ -4,8 +4,6 @@ #include "scaler/io/ymq/pymod_ymq/python.h" // C++ -#include -#include #include // First-party @@ -64,25 +62,26 @@ static PyObject* PyIOContext_repr(PyIOContext* self) return PyUnicode_FromFormat("", (void*)self->ioContext.get()); } -static PyObject* PyIOContext_createIOSocket_( - PyIOContext* self, - PyObject* args, - PyObject* kwargs, - std::function fn) +static PyObject* PyIOContext_numThreads_getter(PyIOContext* self, void* Py_UNUSED(closure)) { - const char* identity = nullptr; - Py_ssize_t identityLen = 0; - PyObject* pySocketType = nullptr; - const char* kwlist[] = {"identity", "pySocketType", nullptr}; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#O", (char**)kwlist, &identity, &identityLen, &pySocketType)) - return nullptr; + return PyLong_FromSize_t(self->ioContext->numThreads()); +} +static PyObject* PyIOContext_createIOSocket(PyIOContext* self, PyObject* args, PyObject* kwargs) +{ YMQState* state = YMQStateFromSelf((PyObject*)self); - if (!state) return nullptr; + PyObject* callback = nullptr; + const char* identity = nullptr; + Py_ssize_t identityLen = 0; + PyObject* pySocketType = nullptr; + const char* kwlist[] = {"", "identity", "socket_type", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, "Os#O", (char**)kwlist, &callback, &identity, &identityLen, &pySocketType)) + return nullptr; + if (!PyObject_IsInstance(pySocketType, *state->PyIOSocketEnumType)) { PyErr_SetString(PyExc_TypeError, "Expected socket_type to be an instance of IOSocketType"); return nullptr; @@ -98,83 +97,36 @@ static PyObject* PyIOContext_createIOSocket_( } long socketTypeValue = PyLong_AsLong(*value); - if (socketTypeValue < 0 && PyErr_Occurred()) return nullptr; - IOSocketType socketType = static_cast(socketTypeValue); - + IOSocketType socketType = static_cast(socketTypeValue); OwnedPyObject ioSocket = PyObject_New(PyIOSocket, (PyTypeObject*)*state->PyIOSocketType); if (!ioSocket) return nullptr; + Py_INCREF(callback); + try { // ensure the fields are init new (&ioSocket->socket) std::shared_ptr(); new (&ioSocket->ioContext) std::shared_ptr(); ioSocket->ioContext = self->ioContext; - } catch (...) { - PyErr_SetString(PyExc_RuntimeError, "Failed to create IOSocket"); - return nullptr; - } - // move ownership of the ioSocket to the callback - return fn(ioSocket.take(), identity, socketType); -} + self->ioContext->createIOSocket( + std::string(identity, identityLen), socketType, [callback, ioSocket](auto socket) { + AcquireGIL _; -static PyObject* PyIOContext_createIOSocket(PyIOContext* self, PyObject* args, PyObject* kwargs) -{ - return PyIOContext_createIOSocket_( - self, args, kwargs, [self](auto ioSocket, Identity identity, IOSocketType socketType) { - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - self->ioContext->createIOSocket(identity, socketType, [=](std::shared_ptr socket) { - future_set_result(future, [=] { - ioSocket->socket = std::move(socket); - return (PyObject*)ioSocket; - }); - }); + ioSocket->socket = socket; + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *ioSocket, nullptr); + Py_DECREF(callback); }); - }); -} - -static PyObject* PyIOContext_createIOSocket_sync(PyIOContext* self, PyObject* args, PyObject* kwargs) -{ - auto state = YMQStateFromSelf((PyObject*)self); - if (!state) + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Failed to create IOSocket"); return nullptr; + } - return PyIOContext_createIOSocket_( - self, args, kwargs, [self, state](auto ioSocket, Identity identity, IOSocketType socketType) { - PyThreadState* _save = PyEval_SaveThread(); - - std::shared_ptr socket {}; - try { - Waiter waiter(state->wakeupfd_rd); - - self->ioContext->createIOSocket( - identity, socketType, [waiter, &socket](std::shared_ptr s) mutable { - socket = std::move(s); - waiter.signal(); - }); - - if (waiter.wait()) - CHECK_SIGNALS; - } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to create io socket synchronously"); - return (PyObject*)nullptr; - } - - PyEval_RestoreThread(_save); - - ioSocket->socket = socket; - return (PyObject*)ioSocket; - }); -} - -static PyObject* PyIOContext_numThreads_getter(PyIOContext* self, void* Py_UNUSED(closure)) -{ - return PyLong_FromSize_t(self->ioContext->numThreads()); + Py_RETURN_NONE; } } // extern "C" @@ -184,10 +136,6 @@ static PyMethodDef PyIOContext_methods[] = { (PyCFunction)PyIOContext_createIOSocket, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Create a new IOSocket")}, - {"createIOSocket_sync", - (PyCFunction)PyIOContext_createIOSocket_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Create a new IOSocket")}, {nullptr, nullptr, 0, nullptr}, }; @@ -210,9 +158,9 @@ static PyType_Slot PyIOContext_slots[] = { }; static PyType_Spec PyIOContext_spec = { - .name = "ymq.IOContext", + .name = "_ymq.BaseIOContext", .basicsize = sizeof(PyIOContext), .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_BASETYPE, .slots = PyIOContext_slots, }; diff --git a/scaler/io/ymq/pymod_ymq/io_socket.h b/scaler/io/ymq/pymod_ymq/io_socket.h index 251a6a461..af318d109 100644 --- a/scaler/io/ymq/pymod_ymq/io_socket.h +++ b/scaler/io/ymq/pymod_ymq/io_socket.h @@ -4,27 +4,25 @@ #include "scaler/io/ymq/pymod_ymq/python.h" // C++ -#include -#include #include -#include #include // C +#include #include #include #include // First-party #include "scaler/io/ymq/bytes.h" +#include "scaler/io/ymq/error.h" #include "scaler/io/ymq/io_context.h" #include "scaler/io/ymq/io_socket.h" #include "scaler/io/ymq/message.h" -#include "scaler/io/ymq/pymod_ymq/async.h" #include "scaler/io/ymq/pymod_ymq/bytes.h" #include "scaler/io/ymq/pymod_ymq/exception.h" +#include "scaler/io/ymq/pymod_ymq/gil.h" #include "scaler/io/ymq/pymod_ymq/message.h" -#include "scaler/io/ymq/pymod_ymq/utils.h" #include "scaler/io/ymq/pymod_ymq/ymq.h" using namespace scaler::ymq; @@ -54,231 +52,133 @@ static void PyIOSocket_dealloc(PyIOSocket* self) } static PyObject* PyIOSocket_send(PyIOSocket* self, PyObject* args, PyObject* kwargs) -{ - // borrowed reference - PyMessage* message = nullptr; - const char* kwlist[] = {"message", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &message)) - return nullptr; - - auto address = message->address.is_none() ? Bytes() : std::move(message->address->bytes); - auto payload = std::move(message->payload->bytes); - - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - try { - self->socket->sendMessage({.address = std::move(address), .payload = std::move(payload)}, [=](auto result) { - future_set_result(future, [=] -> std::expected { - if (result) { - Py_RETURN_NONE; - } else { - return std::unexpected {YMQException_createFromCoreError(state, &result.error())}; - } - }); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to send message"); }); - } - }); -} - -static PyObject* PyIOSocket_send_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; - // borrowed reference + PyObject* callback = nullptr; PyMessage* message = nullptr; - const char* kwlist[] = {"message", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &message)) - return nullptr; - Bytes address = message->address.is_none() ? Bytes() : std::move(message->address->bytes); - Bytes payload = std::move(message->payload->bytes); - - PyThreadState* _save = PyEval_SaveThread(); - - std::shared_ptr> result = std::make_shared>(); - try { - Waiter waiter(state->wakeupfd_rd); - - self->socket->sendMessage({.address = std::move(address), .payload = std::move(payload)}, [=](auto r) mutable { - *result = std::move(r); - waiter.signal(); - }); + // empty str -> positional only + const char* kwlist[] = {"", "message", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", (char**)kwlist, &callback, &message)) + return nullptr; - if (waiter.wait()) - CHECK_SIGNALS; - } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to send synchronously"); + if (!PyObject_TypeCheck(message, (PyTypeObject*)*state->PyMessageType)) { + PyErr_SetString(PyExc_TypeError, "message must be a Message"); return nullptr; } - PyEval_RestoreThread(_save); + auto address = message->address.is_none() ? Bytes() : std::move(message->address->bytes); + auto payload = std::move(message->payload->bytes); + + Py_INCREF(callback); - if (!result) { - YMQException_setFromCoreError(state, &result->error()); + try { + self->socket->sendMessage( + {.address = std::move(address), .payload = std::move(payload)}, [callback, state](auto result) { + AcquireGIL _; + + if (result) { + OwnedPyObject result = PyObject_CallFunctionObjArgs(callback, Py_None, nullptr); + } else { + OwnedPyObject obj = YMQException_createFromCoreError(state, &result.error()); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *obj, nullptr); + } + + Py_DECREF(callback); + }); + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Failed to send message"); return nullptr; } Py_RETURN_NONE; } -static PyObject* PyIOSocket_recv(PyIOSocket* self, PyObject* args) -{ - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - self->socket->recvMessage([=](auto result) { - try { - future_set_result(future, [=] -> std::expected { - if (result.second._errorCode != Error::ErrorCode::Uninit) { - return std::unexpected {YMQException_createFromCoreError(state, &result.second)}; - } - - auto message = result.first; - OwnedPyObject address = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!address) - return YMQ_GetRaisedException(); - - address->bytes = std::move(message.address); - - OwnedPyObject payload = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!payload) - return YMQ_GetRaisedException(); - - payload->bytes = std::move(message.payload); - - OwnedPyObject pyMessage = - (PyMessage*)PyObject_CallFunction(*state->PyMessageType, "OO", *address, *payload); - if (!pyMessage) - return YMQ_GetRaisedException(); - - return (PyObject*)pyMessage.take(); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to receive message"); }); - } - }); - }); -} - -static PyObject* PyIOSocket_recv_sync(PyIOSocket* self, PyObject* args) +static PyObject* PyIOSocket_recv(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; - PyThreadState* _save = PyEval_SaveThread(); - - std::shared_ptr> result = std::make_shared>(); - try { - Waiter waiter(state->wakeupfd_rd); - - self->socket->recvMessage([=](auto r) mutable { - *result = std::move(r); - waiter.signal(); - }); - - if (waiter.wait()) - CHECK_SIGNALS; - } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to recv synchronously"); + PyObject* callback = nullptr; + const char* kwlist[] = {"", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &callback)) return nullptr; - } - PyEval_RestoreThread(_save); + Py_INCREF(callback); - if (result->second._errorCode != Error::ErrorCode::Uninit) { - YMQException_setFromCoreError(state, &result->second); - return nullptr; - } - - auto message = result->first; + try { + self->socket->recvMessage([callback, state](std::pair result) { + AcquireGIL _; - OwnedPyObject address = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!address) - return nullptr; + if (result.second._errorCode != Error::ErrorCode::Uninit) { + OwnedPyObject obj = YMQException_createFromCoreError(state, &result.second); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *obj, nullptr); + return; + } - address->bytes = std::move(message.address); + auto message = result.first; + OwnedPyObject address = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); + if (!address) + return completeCallbackWithRaisedException(callback); - OwnedPyObject payload = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!payload) - return nullptr; + address->bytes = std::move(message.address); - payload->bytes = std::move(message.payload); + OwnedPyObject payload = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); + if (!payload) + return completeCallbackWithRaisedException(callback); - OwnedPyObject pyMessage = - (PyMessage*)PyObject_CallFunction(*state->PyMessageType, "OO", *address, *payload); - if (!pyMessage) - return nullptr; + payload->bytes = std::move(message.payload); - return (PyObject*)pyMessage.take(); -} + OwnedPyObject pyMessage = + (PyMessage*)PyObject_CallFunction(*state->PyMessageType, "OO", *address, *payload); + if (!pyMessage) + return completeCallbackWithRaisedException(callback); -static PyObject* PyIOSocket_bind(PyIOSocket* self, PyObject* args, PyObject* kwargs) -{ - const char* address = nullptr; - Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *pyMessage, nullptr); + Py_DECREF(callback); + }); + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Failed to receive message"); return nullptr; + } - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - try { - self->socket->bindTo(std::string(address, addressLen), [=](auto result) { - future_set_result(future, [=] -> std::expected { - if (!result) { - return std::unexpected {YMQException_createFromCoreError(state, &result.error())}; - } - - Py_RETURN_NONE; - }); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to bind to address"); }); - } - }); + Py_RETURN_NONE; } -static PyObject* PyIOSocket_bind_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs) +static PyObject* PyIOSocket_bind(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; + PyObject* callback = nullptr; const char* address = nullptr; Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + const char* kwlist[] = {"", "address", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Os#", (char**)kwlist, &callback, &address, &addressLen)) return nullptr; - PyThreadState* _save = PyEval_SaveThread(); + Py_INCREF(callback); - auto result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + self->socket->bindTo(std::string(address, addressLen), [callback, state](auto result) { + AcquireGIL _; + + if (!result) { + OwnedPyObject exc = YMQException_createFromCoreError(state, &result.error()); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *exc, nullptr); + } else { + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, Py_None, nullptr); + } - self->socket->bindTo(std::string(address, addressLen), [=](auto r) mutable { - *result = std::move(r); - waiter.signal(); + Py_DECREF(callback); }); - - if (waiter.wait()) - CHECK_SIGNALS; } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to bind synchronously"); - return nullptr; - } - - PyEval_RestoreThread(_save); - - if (!result) { - YMQException_setFromCoreError(state, &result->error()); + PyErr_SetString(PyExc_RuntimeError, "Failed to bind to address"); return nullptr; } @@ -286,66 +186,35 @@ static PyObject* PyIOSocket_bind_sync(PyIOSocket* self, PyObject* args, PyObject } static PyObject* PyIOSocket_connect(PyIOSocket* self, PyObject* args, PyObject* kwargs) -{ - const char* address = nullptr; - Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) - return nullptr; - - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - try { - self->socket->connectTo(std::string(address, addressLen), [=](auto result) { - future_set_result(future, [=] -> std::expected { - if (result || result.error()._errorCode == Error::ErrorCode::InitialConnectFailedWithInProgress) { - Py_RETURN_NONE; - } else { - return std::unexpected {YMQException_createFromCoreError(state, &result.error())}; - } - }); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to connect to address"); }); - } - }); -} - -static PyObject* PyIOSocket_connect_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; + PyObject* callback = nullptr; const char* address = nullptr; Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + const char* kwlist[] = {"", "address", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Os#", (char**)kwlist, &callback, &address, &addressLen)) return nullptr; - PyThreadState* _save = PyEval_SaveThread(); + Py_INCREF(callback); - std::shared_ptr> result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + self->socket->connectTo(std::string(address, addressLen), [callback, state](auto result) { + AcquireGIL _; + + if (result || result.error()._errorCode == Error::ErrorCode::InitialConnectFailedWithInProgress) { + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, Py_None, nullptr); + } else { + OwnedPyObject exc = YMQException_createFromCoreError(state, &result.error()); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *exc, nullptr); + } - self->socket->connectTo(std::string(address, addressLen), [=](auto r) mutable { - *result = std::move(r); - waiter.signal(); + Py_DECREF(callback); }); - - if (waiter.wait()) - CHECK_SIGNALS; } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to connect synchronously"); - return nullptr; - } - - PyEval_RestoreThread(_save); - - if (!result && result->error()._errorCode != Error::ErrorCode::InitialConnectFailedWithInProgress) { - YMQException_setFromCoreError(state, &result->error()); + PyErr_SetString(PyExc_RuntimeError, "Failed to connect to address"); return nullptr; } @@ -386,7 +255,7 @@ static PyGetSetDef PyIOSocket_properties[] = { static PyMethodDef PyIOSocket_methods[] = { {"send", (PyCFunction)PyIOSocket_send, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Send data through the IOSocket")}, - {"recv", (PyCFunction)PyIOSocket_recv, METH_NOARGS, PyDoc_STR("Receive data from the IOSocket")}, + {"recv", (PyCFunction)PyIOSocket_recv, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Receive data from the IOSocket")}, {"bind", (PyCFunction)PyIOSocket_bind, METH_VARARGS | METH_KEYWORDS, @@ -395,19 +264,6 @@ static PyMethodDef PyIOSocket_methods[] = { (PyCFunction)PyIOSocket_connect, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Connect to a remote IOSocket")}, - {"send_sync", - (PyCFunction)PyIOSocket_send_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Send data through the IOSocket synchronously")}, - {"recv_sync", (PyCFunction)PyIOSocket_recv_sync, METH_NOARGS, PyDoc_STR("Receive data from the IOSocket")}, - {"bind_sync", - (PyCFunction)PyIOSocket_bind_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Bind to an address and listen for incoming connections")}, - {"connect_sync", - (PyCFunction)PyIOSocket_connect_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Connect to a remote IOSocket")}, {nullptr, nullptr, 0, nullptr}, }; @@ -421,7 +277,7 @@ static PyType_Slot PyIOSocket_slots[] = { }; static PyType_Spec PyIOSocket_spec = { - .name = "ymq.IOSocket", + .name = "_ymq.BaseIOSocket", .basicsize = sizeof(PyIOSocket), .itemsize = 0, .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_DISALLOW_INSTANTIATION, diff --git a/scaler/io/ymq/pymod_ymq/message.h b/scaler/io/ymq/pymod_ymq/message.h index 52da66763..d8f7df5c9 100644 --- a/scaler/io/ymq/pymod_ymq/message.h +++ b/scaler/io/ymq/pymod_ymq/message.h @@ -5,7 +5,6 @@ // First-party #include "scaler/io/ymq/pymod_ymq/bytes.h" -#include "scaler/io/ymq/pymod_ymq/utils.h" #include "scaler/io/ymq/pymod_ymq/ymq.h" struct PyMessage { @@ -92,7 +91,7 @@ static PyType_Slot PyMessage_slots[] = { }; static PyType_Spec PyMessage_spec = { - .name = "ymq.Message", + .name = "_ymq.Message", .basicsize = sizeof(PyMessage), .itemsize = 0, .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, diff --git a/scaler/io/ymq/pymod_ymq/python.h b/scaler/io/ymq/pymod_ymq/python.h index d627d92c3..4e0df4d2b 100644 --- a/scaler/io/ymq/pymod_ymq/python.h +++ b/scaler/io/ymq/pymod_ymq/python.h @@ -4,8 +4,6 @@ #include #include -#include "scaler/io/ymq/pymod_ymq/utils.h" - #if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 8 static inline PyObject* Py_NewRef(PyObject* obj) { @@ -77,7 +75,7 @@ class OwnedPyObject { // steals a reference OwnedPyObject(T* ptr): _ptr(ptr) {} - OwnedPyObject(const OwnedPyObject& other) { this->_ptr = Py_XNewRef(other._ptr); } + OwnedPyObject(const OwnedPyObject& other) { this->_ptr = (T*)Py_XNewRef((PyObject*)other._ptr); } OwnedPyObject(OwnedPyObject&& other) noexcept: _ptr(other._ptr) { other._ptr = nullptr; } OwnedPyObject& operator=(const OwnedPyObject& other) { @@ -85,7 +83,7 @@ class OwnedPyObject { return *this; this->free(); - this->_ptr = Py_XNewRef(other._ptr); + this->_ptr = (T*)Py_XNewRef((PyObject*)other._ptr); return *this; } OwnedPyObject& operator=(OwnedPyObject&& other) noexcept diff --git a/scaler/io/ymq/pymod_ymq/utils.h b/scaler/io/ymq/pymod_ymq/utils.h deleted file mode 100644 index 522f819e8..000000000 --- a/scaler/io/ymq/pymod_ymq/utils.h +++ /dev/null @@ -1,110 +0,0 @@ -#pragma once - -// Python -#include - -#include "scaler/io/ymq/pymod_ymq/python.h" - -// C++ -#include - -// C -#include -#include - -#include - -// First-party -#include "scaler/io/ymq/common.h" -#include "scaler/io/ymq/pymod_ymq/ymq.h" - -class Waiter { -public: - Waiter(int wakeFd): _waiter(std::shared_ptr(new int, &destroy_efd)), _wakeFd(wakeFd) - { - auto fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); - if (fd < 0) - throw std::runtime_error("failed to create eventfd"); - - *_waiter = fd; - } - - Waiter(const Waiter& other): _waiter(other._waiter), _wakeFd(other._wakeFd) {} - Waiter(Waiter&& other) noexcept: _waiter(std::move(other._waiter)), _wakeFd(other._wakeFd) - { - other._wakeFd = -1; // invalidate the moved-from object - } - - Waiter& operator=(const Waiter& other) - { - if (this == &other) - return *this; - - this->_waiter = other._waiter; - this->_wakeFd = other._wakeFd; - return *this; - } - - Waiter& operator=(Waiter&& other) noexcept - { - if (this == &other) - return *this; - - this->_waiter = std::move(other._waiter); - this->_wakeFd = other._wakeFd; - other._wakeFd = -1; // invalidate the moved-from object - return *this; - } - - void signal() - { - if (eventfd_write(*_waiter, 1) < 0) { - std::println(stderr, "Failed to signal waiter: {}", std::strerror(errno)); - } - } - - // true -> error - // false -> ok - bool wait() - { - pollfd pfds[2] = { - { - .fd = *_waiter, - .events = POLLIN, - .revents = 0, - }, - { - .fd = _wakeFd, - .events = POLLIN, - .revents = 0, - }}; - - for (;;) { - int ready = poll(pfds, 2, -1); - if (ready < 0) { - if (errno == EINTR) - continue; - throw std::runtime_error("poll failed"); - } - - if (pfds[0].revents & POLLIN) - return false; // we got a message - - if (pfds[1].revents & POLLIN) - return true; // signal received - } - } - -private: - std::shared_ptr _waiter; - int _wakeFd; - - static void destroy_efd(int* fd) - { - if (!fd) - return; - - close(*fd); - delete fd; - } -}; diff --git a/scaler/io/ymq/pymod_ymq/ymq.cpp b/scaler/io/ymq/pymod_ymq/ymq.cpp index 8444c21a2..768110472 100644 --- a/scaler/io/ymq/pymod_ymq/ymq.cpp +++ b/scaler/io/ymq/pymod_ymq/ymq.cpp @@ -15,7 +15,7 @@ inline void ymqUnrecoverableError(scaler::ymq::Error e) std::exit(EXIT_FAILURE); } -PyMODINIT_FUNC PyInit_ymq(void) +PyMODINIT_FUNC PyInit__ymq(void) { unrecoverableErrorFunctionHookPtr = ymqUnrecoverableError; diff --git a/scaler/io/ymq/pymod_ymq/ymq.h b/scaler/io/ymq/pymod_ymq/ymq.h index 7442a7b02..59c5cef60 100644 --- a/scaler/io/ymq/pymod_ymq/ymq.h +++ b/scaler/io/ymq/pymod_ymq/ymq.h @@ -9,20 +9,14 @@ // C++ #include -#include -#include #include #include #include // First-party #include "scaler/io/ymq/error.h" -#include "scaler/io/ymq/pymod_ymq/utils.h" struct YMQState { - int wakeupfd_wr; - int wakeupfd_rd; - OwnedPyObject<> enumModule; // Reference to the enum module OwnedPyObject<> asyncioModule; // Reference to the asyncio module @@ -33,79 +27,8 @@ struct YMQState { OwnedPyObject<> PyIOSocketType; // Reference to the IOSocket type OwnedPyObject<> PyIOContextType; // Reference to the IOContext type OwnedPyObject<> PyExceptionType; // Reference to the Exception type - OwnedPyObject<> PyInterruptedExceptionType; // Reference to the YMQInterruptedException type - OwnedPyObject<> PyAwaitableType; // Reference to the Awaitable type }; -#define CHECK_SIGNALS \ - do { \ - PyEval_RestoreThread(_save); \ - if (PyErr_CheckSignals() >= 0) \ - PyErr_SetString( \ - *state->PyInterruptedExceptionType, "A synchronous YMQ operation was interrupted by a signal"); \ - return (PyObject*)nullptr; \ - } while (0); - -static bool future_do_(PyObject* future_, const std::function()>& fn) -{ - // this is an owned reference to the future created in `async_wrapper()` - OwnedPyObject future(future_); - OwnedPyObject loop = PyObject_CallMethod(*future, "get_loop", nullptr); - if (!loop) - return true; - - // if future is already done, no need to call the method - OwnedPyObject result1 = PyObject_CallMethod(*future, "done", nullptr); - if (*result1 == Py_True) - return false; - - const char* method_name = nullptr; - OwnedPyObject arg {}; - - if (auto result = fn()) { - method_name = "set_result"; - arg = *result; - } else { - method_name = "set_exception"; - arg = result.error(); - } - - OwnedPyObject method = PyObject_GetAttrString(*future, method_name); - if (!method) - return true; - - OwnedPyObject obj = PyObject_GetAttrString(*loop, "call_soon_threadsafe"); - - // auto result = PyObject_CallMethod(loop, "call_soon_threadsafe", "OO", method, fn()); - OwnedPyObject result2 = PyObject_CallFunctionObjArgs(*obj, *method, *arg, nullptr); - return !result2; -} - -// this function must be called from a C++ thread -// this function will lock the GIL, call `fn()` and use its return value to set the future's result/exception -static void future_do(PyObject* future, const std::function()>& fn) -{ - PyGILState_STATE gstate = PyGILState_Ensure(); - // begin python critical section - - auto error = future_do_(future, fn); - if (error) - PyErr_WriteUnraisable(future); - - // end python critical section - PyGILState_Release(gstate); -} - -static void future_set_result(PyObject* future, std::function()> fn) -{ - return future_do(future, fn); -} - -static void future_raise_exception(PyObject* future, std::function fn) -{ - return future_do(future, [=] { return std::unexpected {fn()}; }); -} - static YMQState* YMQStateFromSelf(PyObject* self) { // replace with PyType_GetModuleByDef(Py_TYPE(self), &YMQ_module) in a newer Python version @@ -151,8 +74,13 @@ std::expected YMQ_GetRaisedException() #endif } +void completeCallbackWithRaisedException(PyObject* callback) +{ + auto result = YMQ_GetRaisedException(); + OwnedPyObject _ =PyObject_CallFunctionObjArgs(callback, result.value_or(result.error())); +} + // First-Party -#include "scaler/io/ymq/pymod_ymq/async.h" #include "scaler/io/ymq/pymod_ymq/bytes.h" #include "scaler/io/ymq/pymod_ymq/exception.h" #include "scaler/io/ymq/pymod_ymq/io_context.h" @@ -173,22 +101,10 @@ static void YMQ_free(YMQState* state) state->PyIOSocketType.~OwnedPyObject(); state->PyIOContextType.~OwnedPyObject(); state->PyExceptionType.~OwnedPyObject(); - state->PyInterruptedExceptionType.~OwnedPyObject(); - state->PyAwaitableType.~OwnedPyObject(); } catch (...) { PyErr_SetString(PyExc_RuntimeError, "Failed to free YMQState"); PyErr_WriteUnraisable(nullptr); } - - if (close(state->wakeupfd_wr) < 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to close waitfd_wr"); - PyErr_WriteUnraisable(nullptr); - } - - if (close(state->wakeupfd_rd) < 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to close waitfd_rd"); - PyErr_WriteUnraisable(nullptr); - } } static int YMQ_createIntEnum( @@ -322,21 +238,6 @@ static int YMQ_createErrorCodeEnum(PyObject* pyModule, YMQState* state) } } -static int YMQ_createInterruptedException(PyObject* pyModule, OwnedPyObject<>* storage) -{ - *storage = PyErr_NewExceptionWithDoc( - "ymq.YMQInterruptedException", - "Raised when a synchronous method is interrupted by a signal", - PyExc_Exception, - nullptr); - - if (!*storage) - return -1; - if (PyModule_AddObjectRef(pyModule, "YMQInterruptedException", **storage) < 0) - return -1; - return 0; -} - // internal convenience function to create a type and add it to the module static int YMQ_createType( // the module object @@ -380,36 +281,12 @@ static int YMQ_createType( return 0; } -static int YMQ_setupWakeupFd(YMQState* state) -{ - int pipefd[2]; - if (pipe2(pipefd, O_NONBLOCK | O_CLOEXEC) < 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to create pipe for wakeup fd"); - return -1; - } - - state->wakeupfd_rd = pipefd[0]; - state->wakeupfd_wr = pipefd[1]; - - OwnedPyObject signalModule = PyImport_ImportModule("signal"); - if (!signalModule) - return -1; - - OwnedPyObject result = PyObject_CallMethod(*signalModule, "set_wakeup_fd", "i", state->wakeupfd_wr); - if (!result) - return -1; - return 0; -} - static int YMQ_exec(PyObject* pyModule) { auto state = (YMQState*)PyModule_GetState(pyModule); if (!state) return -1; - if (YMQ_setupWakeupFd(state) < 0) - return -1; - state->enumModule = PyImport_ImportModule("enum"); if (!state->enumModule) return -1; @@ -443,10 +320,10 @@ static int YMQ_exec(PyObject* pyModule) if (YMQ_createType(pyModule, &state->PyMessageType, &PyMessage_spec, "Message") < 0) return -1; - if (YMQ_createType(pyModule, &state->PyIOSocketType, &PyIOSocket_spec, "IOSocket") < 0) + if (YMQ_createType(pyModule, &state->PyIOSocketType, &PyIOSocket_spec, "BaseIOSocket") < 0) return -1; - if (YMQ_createType(pyModule, &state->PyIOContextType, &PyIOContext_spec, "IOContext") < 0) + if (YMQ_createType(pyModule, &state->PyIOContextType, &PyIOContext_spec, "BaseIOContext") < 0) return -1; PyObject* exceptionBases = PyTuple_Pack(1, PyExc_Exception); @@ -460,12 +337,6 @@ static int YMQ_exec(PyObject* pyModule) } Py_DECREF(exceptionBases); - if (YMQ_createInterruptedException(pyModule, &state->PyInterruptedExceptionType) < 0) - return -1; - - if (YMQ_createType(pyModule, &state->PyAwaitableType, &Awaitable_spec, "Awaitable", false) < 0) - return -1; - return 0; } @@ -476,7 +347,7 @@ static PyModuleDef_Slot YMQ_slots[] = { static PyModuleDef YMQ_module = { .m_base = PyModuleDef_HEAD_INIT, - .m_name = "ymq", + .m_name = "_ymq", .m_doc = PyDoc_STR("YMQ Python bindings"), .m_size = sizeof(YMQState), .m_slots = YMQ_slots, diff --git a/scaler/io/ymq/ymq.py b/scaler/io/ymq/ymq.py new file mode 100644 index 000000000..0028016de --- /dev/null +++ b/scaler/io/ymq/ymq.py @@ -0,0 +1,117 @@ +# This file wraps the interface exported by the C implementation of the module +# and provides a more ergonomic interface supporting both asynchronous and synchronous execution + +__all__ = ["IOSocket", "IOContext", "Message", "IOSocketType", "YMQException", "Bytes", "ErrorCode"] + +import asyncio +import concurrent.futures +from typing import Callable, Concatenate, Optional, ParamSpec, TypeVar, Union + +from scaler.io.ymq._ymq import BaseIOContext, BaseIOSocket, Bytes, ErrorCode, IOSocketType, Message, YMQException + + +class IOSocket: + _base: BaseIOSocket + + def __init__(self, base: BaseIOSocket) -> None: + self._base = base + + @property + def socket_type(self) -> IOSocketType: + return self._base.socket_type + + @property + def identity(self) -> str: + return self._base.identity + + async def bind(self, address: str) -> None: + """Bind the socket to an address and listen for incoming connections""" + await call_async(self._base.bind, address) + + def bind_sync(self, address: str, /, timeout: Optional[float] = None) -> None: + """Bind the socket to an address and listen for incoming connections""" + call_sync(self._base.bind, address, timeout=timeout) + + async def connect(self, address: str) -> None: + """Connect to a remote socket""" + await call_async(self._base.connect, address) + + def connect_sync(self, address: str, /, timeout: Optional[float] = None) -> None: + """Connect to a remote socket""" + call_sync(self._base.connect, address, timeout=timeout) + + async def send(self, message: Message) -> None: + """Send a message to one of the socket's peers""" + await call_async(self._base.send, message) + + def send_sync(self, message: Message, /, timeout: Optional[float] = None) -> None: + """Send a message to one of the socket's peers""" + call_sync(self._base.send, message, timeout=timeout) + + async def recv(self) -> Message: + """Receive a message from one of the socket's peers""" + return await call_async(self._base.recv) + + def recv_sync(self, /, timeout: Optional[float] = None) -> Message: + """Receive a message from one of the socket's peers""" + return call_sync(self._base.recv, timeout=timeout) + + +class IOContext: + _base: BaseIOContext + + def __init__(self, num_threads: int = 1) -> None: + self._base = BaseIOContext(num_threads) + + async def createIOSocket(self, identity: str, socket_type: IOSocketType) -> IOSocket: + """Create an io socket with an identity and socket type""" + return IOSocket(await call_async(self._base.createIOSocket, identity, socket_type)) + + def createIOSocket_sync(self, identity: str, socket_type: IOSocketType) -> IOSocket: + """Create an io socket with an identity and socket type""" + return IOSocket(call_sync(self._base.createIOSocket, identity, socket_type)) + + +P = ParamSpec("P") +T = TypeVar("T") + + +async def call_async( + func: Callable[Concatenate[Callable[[Union[T, Exception]], None], P], None], *args: P.args, **kwargs: P.kwargs +) -> T: + future = asyncio.get_event_loop().create_future() + + def callback(result: Union[T, Exception]): + if future.done(): + return + + loop = future.get_loop() + + if isinstance(result, Exception): + loop.call_soon_threadsafe(future.set_exception, result) + else: + loop.call_soon_threadsafe(future.set_result, result) + + func(callback, *args, **kwargs) + return await future + + +def call_sync( # type: ignore[valid-type] + func: Callable[Concatenate[Callable[[Union[T, Exception]], None], P], None], + *args: P.args, + timeout: Optional[float] = None, + **kwargs: P.kwargs, +) -> T: + future: concurrent.futures.Future = concurrent.futures.Future() + + def callback(result: Union[T, Exception]): + if future.done(): + return + + if isinstance(result, Exception): + future.set_exception(result) + else: + future.set_result(result) + + func(callback, *args, **kwargs) + return future.result(timeout)