Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/setup-env/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ runs:
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: "3.8"
python-version: "3.12"

- name: Update System (Linux)
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build"
[project]
name = "scaler"
description = "Scaler Distribution Framework"
requires-python = ">=3.8"
requires-python = ">=3.10"
readme = { file = "README.md", content-type = "text/markdown" }
license = { text = "Apache 2.0" }
authors = [{ name = "Citi", email = "opensource@citi.com" }]
Expand Down
104 changes: 33 additions & 71 deletions scaler/io/async_object_storage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Optional, Tuple

from scaler.io.mixins import AsyncObjectStorageConnector
from scaler.io.ymq.ymq import *
from scaler.protocol.capnp._python import _object_storage # noqa
from scaler.protocol.python.object_storage import ObjectRequestHeader, ObjectResponseHeader, to_capnp_object_id
from scaler.utility.exceptions import ObjectStorageException
Expand All @@ -22,40 +23,26 @@ def __init__(self):

self._connected_event = asyncio.Event()

self._reader: Optional[asyncio.StreamReader] = None
self._writer: Optional[asyncio.StreamWriter] = None

self._next_request_id = 0
self._pending_get_requests: Dict[ObjectID, asyncio.Future] = {}

self._identity: bytes = f"{os.getpid()}|{socket.gethostname().split('.')[0]}|{uuid.uuid4()}".encode()
self._lock = asyncio.Lock()
self._identity: str = f"{os.getpid()}|{socket.gethostname().split('.')[0]}|{uuid.uuid4()}"
self._io_context: IOContext = IOContext()
self._io_socket = self._io_context.createIOSocket_sync(self._identity, IOSocketType.Connector)

def __del__(self):
if not self.is_connected():
return

self._writer.close()
self._io_socket = None

async def connect(self, host: str, port: int):
self._host = host
self._port = port

if self.is_connected():
raise ObjectStorageException("connector is already connected.")

self._reader, self._writer = await asyncio.open_connection(self._host, self._port)
await self.__read_framed_message()
self.__write_framed(self._identity)

try:
await self._writer.drain()
except ConnectionResetError:
self.__raise_connection_failure()

# Makes sure the socket is TCP_NODELAY. It seems to be the case by default, but that's not specified in the
# asyncio's documentation and might change in the future.
self._writer.get_extra_info("socket").setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

await self._io_socket.connect(self.address)
self._connected_event.set()

async def wait_until_connected(self):
Expand All @@ -67,23 +54,10 @@ def is_connected(self) -> bool:
async def destroy(self):
if not self.is_connected():
return

if not self._writer.is_closing:
self._writer.close()

await self._writer.wait_closed()

@property
def reader(self) -> Optional[asyncio.StreamReader]:
return self._reader

@property
def writer(self) -> Optional[asyncio.StreamWriter]:
return self._writer
self._io_socket = None

@property
def address(self) -> str:
self.__ensure_is_connected()
return f"tcp://{self._host}:{self._port}"

async def routine(self):
Expand Down Expand Up @@ -136,12 +110,9 @@ async def duplicate_object_id(self, object_id: ObjectID, new_object_id: ObjectID
)

def __ensure_is_connected(self):
if self._writer is None:
if self._io_socket is None:
raise ObjectStorageException("connector is not connected.")

if self._writer.is_closing():
raise ObjectStorageException("connection is closed.")

async def __send_request(
self,
object_id: ObjectID,
Expand All @@ -150,75 +121,66 @@ async def __send_request(
payload: Optional[bytes],
):
self.__ensure_is_connected()
assert self._writer is not None

request_id = self._next_request_id
self._next_request_id += 1
self._next_request_id %= 2**64 - 1 # UINT64_MAX

header = ObjectRequestHeader.new_msg(object_id, payload_length, request_id, request_type)

self.__write_request_header(header)
try:
async with self._lock:
await self.__write_request_header(header)

if payload is not None:
self.__write_request_payload(payload)
if payload is not None:
await self.__write_request_payload(payload)

try:
await self._writer.drain()
except ConnectionResetError:
except YMQException:
self._io_socket = None
self.__raise_connection_failure()

def __write_request_header(self, header: ObjectRequestHeader):
assert self._writer is not None
self.__write_framed(header.get_message().to_bytes())
async def __write_request_header(self, header: ObjectRequestHeader):
assert self._io_socket is not None
await self._io_socket.send(Message(address=None, payload=header.get_message().to_bytes()))

def __write_request_payload(self, payload: bytes):
assert self._writer is not None
self.__write_framed(payload)
async def __write_request_payload(self, payload: bytes):
assert self._io_socket is not None
await self._io_socket.send(Message(address=None, payload=payload))

async def __receive_response(self) -> Optional[Tuple[ObjectResponseHeader, bytes]]:
assert self._reader is not None

if self._writer.is_closing():
if self._io_socket is None:
return None

try:
header = await self.__read_response_header()
payload = await self.__read_response_payload(header)
except asyncio.IncompleteReadError:
except YMQException:
self._io_socket = None
self.__raise_connection_failure()

return header, payload

async def __read_response_header(self) -> ObjectResponseHeader:
assert self._reader is not None
assert self._io_socket is not None

header_data = await self.__read_framed_message()
msg = await self._io_socket.recv()
header_data = msg.payload.data
assert len(header_data) == ObjectResponseHeader.MESSAGE_LENGTH

with _object_storage.ObjectResponseHeader.from_bytes(header_data) as header_message:
return ObjectResponseHeader(header_message)

async def __read_response_payload(self, header: ObjectResponseHeader) -> bytes:
assert self._reader is not None
assert self._io_socket is not None
# assert self._reader is not None

if header.payload_length > 0:
res = await self.__read_framed_message()
assert len(res) == header.payload_length
return res
res = await self._io_socket.recv()
assert len(res.payload) == header.payload_length
return res.payload.data
else:
return b""

async def __read_framed_message(self) -> bytes:
length_bytes = await self._reader.readexactly(8)
(payload_length,) = struct.unpack("<Q", length_bytes)
return await self._reader.readexactly(payload_length) if payload_length > 0 else bytes()

def __write_framed(self, payload: bytes):
self._writer.write(struct.pack("<Q", len(payload)))
self._writer.write(payload)
return

@staticmethod
def __raise_connection_failure():
raise ObjectStorageException("connection failure to object storage server.")
Loading