diff --git a/scaler/client/agent/client_agent.py b/scaler/client/agent/client_agent.py index c3921bde2..617ba184a 100644 --- a/scaler/client/agent/client_agent.py +++ b/scaler/client/agent/client_agent.py @@ -5,7 +5,7 @@ from concurrent.futures import Future from typing import Optional -import zmq.asyncio +from scaler.io.ymq import ymq from scaler.client.agent.disconnect_manager import ClientDisconnectManager from scaler.client.agent.future_manager import ClientFutureManager @@ -13,7 +13,7 @@ from scaler.client.agent.object_manager import ClientObjectManager from scaler.client.agent.task_manager import ClientTaskManager from scaler.client.serializer.mixins import Serializer -from scaler.io.async_connector import ZMQAsyncConnector +from scaler.io.ymq_async_connector import YMQAsyncConnector from scaler.io.mixins import AsyncConnector from scaler.protocol.python.common import ObjectStorageAddress from scaler.protocol.python.message import ( @@ -41,7 +41,7 @@ def __init__( identity: ClientID, client_agent_address: ZMQConfig, scheduler_address: ZMQConfig, - context: zmq.Context, + context: ymq.IOContext, future_manager: ClientFutureManager, stop_event: threading.Event, timeout_seconds: int, @@ -63,19 +63,19 @@ def __init__( self._future_manager = future_manager - self._connector_internal: AsyncConnector = ZMQAsyncConnector( - context=zmq.asyncio.Context.shadow(self._context), + self._connector_internal: AsyncConnector = YMQAsyncConnector( + context=self._context, name="client_agent_internal", - socket_type=zmq.PAIR, + socket_type=ymq.IOSocketType.Binder, bind_or_connect="bind", address=self._client_agent_address, callback=self.__on_receive_from_client, identity=None, ) - self._connector_external: AsyncConnector = ZMQAsyncConnector( - context=zmq.asyncio.Context.shadow(self._context), + self._connector_external: AsyncConnector = YMQAsyncConnector( + context=self._context, name="client_agent_external", - socket_type=zmq.DEALER, + socket_type=ymq.IOSocketType.Connector, address=self._scheduler_address, bind_or_connect="connect", callback=self.__on_receive_from_scheduler, diff --git a/scaler/client/client.py b/scaler/client/client.py index 2868df929..0f0f67a08 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -2,13 +2,13 @@ import functools import logging import threading +import random import uuid from collections import Counter from inspect import signature from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import zmq - +from scaler.io.ymq import ymq from scaler.client.agent.client_agent import ClientAgent from scaler.client.agent.future_manager import ClientFutureManager from scaler.client.future import ScalerFuture @@ -18,7 +18,7 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.config import DEFAULT_CLIENT_TIMEOUT_SECONDS, DEFAULT_HEARTBEAT_INTERVAL_SECONDS from scaler.io.mixins import SyncConnector, SyncObjectStorageConnector -from scaler.io.sync_connector import ZMQSyncConnector +from scaler.io.ymq_sync_connector import YMQSyncConnector from scaler.io.sync_object_storage_connector import PySyncObjectStorageConnector from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse, GraphTask, Task from scaler.utility.exceptions import ClientQuitException, MissingObjects @@ -89,15 +89,16 @@ def __initialize__( self._stream_output = stream_output self._identity = ClientID.generate_client_id() - self._client_agent_address = ZMQConfig(ZMQType.inproc, host=f"scaler_client_{uuid.uuid4().hex}") + # self._client_agent_address = ZMQConfig(ZMQType.inproc, host=f"scaler_client_{uuid.uuid4().hex}") + self._client_agent_address = ZMQConfig(ZMQType.tcp, host="127.0.0.1", port=random.randint(20000, 30000)) self._scheduler_address = ZMQConfig.from_string(address) self._timeout_seconds = timeout_seconds self._heartbeat_interval_seconds = heartbeat_interval_seconds self._stop_event = threading.Event() - self._context = zmq.Context() - self._connector_agent: SyncConnector = ZMQSyncConnector( - context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity + self._context = ymq.IOContext() + self._connector_agent: SyncConnector = YMQSyncConnector( + context=self._context, socket_type=ymq.IOSocketType.Connector, address=self._client_agent_address, identity=self._identity.extend("|agent") ) self._future_manager = ClientFutureManager(self._serializer) diff --git a/scaler/entry_points/top.py b/scaler/entry_points/top.py index 44b77773c..9a4cf4de2 100644 --- a/scaler/entry_points/top.py +++ b/scaler/entry_points/top.py @@ -3,7 +3,7 @@ import functools from typing import Dict, List, Literal, Union -from scaler.io.sync_subscriber import ZMQSyncSubscriber +from scaler.io.ymq_sync_subscriber import YMQSyncSubscriber from scaler.protocol.python.message import StateScheduler from scaler.protocol.python.mixins import Message from scaler.utility.formatter import ( @@ -51,7 +51,7 @@ def poke(screen, args): screen.nodelay(1) try: - subscriber = ZMQSyncSubscriber( + subscriber = YMQSyncSubscriber( address=ZMQConfig.from_string(args.address), callback=functools.partial(show_status, screen=screen), topic=b"", diff --git a/scaler/entry_points/worker_adapter_native.py b/scaler/entry_points/worker_adapter_native.py index 07580bee5..7a67fa860 100644 --- a/scaler/entry_points/worker_adapter_native.py +++ b/scaler/entry_points/worker_adapter_native.py @@ -27,7 +27,7 @@ def get_args(): # Server configuration parser.add_argument( - "--host", type=str, default="localhost", help="host address for the native worker adapter HTTP server" + "--host", type=str, default="127.0.0.1", help="host address for the native worker adapter HTTP server" ) parser.add_argument("--port", "-p", type=int, help="port for the native worker adapter HTTP server") diff --git a/scaler/entry_points/worker_adapter_symphony.py b/scaler/entry_points/worker_adapter_symphony.py index 6b9db7903..a40acd60d 100644 --- a/scaler/entry_points/worker_adapter_symphony.py +++ b/scaler/entry_points/worker_adapter_symphony.py @@ -24,7 +24,7 @@ def get_args(): # Server configuration parser.add_argument( - "--host", type=str, default="localhost", help="host address for the native worker adapter HTTP server" + "--host", type=str, default="127.0.0.1", help="host address for the native worker adapter HTTP server" ) parser.add_argument("--port", "-p", type=int, required=True, help="port for the native worker adapter HTTP server") diff --git a/scaler/io/ymq/pymod_ymq/io_context.h b/scaler/io/ymq/pymod_ymq/io_context.h index deb63003e..b571b572e 100644 --- a/scaler/io/ymq/pymod_ymq/io_context.h +++ b/scaler/io/ymq/pymod_ymq/io_context.h @@ -157,8 +157,7 @@ static PyObject* PyIOContext_createIOSocket_sync(PyIOContext* self, PyObject* ar waiter.signal(); }); - if (waiter.wait()) - CHECK_SIGNALS; + WAIT(waiter); } catch (...) { PyEval_RestoreThread(_save); PyErr_SetString(PyExc_RuntimeError, "Failed to create io socket synchronously"); diff --git a/scaler/io/ymq/pymod_ymq/io_socket.h b/scaler/io/ymq/pymod_ymq/io_socket.h index 251a6a461..1244c565f 100644 --- a/scaler/io/ymq/pymod_ymq/io_socket.h +++ b/scaler/io/ymq/pymod_ymq/io_socket.h @@ -7,10 +7,12 @@ #include #include #include +#include #include #include // C +#include #include #include #include @@ -40,6 +42,7 @@ extern "C" { static void PyIOSocket_dealloc(PyIOSocket* self) { try { + self->ioContext->requestIOSocketStop(self->socket); self->ioContext->removeIOSocket(self->socket); self->ioContext.~shared_ptr(); self->socket.~shared_ptr(); @@ -90,8 +93,9 @@ static PyObject* PyIOSocket_send_sync(PyIOSocket* self, PyObject* args, PyObject // borrowed reference PyMessage* message = nullptr; - const char* kwlist[] = {"message", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &message)) + int timeout_secs = -1; + const char* kwlist[] = {"message", "timeout_secs", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|i", (char**)kwlist, &message, &timeout_secs)) return nullptr; Bytes address = message->address.is_none() ? Bytes() : std::move(message->address->bytes); @@ -101,15 +105,14 @@ static PyObject* PyIOSocket_send_sync(PyIOSocket* self, PyObject* args, PyObject std::shared_ptr> result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs}); self->socket->sendMessage({.address = std::move(address), .payload = std::move(payload)}, [=](auto r) mutable { *result = std::move(r); waiter.signal(); }); - if (waiter.wait()) - CHECK_SIGNALS; + WAIT(waiter); } catch (...) { PyEval_RestoreThread(_save); PyErr_SetString(PyExc_RuntimeError, "Failed to send synchronously"); @@ -164,25 +167,29 @@ static PyObject* PyIOSocket_recv(PyIOSocket* self, PyObject* args) }); } -static PyObject* PyIOSocket_recv_sync(PyIOSocket* self, PyObject* args) +static PyObject* PyIOSocket_recv_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; + int timeout_secs = -1; + const char* kwlist[] = {"timeout_secs", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", (char**)kwlist, &timeout_secs)) + return nullptr; + PyThreadState* _save = PyEval_SaveThread(); std::shared_ptr> result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs}); self->socket->recvMessage([=](auto r) mutable { *result = std::move(r); waiter.signal(); }); - if (waiter.wait()) - CHECK_SIGNALS; + WAIT(waiter); } catch (...) { PyEval_RestoreThread(_save); PyErr_SetString(PyExc_RuntimeError, "Failed to recv synchronously"); @@ -252,23 +259,23 @@ static PyObject* PyIOSocket_bind_sync(PyIOSocket* self, PyObject* args, PyObject const char* address = nullptr; Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + int timeout_secs = -1; + const char* kwlist[] = {"address", "timeout_secs", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#|i", (char**)kwlist, &address, &addressLen, &timeout_secs)) return nullptr; PyThreadState* _save = PyEval_SaveThread(); auto result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs}); self->socket->bindTo(std::string(address, addressLen), [=](auto r) mutable { *result = std::move(r); waiter.signal(); }); - if (waiter.wait()) - CHECK_SIGNALS; + WAIT(waiter); } catch (...) { PyEval_RestoreThread(_save); PyErr_SetString(PyExc_RuntimeError, "Failed to bind synchronously"); @@ -319,23 +326,23 @@ static PyObject* PyIOSocket_connect_sync(PyIOSocket* self, PyObject* args, PyObj const char* address = nullptr; Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + int timeout_secs = -1; + const char* kwlist[] = {"address", "timeout_secs", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#|i", (char**)kwlist, &address, &addressLen, timeout_secs)) return nullptr; PyThreadState* _save = PyEval_SaveThread(); std::shared_ptr> result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs}); self->socket->connectTo(std::string(address, addressLen), [=](auto r) mutable { *result = std::move(r); waiter.signal(); }); - if (waiter.wait()) - CHECK_SIGNALS; + WAIT(waiter); } catch (...) { PyEval_RestoreThread(_save); PyErr_SetString(PyExc_RuntimeError, "Failed to connect synchronously"); @@ -399,7 +406,10 @@ static PyMethodDef PyIOSocket_methods[] = { (PyCFunction)PyIOSocket_send_sync, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Send data through the IOSocket synchronously")}, - {"recv_sync", (PyCFunction)PyIOSocket_recv_sync, METH_NOARGS, PyDoc_STR("Receive data from the IOSocket")}, + {"recv_sync", + (PyCFunction)PyIOSocket_recv_sync, + METH_VARARGS | METH_KEYWORDS, + PyDoc_STR("Receive data from the IOSocket")}, {"bind_sync", (PyCFunction)PyIOSocket_bind_sync, METH_VARARGS | METH_KEYWORDS, diff --git a/scaler/io/ymq/pymod_ymq/utils.h b/scaler/io/ymq/pymod_ymq/utils.h index 522f819e8..a54e4fae9 100644 --- a/scaler/io/ymq/pymod_ymq/utils.h +++ b/scaler/io/ymq/pymod_ymq/utils.h @@ -1,6 +1,7 @@ #pragma once // Python +#include #include #include "scaler/io/ymq/pymod_ymq/python.h" @@ -11,6 +12,7 @@ // C #include #include +#include #include @@ -18,21 +20,37 @@ #include "scaler/io/ymq/common.h" #include "scaler/io/ymq/pymod_ymq/ymq.h" +enum class WaitResult { + Ok, + Signal, + Timeout, +}; + class Waiter { public: - Waiter(int wakeFd): _waiter(std::shared_ptr(new int, &destroy_efd)), _wakeFd(wakeFd) + Waiter(int wakeFd, std::optional timeout_secs = std::nullopt) + : _timeout_secs(timeout_secs) + , _timer_fd(std::shared_ptr(new int, &destroy_fd)) + , _waiter(std::shared_ptr(new int, &destroy_fd)) + , _wake_fd(wakeFd) { - auto fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); - if (fd < 0) + auto efd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + if (efd < 0) throw std::runtime_error("failed to create eventfd"); - *_waiter = fd; + *_waiter = efd; + + auto tfd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC | TFD_NONBLOCK); + if (tfd < 0) + throw std::runtime_error("failed to create timerfd"); + + *_timer_fd = tfd; } - Waiter(const Waiter& other): _waiter(other._waiter), _wakeFd(other._wakeFd) {} - Waiter(Waiter&& other) noexcept: _waiter(std::move(other._waiter)), _wakeFd(other._wakeFd) + Waiter(const Waiter& other): _waiter(other._waiter), _wake_fd(other._wake_fd) {} + Waiter(Waiter&& other) noexcept: _waiter(std::move(other._waiter)), _wake_fd(other._wake_fd) { - other._wakeFd = -1; // invalidate the moved-from object + other._wake_fd = -1; // invalidate the moved-from object } Waiter& operator=(const Waiter& other) @@ -40,8 +58,8 @@ class Waiter { if (this == &other) return *this; - this->_waiter = other._waiter; - this->_wakeFd = other._wakeFd; + this->_waiter = other._waiter; + this->_wake_fd = other._wake_fd; return *this; } @@ -50,9 +68,9 @@ class Waiter { if (this == &other) return *this; - this->_waiter = std::move(other._waiter); - this->_wakeFd = other._wakeFd; - other._wakeFd = -1; // invalidate the moved-from object + this->_waiter = std::move(other._waiter); + this->_wake_fd = other._wake_fd; + other._wake_fd = -1; // invalidate the moved-from object return *this; } @@ -63,24 +81,37 @@ class Waiter { } } - // true -> error - // false -> ok - bool wait() + WaitResult wait() { - pollfd pfds[2] = { + pollfd pfds[3] = { { .fd = *_waiter, .events = POLLIN, .revents = 0, }, { - .fd = _wakeFd, + .fd = _wake_fd, + .events = POLLIN, + .revents = 0, + }, + { + .fd = *_timer_fd, .events = POLLIN, .revents = 0, }}; + if (_timeout_secs) { + itimerspec new_value { + .it_interval = {0, 0}, + .it_value = {*_timeout_secs, 0}, + }; + + if (timerfd_settime(*_timer_fd, 0, &new_value, nullptr) < 0) + throw std::runtime_error("failed to set timerfd"); + } + for (;;) { - int ready = poll(pfds, 2, -1); + int ready = poll(pfds, 3, -1); if (ready < 0) { if (errno == EINTR) continue; @@ -88,18 +119,23 @@ class Waiter { } if (pfds[0].revents & POLLIN) - return false; // we got a message + return WaitResult::Ok; // we got a message if (pfds[1].revents & POLLIN) - return true; // signal received + return WaitResult::Signal; // signal received + + if (pfds[2].revents & POLLIN) + return WaitResult::Timeout; // timeout } } private: + std::optional _timeout_secs; + std::shared_ptr _timer_fd; std::shared_ptr _waiter; - int _wakeFd; + int _wake_fd; - static void destroy_efd(int* fd) + static void destroy_fd(int* fd) { if (!fd) return; diff --git a/scaler/io/ymq/pymod_ymq/ymq.h b/scaler/io/ymq/pymod_ymq/ymq.h index f25a7c079..7397bc524 100644 --- a/scaler/io/ymq/pymod_ymq/ymq.h +++ b/scaler/io/ymq/pymod_ymq/ymq.h @@ -34,16 +34,24 @@ struct YMQState { OwnedPyObject<> PyIOContextType; // Reference to the IOContext type OwnedPyObject<> PyExceptionType; // Reference to the Exception type OwnedPyObject<> PyInterruptedExceptionType; // Reference to the YMQInterruptedException type + OwnedPyObject<> PyTimeoutExceptionType; // Reference to the YMQTimeoutException type OwnedPyObject<> PyAwaitableType; // Reference to the Awaitable type }; -#define CHECK_SIGNALS \ - do { \ - PyEval_RestoreThread(_save); \ - if (PyErr_CheckSignals() >= 0) \ - PyErr_SetString( \ - *state->PyInterruptedExceptionType, "A synchronous YMQ operation was interrupted by a signal"); \ - return (PyObject*)nullptr; \ +#define WAIT(waiter) \ + do { \ + auto result = waiter.wait(); \ + if (result == WaitResult::Signal) { \ + PyEval_RestoreThread(_save); \ + if (PyErr_CheckSignals() >= 0) \ + PyErr_SetString( \ + *state->PyInterruptedExceptionType, "A synchronous YMQ operation was interrupted by a signal"); \ + return (PyObject*)nullptr; \ + } else if (result == WaitResult::Timeout) { \ + PyEval_RestoreThread(_save); \ + PyErr_SetString(*state->PyTimeoutExceptionType, "A synchronous YMQ operation timed out"); \ + return (PyObject*)nullptr; \ + } \ } while (0); static bool future_do_(PyObject* future_, const std::function()>& fn) @@ -174,6 +182,7 @@ static void YMQ_free(YMQState* state) state->PyIOContextType.~OwnedPyObject(); state->PyExceptionType.~OwnedPyObject(); state->PyInterruptedExceptionType.~OwnedPyObject(); + state->PyTimeoutExceptionType.~OwnedPyObject(); state->PyAwaitableType.~OwnedPyObject(); } catch (...) { PyErr_SetString(PyExc_RuntimeError, "Failed to free YMQState"); @@ -335,6 +344,18 @@ static int YMQ_createInterruptedException(PyObject* pyModule, OwnedPyObject<>* s return 0; } +static int YMQ_createTimeoutException(PyObject* pyModule, OwnedPyObject<>* storage) +{ + *storage = PyErr_NewExceptionWithDoc( + "ymq.YMQTimeoutException", "Raised when a synchronous method times out", PyExc_Exception, nullptr); + + if (!*storage) + return -1; + if (PyModule_AddObjectRef(pyModule, "YMQTimeoutException", **storage) < 0) + return -1; + return 0; +} + // internal convenience function to create a type and add it to the module static int YMQ_createType( // the module object @@ -461,6 +482,9 @@ static int YMQ_exec(PyObject* pyModule) if (YMQ_createInterruptedException(pyModule, &state->PyInterruptedExceptionType) < 0) return -1; + if (YMQ_createTimeoutException(pyModule, &state->PyTimeoutExceptionType) < 0) + return -1; + if (YMQ_createType(pyModule, &state->PyAwaitableType, &Awaitable_spec, "Awaitable", false) < 0) return -1; diff --git a/scaler/io/ymq/ymq.pyi b/scaler/io/ymq/ymq.pyi index d436fd136..d675e1daa 100644 --- a/scaler/io/ymq/ymq.pyi +++ b/scaler/io/ymq/ymq.pyi @@ -67,16 +67,16 @@ class IOSocket: async def connect(self, address: str) -> None: """Connect to a remote socket""" - def send_sync(self, message: Message) -> None: + def send_sync(self, message: Message, timeout_secs: int = -1) -> None: """Send a message to one of the socket's peers synchronously""" - def recv_sync(self) -> Message: + def recv_sync(self, timeout_secs: int = -1) -> Message: """Receive a message from one of the socket's peers synchronously""" - def bind_sync(self, address: str) -> None: + def bind_sync(self, address: str, timeout_secs: int = -1) -> None: """Bind the socket to an address and listen for incoming connections synchronously""" - def connect_sync(self, address: str) -> None: + def connect_sync(self, address: str, timeout_secs: int = -1) -> None: """Connect to a remote socket synchronously""" class ErrorCode(IntEnum): @@ -109,3 +109,6 @@ class YMQException(Exception): class YMQInterruptedException(YMQException): def __init__(self) -> None: ... + +class YMQTimeoutException(YMQException): + def __init__(self) -> None: ... diff --git a/scaler/io/ymq_async_binder.py b/scaler/io/ymq_async_binder.py new file mode 100644 index 000000000..a97a0ee72 --- /dev/null +++ b/scaler/io/ymq_async_binder.py @@ -0,0 +1,74 @@ +import logging +import os +import uuid +from collections import defaultdict +from typing import Awaitable, Callable, Dict, List, Optional + +from scaler.io.ymq import ymq + +from scaler.io.mixins import AsyncBinder +from scaler.io.utility import deserialize, serialize +from scaler.protocol.python.mixins import Message +from scaler.protocol.python.status import BinderStatus +from scaler.utility.zmq_config import ZMQConfig, ZMQType + + +class YMQAsyncBinder(AsyncBinder): + def __init__(self, context: ymq.IOContext, name: str, address: ZMQConfig, identity: Optional[bytes] = None): + self._address = address + + if identity is None: + identity = f"{os.getpid()}|{name}|{uuid.uuid4()}".encode() + self._identity = identity + + self._context = context + self._socket = self._context.createIOSocket_sync(self.identity.decode(), ymq.IOSocketType.Binder) + + if self._address.type != ZMQType.tcp: + raise ValueError(f"YMQ only supports tcp transport, got {self._address.type}") + + self._socket.bind_sync(self._address.to_address()) + + self._callback: Optional[Callable[[bytes, Message], Awaitable[None]]] = None + + self._received: Dict[str, int] = defaultdict(lambda: 0) + self._sent: Dict[str, int] = defaultdict(lambda: 0) + + @property + def identity(self): + return self._identity + + def destroy(self): + self._context = None + self._socket = None + + def register(self, callback: Callable[[bytes, Message], Awaitable[None]]): + self._callback = callback + + async def routine(self): + recvd = await self._socket.recv() + + # TODO: zero-copy + message: Optional[Message] = deserialize(recvd.payload.data) + if message is None: + logging.error(f"received unknown message from {recvd.address!r}: {recvd.payload!r}") + return + + self.__count_received(message.__class__.__name__) + await self._callback(recvd.address.data, message) + + async def send(self, to: bytes, message: Message): + self.__count_sent(message.__class__.__name__) + await self._socket.send(ymq.Message(to, serialize(message))) + + def get_status(self) -> BinderStatus: + return BinderStatus.new_msg(received=self._received, sent=self._sent) + + def __count_received(self, message_type: str): + self._received[message_type] += 1 + + def __count_sent(self, message_type: str): + self._sent[message_type] += 1 + + def __get_prefix(self): + return f"{self.__class__.__name__}[{self._identity.decode()}]:" diff --git a/scaler/io/ymq_async_connector.py b/scaler/io/ymq_async_connector.py new file mode 100644 index 000000000..6d5caaac8 --- /dev/null +++ b/scaler/io/ymq_async_connector.py @@ -0,0 +1,90 @@ +import logging +import os +import uuid +from typing import Awaitable, Callable, Literal, Optional + +from scaler.io.ymq import ymq + +from scaler.io.mixins import AsyncConnector +from scaler.io.utility import deserialize, serialize +from scaler.protocol.python.mixins import Message +from scaler.utility.zmq_config import ZMQConfig, ZMQType + + +class YMQAsyncConnector(AsyncConnector): + def __init__( + self, + context: ymq.IOContext, + name: str, + socket_type: int, + address: ZMQConfig, + bind_or_connect: Literal["bind", "connect"], + callback: Optional[Callable[[Message], Awaitable[None]]], + identity: Optional[bytes], + ): + self._address = address + self._context = context + + if identity is None: + identity = f"{os.getpid()}|{name}|{uuid.uuid4().bytes.hex()}".encode() + self._identity = identity + + self._socket = self._context.createIOSocket_sync(self.identity.decode(), socket_type) + + if self._address.type != ZMQType.tcp: + raise ValueError(f"YMQ only supports tcp transport, got {self._address.type}") + + if bind_or_connect == "bind": + self._socket.bind_sync(self.address) + elif bind_or_connect == "connect": + self._socket.connect_sync(self.address) + else: + raise TypeError("bind_or_connect has to be 'bind' or 'connect'") + + self._callback: Optional[Callable[[Message], Awaitable[None]]] = callback + + def __del__(self): + self.destroy() + + def destroy(self): + self._context = None + self._socket = None + + @property + def identity(self) -> bytes: + return self._identity + + @property + def socket(self) -> ymq.IOSocket: + return self._socket + + @property + def address(self) -> str: + return self._address.to_address() + + async def routine(self): + if self._callback is None: + return + + message: Optional[Message] = await self.receive() + if message is None: + return + + await self._callback(message) + + async def receive(self) -> Optional[Message]: + if self._socket is None: + return None + + msg = await self._socket.recv() + + # TODO: zero-copy + result: Optional[Message] = deserialize(msg.payload.data) + if result is None: + logging.error(f"received unknown message: {msg.payload!r}") + return None + + return result + + async def send(self, message: Message): + await self._socket.send(ymq.Message(None, serialize(message))) diff --git a/scaler/io/ymq_sync_connector.py b/scaler/io/ymq_sync_connector.py new file mode 100644 index 000000000..6c8dfda2d --- /dev/null +++ b/scaler/io/ymq_sync_connector.py @@ -0,0 +1,69 @@ +import logging +import os +import socket +import threading +import uuid +from typing import Optional + +from scaler.io.ymq import ymq + +from scaler.io.mixins import SyncConnector +from scaler.io.utility import deserialize, serialize +from scaler.protocol.python.mixins import Message +from scaler.utility.zmq_config import ZMQConfig, ZMQType + + +class YMQSyncConnector(SyncConnector): + def __init__(self, context: ymq.IOContext, socket_type: int, address: ZMQConfig, identity: Optional[bytes]): + self._address = address + self._context = context + + self._identity: bytes = ( + f"{os.getpid()}|{socket.gethostname().split('.')[0]}|{uuid.uuid4()}".encode() + if identity is None + else identity + ) + + self._socket = self._context.createIOSocket_sync(self.identity.decode(), socket_type) + + if self._address.type != ZMQType.tcp: + raise ValueError(f"YMQ only supports tcp transport, got {self._address.type}") + + self._socket.connect_sync(self._address.to_address()) + + self._lock = threading.Lock() + + def destroy(self): + self._context = None + self._socket = None + + @property + def address(self) -> str: + return self._address.to_address() + + @property + def identity(self) -> bytes: + return self._identity + + def send(self, message: Message): + with self._lock: + self._socket.send_sync(ymq.Message(None, serialize(message)), timeout_secs=3) + + def receive(self) -> Optional[Message]: + print("NONONONONO") + with self._lock: + msg = self._socket.recv_sync(timeout_secs=3) + + # TODO: zero-copy + return self.__compose_message(msg.payload.data) + + def __compose_message(self, payload: bytes) -> Optional[Message]: + result: Optional[Message] = deserialize(payload) + if result is None: + logging.error(f"{self.__get_prefix()}: received unknown message: {payload!r}") + return None + + return result + + def __get_prefix(self): + return f"{self.__class__.__name__}[{self._identity.decode()}]:" diff --git a/scaler/io/ymq_sync_subscriber.py b/scaler/io/ymq_sync_subscriber.py new file mode 100644 index 000000000..3fe7d9db5 --- /dev/null +++ b/scaler/io/ymq_sync_subscriber.py @@ -0,0 +1,81 @@ +import logging +import threading +import random +from typing import Callable, Optional + +from scaler.io.ymq import ymq + +from scaler.io.mixins import SyncSubscriber +from scaler.io.utility import deserialize +from scaler.protocol.python.mixins import Message +from scaler.utility.zmq_config import ZMQConfig, ZMQType + + +class YMQSyncSubscriber(SyncSubscriber, threading.Thread): + def __init__( + self, + address: ZMQConfig, + callback: Callable[[Message], None], + topic: bytes, + exit_callback: Optional[Callable[[], None]] = None, + stop_event: threading.Event = threading.Event(), + daemonic: bool = False, + timeout_seconds: int = -1, + ): + threading.Thread.__init__(self) + + self._stop_event = stop_event + self._address = address + self._callback = callback + self._exit_callback = exit_callback + self._topic = topic + self.daemon = bool(daemonic) + self._timeout_seconds = timeout_seconds + + self._context: Optional[ymq.IOContext] = None + self._socket: Optional[ymq.IOSocket] = None + + def __close(self): + self._context = None + self._socket = None + + def __stop_polling(self): + self._stop_event.set() + + def destroy(self): + self.__stop_polling() + + def run(self) -> None: + self.__initialize() + + while not self._stop_event.is_set(): + self.__routine_polling() + + if self._exit_callback is not None: + self._exit_callback() + + self.__close() + + def __initialize(self): + self._context = ymq.IOContext() + self._socket = self._context.createIOSocket_sync(f"{self._topic.decode()}_subscriber_{random.randint(10_000, 20_000)}", ymq.IOSocketType.Unicast) + + if self._address.type != ZMQType.tcp: + raise ValueError(f"YMQ only supports tcp transport, got {self._address.type}") + + self._socket.connect_sync(self._address.to_address()) + + def __routine_polling(self): + try: + # TODO: zero-copy + self.__routine_receive(self._socket.recv_sync(timeout_secs=1).payload.data) + except ymq.YMQTimeoutException: + pass + + def __routine_receive(self, payload: bytes): + result: Optional[Message] = deserialize(payload) + if result is None: + logging.error(f"received unknown message: {payload!r}") + return None + + self._callback(result) diff --git a/scaler/io/async_binder.py b/scaler/io/zmq_async_binder.py similarity index 100% rename from scaler/io/async_binder.py rename to scaler/io/zmq_async_binder.py diff --git a/scaler/io/async_connector.py b/scaler/io/zmq_async_connector.py similarity index 100% rename from scaler/io/async_connector.py rename to scaler/io/zmq_async_connector.py diff --git a/scaler/io/sync_connector.py b/scaler/io/zmq_sync_connector.py similarity index 100% rename from scaler/io/sync_connector.py rename to scaler/io/zmq_sync_connector.py diff --git a/scaler/io/sync_subscriber.py b/scaler/io/zmq_sync_subscriber.py similarity index 100% rename from scaler/io/sync_subscriber.py rename to scaler/io/zmq_sync_subscriber.py diff --git a/scaler/scheduler/scheduler.py b/scaler/scheduler/scheduler.py index 6598c2784..0e31d5dce 100644 --- a/scaler/scheduler/scheduler.py +++ b/scaler/scheduler/scheduler.py @@ -2,10 +2,9 @@ import functools import logging -import zmq.asyncio - -from scaler.io.async_binder import ZMQAsyncBinder -from scaler.io.async_connector import ZMQAsyncConnector +from scaler.io.ymq import ymq +from scaler.io.ymq_async_binder import YMQAsyncBinder +from scaler.io.ymq_async_connector import YMQAsyncConnector from scaler.io.async_object_storage_connector import PyAsyncObjectStorageConnector from scaler.io.config import CLEANUP_INTERVAL_SECONDS, STATUS_REPORT_INTERVAL_SECONDS from scaler.io.mixins import AsyncBinder, AsyncConnector, AsyncObjectStorageConnector @@ -66,18 +65,18 @@ def __init__(self, config: SchedulerConfig): monitor_address = config.monitor_address self._config_controller.update_config("monitor_address", monitor_address) - self._context = zmq.asyncio.Context(io_threads=config.io_threads) + self._context = ymq.IOContext(config.io_threads) - self._binder: AsyncBinder = ZMQAsyncBinder(context=self._context, name="scheduler", address=config.address) + self._binder: AsyncBinder = YMQAsyncBinder(context=self._context, name="scheduler", address=config.address) logging.info(f"{self.__class__.__name__}: listen to scheduler address {config.address}") self._connector_storage: AsyncObjectStorageConnector = PyAsyncObjectStorageConnector() logging.info(f"{self.__class__.__name__}: connect to object storage server {object_storage_address!r}") - self._binder_monitor: AsyncConnector = ZMQAsyncConnector( + self._binder_monitor: AsyncConnector = YMQAsyncConnector( context=self._context, name="scheduler_monitor", - socket_type=zmq.PUB, + socket_type=ymq.IOSocketType.Multicast, address=monitor_address, bind_or_connect="bind", callback=None, diff --git a/scaler/ui/webui.py b/scaler/ui/webui.py index 75e80ffed..8b0cbcad1 100644 --- a/scaler/ui/webui.py +++ b/scaler/ui/webui.py @@ -4,7 +4,7 @@ from nicegui import ui -from scaler.io.sync_subscriber import ZMQSyncSubscriber +from scaler.io.ymq_sync_subscriber import YMQSyncSubscriber from scaler.protocol.python.message import StateScheduler, StateTask from scaler.protocol.python.mixins import Message from scaler.ui.constants import ( @@ -75,7 +75,7 @@ def start_webui(address: str, host: str, port: int): with ui.tab_panel(settings_tab): tables.settings_section.draw_section() - subscriber = ZMQSyncSubscriber( + subscriber = YMQSyncSubscriber( address=ZMQConfig.from_string(address), callback=partial(__show_status, tables=tables), topic=b"", diff --git a/scaler/utility/identifiers.py b/scaler/utility/identifiers.py index 88a2054c2..d55bcea15 100644 --- a/scaler/utility/identifiers.py +++ b/scaler/utility/identifiers.py @@ -10,6 +10,9 @@ class Identifier(bytes, metaclass=abc.ABCMeta): def __repr__(self) -> str: raise NotImplementedError() + def extend(self, extra: str): + return self.__class__(self + extra.encode()) + class ClientID(Identifier): def __repr__(self) -> str: diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py index 33a6e01b9..0b4a0b74c 100644 --- a/scaler/worker/agent/processor/processor.py +++ b/scaler/worker/agent/processor/processor.py @@ -9,10 +9,10 @@ from typing import IO, Callable, List, Optional, Tuple, cast import tblib.pickling_support -import zmq +from scaler.io.ymq import ymq from scaler.io.mixins import SyncConnector, SyncObjectStorageConnector -from scaler.io.sync_connector import ZMQSyncConnector +from scaler.io.ymq_sync_connector import YMQSyncConnector from scaler.io.sync_object_storage_connector import PySyncObjectStorageConnector from scaler.protocol.python.common import ObjectMetadata, TaskResultType from scaler.protocol.python.message import ObjectInstruction, ProcessorInitialized, Task, TaskLog, TaskResult @@ -65,6 +65,7 @@ def __init__( def run(self) -> None: self.__initialize() self.__run_forever() + print("PROCESSOR EXITING..........") @staticmethod def get_current_processor() -> Optional["Processor"]: @@ -83,8 +84,8 @@ def __initialize(self): setup_logger(log_paths=tuple(logging_paths), logging_level=self._logging_level) tblib.pickling_support.install() - self._connector_agent: SyncConnector = ZMQSyncConnector( - context=zmq.Context(), socket_type=zmq.DEALER, address=self._agent_address, identity=None + self._connector_agent: SyncConnector = YMQSyncConnector( + context=ymq.IOContext(), socket_type=ymq.IOSocketType.Connector, address=self._agent_address, identity=None ) self._connector_storage: SyncObjectStorageConnector = PySyncObjectStorageConnector( self._storage_address.host, self._storage_address.port @@ -127,11 +128,7 @@ def __run_forever(self): self.__on_connector_receive(message) - except zmq.error.ZMQError as e: - if e.errno != zmq.ENOTSOCK: # ignore if socket got closed - raise - - except (KeyboardInterrupt, InterruptedError): + except (KeyboardInterrupt, InterruptedError, ymq.YMQInterruptedException): pass except Exception as e: diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index 18092fe40..afa8ee42f 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -4,13 +4,14 @@ import os import signal import tempfile +import random import uuid from typing import Dict, Optional, Tuple -import zmq.asyncio +from scaler.io.ymq import ymq -from scaler.io.async_binder import ZMQAsyncBinder -from scaler.io.async_connector import ZMQAsyncConnector +from scaler.io.ymq_async_binder import YMQAsyncBinder +from scaler.io.ymq_async_connector import YMQAsyncConnector from scaler.io.async_object_storage_connector import PyAsyncObjectStorageConnector from scaler.io.config import PROFILING_INTERVAL_SECONDS from scaler.io.mixins import AsyncBinder, AsyncConnector, AsyncObjectStorageConnector @@ -70,8 +71,9 @@ def __init__( self._ident = WorkerID.generate_worker_id(name) # _identity is internal to multiprocessing.Process - self._address_path_internal = os.path.join(tempfile.gettempdir(), f"scaler_worker_{uuid.uuid4().hex}") - self._address_internal = ZMQConfig(ZMQType.ipc, host=self._address_path_internal) + # self._address_path_internal = os.path.join(tempfile.gettempdir(), f"scaler_worker_{uuid.uuid4().hex}") + # self._address_internal = ZMQConfig(ZMQType.ipc, host=self._address_path_internal) + self._address_internal = ZMQConfig(ZMQType.tcp, host="127.0.0.1", port=random.randint(20000, 30000)) self._task_queue_size = task_queue_size self._heartbeat_interval_seconds = heartbeat_interval_seconds @@ -84,7 +86,7 @@ def __init__( self._logging_paths = logging_paths self._logging_level = logging_level - self._context: Optional[zmq.asyncio.Context] = None + self._context: Optional[ymq.IOContext] = None self._connector_external: Optional[AsyncConnector] = None self._binder_internal: Optional[AsyncBinder] = None self._connector_storage: Optional[AsyncObjectStorageConnector] = None @@ -105,19 +107,19 @@ def __initialize(self): setup_logger() register_event_loop(self._event_loop) - self._context = zmq.asyncio.Context() - self._connector_external = ZMQAsyncConnector( + self._context = ymq.IOContext() + self._connector_external = YMQAsyncConnector( context=self._context, name=self.name, - socket_type=zmq.DEALER, + socket_type=ymq.IOSocketType.Connector, address=self._address, bind_or_connect="connect", callback=self.__on_receive_external, identity=self._ident, ) - self._binder_internal = ZMQAsyncBinder( - context=self._context, name=self.name, address=self._address_internal, identity=self._ident + self._binder_internal = YMQAsyncBinder( + context=self._context, name=self.name, address=self._address_internal, identity=self._ident.extend("|internal") ) self._binder_internal.register(self.__on_receive_internal) @@ -228,7 +230,7 @@ async def __get_loops(self): create_async_loop_routine(self._profiling_manager.routine, PROFILING_INTERVAL_SECONDS), ) except asyncio.CancelledError: - pass + print("WORKER TASK CANCELED") except (ClientShutdownException, TimeoutError) as e: logging.info(f"{self.identity!r}: {str(e)}") except Exception as e: @@ -236,10 +238,11 @@ async def __get_loops(self): await self._connector_external.send(DisconnectRequest.new_msg(self.identity)) + self._connector_external.destroy() self._processor_manager.destroy("quit") self._binder_internal.destroy() - os.remove(self._address_path_internal) + # os.remove(self._address_path_internal) logging.info(f"{self.identity!r}: quit") diff --git a/scaler/worker_adapter/symphony/worker.py b/scaler/worker_adapter/symphony/worker.py index fdd6f7811..652ffbc32 100644 --- a/scaler/worker_adapter/symphony/worker.py +++ b/scaler/worker_adapter/symphony/worker.py @@ -5,9 +5,9 @@ from collections import deque from typing import Dict, Optional -import zmq +from scaler.io.ymq import ymq -from scaler.io.async_connector import ZMQAsyncConnector +from scaler.io.ymq_async_connector import YMQAsyncConnector from scaler.io.async_object_storage_connector import PyAsyncObjectStorageConnector from scaler.io.mixins import AsyncConnector, AsyncObjectStorageConnector from scaler.protocol.python.message import ( @@ -68,7 +68,7 @@ def __init__( self._death_timeout_seconds = death_timeout_seconds self._task_queue_size = task_queue_size - self._context: Optional[zmq.asyncio.Context] = None + self._context: Optional[ymq.IOContext] = None self._connector_external: Optional[AsyncConnector] = None self._connector_storage: Optional[AsyncObjectStorageConnector] = None self._task_manager: Optional[SymphonyTaskManager] = None @@ -93,11 +93,11 @@ def __initialize(self): setup_logger() register_event_loop(self._event_loop) - self._context = zmq.asyncio.Context() - self._connector_external = ZMQAsyncConnector( + self._context = ymq.IOContext() + self._connector_external = YMQAsyncConnector( context=self._context, name=self.name, - socket_type=zmq.DEALER, + socket_type=ymq.IOSocketType.Connector, address=self._address, bind_or_connect="connect", callback=self.__on_receive_external,