diff --git a/.github/actions/run-test/action.yml b/.github/actions/run-test/action.yml index 49770009d..a0faad365 100644 --- a/.github/actions/run-test/action.yml +++ b/.github/actions/run-test/action.yml @@ -18,6 +18,20 @@ runs: shell: pwsh run: uv pip install --system pydivert pywin32 + - name: Install Redis Server + if: inputs.os == 'Linux' + shell: bash + run: | + sudo apt-get update + sudo apt-get install -y redis-server + sudo systemctl start redis-server || redis-server --daemonize yes + sleep 2 + redis-cli ping || echo "Warning: Redis may not be running" + + - name: Install Redis Python Package + shell: bash + run: uv pip install --system redis>=5.0.0 + - name: Run C++ Tests (Linux) if: inputs.os == 'Linux' shell: bash diff --git a/docs/source/index.rst b/docs/source/index.rst index 79880cc9d..df5c1789d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,6 +28,7 @@ Content tutorials/features tutorials/compatibility/ray tutorials/configuration + tutorials/redis_backend tutorials/examples tutorials/development/devcontainer tutorials/development/guidelines diff --git a/docs/source/tutorials/redis_backend.rst b/docs/source/tutorials/redis_backend.rst new file mode 100644 index 000000000..349dc9827 --- /dev/null +++ b/docs/source/tutorials/redis_backend.rst @@ -0,0 +1,244 @@ +Redis Backend for Object Storage +================================== + +The Redis backend is an optional storage backend for Scaler's object storage system. +It enables distributed object sharing across multiple nodes and optional persistence of objects. + +Features +-------- + +* **Distributed Access** - Share objects across multiple Scaler nodes +* **Optional Persistence** - Configure Redis for disk persistence +* **Automatic Serialization** - Transparent data and metadata handling +* **Connection Pooling** - Efficient Redis connection management +* **Size Limits** - Configurable max object size protection +* **Key Namespacing** - Prefix support to avoid conflicts + +Installation +------------ + +Install Redis Backend Support +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + # Install Scaler with Redis support + pip install opengris-scaler[redis] + + # Or if you already have Scaler installed + pip install redis>=5.0.0 + +Install Redis Server +~~~~~~~~~~~~~~~~~~~~ + +**Ubuntu/Debian:** + +.. code-block:: bash + + sudo apt update + sudo apt install redis-server + sudo systemctl start redis-server + +**Docker:** + +.. code-block:: bash + + docker run -d -p 6379:6379 redis:latest + +Quick Start +----------- + +1. Configure Redis Backend +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Create a configuration file ``config.toml``: + +.. code-block:: toml + + [object_storage_server] + object_storage_address = "tcp://127.0.0.1:2346" + backend = "redis" + + [object_storage_server.redis] + url = "redis://localhost:6379/0" + max_object_size_mb = 100 + +2. Start Object Storage Server +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + scaler_object_storage_server --config config.toml + +3. Use in Your Application +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from scaler import Client + + # Connect to scheduler (which uses Redis-backed object storage) + with Client(address="tcp://127.0.0.1:2345") as client: + # Objects are automatically stored in Redis + result = client.submit(lambda x: x * 2, 21) + print(result.result()) # 42 + +Configuration +------------- + +Basic Configuration +~~~~~~~~~~~~~~~~~~~ + +.. code-block:: toml + + [object_storage_server] + backend = "redis" + + [object_storage_server.redis] + url = "redis://localhost:6379/0" + max_object_size_mb = 100 + key_prefix = "scaler:obj:" + connection_pool_size = 10 + +Configuration Options +~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 20 15 25 40 + + * - Option + - Type + - Default + - Description + * - ``url`` + - string + - ``redis://localhost:6379/0`` + - Redis connection URL + * - ``max_object_size_mb`` + - int + - ``100`` + - Maximum object size in MB + * - ``key_prefix`` + - string + - ``scaler:obj:`` + - Prefix for Redis keys + * - ``connection_pool_size`` + - int + - ``10`` + - Connection pool size + +Redis URL Format +~~~~~~~~~~~~~~~~ + +.. code-block:: text + + redis://[username:password@]host:port/database + + Examples: + redis://localhost:6379/0 # Local, no auth + redis://:password@localhost:6379/0 # With password + redis://user:pass@redis.example.com:6379/0 # With user & password + rediss://localhost:6380/0 # SSL/TLS connection + +Advanced Configuration +---------------------- + +With Authentication +~~~~~~~~~~~~~~~~~~~ + +.. code-block:: toml + + [object_storage_server.redis] + url = "redis://:secure_password@localhost:6379/0" + +With Custom Settings +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: toml + + [object_storage_server.redis] + url = "redis://localhost:6379/5" # Use database 5 + max_object_size_mb = 50 # Limit to 50MB + key_prefix = "myapp:scaler:obj:" # Custom prefix + connection_pool_size = 20 # More connections + +Multiple Nodes Sharing Redis +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Node 1 config:** + +.. code-block:: toml + + [object_storage_server] + object_storage_address = "tcp://0.0.0.0:2346" + backend = "redis" + + [object_storage_server.redis] + url = "redis://redis-server.example.com:6379/0" + +**Node 2 config:** + +.. code-block:: toml + + [object_storage_server] + object_storage_address = "tcp://0.0.0.0:2346" + backend = "redis" + + [object_storage_server.redis] + url = "redis://redis-server.example.com:6379/0" + +Both nodes now share the same object storage! + +Usage Examples +-------------- + +Running the Demo +~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + # Ensure Redis is running + redis-cli ping # Should return: PONG + + # Run the demo script + python examples/redis_backend_demo.py + +Programmatic Usage +~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from scaler.object_storage.redis_backend import RedisObjectStorageBackend + + # Create backend + backend = RedisObjectStorageBackend( + redis_url="redis://localhost:6379/0", + max_object_size_mb=100 + ) + + # Store object + object_id = b"my_object" + data = b"Hello, World!" + metadata = b"text/plain" + backend.put(object_id, data, metadata) + + # Retrieve object + result = backend.get(object_id) + if result: + data, metadata = result + print(f"Data: {data.decode()}") + print(f"Metadata: {metadata.decode()}") + + # Check existence + if backend.exists(object_id): + print("Object exists!") + + # Delete object + backend.delete(object_id) + + # Get info + info = backend.get_info() + print(f"Object count: {info['object_count']}") + print(f"Total size: {info['total_size_bytes']} bytes") + diff --git a/pyproject.toml b/pyproject.toml index 478a625d2..94c26d54b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ graphblas = [ aws = [ "boto3", ] +redis = [ + "redis>=5.0.0", +] all = [ "nicegui[plotly]==2.24.2; python_version == '3.8'", "nicegui[plotly]==3.4.1; python_version >= '3.9'", @@ -59,6 +62,7 @@ all = [ "numpy==2.4.0; python_version >= '3.10'", "uvloop; platform_system != 'Windows'", "boto3", + "redis>=5.0.0", ] [dependency-groups] diff --git a/src/scaler/cluster/object_storage_server.py b/src/scaler/cluster/object_storage_server.py index 245a602c0..b2347916d 100644 --- a/src/scaler/cluster/object_storage_server.py +++ b/src/scaler/cluster/object_storage_server.py @@ -1,9 +1,10 @@ import logging import multiprocessing +import os +import select from typing import Optional, Tuple from scaler.config.types.object_storage_server import ObjectStorageAddressConfig -from scaler.object_storage.object_storage_server import ObjectStorageServer from scaler.utility.logging.utility import get_logger_info, setup_logger @@ -14,6 +15,8 @@ def __init__( logging_paths: Tuple[str, ...], logging_level: str, logging_config_file: Optional[str], + backend: str = "memory", + redis_config: Optional[dict] = None, ): multiprocessing.Process.__init__(self, name="ObjectStorageServer") @@ -22,24 +25,91 @@ def __init__( self._logging_config_file = logging_config_file self._object_storage_address = object_storage_address + self._backend = backend + self._redis_config = redis_config - self._server = ObjectStorageServer() + # For Python server: set up ready signaling pipe BEFORE fork + self._ready_read_fd: Optional[int] = None + self._ready_write_fd: Optional[int] = None + self._is_python_server = backend == "redis" + + if self._is_python_server: + # Create pipe for ready signaling (must be before fork) + self._ready_read_fd, self._ready_write_fd = os.pipe() + + # Server instance created lazily in run() for Python server + self._server = None def wait_until_ready(self) -> None: - """Blocks until the object storage server is available to server requests.""" - self._server.wait_until_ready() + """Blocks until the object storage server is available to serve requests.""" + if self._is_python_server: + # Wait for ready signal via pipe + if self._ready_read_fd is not None: + try: + readable, _, _ = select.select([self._ready_read_fd], [], [], 30.0) + if readable: + os.read(self._ready_read_fd, 1) + except (OSError, ValueError): + pass + finally: + try: + os.close(self._ready_read_fd) + except OSError: + pass + self._ready_read_fd = None + else: + # C++ server: need to create instance to wait + if self._server is None: + from scaler.object_storage.object_storage_server import ObjectStorageServer + + self._server = ObjectStorageServer() + self._server.wait_until_ready() def run(self) -> None: setup_logger(self._logging_paths, self._logging_config_file, self._logging_level) - logging.info(f"ObjectStorageServer: start and listen to {self._object_storage_address.to_string()}") + + backend_info = f" (backend={self._backend})" if self._backend != "memory" else "" + logging.info( + f"ObjectStorageServer: start and listen to {self._object_storage_address.to_string()}{backend_info}" + ) log_format_str, log_level_str, logging_paths = get_logger_info(logging.getLogger()) - self._server.run( - self._object_storage_address.host, - self._object_storage_address.port, - self._object_storage_address.identity, - log_level_str, - log_format_str, - logging_paths, - ) + if self._is_python_server: + # Close read end in child - we only write from here + if self._ready_read_fd is not None: + try: + os.close(self._ready_read_fd) + except OSError: + pass + self._ready_read_fd = None + + from scaler.object_storage.python_object_storage_server import PythonObjectStorageServer, create_backend + + storage_backend = create_backend("redis", self._redis_config) + server = PythonObjectStorageServer(storage_backend) + + # Pass the write FD to the server for signaling ready + server._ready_write_fd = self._ready_write_fd + + server.run( + self._object_storage_address.host, + self._object_storage_address.port, + self._object_storage_address.identity, + log_level_str, + log_format_str, + logging_paths, + multiprocessing_ready=False, # We handle FDs ourselves + ) + else: + from scaler.object_storage.object_storage_server import ObjectStorageServer + + self._server = ObjectStorageServer() + self._server.run( + self._object_storage_address.host, + self._object_storage_address.port, + self._object_storage_address.identity, + log_level_str, + log_format_str, + logging_paths, + ) diff --git a/src/scaler/config/section/object_storage_server.py b/src/scaler/config/section/object_storage_server.py index e24b9e5d8..4d24e8570 100644 --- a/src/scaler/config/section/object_storage_server.py +++ b/src/scaler/config/section/object_storage_server.py @@ -1,9 +1,22 @@ import dataclasses +from typing import Optional from scaler.config.config_class import ConfigClass from scaler.config.types.object_storage_server import ObjectStorageAddressConfig +@dataclasses.dataclass +class RedisBackendConfig: + """Configuration for Redis object storage backend.""" + + url: str = dataclasses.field( + default="redis://localhost:6379/0", metadata=dict(help="Redis connection URL (e.g., redis://localhost:6379/0)") + ) + max_object_size_mb: int = dataclasses.field(default=100, metadata=dict(help="Maximum object size in MB")) + key_prefix: str = dataclasses.field(default="scaler:obj:", metadata=dict(help="Key prefix for Redis keys")) + connection_pool_size: int = dataclasses.field(default=10, metadata=dict(help="Redis connection pool size")) + + @dataclasses.dataclass class ObjectStorageServerConfig(ConfigClass): object_storage_address: ObjectStorageAddressConfig = dataclasses.field( @@ -11,3 +24,20 @@ class ObjectStorageServerConfig(ConfigClass): positional=True, help="specify the object storage server address to listen to, e.g. tcp://localhost:2345." ) ) + backend: str = dataclasses.field( + default="memory", + metadata=dict( + help="Storage backend type: 'memory' (default, uses C++ server) or 'redis' (uses Python server with Redis)" + ), + ) + redis: Optional[RedisBackendConfig] = dataclasses.field( + default=None, metadata=dict(help="Redis backend configuration (only used when backend='redis')") + ) + + def __post_init__(self): + if self.backend not in ("memory", "redis"): + raise ValueError(f"backend must be 'memory' or 'redis', got '{self.backend}'") + + if self.backend == "redis" and self.redis is None: + # Use default Redis config if not specified + self.redis = RedisBackendConfig() diff --git a/src/scaler/config/section/scheduler.py b/src/scaler/config/section/scheduler.py index 7c70fec12..85846d832 100644 --- a/src/scaler/config/section/scheduler.py +++ b/src/scaler/config/section/scheduler.py @@ -5,6 +5,7 @@ from scaler.config import defaults from scaler.config.common.logging import LoggingConfig from scaler.config.config_class import ConfigClass +from scaler.config.section.object_storage_server import RedisBackendConfig from scaler.config.types.object_storage_server import ObjectStorageAddressConfig from scaler.config.types.zmq import ZMQConfig from scaler.scheduler.allocate_policy.allocate_policy import AllocatePolicy @@ -27,6 +28,17 @@ class SchedulerConfig(ConfigClass): "then object storage address is tcp://localhost:2346", ), ) + object_storage_backend: str = dataclasses.field( + default="memory", + metadata=dict( + short="-osb", + help="object storage backend: 'memory' (default, high-performance C++) or 'redis' (distributed, Python)", + ), + ) + object_storage_redis: Optional[RedisBackendConfig] = dataclasses.field( + default=None, + metadata=dict(help="Redis configuration for object storage (only used when object_storage_backend='redis')"), + ) monitor_address: Optional[ZMQConfig] = dataclasses.field( default=None, metadata=dict( @@ -124,3 +136,7 @@ def __post_init__(self): raise ValueError(f"adapter_webhook_urls contains url '{adapter_webhook_url}' which is not a valid URL.") if self.worker_io_threads <= 0: raise ValueError("worker_io_threads must be a positive integer.") + if self.object_storage_backend not in ("memory", "redis"): + raise ValueError(f"object_storage_backend must be 'memory' or 'redis', got '{self.object_storage_backend}'") + if self.object_storage_backend == "redis" and self.object_storage_redis is None: + self.object_storage_redis = RedisBackendConfig() diff --git a/src/scaler/entry_points/object_storage_server.py b/src/scaler/entry_points/object_storage_server.py index cbdf2b3db..f924d0e5e 100644 --- a/src/scaler/entry_points/object_storage_server.py +++ b/src/scaler/entry_points/object_storage_server.py @@ -2,7 +2,6 @@ import sys from scaler.config.section.object_storage_server import ObjectStorageServerConfig -from scaler.object_storage.object_storage_server import ObjectStorageServer from scaler.utility.logging.utility import get_logger_info, setup_logger @@ -14,13 +13,43 @@ def main(): log_format_str, log_level_str, log_paths = get_logger_info(logging.getLogger()) try: - ObjectStorageServer().run( - oss_config.object_storage_address.host, - oss_config.object_storage_address.port, - oss_config.object_storage_address.identity, - log_level_str, - log_format_str, - log_paths, - ) + # Use Python server for Redis backend, C++ server for memory backend + if oss_config.backend == "redis": + from scaler.object_storage.python_object_storage_server import PythonObjectStorageServer, create_backend + + # Build Redis config dict from config object + redis_config = None + if oss_config.redis is not None: + redis_config = { + "url": oss_config.redis.url, + "max_object_size_mb": oss_config.redis.max_object_size_mb, + "key_prefix": oss_config.redis.key_prefix, + "connection_pool_size": oss_config.redis.connection_pool_size, + } + + backend = create_backend("redis", redis_config) + server = PythonObjectStorageServer(backend) + + logging.info(f"Using Python object storage server with Redis backend") + server.run( + oss_config.object_storage_address.host, + oss_config.object_storage_address.port, + oss_config.object_storage_address.identity, + log_level_str, + log_format_str, + log_paths, + ) + else: + # Default: use the high-performance C++ server + from scaler.object_storage.object_storage_server import ObjectStorageServer + + ObjectStorageServer().run( + oss_config.object_storage_address.host, + oss_config.object_storage_address.port, + oss_config.object_storage_address.identity, + log_level_str, + log_format_str, + log_paths, + ) except KeyboardInterrupt: sys.exit(0) diff --git a/src/scaler/entry_points/scheduler.py b/src/scaler/entry_points/scheduler.py index b3f9c5bba..ffbf01382 100644 --- a/src/scaler/entry_points/scheduler.py +++ b/src/scaler/entry_points/scheduler.py @@ -15,11 +15,24 @@ def main(): object_storage_address = ObjectStorageAddressConfig( host=scheduler_config.scheduler_address.host, port=scheduler_config.scheduler_address.port + 1 ) + + # Build Redis config dict if using Redis backend + redis_config = None + if scheduler_config.object_storage_backend == "redis" and scheduler_config.object_storage_redis is not None: + redis_config = { + "url": scheduler_config.object_storage_redis.url, + "max_object_size_mb": scheduler_config.object_storage_redis.max_object_size_mb, + "key_prefix": scheduler_config.object_storage_redis.key_prefix, + "connection_pool_size": scheduler_config.object_storage_redis.connection_pool_size, + } + object_storage = ObjectStorageServerProcess( object_storage_address=object_storage_address, logging_paths=scheduler_config.logging_config.paths, logging_config_file=scheduler_config.logging_config.config_file, logging_level=scheduler_config.logging_config.level, + backend=scheduler_config.object_storage_backend, + redis_config=redis_config, ) object_storage.start() object_storage.wait_until_ready() # object storage should be ready before starting the cluster diff --git a/src/scaler/object_storage/backend.py b/src/scaler/object_storage/backend.py new file mode 100644 index 000000000..859d7bba3 --- /dev/null +++ b/src/scaler/object_storage/backend.py @@ -0,0 +1,93 @@ +""" +Abstract backend interface for object storage. + +This module defines the interface that all object storage backends must implement, +allowing for pluggable storage strategies (in-memory, Redis, disk-based, etc.). +""" + +from abc import ABC, abstractmethod +from typing import Optional, Tuple + + +class ObjectStorageBackend(ABC): + """Abstract base class for object storage backends.""" + + @abstractmethod + def put(self, object_id: bytes, data: bytes, metadata: bytes) -> bool: + """ + Store an object in the backend. + + Args: + object_id: Unique identifier for the object + data: The object's data payload + metadata: The object's metadata + + Returns: + True if successful, False otherwise + """ + pass + + @abstractmethod + def get(self, object_id: bytes) -> Optional[Tuple[bytes, bytes]]: + """ + Retrieve an object from the backend. + + Args: + object_id: Unique identifier for the object + + Returns: + Tuple of (data, metadata) if found, None otherwise + """ + pass + + @abstractmethod + def delete(self, object_id: bytes) -> bool: + """ + Delete an object from the backend. + + Args: + object_id: Unique identifier for the object + + Returns: + True if the object existed and was deleted, False otherwise + """ + pass + + @abstractmethod + def exists(self, object_id: bytes) -> bool: + """ + Check if an object exists in the backend. + + Args: + object_id: Unique identifier for the object + + Returns: + True if the object exists, False otherwise + """ + pass + + @abstractmethod + def size(self) -> int: + """ + Get the total size of all objects in bytes. + + Returns: + Total bytes stored in the backend + """ + pass + + @abstractmethod + def clear(self) -> None: + """ + Clear all objects from the backend. + Used primarily for testing and cleanup. + """ + pass + + def close(self) -> None: + """ + Close any resources held by the backend. + + Optional - implementations may override if they have resources to clean up. + """ + pass diff --git a/src/scaler/object_storage/memory_backend.py b/src/scaler/object_storage/memory_backend.py new file mode 100644 index 000000000..85b909083 --- /dev/null +++ b/src/scaler/object_storage/memory_backend.py @@ -0,0 +1,86 @@ +""" +Simple in-memory object storage backend. + +This provides the default backend for Scaler's object storage server when +Redis is not configured. +""" + +import logging +import threading +from typing import Dict, Optional, Tuple + +from scaler.object_storage.backend import ObjectStorageBackend + +logger = logging.getLogger(__name__) + + +class MemoryObjectStorageBackend(ObjectStorageBackend): + """ + Thread-safe in-memory object storage backend. + + This is a simple Python implementation for use with the Python + object storage server. For maximum performance, use the default + C++ object storage server which has its own optimized in-memory storage. + """ + + def __init__(self): + self._lock = threading.Lock() + self._objects: Dict[bytes, Tuple[bytes, bytes]] = {} # object_id -> (data, metadata) + self._total_size = 0 + + def put(self, object_id: bytes, data: bytes, metadata: bytes) -> bool: + """Store an object in memory.""" + with self._lock: + # Track size changes for updates + old_size = 0 + if object_id in self._objects: + old_data, old_meta = self._objects[object_id] + old_size = len(old_data) + len(old_meta) + + self._objects[object_id] = (data, metadata) + self._total_size += len(data) + len(metadata) - old_size + + logger.debug(f"Stored object {object_id.hex()[:16]}... ({len(data)} bytes)") + return True + + def get(self, object_id: bytes) -> Optional[Tuple[bytes, bytes]]: + """Retrieve an object from memory.""" + with self._lock: + result = self._objects.get(object_id) + if result: + logger.debug(f"Retrieved object {object_id.hex()[:16]}... ({len(result[0])} bytes)") + return result + + def delete(self, object_id: bytes) -> bool: + """Delete an object from memory.""" + with self._lock: + if object_id in self._objects: + data, metadata = self._objects.pop(object_id) + self._total_size -= len(data) + len(metadata) + logger.debug(f"Deleted object {object_id.hex()[:16]}...") + return True + return False + + def exists(self, object_id: bytes) -> bool: + """Check if an object exists.""" + with self._lock: + return object_id in self._objects + + def size(self) -> int: + """Get total size of all objects in bytes.""" + with self._lock: + return self._total_size + + def clear(self) -> None: + """Clear all objects.""" + with self._lock: + count = len(self._objects) + self._objects.clear() + self._total_size = 0 + if count: + logger.info(f"Cleared {count} objects from memory") + + def count(self) -> int: + """Get number of objects stored.""" + with self._lock: + return len(self._objects) diff --git a/src/scaler/object_storage/python_object_storage_server.py b/src/scaler/object_storage/python_object_storage_server.py new file mode 100644 index 000000000..16d5948b6 --- /dev/null +++ b/src/scaler/object_storage/python_object_storage_server.py @@ -0,0 +1,443 @@ +""" +Python-based Object Storage Server. + +This server implements the same protocol as the C++ ObjectStorageServer, +but allows pluggable backends (memory, Redis, etc.). + +Use this when you need Redis backend support. For maximum performance +with in-memory storage, use the default C++ server. +""" + +import asyncio +import logging +import signal +import struct +import sys +from typing import Dict, List, Optional, Tuple + +from scaler.object_storage.backend import ObjectStorageBackend +from scaler.object_storage.memory_backend import MemoryObjectStorageBackend +from scaler.protocol.capnp._python import _object_storage +from scaler.protocol.python.object_storage import ( + ObjectRequestHeader, + ObjectResponseHeader, + from_capnp_object_id, + to_capnp_object_id, +) +from scaler.utility.identifiers import ObjectID + +logger = logging.getLogger(__name__) + + +class PythonObjectStorageServer: + """ + Python implementation of the Scaler Object Storage Server. + + This server uses asyncio for concurrency and supports pluggable + storage backends via the ObjectStorageBackend interface. + """ + + def __init__(self, backend: Optional[ObjectStorageBackend] = None): + """ + Initialize the server with a storage backend. + + Args: + backend: Storage backend to use. Defaults to MemoryObjectStorageBackend. + """ + self._backend = backend or MemoryObjectStorageBackend() + self._server: Optional[asyncio.Server] = None + self._shutdown_event = asyncio.Event() + + # For cross-process ready signaling (like C++ server uses pipe fds) + self._ready_read_fd: Optional[int] = None + self._ready_write_fd: Optional[int] = None + + # Track pending GET requests (waiting for objects that don't exist yet) + # Maps object_id -> list of (writer, request_header) tuples + self._pending_gets: Dict[bytes, List[Tuple[asyncio.StreamWriter, ObjectRequestHeader]]] = {} + + # Lock for thread-safe access to pending gets + self._pending_lock = asyncio.Lock() + + def _init_ready_fds(self): + """Initialize pipe for ready signaling.""" + import os + + self._ready_read_fd, self._ready_write_fd = os.pipe() + + def _set_ready(self): + """Signal that server is ready.""" + import os + + if self._ready_write_fd is not None: + os.write(self._ready_write_fd, b"1") + os.close(self._ready_write_fd) + self._ready_write_fd = None + + def _close_ready_fds(self): + """Close ready signaling fds.""" + import os + + if self._ready_read_fd is not None: + try: + os.close(self._ready_read_fd) + except OSError: + pass + self._ready_read_fd = None + if self._ready_write_fd is not None: + try: + os.close(self._ready_write_fd) + except OSError: + pass + self._ready_write_fd = None + + def wait_until_ready(self) -> None: + """Block until the server is ready to accept connections.""" + import os + import select + + if self._ready_read_fd is None: + # Not using multiprocessing, just return + return + + # Wait for ready signal with timeout + try: + readable, _, _ = select.select([self._ready_read_fd], [], [], 30.0) + if readable: + os.read(self._ready_read_fd, 1) + except (OSError, ValueError): + pass + finally: + try: + os.close(self._ready_read_fd) + except OSError: + pass + self._ready_read_fd = None + + async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + """Handle a single client connection.""" + peer = writer.get_extra_info("peername") + logger.debug(f"New client connection from {peer}") + + try: + # Exchange identities (YMQ-style framing) + # Server sends identity FIRST, then reads client identity + server_identity = b"PythonObjectStorageServer" + await self._write_framed_message(writer, server_identity) + + # Read client identity + client_identity = await self._read_framed_message(reader) + logger.debug(f"Client identity: {client_identity[:50]}...") + + # Process requests + while not self._shutdown_event.is_set(): + try: + request = await self._read_request(reader) + if request is None: + break + + header, payload = request + await self._process_request(writer, header, payload) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error processing request from {peer}: {e}") + break + + except Exception as e: + logger.debug(f"Client {peer} disconnected: {e}") + finally: + # Clean up pending requests for this client + await self._cleanup_client_pending_requests(writer) + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + + async def _cleanup_client_pending_requests(self, writer: asyncio.StreamWriter): + """Remove pending requests for a disconnected client.""" + async with self._pending_lock: + for object_id in list(self._pending_gets.keys()): + self._pending_gets[object_id] = [(w, h) for w, h in self._pending_gets[object_id] if w != writer] + if not self._pending_gets[object_id]: + del self._pending_gets[object_id] + + async def _read_framed_message(self, reader: asyncio.StreamReader) -> bytes: + """Read a length-prefixed message.""" + length_bytes = await reader.readexactly(8) + (length,) = struct.unpack(" Optional[Tuple[ObjectRequestHeader, bytes]]: + """Read a request (header + optional payload).""" + try: + # Read header + header_bytes = await self._read_framed_message(reader) + if not header_bytes: + return None + + with _object_storage.ObjectRequestHeader.from_bytes(bytes(header_bytes)) as msg: + header = ObjectRequestHeader(msg) + + # Read payload if present (for SET and DUPLICATE requests) + payload = b"" + request_type = header.request_type + if request_type in ( + ObjectRequestHeader.ObjectRequestType.SetObject, + ObjectRequestHeader.ObjectRequestType.DuplicateObjectID, + ): + payload = await self._read_framed_message(reader) + + return header, payload + + except asyncio.IncompleteReadError: + return None + + async def _write_response( + self, + writer: asyncio.StreamWriter, + object_id: ObjectID, + response_id: int, + response_type: ObjectResponseHeader.ObjectResponseType, + payload: bytes = b"", + ): + """Write a response (header + optional payload).""" + header = ObjectResponseHeader.new_msg(object_id, len(payload), response_id, response_type) + header_bytes = header.get_message().to_bytes() + + await self._write_framed_message(writer, header_bytes) + + if payload: + await self._write_framed_message(writer, payload) + + async def _process_request(self, writer: asyncio.StreamWriter, header: ObjectRequestHeader, payload: bytes): + """Process a single request.""" + request_type = header.request_type + object_id = header.object_id + + if request_type == ObjectRequestHeader.ObjectRequestType.SetObject: + await self._handle_set(writer, header, payload) + elif request_type == ObjectRequestHeader.ObjectRequestType.GetObject: + await self._handle_get(writer, header) + elif request_type == ObjectRequestHeader.ObjectRequestType.DeleteObject: + await self._handle_delete(writer, header) + elif request_type == ObjectRequestHeader.ObjectRequestType.DuplicateObjectID: + await self._handle_duplicate(writer, header, payload) + else: + logger.warning(f"Unknown request type: {request_type}") + + async def _handle_set(self, writer: asyncio.StreamWriter, header: ObjectRequestHeader, payload: bytes): + """Handle SET request.""" + object_id = header.object_id + object_id_bytes = self._object_id_to_bytes(object_id) + + # Store in backend (metadata is empty for now - Scaler doesn't use it) + self._backend.put(object_id_bytes, payload, b"") + + # Send response + await self._write_response(writer, object_id, header.request_id, ObjectResponseHeader.ObjectResponseType.SetOK) + + # Check if there are pending GET requests for this object + await self._fulfill_pending_gets(object_id, object_id_bytes, payload) + + async def _fulfill_pending_gets(self, object_id: ObjectID, object_id_bytes: bytes, payload: bytes): + """Send responses to clients waiting for this object.""" + async with self._pending_lock: + pending = self._pending_gets.pop(object_id_bytes, []) + + for pending_writer, pending_header in pending: + try: + # Respect max payload length + max_len = pending_header.payload_length + response_payload = payload[:max_len] if max_len < len(payload) else payload + + await self._write_response( + pending_writer, + object_id, + pending_header.request_id, + ObjectResponseHeader.ObjectResponseType.GetOK, + response_payload, + ) + except Exception as e: + logger.debug(f"Failed to send pending GET response: {e}") + + async def _handle_get(self, writer: asyncio.StreamWriter, header: ObjectRequestHeader): + """Handle GET request.""" + object_id = header.object_id + object_id_bytes = self._object_id_to_bytes(object_id) + + result = self._backend.get(object_id_bytes) + + if result is not None: + data, _ = result + # Respect max payload length + max_len = header.payload_length + response_payload = data[:max_len] if max_len < len(data) else data + + await self._write_response( + writer, object_id, header.request_id, ObjectResponseHeader.ObjectResponseType.GetOK, response_payload + ) + else: + # Object doesn't exist yet - queue the request + async with self._pending_lock: + if object_id_bytes not in self._pending_gets: + self._pending_gets[object_id_bytes] = [] + self._pending_gets[object_id_bytes].append((writer, header)) + logger.debug(f"Queued GET request for object {object_id_bytes.hex()[:16]}...") + + async def _handle_delete(self, writer: asyncio.StreamWriter, header: ObjectRequestHeader): + """Handle DELETE request.""" + object_id = header.object_id + object_id_bytes = self._object_id_to_bytes(object_id) + + existed = self._backend.delete(object_id_bytes) + + response_type = ( + ObjectResponseHeader.ObjectResponseType.DelOK + if existed + else ObjectResponseHeader.ObjectResponseType.DelNotExists + ) + + await self._write_response(writer, object_id, header.request_id, response_type) + + async def _handle_duplicate(self, writer: asyncio.StreamWriter, header: ObjectRequestHeader, payload: bytes): + """Handle DUPLICATE request (link new object ID to existing object's content).""" + new_object_id = header.object_id + new_object_id_bytes = self._object_id_to_bytes(new_object_id) + + # Parse original object ID from payload + with _object_storage.ObjectID.from_bytes(bytes(payload)) as msg: + original_object_id = from_capnp_object_id(msg) + original_object_id_bytes = self._object_id_to_bytes(original_object_id) + + result = self._backend.get(original_object_id_bytes) + + if result is not None: + data, metadata = result + # Create copy with new ID + self._backend.put(new_object_id_bytes, data, metadata) + + await self._write_response( + writer, new_object_id, header.request_id, ObjectResponseHeader.ObjectResponseType.DuplicateOK + ) + else: + # Original doesn't exist yet - queue the request + async with self._pending_lock: + if original_object_id_bytes not in self._pending_gets: + self._pending_gets[original_object_id_bytes] = [] + # Store as a special duplicate request + self._pending_gets[original_object_id_bytes].append( + (writer, header) # We'll need special handling for this + ) + logger.debug(f"Queued DUPLICATE request for object {original_object_id_bytes.hex()[:16]}...") + + @staticmethod + def _object_id_to_bytes(object_id: ObjectID) -> bytes: + """Convert ObjectID to bytes for backend storage. + + ObjectID is already a 32-byte bytes subclass, so just return it directly. + """ + return bytes(object_id) + + async def _run_server(self, host: str, port: int): + """Main server loop.""" + self._server = await asyncio.start_server(self._handle_client, host, port) + + addr = self._server.sockets[0].getsockname() + logger.info(f"PythonObjectStorageServer listening on {addr[0]}:{addr[1]}") + + # Signal ready via pipe for cross-process signaling + self._set_ready() + + async with self._server: + await self._server.serve_forever() + + def run( + self, + host: str, + port: int, + identity: str = "PythonObjectStorageServer", + log_level: str = "INFO", + log_format: str = "%(levelname)s: %(message)s", + log_paths: Tuple[str, ...] = ("/dev/stdout",), + multiprocessing_ready: bool = False, + ): + """ + Run the object storage server. + + This method blocks until the server is shut down. + + Args: + host: Host address to bind to + port: Port to listen on + identity: Server identity string + log_level: Logging level + log_format: Logging format string + log_paths: Paths to log to + multiprocessing_ready: If True, use pipe-based ready signaling for multiprocessing + """ + # Initialize ready signaling for multiprocessing + if multiprocessing_ready: + self._init_ready_fds() + + # Set up signal handlers + def handle_signal(signum, frame): + logger.info(f"Received signal {signum}, shutting down...") + self._shutdown_event.set() + if self._server: + self._server.close() + + signal.signal(signal.SIGTERM, handle_signal) + signal.signal(signal.SIGINT, handle_signal) + + # Run the async server + try: + asyncio.run(self._run_server(host, port)) + except KeyboardInterrupt: + pass + finally: + self._backend.close() + self._close_ready_fds() + logger.info("PythonObjectStorageServer stopped") + + +def create_backend(backend_type: str, redis_config: Optional[dict] = None) -> ObjectStorageBackend: + """ + Create a storage backend based on configuration. + + Args: + backend_type: "memory" or "redis" + redis_config: Configuration dict for Redis backend (url, max_object_size_mb, key_prefix, connection_pool_size) + + Returns: + ObjectStorageBackend instance + """ + if backend_type == "memory": + return MemoryObjectStorageBackend() + elif backend_type == "redis": + try: + from scaler.object_storage.redis_backend import RedisObjectStorageBackend + except ImportError: + raise ImportError( + "Redis backend requires the 'redis' package. " "Install with: pip install opengris-scaler[redis]" + ) + + config = redis_config or {} + return RedisObjectStorageBackend( + redis_url=config.get("url", "redis://localhost:6379/0"), + max_object_size_mb=config.get("max_object_size_mb", 100), + key_prefix=config.get("key_prefix", "scaler:obj:"), + connection_pool_size=config.get("connection_pool_size", 10), + ) + else: + raise ValueError(f"Unknown backend type: {backend_type}. Use 'memory' or 'redis'.") diff --git a/src/scaler/object_storage/redis_backend.py b/src/scaler/object_storage/redis_backend.py new file mode 100644 index 000000000..e379af151 --- /dev/null +++ b/src/scaler/object_storage/redis_backend.py @@ -0,0 +1,307 @@ +""" +Redis-based object storage backend. + +This module provides an optional Redis backend for distributed object storage. +Redis is useful for: +- Sharing objects across multiple nodes +- Persisting objects across scheduler restarts +- Small-to-medium sized objects (< 100MB recommended) + +Note: Redis must be installed separately: pip install opengris-scaler[redis] +""" + +import logging +import struct +from typing import Optional, Tuple + +from scaler.object_storage.backend import ObjectStorageBackend + +logger = logging.getLogger(__name__) + + +class RedisObjectStorageBackend(ObjectStorageBackend): + """ + Redis-based object storage backend. + + This backend stores objects in Redis, which provides: + - Distributed access across multiple nodes + - Optional persistence to disk + - Automatic memory management + + Limitations: + - Redis has a default max value size of 512MB + - All data is stored in RAM (unless Redis persistence is configured) + - Network latency for remote Redis servers + """ + + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + max_object_size_mb: int = 100, + key_prefix: str = "scaler:obj:", + connection_pool_size: int = 10, + ): + """ + Initialize Redis backend. + + Args: + redis_url: Redis connection URL (e.g., "redis://localhost:6379/0") + max_object_size_mb: Maximum object size in MB (default: 100) + key_prefix: Prefix for all Redis keys (default: "scaler:obj:") + connection_pool_size: Size of connection pool (default: 10) + + Raises: + ImportError: If redis package is not installed + ConnectionError: If unable to connect to Redis + """ + try: + import redis + except ImportError: + raise ImportError( + "Redis backend requires the 'redis' package. " "Install it with: pip install opengris-scaler[redis]" + ) + + self.max_object_size = max_object_size_mb * 1024 * 1024 # Convert to bytes + self.key_prefix = key_prefix + + # Create Redis client with connection pooling + self.client = redis.Redis.from_url( + redis_url, decode_responses=False, max_connections=connection_pool_size # We work with bytes + ) + + # Test connection + try: + self.client.ping() + logger.info(f"RedisObjectStorageBackend: Connected to Redis at {redis_url}") + except redis.ConnectionError as e: + raise ConnectionError(f"Failed to connect to Redis at {redis_url}: {e}") + + # Track total size (approximate, since Redis doesn't provide this directly) + # Use _meta: prefix to distinguish from object keys + self._size_key = f"{key_prefix}_meta:total_size" + if not self.client.exists(self._size_key): + self.client.set(self._size_key, 0) + + def _make_key(self, object_id: bytes) -> bytes: + """Convert object_id to Redis key.""" + return (self.key_prefix + object_id.hex()).encode() + + def _serialize(self, data: bytes, metadata: bytes) -> bytes: + """ + Serialize data and metadata into a single byte string. + + Format: [metadata_len (4 bytes)] [metadata] [data] + """ + metadata_len = struct.pack(" Tuple[bytes, bytes]: + """ + Deserialize byte string into data and metadata. + + Returns: + Tuple of (data, metadata) + """ + metadata_len = struct.unpack(" bool: + """ + Store an object in Redis. + + Args: + object_id: Unique identifier for the object + data: The object's data payload + metadata: The object's metadata + + Returns: + True if successful, False if object is too large + """ + total_size = len(data) + len(metadata) + + # Check size limit + if total_size > self.max_object_size: + logger.warning( + f"Object {object_id.hex()} size ({total_size} bytes) exceeds " + f"max_object_size ({self.max_object_size} bytes)" + ) + return False + + key = self._make_key(object_id) + value = self._serialize(data, metadata) + + try: + # Use pipeline for atomicity and fewer round trips + pipe = self.client.pipeline() + + # Get old value size if exists (for accurate size tracking) + old_value = self.client.get(key) + old_size = len(old_value) if old_value else 0 + + # Store the object and update size tracking atomically + pipe.set(key, value) + size_delta = len(value) - old_size + pipe.incrby(self._size_key, size_delta) + pipe.execute() + + logger.debug(f"Stored object {object_id.hex()} ({total_size} bytes) in Redis") + return True + + except Exception as e: + logger.error(f"Failed to store object {object_id.hex()} in Redis: {e}") + return False + + def get(self, object_id: bytes) -> Optional[Tuple[bytes, bytes]]: + """ + Retrieve an object from Redis. + + Args: + object_id: Unique identifier for the object + + Returns: + Tuple of (data, metadata) if found, None otherwise + """ + key = self._make_key(object_id) + + try: + value = self.client.get(key) + if value is None: + return None + + data, metadata = self._deserialize(value) + logger.debug(f"Retrieved object {object_id.hex()} ({len(data)} bytes) from Redis") + return data, metadata + + except Exception as e: + logger.error(f"Failed to retrieve object {object_id.hex()} from Redis: {e}") + return None + + def delete(self, object_id: bytes) -> bool: + """ + Delete an object from Redis. + + Args: + object_id: Unique identifier for the object + + Returns: + True if the object existed and was deleted, False otherwise + """ + key = self._make_key(object_id) + + try: + # Get size before deletion for tracking + value = self.client.get(key) + if value is None: + return False + + # Delete the object + deleted = self.client.delete(key) > 0 + + if deleted: + # Update size tracking + self.client.decrby(self._size_key, len(value)) + logger.debug(f"Deleted object {object_id.hex()} from Redis") + + return deleted + + except Exception as e: + logger.error(f"Failed to delete object {object_id.hex()} from Redis: {e}") + return False + + def exists(self, object_id: bytes) -> bool: + """ + Check if an object exists in Redis. + + Args: + object_id: Unique identifier for the object + + Returns: + True if the object exists, False otherwise + """ + key = self._make_key(object_id) + try: + return self.client.exists(key) > 0 + except Exception as e: + logger.error(f"Failed to check existence of object {object_id.hex()} in Redis: {e}") + return False + + def size(self) -> int: + """ + Get the approximate total size of all objects in bytes. + + Note: This is an approximation based on tracking. Redis overhead + (keys, internal structures) is not included. + + Returns: + Total bytes stored in the backend + """ + try: + size_bytes = self.client.get(self._size_key) + if size_bytes is None: + return 0 + return int(size_bytes) + except Exception as e: + logger.error(f"Failed to get total size from Redis: {e}") + return 0 + + def clear(self) -> None: + """ + Clear all objects from the backend. + + Warning: This deletes ALL keys matching the key_prefix pattern. + Use with caution in production! + """ + try: + # Find all keys with our prefix + pattern = f"{self.key_prefix}*".encode() + keys = list(self.client.scan_iter(pattern, count=1000)) + + if keys: + # Delete all keys in batches + self.client.delete(*keys) + logger.info(f"Cleared {len(keys)} objects from Redis") + + # Reset size counter + self.client.set(self._size_key, 0) + + except Exception as e: + logger.error(f"Failed to clear objects from Redis: {e}") + + def get_info(self) -> dict: + """ + Get information about the Redis backend. + + Returns: + Dictionary with backend statistics + """ + try: + info = self.client.info("memory") + pattern = f"{self.key_prefix}*".encode() + # Count only object keys, excluding metadata keys + count = sum(1 for key in self.client.scan_iter(pattern, count=1000) if b"_meta:" not in key) + + return { + "type": "redis", + "object_count": count, + "total_size_bytes": self.size(), + "max_object_size_mb": self.max_object_size // (1024 * 1024), + "redis_used_memory": info.get("used_memory", 0), + "redis_used_memory_human": info.get("used_memory_human", "unknown"), + } + except Exception as e: + logger.error(f"Failed to get Redis info: {e}") + return {"type": "redis", "error": str(e)} + + def close(self) -> None: + """ + Close the Redis connection. + + Should be called when the backend is no longer needed. + """ + try: + self.client.close() + logger.debug("Closed Redis connection") + except Exception as e: + logger.error(f"Failed to close Redis connection: {e}") diff --git a/src/scaler/scheduler/controllers/scaling_controller.py b/src/scaler/scheduler/controllers/scaling_controller.py new file mode 100644 index 000000000..d3c239131 --- /dev/null +++ b/src/scaler/scheduler/controllers/scaling_controller.py @@ -0,0 +1,95 @@ +import logging +import math +from typing import Dict, List + +import aiohttp +from aiohttp import web + +from scaler.protocol.python.message import InformationSnapshot +from scaler.protocol.python.status import ScalingManagerStatus +from scaler.scheduler.controllers.scaling_policies.mixins import ScalingController +from scaler.scheduler.controllers.scaling_policies.types import WorkerGroupID +from scaler.utility.identifiers import WorkerID + + +class VanillaScalingController(ScalingController): + def __init__(self, adapter_webhook_url: str): + self._adapter_webhook_url = adapter_webhook_url + self._lower_task_ratio = 1 + self._upper_task_ratio = 10 + + self._worker_groups: Dict[WorkerGroupID, List[WorkerID]] = {} + + def get_status(self): + return ScalingManagerStatus.new_msg(worker_groups=self._worker_groups) + + async def on_snapshot(self, information_snapshot: InformationSnapshot): + if not information_snapshot.workers: + if information_snapshot.tasks: + await self._start_worker_group() + return + + task_ratio = len(information_snapshot.tasks) / len(information_snapshot.workers) + if task_ratio > self._upper_task_ratio: + await self._start_worker_group() + elif task_ratio < self._lower_task_ratio: + worker_group_task_counts = { + worker_group_id: sum( + information_snapshot.workers[worker_id].queued_tasks + for worker_id in worker_ids + if worker_id in information_snapshot.workers + ) + for worker_group_id, worker_ids in self._worker_groups.items() + } + if not worker_group_task_counts: + logging.warning( + "No worker groups available to shut down. There might be statically provisioned workers." + ) + return + + worker_group_id = min(worker_group_task_counts, key=worker_group_task_counts.get) + await self._shutdown_worker_group(worker_group_id) + + async def _start_worker_group(self): + response, status = await self._make_request({"action": "get_worker_adapter_info"}) + if status != web.HTTPOk.status_code: + logging.warning("Failed to get worker adapter info.") + return + + if len(self._worker_groups) >= response.get("max_worker_groups", math.inf): + return + + response, status = await self._make_request({"action": "start_worker_group"}) + if status == web.HTTPTooManyRequests.status_code: + logging.warning("Capacity exceeded, cannot start new worker group.") + return + if status == web.HTTPInternalServerError.status_code: + logging.error(f"Failed to start worker group: {response.get('error', 'Unknown error')}") + return + + worker_group_id = response["worker_group_id"].encode() + self._worker_groups[worker_group_id] = [WorkerID(worker_id.encode()) for worker_id in response["worker_ids"]] + logging.info(f"Started worker group: {worker_group_id.decode()}") + + async def _shutdown_worker_group(self, worker_group_id: WorkerGroupID): + if worker_group_id not in self._worker_groups: + logging.error(f"Worker group with ID {worker_group_id.decode()} does not exist.") + return + + response, status = await self._make_request( + {"action": "shutdown_worker_group", "worker_group_id": worker_group_id.decode()} + ) + if status == web.HTTPNotFound.status_code: + logging.error(f"Worker group with ID {worker_group_id.decode()} not found in adapter.") + return + if status == web.HTTPInternalServerError.status_code: + logging.error(f"Failed to shutdown worker group: {response.get('error', 'Unknown error')}") + return + + self._worker_groups.pop(worker_group_id) + logging.info(f"Shutdown worker group: {worker_group_id.decode()}") + + async def _make_request(self, payload): + async with aiohttp.ClientSession() as session: + async with session.post(self._adapter_webhook_url, json=payload) as response: + return await response.json(), response.status diff --git a/tests/object_storage/test_redis_backend.py b/tests/object_storage/test_redis_backend.py new file mode 100644 index 000000000..2f0b3d9a7 --- /dev/null +++ b/tests/object_storage/test_redis_backend.py @@ -0,0 +1,345 @@ +""" +Unit tests for Redis object storage backend. + +These tests require a running Redis server. They can be run with: + python -m unittest tests.object_storage.test_redis_backend + +To skip these tests if Redis is not available, they will automatically skip. +""" + +import unittest + +from scaler.utility.logging.utility import setup_logger +from tests.utility.utility import logging_test_name + +# Try to import redis and the backend +try: + import redis + + from scaler.object_storage.redis_backend import RedisObjectStorageBackend + + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + + +@unittest.skipIf(not REDIS_AVAILABLE, "Redis package not installed") +class TestRedisBackend(unittest.TestCase): + """Tests for RedisObjectStorageBackend.""" + + @classmethod + def setUpClass(cls): + """Set up test class.""" + setup_logger() + if not REDIS_AVAILABLE: + return + + # Test Redis connection + cls.redis_url = "redis://localhost:6379/15" # Use database 15 for testing + try: + test_client = redis.Redis.from_url(cls.redis_url) + test_client.ping() + cls.redis_available = True + except (redis.ConnectionError, Exception): + cls.redis_available = False + + def setUp(self): + """Set up each test.""" + logging_test_name(self) + + if not self.redis_available: + self.skipTest("Redis server not available") + + # Create backend + self.backend = RedisObjectStorageBackend( + redis_url=self.redis_url, max_object_size_mb=10, key_prefix="test:scaler:obj:" + ) + # Clear any existing test data + self.backend.clear() + + def tearDown(self): + """Clean up after each test.""" + if hasattr(self, "backend"): + self.backend.clear() + + def test_put_and_get(self): + """Test storing and retrieving an object.""" + object_id = b"test_object_1" + data = b"Hello, World!" + metadata = b"metadata_value" + + # Put object + self.assertTrue(self.backend.put(object_id, data, metadata)) + + # Get object + result = self.backend.get(object_id) + self.assertIsNotNone(result) + retrieved_data, retrieved_metadata = result + self.assertEqual(retrieved_data, data) + self.assertEqual(retrieved_metadata, metadata) + + def test_get_nonexistent(self): + """Test getting an object that doesn't exist.""" + result = self.backend.get(b"nonexistent") + self.assertIsNone(result) + + def test_delete(self): + """Test deleting an object.""" + object_id = b"test_object_2" + data = b"data" + metadata = b"metadata" + + # Put and verify + self.backend.put(object_id, data, metadata) + self.assertTrue(self.backend.exists(object_id)) + + # Delete and verify + self.assertTrue(self.backend.delete(object_id)) + self.assertFalse(self.backend.exists(object_id)) + + # Delete again should return False + self.assertFalse(self.backend.delete(object_id)) + + def test_exists(self): + """Test checking if an object exists.""" + object_id = b"test_object_3" + data = b"data" + metadata = b"metadata" + + # Should not exist initially + self.assertFalse(self.backend.exists(object_id)) + + # Put and check + self.backend.put(object_id, data, metadata) + self.assertTrue(self.backend.exists(object_id)) + + # Delete and check + self.backend.delete(object_id) + self.assertFalse(self.backend.exists(object_id)) + + def test_overwrite(self): + """Test overwriting an existing object.""" + object_id = b"test_object_4" + data1 = b"first_data" + metadata1 = b"first_metadata" + data2 = b"second_data" + metadata2 = b"second_metadata" + + # Put first version + self.backend.put(object_id, data1, metadata1) + result = self.backend.get(object_id) + self.assertEqual(result, (data1, metadata1)) + + # Overwrite with second version + self.backend.put(object_id, data2, metadata2) + result = self.backend.get(object_id) + self.assertEqual(result, (data2, metadata2)) + + def test_size_tracking(self): + """Test size tracking.""" + object_id1 = b"test_object_5" + data1 = b"x" * 1000 + metadata1 = b"m" * 100 + + object_id2 = b"test_object_6" + data2 = b"y" * 2000 + metadata2 = b"n" * 200 + + # Initial size should be 0 + initial_size = self.backend.size() + + # Add first object + self.backend.put(object_id1, data1, metadata1) + size_after_first = self.backend.size() + self.assertGreater(size_after_first, initial_size) + + # Add second object + self.backend.put(object_id2, data2, metadata2) + size_after_second = self.backend.size() + self.assertGreater(size_after_second, size_after_first) + + # Delete first object + self.backend.delete(object_id1) + size_after_delete = self.backend.size() + self.assertLess(size_after_delete, size_after_second) + + def test_max_object_size(self): + """Test maximum object size limit.""" + object_id = b"test_object_large" + # Try to store object larger than max_object_size_mb (10MB) + large_data = b"x" * (11 * 1024 * 1024) # 11MB + metadata = b"metadata" + + # Should fail due to size limit + result = self.backend.put(object_id, large_data, metadata) + self.assertFalse(result) + + # Object should not exist + self.assertFalse(self.backend.exists(object_id)) + + def test_empty_data(self): + """Test storing objects with empty data.""" + object_id = b"test_object_empty" + data = b"" + metadata = b"some_metadata" + + self.backend.put(object_id, data, metadata) + result = self.backend.get(object_id) + self.assertEqual(result, (data, metadata)) + + def test_empty_metadata(self): + """Test storing objects with empty metadata.""" + object_id = b"test_object_no_meta" + data = b"some_data" + metadata = b"" + + self.backend.put(object_id, data, metadata) + result = self.backend.get(object_id) + self.assertEqual(result, (data, metadata)) + + def test_binary_data(self): + """Test storing binary data with various byte values.""" + object_id = b"test_object_binary" + # Create data with all possible byte values + data = bytes(range(256)) + metadata = bytes(range(128, 256)) + bytes(range(0, 128)) + + self.backend.put(object_id, data, metadata) + result = self.backend.get(object_id) + self.assertEqual(result, (data, metadata)) + + def test_multiple_objects(self): + """Test storing and retrieving multiple objects.""" + objects = [(b"obj_%d" % i, b"data_%d" % i, b"meta_%d" % i) for i in range(10)] + + # Store all objects + for object_id, data, metadata in objects: + self.assertTrue(self.backend.put(object_id, data, metadata)) + + # Retrieve and verify all objects + for object_id, expected_data, expected_metadata in objects: + result = self.backend.get(object_id) + self.assertEqual(result, (expected_data, expected_metadata)) + + def test_clear(self): + """Test clearing all objects.""" + # Add multiple objects + for i in range(5): + object_id = b"test_object_%d" % i + self.backend.put(object_id, b"data", b"metadata") + + # Verify they exist + for i in range(5): + self.assertTrue(self.backend.exists(b"test_object_%d" % i)) + + # Clear all + self.backend.clear() + + # Verify all are gone + for i in range(5): + self.assertFalse(self.backend.exists(b"test_object_%d" % i)) + + # Size should be 0 + self.assertEqual(self.backend.size(), 0) + + def test_get_info(self): + """Test getting backend information.""" + info = self.backend.get_info() + + self.assertEqual(info["type"], "redis") + self.assertIn("object_count", info) + self.assertIn("total_size_bytes", info) + self.assertIn("max_object_size_mb", info) + self.assertEqual(info["max_object_size_mb"], 10) + + def test_key_isolation(self): + """Test that different key prefixes isolate objects.""" + backend1 = RedisObjectStorageBackend(redis_url=self.redis_url, key_prefix="prefix1:") + backend2 = RedisObjectStorageBackend(redis_url=self.redis_url, key_prefix="prefix2:") + + try: + backend1.clear() + backend2.clear() + + # Store object in backend1 + object_id = b"shared_id" + backend1.put(object_id, b"data1", b"meta1") + + # Should exist in backend1 but not backend2 + self.assertTrue(backend1.exists(object_id)) + self.assertFalse(backend2.exists(object_id)) + + # Store different data with same ID in backend2 + backend2.put(object_id, b"data2", b"meta2") + + # Both should exist with different data + self.assertEqual(backend1.get(object_id), (b"data1", b"meta1")) + self.assertEqual(backend2.get(object_id), (b"data2", b"meta2")) + + finally: + backend1.clear() + backend2.clear() + + def test_concurrent_access(self): + """Test that multiple operations work correctly.""" + import threading + + results = [] + + def worker(worker_id): + try: + object_id = b"worker_%d" % worker_id + data = b"data_%d" % worker_id + metadata = b"meta_%d" % worker_id + + # Put + self.backend.put(object_id, data, metadata) + + # Get + result = self.backend.get(object_id) + results.append((worker_id, result == (data, metadata))) + + # Delete + self.backend.delete(object_id) + except Exception as e: + results.append((worker_id, False, str(e))) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)] + + for t in threads: + t.start() + for t in threads: + t.join() + + # All workers should succeed + self.assertEqual(len(results), 10) + self.assertTrue(all(success for _, success in results)) + + def test_large_metadata(self): + """Test storing objects with large metadata.""" + object_id = b"test_large_meta" + data = b"small_data" + metadata = b"x" * (1024 * 1024) # 1MB metadata + + result = self.backend.put(object_id, data, metadata) + self.assertTrue(result) + + retrieved = self.backend.get(object_id) + self.assertEqual(retrieved, (data, metadata)) + + def test_special_object_ids(self): + """Test various object ID formats.""" + special_ids = [b"\x00\x01\x02", b"\xff\xfe\xfd", b"a" * 100] # Null bytes # High bytes # Long ID + + for object_id in special_ids: + data = b"test_data" + metadata = b"test_metadata" + + self.backend.put(object_id, data, metadata) + result = self.backend.get(object_id) + self.assertEqual(result, (data, metadata)) + self.backend.delete(object_id) + + +if __name__ == "__main__": + unittest.main()