Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions scaler/client/agent/client_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from concurrent.futures import Future
from typing import Optional

import zmq.asyncio
from scaler.io.ymq import ymq

from scaler.client.agent.disconnect_manager import ClientDisconnectManager
from scaler.client.agent.future_manager import ClientFutureManager
from scaler.client.agent.heartbeat_manager import ClientHeartbeatManager
from scaler.client.agent.object_manager import ClientObjectManager
from scaler.client.agent.task_manager import ClientTaskManager
from scaler.client.serializer.mixins import Serializer
from scaler.io.async_connector import ZMQAsyncConnector
from scaler.io.ymq_async_connector import YMQAsyncConnector
from scaler.io.mixins import AsyncConnector
from scaler.protocol.python.common import ObjectStorageAddress
from scaler.protocol.python.message import (
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
identity: ClientID,
client_agent_address: ZMQConfig,
scheduler_address: ZMQConfig,
context: zmq.Context,
context: ymq.IOContext,
future_manager: ClientFutureManager,
stop_event: threading.Event,
timeout_seconds: int,
Expand All @@ -63,19 +63,19 @@ def __init__(

self._future_manager = future_manager

self._connector_internal: AsyncConnector = ZMQAsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
self._connector_internal: AsyncConnector = YMQAsyncConnector(
context=self._context,
name="client_agent_internal",
socket_type=zmq.PAIR,
socket_type=ymq.IOSocketType.Binder,
bind_or_connect="bind",
address=self._client_agent_address,
callback=self.__on_receive_from_client,
identity=None,
)
self._connector_external: AsyncConnector = ZMQAsyncConnector(
context=zmq.asyncio.Context.shadow(self._context),
self._connector_external: AsyncConnector = YMQAsyncConnector(
context=self._context,
name="client_agent_external",
socket_type=zmq.DEALER,
socket_type=ymq.IOSocketType.Connector,
address=self._scheduler_address,
bind_or_connect="connect",
callback=self.__on_receive_from_scheduler,
Expand Down
15 changes: 8 additions & 7 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import functools
import logging
import threading
import random
import uuid
from collections import Counter
from inspect import signature
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import zmq

from scaler.io.ymq import ymq
from scaler.client.agent.client_agent import ClientAgent
from scaler.client.agent.future_manager import ClientFutureManager
from scaler.client.future import ScalerFuture
Expand All @@ -18,7 +18,7 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.config import DEFAULT_CLIENT_TIMEOUT_SECONDS, DEFAULT_HEARTBEAT_INTERVAL_SECONDS
from scaler.io.mixins import SyncConnector, SyncObjectStorageConnector
from scaler.io.sync_connector import ZMQSyncConnector
from scaler.io.ymq_sync_connector import YMQSyncConnector
from scaler.io.sync_object_storage_connector import PySyncObjectStorageConnector
from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse, GraphTask, Task
from scaler.utility.exceptions import ClientQuitException, MissingObjects
Expand Down Expand Up @@ -89,15 +89,16 @@ def __initialize__(
self._stream_output = stream_output
self._identity = ClientID.generate_client_id()

self._client_agent_address = ZMQConfig(ZMQType.inproc, host=f"scaler_client_{uuid.uuid4().hex}")
# self._client_agent_address = ZMQConfig(ZMQType.inproc, host=f"scaler_client_{uuid.uuid4().hex}")
self._client_agent_address = ZMQConfig(ZMQType.tcp, host="127.0.0.1", port=random.randint(20000, 30000))
self._scheduler_address = ZMQConfig.from_string(address)
self._timeout_seconds = timeout_seconds
self._heartbeat_interval_seconds = heartbeat_interval_seconds

self._stop_event = threading.Event()
self._context = zmq.Context()
self._connector_agent: SyncConnector = ZMQSyncConnector(
context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity
self._context = ymq.IOContext()
self._connector_agent: SyncConnector = YMQSyncConnector(
context=self._context, socket_type=ymq.IOSocketType.Connector, address=self._client_agent_address, identity=self._identity.extend("|agent")
)

self._future_manager = ClientFutureManager(self._serializer)
Expand Down
4 changes: 2 additions & 2 deletions scaler/entry_points/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
from typing import Dict, List, Literal, Union

from scaler.io.sync_subscriber import ZMQSyncSubscriber
from scaler.io.ymq_sync_subscriber import YMQSyncSubscriber
from scaler.protocol.python.message import StateScheduler
from scaler.protocol.python.mixins import Message
from scaler.utility.formatter import (
Expand Down Expand Up @@ -51,7 +51,7 @@ def poke(screen, args):
screen.nodelay(1)

try:
subscriber = ZMQSyncSubscriber(
subscriber = YMQSyncSubscriber(
address=ZMQConfig.from_string(args.address),
callback=functools.partial(show_status, screen=screen),
topic=b"",
Expand Down
2 changes: 1 addition & 1 deletion scaler/entry_points/worker_adapter_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_args():

# Server configuration
parser.add_argument(
"--host", type=str, default="localhost", help="host address for the native worker adapter HTTP server"
"--host", type=str, default="127.0.0.1", help="host address for the native worker adapter HTTP server"
)
parser.add_argument("--port", "-p", type=int, help="port for the native worker adapter HTTP server")

Expand Down
2 changes: 1 addition & 1 deletion scaler/entry_points/worker_adapter_symphony.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_args():

# Server configuration
parser.add_argument(
"--host", type=str, default="localhost", help="host address for the native worker adapter HTTP server"
"--host", type=str, default="127.0.0.1", help="host address for the native worker adapter HTTP server"
)
parser.add_argument("--port", "-p", type=int, required=True, help="port for the native worker adapter HTTP server")

Expand Down
3 changes: 1 addition & 2 deletions scaler/io/ymq/pymod_ymq/io_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ static PyObject* PyIOContext_createIOSocket_sync(PyIOContext* self, PyObject* ar
waiter.signal();
});

if (waiter.wait())
CHECK_SIGNALS;
WAIT(waiter);
} catch (...) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_RuntimeError, "Failed to create io socket synchronously");
Expand Down
50 changes: 30 additions & 20 deletions scaler/io/ymq/pymod_ymq/io_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
#include <chrono>
#include <expected>
#include <memory>
#include <optional>
#include <thread>
#include <utility>

// C
#include <methodobject.h>
#include <semaphore.h>
#include <sys/eventfd.h>
#include <sys/poll.h>
Expand Down Expand Up @@ -40,6 +42,7 @@ extern "C" {
static void PyIOSocket_dealloc(PyIOSocket* self)
{
try {
self->ioContext->requestIOSocketStop(self->socket);
self->ioContext->removeIOSocket(self->socket);
self->ioContext.~shared_ptr();
self->socket.~shared_ptr();
Expand Down Expand Up @@ -90,8 +93,9 @@ static PyObject* PyIOSocket_send_sync(PyIOSocket* self, PyObject* args, PyObject

// borrowed reference
PyMessage* message = nullptr;
const char* kwlist[] = {"message", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &message))
int timeout_secs = -1;
const char* kwlist[] = {"message", "timeout_secs", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|i", (char**)kwlist, &message, &timeout_secs))
return nullptr;

Bytes address = message->address.is_none() ? Bytes() : std::move(message->address->bytes);
Expand All @@ -101,15 +105,14 @@ static PyObject* PyIOSocket_send_sync(PyIOSocket* self, PyObject* args, PyObject

std::shared_ptr<std::expected<void, Error>> result = std::make_shared<std::expected<void, Error>>();
try {
Waiter waiter(state->wakeupfd_rd);
Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs});

self->socket->sendMessage({.address = std::move(address), .payload = std::move(payload)}, [=](auto r) mutable {
*result = std::move(r);
waiter.signal();
});

if (waiter.wait())
CHECK_SIGNALS;
WAIT(waiter);
} catch (...) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_RuntimeError, "Failed to send synchronously");
Expand Down Expand Up @@ -164,25 +167,29 @@ static PyObject* PyIOSocket_recv(PyIOSocket* self, PyObject* args)
});
}

static PyObject* PyIOSocket_recv_sync(PyIOSocket* self, PyObject* args)
static PyObject* PyIOSocket_recv_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs)
{
auto state = YMQStateFromSelf((PyObject*)self);
if (!state)
return nullptr;

int timeout_secs = -1;
const char* kwlist[] = {"timeout_secs", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", (char**)kwlist, &timeout_secs))
return nullptr;

PyThreadState* _save = PyEval_SaveThread();

std::shared_ptr<std::pair<Message, Error>> result = std::make_shared<std::pair<Message, Error>>();
try {
Waiter waiter(state->wakeupfd_rd);
Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs});

self->socket->recvMessage([=](auto r) mutable {
*result = std::move(r);
waiter.signal();
});

if (waiter.wait())
CHECK_SIGNALS;
WAIT(waiter);
} catch (...) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_RuntimeError, "Failed to recv synchronously");
Expand Down Expand Up @@ -252,23 +259,23 @@ static PyObject* PyIOSocket_bind_sync(PyIOSocket* self, PyObject* args, PyObject

const char* address = nullptr;
Py_ssize_t addressLen = 0;
const char* kwlist[] = {"address", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen))
int timeout_secs = -1;
const char* kwlist[] = {"address", "timeout_secs", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#|i", (char**)kwlist, &address, &addressLen, &timeout_secs))
return nullptr;

PyThreadState* _save = PyEval_SaveThread();

auto result = std::make_shared<std::expected<void, Error>>();
try {
Waiter waiter(state->wakeupfd_rd);
Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs});

self->socket->bindTo(std::string(address, addressLen), [=](auto r) mutable {
*result = std::move(r);
waiter.signal();
});

if (waiter.wait())
CHECK_SIGNALS;
WAIT(waiter);
} catch (...) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_RuntimeError, "Failed to bind synchronously");
Expand Down Expand Up @@ -319,23 +326,23 @@ static PyObject* PyIOSocket_connect_sync(PyIOSocket* self, PyObject* args, PyObj

const char* address = nullptr;
Py_ssize_t addressLen = 0;
const char* kwlist[] = {"address", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen))
int timeout_secs = -1;
const char* kwlist[] = {"address", "timeout_secs", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#|i", (char**)kwlist, &address, &addressLen, timeout_secs))
return nullptr;

PyThreadState* _save = PyEval_SaveThread();

std::shared_ptr<std::expected<void, Error>> result = std::make_shared<std::expected<void, Error>>();
try {
Waiter waiter(state->wakeupfd_rd);
Waiter waiter(state->wakeupfd_rd, timeout_secs == -1 ? std::nullopt : std::optional {timeout_secs});

self->socket->connectTo(std::string(address, addressLen), [=](auto r) mutable {
*result = std::move(r);
waiter.signal();
});

if (waiter.wait())
CHECK_SIGNALS;
WAIT(waiter);
} catch (...) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_RuntimeError, "Failed to connect synchronously");
Expand Down Expand Up @@ -399,7 +406,10 @@ static PyMethodDef PyIOSocket_methods[] = {
(PyCFunction)PyIOSocket_send_sync,
METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Send data through the IOSocket synchronously")},
{"recv_sync", (PyCFunction)PyIOSocket_recv_sync, METH_NOARGS, PyDoc_STR("Receive data from the IOSocket")},
{"recv_sync",
(PyCFunction)PyIOSocket_recv_sync,
METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Receive data from the IOSocket")},
{"bind_sync",
(PyCFunction)PyIOSocket_bind_sync,
METH_VARARGS | METH_KEYWORDS,
Expand Down
Loading
Loading