diff --git a/python/ray/serve/_private/broker.py b/python/ray/serve/_private/broker.py new file mode 100644 index 000000000000..12ddfbf5f90b --- /dev/null +++ b/python/ray/serve/_private/broker.py @@ -0,0 +1,285 @@ +# This module provides broker clients for querying queue lengths from message brokers. +# Adapted from Flower's broker.py (https://github.com/mher/flower/blob/master/flower/utils/broker.py) +# with the following modification: +# - Added close() method to BrokerBase and RedisBase for resource cleanup + +import json +import logging +import numbers +import socket +from urllib.parse import quote, unquote, urljoin, urlparse + +from tornado import httpclient, ioloop + +from ray.serve._private.constants import SERVE_LOGGER_NAME + +try: + import redis +except ImportError: + redis = None + + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +class BrokerBase: + def __init__(self, broker_url, *_, **__): + purl = urlparse(broker_url) + self.host = purl.hostname + self.port = purl.port + self.vhost = purl.path[1:] + + username = purl.username + password = purl.password + + self.username = unquote(username) if username else username + self.password = unquote(password) if password else password + + async def queues(self, names): + raise NotImplementedError + + def close(self): + """Close any open connections. Override in subclasses as needed.""" + pass + + +class RabbitMQ(BrokerBase): + def __init__(self, broker_url, http_api, io_loop=None, **__): + super().__init__(broker_url) + self.io_loop = io_loop or ioloop.IOLoop.instance() + + self.host = self.host or "localhost" + self.port = self.port or 15672 + self.vhost = quote(self.vhost, "") or "/" if self.vhost != "/" else self.vhost + self.username = self.username or "guest" + self.password = self.password or "guest" + + if not http_api: + http_api = f"http://{self.username}:{self.password}@{self.host}:{self.port}/api/{self.vhost}" + + try: + self.validate_http_api(http_api) + except ValueError: + logger.error("Invalid broker api url: %s", http_api) + + self.http_api = http_api + + async def queues(self, names): + url = urljoin(self.http_api, "queues/" + self.vhost) + api_url = urlparse(self.http_api) + username = unquote(api_url.username or "") or self.username + password = unquote(api_url.password or "") or self.password + + http_client = httpclient.AsyncHTTPClient() + try: + response = await http_client.fetch( + url, + auth_username=username, + auth_password=password, + connect_timeout=1.0, + request_timeout=2.0, + validate_cert=False, + ) + except (socket.error, httpclient.HTTPError) as e: + logger.error("RabbitMQ management API call failed: %s", e) + return [] + finally: + http_client.close() + + if response.code == 200: + info = json.loads(response.body.decode()) + return [x for x in info if x["name"] in names] + response.rethrow() + + @classmethod + def validate_http_api(cls, http_api): + url = urlparse(http_api) + if url.scheme not in ("http", "https"): + raise ValueError(f"Invalid http api schema: {url.scheme}") + + +class RedisBase(BrokerBase): + DEFAULT_SEP = "\x06\x16" + DEFAULT_PRIORITY_STEPS = [0, 3, 6, 9] + + def __init__(self, broker_url, *_, **kwargs): + super().__init__(broker_url) + self.redis = None + + if not redis: + raise ImportError("redis library is required") + + broker_options = kwargs.get("broker_options", {}) + self.priority_steps = broker_options.get( + "priority_steps", self.DEFAULT_PRIORITY_STEPS + ) + self.sep = broker_options.get("sep", self.DEFAULT_SEP) + self.broker_prefix = broker_options.get("global_keyprefix", "") + + def _q_for_pri(self, queue, pri): + if pri not in self.priority_steps: + raise ValueError("Priority not in priority steps") + # pylint: disable=consider-using-f-string + return "{0}{1}{2}".format(*((queue, self.sep, pri) if pri else (queue, "", ""))) + + async def queues(self, names): + queue_stats = [] + for name in names: + priority_names = [ + self.broker_prefix + self._q_for_pri(name, pri) + for pri in self.priority_steps + ] + queue_stats.append( + { + "name": name, + "messages": sum((self.redis.llen(x) for x in priority_names)), + } + ) + return queue_stats + + def close(self): + """Close the Redis connection.""" + if self.redis is not None: + self.redis.close() + self.redis = None + + +class Redis(RedisBase): + def __init__(self, broker_url, *args, **kwargs): + super().__init__(broker_url, *args, **kwargs) + self.host = self.host or "localhost" + self.port = self.port or 6379 + self.vhost = self._prepare_virtual_host(self.vhost) + self.redis = self._get_redis_client() + + def _prepare_virtual_host(self, vhost): + if not isinstance(vhost, numbers.Integral): + if not vhost or vhost == "/": + vhost = 0 + elif vhost.startswith("/"): + vhost = vhost[1:] + try: + vhost = int(vhost) + except ValueError as exc: + raise ValueError( + f"Database is int between 0 and limit - 1, not {vhost}" + ) from exc + return vhost + + def _get_redis_client_args(self): + return { + "host": self.host, + "port": self.port, + "db": self.vhost, + "username": self.username, + "password": self.password, + } + + def _get_redis_client(self): + return redis.Redis(**self._get_redis_client_args()) + + +class RedisSentinel(RedisBase): + def __init__(self, broker_url, *args, **kwargs): + super().__init__(broker_url, *args, **kwargs) + broker_options = kwargs.get("broker_options", {}) + broker_use_ssl = kwargs.get("broker_use_ssl", None) + self.host = self.host or "localhost" + self.port = self.port or 26379 + self.vhost = self._prepare_virtual_host(self.vhost) + self.master_name = self._prepare_master_name(broker_options) + self.redis = self._get_redis_client(broker_options, broker_use_ssl) + + def _prepare_virtual_host(self, vhost): + if not isinstance(vhost, numbers.Integral): + if not vhost or vhost == "/": + vhost = 0 + elif vhost.startswith("/"): + vhost = vhost[1:] + try: + vhost = int(vhost) + except ValueError as exc: + raise ValueError( + f"Database is int between 0 and limit - 1, not {vhost}" + ) from exc + return vhost + + def _prepare_master_name(self, broker_options): + try: + master_name = broker_options["master_name"] + except KeyError as exc: + raise ValueError("master_name is required for Sentinel broker") from exc + return master_name + + def _get_redis_client(self, broker_options, broker_use_ssl): + connection_kwargs = { + "password": self.password, + "sentinel_kwargs": broker_options.get("sentinel_kwargs"), + } + if isinstance(broker_use_ssl, dict): + connection_kwargs["ssl"] = True + connection_kwargs.update(broker_use_ssl) + # get all sentinel hosts from Celery App config and use them to initialize Sentinel + sentinel = redis.sentinel.Sentinel( + [(self.host, self.port)], **connection_kwargs + ) + redis_client = sentinel.master_for(self.master_name) + return redis_client + + +class RedisSocket(RedisBase): + def __init__(self, broker_url, *args, **kwargs): + super().__init__(broker_url, *args, **kwargs) + self.redis = redis.Redis( + unix_socket_path="/" + self.vhost, password=self.password + ) + + +class RedisSsl(Redis): + """ + Redis SSL class offering connection to the broker over SSL. + This does not currently support SSL settings through the url, only through + the broker_use_ssl celery configuration. + """ + + def __init__(self, broker_url, *args, **kwargs): + if "broker_use_ssl" not in kwargs: + raise ValueError("rediss broker requires broker_use_ssl") + self.broker_use_ssl = kwargs.get("broker_use_ssl", {}) + super().__init__(broker_url, *args, **kwargs) + + def _get_redis_client_args(self): + client_args = super()._get_redis_client_args() + client_args["ssl"] = True + if isinstance(self.broker_use_ssl, dict): + client_args.update(self.broker_use_ssl) + return client_args + + +class Broker: + """Factory returning the appropriate broker client based on URL scheme. + + Supported schemes: + ``amqp`` or ``amqps`` -> :class:`RabbitMQ` + ``redis`` -> :class:`Redis` + ``rediss`` -> :class:`RedisSsl` + ``redis+socket`` -> :class:`RedisSocket` + ``sentinel`` -> :class:`RedisSentinel` + """ + + def __new__(cls, broker_url, *args, **kwargs): + scheme = urlparse(broker_url).scheme + if scheme in ("amqp", "amqps"): + return RabbitMQ(broker_url, *args, **kwargs) + if scheme == "redis": + return Redis(broker_url, *args, **kwargs) + if scheme == "rediss": + return RedisSsl(broker_url, *args, **kwargs) + if scheme == "redis+socket": + return RedisSocket(broker_url, *args, **kwargs) + if scheme == "sentinel": + return RedisSentinel(broker_url, *args, **kwargs) + raise NotImplementedError + + async def queues(self, names): + raise NotImplementedError diff --git a/python/ray/serve/_private/queue_monitor.py b/python/ray/serve/_private/queue_monitor.py new file mode 100644 index 000000000000..2b0c3f5294ae --- /dev/null +++ b/python/ray/serve/_private/queue_monitor.py @@ -0,0 +1,163 @@ +import logging +from typing import Any, Dict + +import ray +from ray._common.constants import HEAD_NODE_RESOURCE_NAME +from ray.serve._private.broker import Broker +from ray.serve._private.constants import SERVE_LOGGER_NAME + +logger = logging.getLogger(SERVE_LOGGER_NAME) + +# Actor name prefix for QueueMonitor actors +QUEUE_MONITOR_ACTOR_PREFIX = "QUEUE_MONITOR::" + + +@ray.remote(num_cpus=0) +class QueueMonitorActor: + """ + Actor that monitors queue length by directly querying the broker. + + Returns pending tasks in the queue. + + Uses native broker clients: + - Redis: Uses redis-py library with LLEN command + - RabbitMQ: Uses HTTP management API + """ + + def __init__( + self, + broker_url: str, + queue_name: str, + rabbitmq_http_url: str = "http://guest:guest@localhost:15672/api/", + ): + self._broker_url = broker_url + self._queue_name = queue_name + self._rabbitmq_http_url = rabbitmq_http_url + + self._broker = Broker(self._broker_url, http_api=self._rabbitmq_http_url) + + def __ray_shutdown__(self): + if self._broker is not None: + self._broker.close() + self._broker = None + + def get_config(self) -> Dict[str, Any]: + """ + Get the QueueMonitor configuration as a serializable dict. + + Returns: + Dict with 'broker_url', 'queue_name', and 'rabbitmq_http_url' keys + """ + return { + "broker_url": self._broker_url, + "queue_name": self._queue_name, + "rabbitmq_http_url": self._rabbitmq_http_url, + } + + async def get_queue_length(self) -> int: + """ + Get the current queue length from the broker. + + Returns: + Number of pending tasks in the queue. + + Raises: + ValueError: If queue is not found in broker response or + if queue data is missing the 'messages' field. + """ + queues = await self._broker.queues([self._queue_name]) + if queues is not None: + for q in queues: + if q.get("name") == self._queue_name: + queue_length = q.get("messages") + if queue_length is None: + raise ValueError( + f"Queue '{self._queue_name}' is missing 'messages' field" + ) + return queue_length + + raise ValueError(f"Queue '{self._queue_name}' not found in broker response") + + +def create_queue_monitor_actor( + deployment_name: str, + broker_url: str, + queue_name: str, + rabbitmq_http_url: str = "http://guest:guest@localhost:15672/api/", + namespace: str = "serve", +) -> ray.actor.ActorHandle: + """ + Create a named QueueMonitor Ray actor. + + Args: + deployment_name: Name of the deployment + broker_url: URL of the message broker + queue_name: Name of the queue to monitor + rabbitmq_http_url: HTTP API URL for RabbitMQ management (only for RabbitMQ) + namespace: Ray namespace for the actor + + Returns: + ActorHandle for the QueueMonitor actor + """ + full_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{deployment_name}" + + # Check if actor already exists + try: + existing = get_queue_monitor_actor(deployment_name, namespace=namespace) + logger.info(f"QueueMonitor actor '{full_actor_name}' already exists, reusing") + return existing + except ValueError: + actor = QueueMonitorActor.options( + name=full_actor_name, + namespace=namespace, + max_restarts=-1, + max_task_retries=-1, + resources={HEAD_NODE_RESOURCE_NAME: 0.001}, + ).remote(broker_url, queue_name, rabbitmq_http_url) + + logger.info( + f"Created QueueMonitor actor '{full_actor_name}' in namespace '{namespace}'" + ) + return actor + + +def get_queue_monitor_actor( + deployment_name: str, + namespace: str = "serve", +) -> ray.actor.ActorHandle: + """ + Get an existing QueueMonitor actor by name. + + Args: + deployment_name: Name of the deployment + namespace: Ray namespace + + Returns: + ActorHandle for the QueueMonitor actor + + Raises: + ValueError: If actor doesn't exist + """ + full_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{deployment_name}" + return ray.get_actor(full_actor_name, namespace=namespace) + + +def kill_queue_monitor_actor( + deployment_name: str, + namespace: str = "serve", +) -> None: + """ + Delete a QueueMonitor actor by name. + + Args: + deployment_name: Name of the deployment + namespace: Ray namespace + + Raises: + ValueError: If actor doesn't exist + """ + full_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{deployment_name}" + actor = get_queue_monitor_actor(deployment_name, namespace=namespace) + + ray.kill(actor, no_restart=True) + logger.info(f"Deleted QueueMonitor actor '{full_actor_name}'") diff --git a/python/ray/serve/_private/test_utils.py b/python/ray/serve/_private/test_utils.py index db8cd4def7ca..623129481451 100644 --- a/python/ray/serve/_private/test_utils.py +++ b/python/ray/serve/_private/test_utils.py @@ -302,7 +302,7 @@ def check_num_replicas_eq( target: int, app_name: str = SERVE_DEFAULT_APP_NAME, use_controller: bool = False, -) -> int: +) -> bool: """Check if num replicas is == target.""" if use_controller: diff --git a/python/ray/serve/tests/BUILD.bazel b/python/ray/serve/tests/BUILD.bazel index f1327f9f8f60..1ed740fc9902 100644 --- a/python/ray/serve/tests/BUILD.bazel +++ b/python/ray/serve/tests/BUILD.bazel @@ -125,6 +125,7 @@ py_test_module_list( "test_multiplex.py", "test_proxy.py", "test_proxy_response_generator.py", + "test_queue_monitor.py", "test_ray_client.py", "test_record_routing_stats.py", "test_regression.py", diff --git a/python/ray/serve/tests/test_queue_monitor.py b/python/ray/serve/tests/test_queue_monitor.py new file mode 100644 index 000000000000..bf15092b6c92 --- /dev/null +++ b/python/ray/serve/tests/test_queue_monitor.py @@ -0,0 +1,74 @@ +"""Integration tests for QueueMonitorActor using real Redis.""" +import os +import sys + +import pytest +import redis + +import ray +from ray.serve._private.queue_monitor import ( + create_queue_monitor_actor, +) +from ray.tests.conftest import external_redis # noqa: F401 + + +@pytest.fixture +def redis_client(external_redis): # noqa: F811 + """Create a Redis client connected to the external Redis.""" + redis_address = os.environ.get("RAY_REDIS_ADDRESS") + host, port = redis_address.split(":") + client = redis.Redis(host=host, port=int(port), db=0) + yield client + # Cleanup: delete test queue after each test + client.delete("test_queue") + client.close() + + +@pytest.fixture +def redis_broker_url(external_redis): # noqa: F811 + """Get the Redis broker URL for the external Redis.""" + redis_address = os.environ.get("RAY_REDIS_ADDRESS") + return f"redis://{redis_address}/0" + + +class TestQueueMonitorActor: + """Integration tests for QueueMonitorActor with real Redis.""" + + def test_get_queue_length(self, ray_instance, redis_client, redis_broker_url): + """Test queue length returns number of messages from broker.""" + # Push some messages to the queue + for i in range(30): + redis_client.lpush("test_queue", f"message_{i}") + + monitor = create_queue_monitor_actor( + "test_deployment", redis_broker_url, "test_queue" + ) + length = ray.get(monitor.get_queue_length.remote()) + + assert length == 30 + + def test_get_queue_length_empty_queue( + self, ray_instance, redis_client, redis_broker_url + ): + """Test queue length returns 0 for empty queue.""" + monitor = create_queue_monitor_actor( + "test_deployment", redis_broker_url, "test_queue" + ) + length = ray.get(monitor.get_queue_length.remote()) + + assert length == 0 + + def test_get_config(self, ray_instance, redis_broker_url): + """Test get_config returns the configuration as a dict.""" + monitor = create_queue_monitor_actor( + "test_deployment", redis_broker_url, "test_queue" + ) + config = ray.get(monitor.get_config.remote()) + + assert config["broker_url"] == redis_broker_url + assert config["queue_name"] == "test_queue" + assert config["rabbitmq_http_url"] == "http://guest:guest@localhost:15672/api/" + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__]))