diff --git a/python/ray/serve/_private/application_state.py b/python/ray/serve/_private/application_state.py index e44ac3859e0c..0adeb0b1f9d4 100644 --- a/python/ray/serve/_private/application_state.py +++ b/python/ray/serve/_private/application_state.py @@ -26,6 +26,7 @@ from ray.serve._private.config import DeploymentConfig from ray.serve._private.constants import ( DEFAULT_AUTOSCALING_POLICY_NAME, + DEFAULT_QUEUE_BASED_AUTOSCALING_POLICY, DEFAULT_REQUEST_ROUTER_PATH, RAY_SERVE_ENABLE_TASK_EVENTS, SERVE_LOGGER_NAME, @@ -39,6 +40,10 @@ from ray.serve._private.deployment_state import DeploymentStateManager from ray.serve._private.endpoint_state import EndpointState from ray.serve._private.logging_utils import configure_component_logger +from ray.serve._private.queue_monitor import ( + QueueMonitorConfig, + create_queue_monitor_actor, +) from ray.serve._private.storage.kv_store import KVStoreBase from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( @@ -74,6 +79,87 @@ CHECKPOINT_KEY = "serve-application-state-checkpoint" +def _is_task_consumer_deployment(deployment_info: DeploymentInfo) -> bool: + """Check if a deployment is a TaskConsumer.""" + try: + deployment_def = deployment_info.replica_config.deployment_def + if deployment_def is None: + return False + return getattr(deployment_def, "_is_task_consumer", False) + except Exception as e: + logger.debug(f"Error checking if deployment is TaskConsumer: {e}") + return False + + +def _get_queue_monitor_config( + deployment_info: DeploymentInfo, +) -> Optional[QueueMonitorConfig]: + """Extract QueueMonitorConfig from a TaskConsumer deployment.""" + try: + deployment_def = deployment_info.replica_config.deployment_def + if hasattr(deployment_def, "get_queue_monitor_config"): + return deployment_def.get_queue_monitor_config() + except Exception as e: + logger.warning(f"Failed to get queue monitor config: {e}") + return None + + +def _configure_queue_based_autoscaling_for_task_consumers( + deployment_infos: Dict[str, DeploymentInfo] +) -> None: + """ + Configure queue-based autoscaling for TaskConsumers. + + For TaskConsumer deployments with autoscaling enabled and no custom policy, + this function switches the autoscaling policy to queue-based autoscaling. + + Args: + deployment_infos: Deployment infos dict + """ + for deployment_name, deployment_info in deployment_infos.items(): + is_task_consumer = _is_task_consumer_deployment(deployment_info) + has_autoscaling = ( + deployment_info.deployment_config.autoscaling_config is not None + ) + + # Set queue-based autoscaling policy on TaskConsumer only if user hasn't set a custom policy. This respects user's explicit choice. + if is_task_consumer and has_autoscaling: + logger.info( + f"Deployment '{deployment_name}' is a TaskConsumer with autoscaling enabled" + ) + is_default_policy = ( + deployment_info.deployment_config.autoscaling_config.policy.is_default_policy_function() + ) + + if is_default_policy: + queue_monitor_config = _get_queue_monitor_config(deployment_info) + if queue_monitor_config is not None: + # Create QueueMonitor as a Ray actor (not Serve deployment) + # This avoids deadlock when autoscaling policy queries it from controller + try: + create_queue_monitor_actor( + deployment_name=deployment_name, + config=queue_monitor_config, + ) + except Exception as e: + logger.error( + f"Failed to create QueueMonitor actor for '{deployment_name}': {e}" + ) + continue + + # Switch to queue-based autoscaling policy + deployment_info.deployment_config.autoscaling_config.policy = ( + AutoscalingPolicy( + policy_function=DEFAULT_QUEUE_BASED_AUTOSCALING_POLICY + ) + ) + logger.info( + f"Switched TaskConsumer '{deployment_name}' to queue-based autoscaling policy" + ) + + return deployment_infos + + class BuildAppStatus(Enum): """Status of the build application task.""" @@ -1220,6 +1306,10 @@ def deploy_apps( ) for params in deployment_args } + + # Configure queue-based autoscaling for TaskConsumers + _configure_queue_based_autoscaling_for_task_consumers(deployment_infos) + self._application_states[name].deploy_app( deployment_infos, external_scaler_enabled ) diff --git a/python/ray/serve/_private/autoscaling_state.py b/python/ray/serve/_private/autoscaling_state.py index a17b22e67ffb..94b32fb73e3f 100644 --- a/python/ray/serve/_private/autoscaling_state.py +++ b/python/ray/serve/_private/autoscaling_state.py @@ -23,6 +23,7 @@ aggregate_timeseries, merge_instantaneous_total, ) +from ray.serve._private.queue_monitor import delete_queue_monitor_actor from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import get_capacity_adjusted_num_replicas from ray.serve.config import AutoscalingContext, AutoscalingPolicy @@ -941,6 +942,10 @@ def deregister_deployment(self, deployment_id: DeploymentID): ) app_state.deregister_deployment(deployment_id) + # Clean up QueueMonitor actor if it exists for this deployment + # This is needed for TaskConsumer deployments with queue-based autoscaling + delete_queue_monitor_actor(deployment_id.name) + def register_application( self, app_name: ApplicationName, diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index e96034b77993..c9026501be4b 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -378,6 +378,11 @@ "ray.serve.autoscaling_policy:default_autoscaling_policy" ) +# The default queue-based autoscaling policy to use for TaskConsumers if none is specified. +DEFAULT_QUEUE_BASED_AUTOSCALING_POLICY = ( + "ray.serve.autoscaling_policy:default_queue_based_autoscaling_policy" +) + # Feature flag to enable collecting all queued and ongoing request # metrics at handles instead of replicas. ON by default. RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE = get_env_bool( diff --git a/python/ray/serve/_private/queue_monitor.py b/python/ray/serve/_private/queue_monitor.py new file mode 100644 index 000000000000..d0b5d8c32112 --- /dev/null +++ b/python/ray/serve/_private/queue_monitor.py @@ -0,0 +1,269 @@ +import logging +from typing import Any, Dict + +import pika +import redis + +import ray +from ray.serve._private.constants import SERVE_LOGGER_NAME + +logger = logging.getLogger(SERVE_LOGGER_NAME) + +# Actor name prefix for QueueMonitor actors +QUEUE_MONITOR_ACTOR_PREFIX = "QUEUE_MONITOR::" + + +class QueueMonitorConfig: + """Configuration for the QueueMonitor deployment.""" + + def __init__(self, broker_url: str, queue_name: str): + self.broker_url = broker_url + self.queue_name = queue_name + + @property + def broker_type(self) -> str: + url_lower = self.broker_url.lower() + if url_lower.startswith("redis"): + return "redis" + elif url_lower.startswith("amqp") or url_lower.startswith("pyamqp"): + return "rabbitmq" + elif "sqs" in url_lower: + return "sqs" + else: + return "unknown" + + +class QueueMonitor: + """ + Actor that monitors queue length by directly querying the broker. + + Uses native broker clients: + - Redis: Uses redis-py library with LLEN command + - RabbitMQ: Uses pika library with passive queue declaration + """ + + def __init__(self, config: QueueMonitorConfig): + self._config = config + self._client: Any = None + self._last_queue_length: int = 0 + self._is_initialized: bool = False + + def initialize(self) -> None: + """ + Initialize connection to the broker. + + Creates the appropriate client based on broker type and tests the connection. + """ + if self._is_initialized: + return + + broker_type = self._config.broker_type + try: + if broker_type == "redis": + self._init_redis() + elif broker_type == "rabbitmq": + self._init_rabbitmq() + else: + raise ValueError( + f"Unsupported broker type: {broker_type}. Supported: redis, rabbitmq" + ) + + self._is_initialized = True + logger.info( + f"QueueMonitor initialized for queue '{self._config.queue_name}' (broker: {broker_type})" + ) + + except Exception as e: + logger.error(f"Failed to initialize QueueMonitor: {e}") + raise + + def _init_redis(self) -> None: + """Initialize Redis client.""" + self._client = redis.from_url(self._config.broker_url) + + # Test connection + self._client.ping() + + def _init_rabbitmq(self) -> None: + """Initialize RabbitMQ connection parameters.""" + # Store connection parameters - we'll create connections as needed + self._connection_params = pika.URLParameters(self._config.broker_url) + + # Test connection + connection = pika.BlockingConnection(self._connection_params) + connection.close() + + def _get_redis_queue_length(self) -> int: + return self._client.llen(self._config.queue_name) + + def _get_rabbitmq_queue_length(self) -> int: + connection = pika.BlockingConnection(self._connection_params) + try: + channel = connection.channel() + + # Passive declaration - doesn't create queue, just gets info + result = channel.queue_declare(queue=self._config.queue_name, passive=True) + + return result.method.message_count + finally: + connection.close() + + def get_config(self) -> Dict[str, str]: + """ + Get the QueueMonitor configuration as a serializable dict. + + Returns: + Dict with 'broker_url' and 'queue_name' keys + """ + return { + "broker_url": self._config.broker_url, + "queue_name": self._config.queue_name, + } + + def get_queue_length(self) -> int: + """ + Get the current queue length from the broker. + + Returns: + Number of pending tasks in the queue + """ + if not self._is_initialized: + logger.warning( + f"QueueMonitor not initialized for queue '{self._config.queue_name}', returning 0" + ) + return 0 + + try: + broker_type = self._config.broker_type + + if broker_type == "redis": + queue_length = self._get_redis_queue_length() + elif broker_type == "rabbitmq": + queue_length = self._get_rabbitmq_queue_length() + else: + raise ValueError(f"Unsupported broker type: {broker_type}") + + # Update cache + self._last_queue_length = queue_length + + return queue_length + + except Exception as e: + logger.warning( + f"Failed to query queue length: {e}. Using last known value: {self._last_queue_length}" + ) + return self._last_queue_length + + def shutdown(self) -> None: + if self._client is not None: + try: + if hasattr(self._client, "close"): + self._client.close() + except Exception as e: + logger.warning(f"Error closing client: {e}") + + self._client = None + self._is_initialized = False + + def __del__(self): + self.shutdown() + + +@ray.remote(num_cpus=0) +class QueueMonitorActor(QueueMonitor): + """ + Ray actor version of QueueMonitor for direct access from ServeController. + + This is used instead of a Serve deployment because the autoscaling policy + runs inside the ServeController, and using serve.get_deployment_handle() + from within the controller causes a deadlock. + """ + + def __init__(self, config: QueueMonitorConfig): + super().__init__(config) + self.initialize() + + +def create_queue_monitor_actor( + deployment_name: str, + config: QueueMonitorConfig, + namespace: str = "serve", +) -> ray.actor.ActorHandle: + """ + Create a named QueueMonitor Ray actor. + + Args: + deployment_name: Name of the deployment + config: QueueMonitorConfig with broker URL and queue name + namespace: Ray namespace for the actor + + Returns: + ActorHandle for the QueueMonitor actor + """ + full_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{deployment_name}" + + # Check if actor already exists + try: + existing = ray.get_actor(full_actor_name, namespace=namespace) + logger.info(f"QueueMonitor actor '{full_actor_name}' already exists, reusing") + + return existing + except ValueError: + pass # Actor doesn't exist, create it + + actor = QueueMonitorActor.options( + name=full_actor_name, + namespace=namespace, + lifetime="detached", + ).remote(config) + + logger.info( + f"Created QueueMonitor actor '{full_actor_name}' in namespace '{namespace}'" + ) + return actor + + +def get_queue_monitor_actor( + deployment_name: str, + namespace: str = "serve", +) -> ray.actor.ActorHandle: + """ + Get an existing QueueMonitor actor by name. + + Args: + deployment_name: Name of the deployment + namespace: Ray namespace + + Returns: + ActorHandle for the QueueMonitor actor + + Raises: + ValueError: If actor doesn't exist + """ + full_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{deployment_name}" + return ray.get_actor(full_actor_name, namespace=namespace) + + +def delete_queue_monitor_actor( + deployment_name: str, + namespace: str = "serve", +) -> bool: + """ + Delete a QueueMonitor actor by name. + + Args: + deployment_name: Name of the deployment + namespace: Ray namespace + + Returns: + True if actor was deleted, False if it didn't exist + """ + full_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{deployment_name}" + try: + actor = ray.get_actor(full_actor_name, namespace=namespace) + ray.kill(actor) + logger.info(f"Deleted QueueMonitor actor '{full_actor_name}'") + return True + except ValueError: + # Actor doesn't exist + return False diff --git a/python/ray/serve/_private/test_utils.py b/python/ray/serve/_private/test_utils.py index db8cd4def7ca..623129481451 100644 --- a/python/ray/serve/_private/test_utils.py +++ b/python/ray/serve/_private/test_utils.py @@ -302,7 +302,7 @@ def check_num_replicas_eq( target: int, app_name: str = SERVE_DEFAULT_APP_NAME, use_controller: bool = False, -) -> int: +) -> bool: """Check if num replicas is == target.""" if use_controller: diff --git a/python/ray/serve/autoscaling_policy.py b/python/ray/serve/autoscaling_policy.py index 35083aa68df0..c2a698165820 100644 --- a/python/ray/serve/autoscaling_policy.py +++ b/python/ray/serve/autoscaling_policy.py @@ -2,10 +2,17 @@ import math from typing import Any, Dict, Optional, Tuple +import ray from ray.serve._private.constants import ( CONTROL_LOOP_INTERVAL_S, SERVE_AUTOSCALING_DECISION_COUNTERS_KEY, SERVE_LOGGER_NAME, + SERVE_NAMESPACE, +) +from ray.serve._private.queue_monitor import ( + QUEUE_MONITOR_ACTOR_PREFIX, + QueueMonitorConfig, + create_queue_monitor_actor, ) from ray.serve.config import AutoscalingConfig, AutoscalingContext from ray.util.annotations import PublicAPI @@ -119,8 +126,6 @@ def replica_queue_length_autoscaling_policy( ) return curr_target_num_replicas, policy_state - decision_num_replicas = curr_target_num_replicas - desired_num_replicas = _calculate_desired_num_replicas( config, total_num_requests, @@ -128,53 +133,260 @@ def replica_queue_length_autoscaling_policy( override_min_replicas=capacity_adjusted_min_replicas, override_max_replicas=capacity_adjusted_max_replicas, ) - # Scale up. + + decision_num_replicas, decision_counter = _apply_scaling_decision_smoothing( + desired_num_replicas=desired_num_replicas, + curr_target_num_replicas=curr_target_num_replicas, + decision_counter=decision_counter, + config=config, + ) + + policy_state["decision_counter"] = decision_counter + policy_state[SERVE_AUTOSCALING_DECISION_COUNTERS_KEY] = decision_counter + return decision_num_replicas, policy_state + + +@PublicAPI(stability="alpha") +def queue_based_autoscaling_policy( + ctx: AutoscalingContext, +) -> Tuple[int, Dict[str, Any]]: + """ + Autoscaling policy for TaskConsumer deployments based on queue depth. + + This policy scales replicas based on the number of pending tasks in the + message queue, rather than HTTP request load. + + Formula: + desired_replicas = ceil(queue_length / target_ongoing_requests) + + Behavior: + - Queries QueueMonitor Ray actor directly via ray.get_actor() + - If QueueMonitor unavailable, maintains current replica count + - Uses same smoothing/delay logic as default policy to prevent oscillation + + Args: + ctx: AutoscalingContext containing metrics, config, and state + + Returns: + Tuple of (desired_num_replicas, updated_policy_state) + """ + + # Extract state + policy_state: Dict[str, Any] = ctx.policy_state + current_num_replicas: int = ctx.current_num_replicas + curr_target_num_replicas: int = ctx.target_num_replicas + config: Optional[AutoscalingConfig] = ctx.config + capacity_adjusted_min_replicas: int = ctx.capacity_adjusted_min_replicas + capacity_adjusted_max_replicas: int = ctx.capacity_adjusted_max_replicas + + # Get decision counter from state (for smoothing) + decision_counter = policy_state.get("decision_counter", 0) + + # === STEP 1: Get queue length from QueueMonitor actor === + # Actor name format: "QUEUE_MONITOR::" + queue_monitor_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{ctx.deployment_name}" + queue_monitor_actor, actor_found = _get_or_recover_queue_monitor_actor( + queue_monitor_actor_name=queue_monitor_actor_name, + deployment_name=ctx.deployment_name, + policy_state=policy_state, + ) + + if not actor_found: + logger.warning( + f"[{ctx.deployment_name}] QueueMonitor actor unavailable, maintaining {curr_target_num_replicas} replicas" + ) + return curr_target_num_replicas, policy_state + + try: + queue_length = ray.get( + queue_monitor_actor.get_queue_length.remote(), timeout=5.0 + ) + + # Store config in policy_state if not already stored (for future recovery) + if "queue_monitor_config" not in policy_state: + try: + config_dict = ray.get( + queue_monitor_actor.get_config.remote(), timeout=5.0 + ) + policy_state["queue_monitor_config"] = config_dict + logger.info( + f"[{ctx.deployment_name}] Stored QueueMonitor config in policy_state for recovery" + ) + except Exception as e: + logger.warning( + f"[{ctx.deployment_name}] Failed to store config in policy_state: {e}" + ) + + except Exception as e: + # Error querying actor - maintain current replicas + logger.warning( + f"[{ctx.deployment_name}] Could not query QueueMonitor ({e}), maintaining {curr_target_num_replicas} replicas" + ) + return curr_target_num_replicas, policy_state + + # === STEP 2: Calculate desired replicas === + target_ongoing_requests = config.get_target_ongoing_requests() + + policy_state["last_queue_length"] = queue_length + + # Handle scale from zero + if current_num_replicas == 0: + if queue_length > 0: + desired = math.ceil(queue_length / target_ongoing_requests) + desired = max(1, min(desired, capacity_adjusted_max_replicas)) + return desired, policy_state + return 0, policy_state + + # Calculate desired replicas based on queue depth + desired_num_replicas = math.ceil(queue_length / target_ongoing_requests) + + # Clamp to min/max bounds + desired_num_replicas = max( + capacity_adjusted_min_replicas, + min(capacity_adjusted_max_replicas, desired_num_replicas), + ) + + # === STEP 3: Apply smoothing (same logic as default policy) === + decision_num_replicas, decision_counter = _apply_scaling_decision_smoothing( + desired_num_replicas=desired_num_replicas, + curr_target_num_replicas=curr_target_num_replicas, + decision_counter=decision_counter, + config=config, + ) + + # Update policy state + policy_state["decision_counter"] = decision_counter + + policy_state[SERVE_AUTOSCALING_DECISION_COUNTERS_KEY] = decision_counter + return decision_num_replicas, policy_state + + +def _apply_scaling_decision_smoothing( + desired_num_replicas: int, + curr_target_num_replicas: int, + decision_counter: int, + config: AutoscalingConfig, +) -> Tuple[int, int]: + """ + Apply smoothing logic to prevent oscillation in scaling decisions. + + This function implements delay-based smoothing: a scaling decision must be + made for a consecutive number of periods before actually scaling. + + Args: + desired_num_replicas: The calculated desired number of replicas. + curr_target_num_replicas: Current target number of replicas. + decision_counter: Counter tracking consecutive scaling decisions. + Positive = consecutive scale-up decisions, negative = scale-down. + config: Autoscaling configuration containing delay settings. + + Returns: + Tuple of (decision_num_replicas, updated_decision_counter). + """ + decision_num_replicas = curr_target_num_replicas + + # Scale up if desired_num_replicas > curr_target_num_replicas: - # If the previous decision was to scale down (the counter was - # negative), we reset it and then increment it (set to 1). - # Otherwise, just increment. if decision_counter < 0: decision_counter = 0 decision_counter += 1 - # Only actually scale the replicas if we've made this decision for - # 'scale_up_consecutive_periods' in a row. + # Only scale after upscale_delay_s if decision_counter > int(config.upscale_delay_s / CONTROL_LOOP_INTERVAL_S): decision_counter = 0 decision_num_replicas = desired_num_replicas - # Scale down. + # Scale down elif desired_num_replicas < curr_target_num_replicas: - # If the previous decision was to scale up (the counter was - # positive), reset it to zero before decrementing. - if decision_counter > 0: decision_counter = 0 decision_counter -= 1 + # Downscaling to zero is only allowed from 1 -> 0 - is_scaling_to_zero = curr_target_num_replicas == 1 - # Determine the delay to use - if is_scaling_to_zero: - # Check if the downscale_to_zero_delay_s is set - if config.downscale_to_zero_delay_s is not None: - delay_s = config.downscale_to_zero_delay_s - else: - delay_s = config.downscale_delay_s + is_scaling_to_zero = curr_target_num_replicas == 1 and desired_num_replicas == 0 + if is_scaling_to_zero and config.downscale_to_zero_delay_s is not None: + delay_s = config.downscale_to_zero_delay_s else: delay_s = config.downscale_delay_s - # The desired_num_replicas>0 for downscaling cases other than 1->0 + # Ensure desired_num_replicas >= 1 for non-zero scaling cases desired_num_replicas = max(1, desired_num_replicas) - # Only actually scale the replicas if we've made this decision for - # 'scale_down_consecutive_periods' in a row. + + # Only scale after delay if decision_counter < -int(delay_s / CONTROL_LOOP_INTERVAL_S): decision_counter = 0 decision_num_replicas = desired_num_replicas - # Do nothing. + + # No change else: decision_counter = 0 - policy_state[SERVE_AUTOSCALING_DECISION_COUNTERS_KEY] = decision_counter - return decision_num_replicas, policy_state + return decision_num_replicas, decision_counter + + +def _get_or_recover_queue_monitor_actor( + queue_monitor_actor_name: str, + deployment_name: str, + policy_state: Dict[str, Any], +) -> Tuple[Optional[ray.actor.ActorHandle], bool]: + """ + Try to get an existing QueueMonitor actor, or recover it from policy_state. + + Args: + queue_monitor_actor_name: The name of the QueueMonitor actor to look up. + deployment_name: The deployment name (for logging). + policy_state: The policy state dict that may contain stored config for recovery. + + Returns: + Tuple of (queue_monitor_actor, actor_found). If actor_found is False, + queue_monitor_actor will be None. + """ + queue_monitor_actor = None + actor_found = False + + # Try to get existing actor + try: + queue_monitor_actor = ray.get_actor( + queue_monitor_actor_name, namespace=SERVE_NAMESPACE + ) + actor_found = True + except ValueError: + # Actor not found - try to recover from policy_state + logger.warning( + f"[{deployment_name}] QueueMonitor actor not found, checking policy_state for recovery" + ) + + stored_config = policy_state.get("queue_monitor_config") + if stored_config is not None: + # Attempt to recreate actor from stored config + try: + logger.info( + f"[{deployment_name}] Attempting to recreate QueueMonitor actor from stored config" + ) + queue_config = QueueMonitorConfig( + broker_url=stored_config["broker_url"], + queue_name=stored_config["queue_name"], + ) + queue_monitor_actor = create_queue_monitor_actor( + deployment_name=deployment_name, + config=queue_config, + ) + actor_found = True + logger.info( + f"[{deployment_name}] Successfully recreated QueueMonitor actor" + ) + except Exception as e: + logger.error( + f"[{deployment_name}] Failed to recreate QueueMonitor actor: {e}" + ) + else: + logger.warning( + f"[{deployment_name}] No stored config in policy_state, " + f"cannot recover QueueMonitor actor" + ) + + return queue_monitor_actor, actor_found default_autoscaling_policy = replica_queue_length_autoscaling_policy + +default_queue_based_autoscaling_policy = queue_based_autoscaling_policy diff --git a/python/ray/serve/task_consumer.py b/python/ray/serve/task_consumer.py index e057eaf966a0..7cbf378cdd01 100644 --- a/python/ray/serve/task_consumer.py +++ b/python/ray/serve/task_consumer.py @@ -8,6 +8,7 @@ DEFAULT_CONSUMER_CONCURRENCY, SERVE_LOGGER_NAME, ) +from ray.serve._private.queue_monitor import QueueMonitorConfig from ray.serve._private.task_consumer import TaskConsumerWrapper from ray.serve._private.utils import copy_class_metadata from ray.serve.schema import ( @@ -161,6 +162,37 @@ def __del__(self): copy_class_metadata(_TaskConsumerWrapper, target_cls) + # Attach metadata for TaskConsumer detection + _TaskConsumerWrapper._is_task_consumer = True + _TaskConsumerWrapper._task_processor_config = task_processor_config + + @classmethod + def get_queue_monitor_config() -> Optional[QueueMonitorConfig]: + """ + Returns the QueueMonitorConfig for this TaskConsumer. + + This method is called by ApplicationState to create the internal + QueueMonitor deployment when deploying a TaskConsumer. + + Returns: + QueueMonitorConfig for connecting to the broker. + """ + # Extract broker_url from adapter_config + adapter_config = task_processor_config.adapter_config + broker_url = getattr(adapter_config, "broker_url", None) + + if broker_url is None: + raise ValueError( + "broker_url is required in adapter_config for queue monitoring" + ) + + return QueueMonitorConfig( + broker_url=broker_url, + queue_name=task_processor_config.queue_name, + ) + + _TaskConsumerWrapper.get_queue_monitor_config = get_queue_monitor_config + return _TaskConsumerWrapper return decorator diff --git a/python/ray/serve/tests/test_task_processor.py b/python/ray/serve/tests/test_task_processor.py index f09e09c1b3f6..0071d109d7e2 100644 --- a/python/ray/serve/tests/test_task_processor.py +++ b/python/ray/serve/tests/test_task_processor.py @@ -2,6 +2,7 @@ import os import sys import tempfile +import time from collections import defaultdict from pathlib import Path @@ -10,6 +11,10 @@ import ray from ray import serve from ray._common.test_utils import SignalActor, wait_for_condition +from ray.serve._private.test_utils import ( + check_num_replicas_eq, + check_num_replicas_gte, +) from ray.serve.schema import CeleryAdapterConfig, TaskProcessorConfig from ray.serve.task_consumer import ( instantiate_adapter_from_config, @@ -851,5 +856,241 @@ def process_request(self, data: str): ) +def check_app_running(app_name) -> bool: + try: + status = serve.status() + if app_name not in status.applications: + return False + app_status = status.applications[app_name] + return app_status.status == "RUNNING" + except Exception: + return False + + +def get_task_processor_config(queue_name, redis_address): + return TaskProcessorConfig( + queue_name=queue_name, + adapter_config=CeleryAdapterConfig( + broker_url=f"redis://{redis_address}/0", + backend_url=f"redis://{redis_address}/1", + app_custom_config={"worker_prefetch_multiplier": 1}, + ), + ) + + +def create_autoscaling_task_consumer( + processor_config: TaskProcessorConfig, + autoscaling_config: dict, + signal_actor, + processed_tasks_tracker, +): + """Factory to create an AutoscalingTaskConsumer deployment with given config.""" + + @serve.deployment( + autoscaling_config=autoscaling_config, + max_ongoing_requests=1, + ) + @task_consumer(task_processor_config=processor_config) + class AutoscalingTaskConsumer: + def __init__(self, signal_actor, processed_tasks_tracker): + self._signal = signal_actor + self._processed_tasks_tracker = processed_tasks_tracker + + @task_handler(name="blocking_task") + def blocking_task(self, data): + ray.get(self._signal.wait.remote()) + ray.get(self._processed_tasks_tracker.add_task.remote(data)) + return f"Processed: {data}" + + return AutoscalingTaskConsumer.bind(signal_actor, processed_tasks_tracker) + + +def enqueue_tasks( + processor_config: TaskProcessorConfig, + num_tasks: int, + task_name: str = "blocking_task", +): + """Enqueue multiple tasks to the queue.""" + for i in range(num_tasks): + send_request_to_queue.remote(processor_config, f"task_{i}", task_name=task_name) + + +def wait_for_tasks_processed( + processed_tasks_tracker, expected_count: int, timeout: int = 60 +): + """Wait for expected number of tasks to be processed.""" + wait_for_condition( + lambda: ray.get(processed_tasks_tracker.get_count.remote()) == expected_count, + timeout=timeout, + ) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows.") +class TestTaskConsumerQueueBasedAutoscaling: + """Integration tests for queue-based autoscaling with TaskConsumer. + + These tests verify that TaskConsumer deployments with autoscaling enabled + correctly scale based on queue depth using the queue_based_autoscaling_policy. + + The tests use: + - external_redis: Real Redis broker for task queue + - SignalActor: Block task handlers to accumulate tasks in queue + - wait_for_condition: Verify replica count scales based on queue length + """ + + def _setup_test(self): + """Common setup for autoscaling tests.""" + redis_address = os.environ.get("RAY_REDIS_ADDRESS") + queue_name = f"autoscaling_queue_{time.time_ns()}" + processor_config = get_task_processor_config(queue_name, redis_address) + signal = SignalActor.remote() + processed_tasks_tracker = ProcessedTasksTracker.remote() + return processor_config, signal, processed_tasks_tracker + + def _deploy_and_wait(self, deployment, app_name: str): + """Deploy and wait for the application to be running.""" + serve.run(deployment, name=app_name) + wait_for_condition(lambda: check_app_running(app_name), timeout=30) + + def _release_and_wait_for_completion( + self, signal, processed_tasks_tracker, num_tasks: int + ): + """Release blocked tasks and wait for all to complete.""" + ray.get(signal.send.remote()) + wait_for_tasks_processed(processed_tasks_tracker, num_tasks) + + def test_task_consumer_scales_up_based_on_queue_depth( + self, external_redis, serve_instance # noqa: F811 + ): + """Test that TaskConsumer deployment scales up when queue fills up. + + Scenario: + - Deploy TaskConsumer with autoscaling (min=1, max=5, target_ongoing_requests=5) + - Block task handler using SignalActor + - Enqueue 20 tasks + - Verify replicas scale up: 20 tasks / 5 target = 4 replicas expected + """ + processor_config, signal, tracker = self._setup_test() + app_name = "autoscaling_scale_up_test" + + autoscaling_config = { + "min_replicas": 1, + "max_replicas": 5, + "target_ongoing_requests": 5, + "upscale_delay_s": 0, + "downscale_delay_s": 10, + } + + deployment = create_autoscaling_task_consumer( + processor_config, autoscaling_config, signal, tracker + ) + self._deploy_and_wait(deployment, app_name) + + # Enqueue 20 tasks -> expect ceil(20/5) = 4 replicas + num_tasks = 20 + enqueue_tasks(processor_config, num_tasks) + + wait_for_condition( + lambda: check_num_replicas_eq("AutoscalingTaskConsumer", 4, app_name), + timeout=60, + ) + + # Release tasks and verify scale down to min_replicas + self._release_and_wait_for_completion(signal, tracker, num_tasks) + + wait_for_condition( + lambda: check_num_replicas_eq("AutoscalingTaskConsumer", 1, app_name), + timeout=60, + ) + + def test_task_consumer_autoscaling_respects_max_replicas( + self, external_redis, serve_instance # noqa: F811 + ): + """Test that autoscaling respects max_replicas even with large queue. + + Scenario: + - Deploy TaskConsumer with max_replicas=2 + - Enqueue 20 tasks (would need ceil(20/2)=10 replicas) + - Verify scaling is capped at max_replicas=2 + """ + processor_config, signal, tracker = self._setup_test() + app_name = "autoscaling_max_replicas_test" + + autoscaling_config = { + "min_replicas": 1, + "max_replicas": 2, + "target_ongoing_requests": 2, + "upscale_delay_s": 0, + "downscale_delay_s": 10, + } + + deployment = create_autoscaling_task_consumer( + processor_config, autoscaling_config, signal, tracker + ) + self._deploy_and_wait(deployment, app_name) + + # Enqueue 20 tasks -> would need 10 replicas, but capped at max=2 + num_tasks = 20 + enqueue_tasks(processor_config, num_tasks) + + wait_for_condition( + lambda: check_num_replicas_eq("AutoscalingTaskConsumer", 2, app_name), + timeout=60, + ) + + # Release tasks and verify scale down to min_replicas + self._release_and_wait_for_completion(signal, tracker, num_tasks) + + wait_for_condition( + lambda: check_num_replicas_eq("AutoscalingTaskConsumer", 1, app_name), + timeout=60, + ) + + def test_task_consumer_autoscaling_respects_min_replicas( + self, external_redis, serve_instance # noqa: F811 + ): + """Test that autoscaling respects min_replicas even with empty queue. + + Scenario: + - Deploy TaskConsumer with min_replicas=2 + - Enqueue tasks and let them complete (queue becomes empty) + - Verify replicas stay at min_replicas=2, not scaling below + """ + processor_config, signal, tracker = self._setup_test() + app_name = "autoscaling_min_replicas_test" + + autoscaling_config = { + "min_replicas": 2, + "max_replicas": 5, + "target_ongoing_requests": 5, + "upscale_delay_s": 0, + "downscale_delay_s": 0, # Fast downscale to test min_replicas floor + } + + deployment = create_autoscaling_task_consumer( + processor_config, autoscaling_config, signal, tracker + ) + self._deploy_and_wait(deployment, app_name) + + # Wait for initial scale to min_replicas + wait_for_condition( + lambda: check_num_replicas_gte("AutoscalingTaskConsumer", 2, app_name), + timeout=30, + ) + + # Enqueue a few tasks and let them complete + num_tasks = 5 + enqueue_tasks(processor_config, num_tasks) + self._release_and_wait_for_completion(signal, tracker, num_tasks) + + # Wait and verify replicas stay at min_replicas (2), not below + time.sleep(5) + + wait_for_condition( + lambda: check_num_replicas_eq("AutoscalingTaskConsumer", 2, app_name), + timeout=30, + ) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_queue_autoscaling_policy.py b/python/ray/serve/tests/unit/test_queue_autoscaling_policy.py new file mode 100644 index 000000000000..b02e2e5c1ed6 --- /dev/null +++ b/python/ray/serve/tests/unit/test_queue_autoscaling_policy.py @@ -0,0 +1,489 @@ +import sys +import time +from typing import Any, Dict, Optional +from unittest.mock import MagicMock, patch + +import pytest + +from ray.serve._private.constants import CONTROL_LOOP_INTERVAL_S +from ray.serve._private.queue_monitor import QUEUE_MONITOR_ACTOR_PREFIX +from ray.serve.autoscaling_policy import queue_based_autoscaling_policy +from ray.serve.config import AutoscalingConfig, AutoscalingContext + + +def create_autoscaling_context( + current_num_replicas: int = 1, + target_num_replicas: int = 1, + min_replicas: int = 0, + max_replicas: int = 10, + target_ongoing_requests: float = 10.0, + upscale_delay_s: float = 0.0, + downscale_delay_s: float = 0.0, + downscale_to_zero_delay_s: Optional[float] = None, + policy_state: Optional[Dict[str, Any]] = None, + current_time: Optional[float] = None, + deployment_name: str = "test_deployment", + app_name: str = "test_app", +) -> AutoscalingContext: + """Helper to create AutoscalingContext for tests.""" + config = AutoscalingConfig( + min_replicas=min_replicas, + max_replicas=max_replicas, + target_ongoing_requests=target_ongoing_requests, + upscale_delay_s=upscale_delay_s, + downscale_delay_s=downscale_delay_s, + downscale_to_zero_delay_s=downscale_to_zero_delay_s, + ) + + return AutoscalingContext( + config=config, + current_num_replicas=current_num_replicas, + target_num_replicas=target_num_replicas, + total_num_requests=0, + capacity_adjusted_min_replicas=min_replicas, + capacity_adjusted_max_replicas=max_replicas, + policy_state=policy_state or {}, + deployment_id=None, + deployment_name=deployment_name, + app_name=app_name, + running_replicas=None, + current_time=current_time or time.time(), + total_queued_requests=None, + total_running_requests=None, + aggregated_metrics={}, + raw_metrics={}, + last_scale_up_time=None, + last_scale_down_time=None, + ) + + +@pytest.fixture +def mock_ray_actor_methods(): + """Fixture to mock ray.get_actor and ray.get for QueueMonitor actor access.""" + with patch("ray.serve.autoscaling_policy.ray.get_actor") as mock_get_actor, patch( + "ray.serve.autoscaling_policy.ray.get" + ) as mock_ray_get: + yield mock_get_actor, mock_ray_get + + +def setup_queue_monitor_mocks( + mock_get_actor, + mock_ray_get, + queue_length, + deployment_name="test_deployment", + config_dict=None, +): + """Helper to set up all mocks for a successful queue monitor query.""" + queue_monitor_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}{deployment_name}" + + # Mock the actor handle + mock_actor = MagicMock() + mock_queue_length_ref = MagicMock() + mock_config_ref = MagicMock() + mock_actor.get_queue_length.remote.return_value = mock_queue_length_ref + mock_actor.get_config.remote.return_value = mock_config_ref + mock_get_actor.return_value = mock_actor + + # Mock ray.get to return the queue length or config based on the ref + if config_dict is None: + config_dict = {"broker_url": "redis://localhost", "queue_name": "test_queue"} + + def mock_ray_get_side_effect(ref, **kwargs): + if ref == mock_queue_length_ref: + return queue_length + elif ref == mock_config_ref: + return config_dict + return queue_length # fallback + + mock_ray_get.side_effect = mock_ray_get_side_effect + + return queue_monitor_actor_name + + +class TestQueueBasedAutoscalingPolicy: + """Tests for queue_based_autoscaling_policy function.""" + + def test_queue_monitor_unavailable_maintains_replicas(self, mock_ray_actor_methods): + """Test that unavailable QueueMonitor maintains current replica count.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + # Mock actor not found + mock_get_actor.side_effect = ValueError("Actor not found") + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 5 + + def test_queue_monitor_query_fails_maintains_replicas(self, mock_ray_actor_methods): + """Test that failed query maintains current replica count.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + # Mock successful actor but failed ray.get + mock_actor = MagicMock() + mock_get_actor.return_value = mock_actor + mock_ray_get.side_effect = Exception("Query failed") + + ctx = create_autoscaling_context( + current_num_replicas=3, + target_num_replicas=3, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 3 + + @pytest.mark.parametrize( + "queue_length,target_per_replica,expected_replicas", + [ + (10, 10, 1), # 10/10 = 1 + (15, 10, 2), # 15/10 = 1.5 -> ceil = 2 + (100, 10, 10), # 100/10 = 10 + (5, 10, 1), # 5/10 = 0.5 -> ceil = 1 (clamped to min) + ], + ) + def test_basic_scaling_formula( + self, + mock_ray_actor_methods, + queue_length, + target_per_replica, + expected_replicas, + ): + """Test basic scaling formula: ceil(queue_length / target_per_replica).""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, queue_length) + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=target_per_replica, + min_replicas=0, + max_replicas=20, + upscale_delay_s=0, + downscale_delay_s=0, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == expected_replicas + + def test_scale_down_to_one_before_zero(self, mock_ray_actor_methods): + """Test that scaling to zero goes through 1 first (policy enforces 1->0).""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 0) + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=10, + min_replicas=0, + max_replicas=20, + upscale_delay_s=0, + downscale_delay_s=0, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + # Policy enforces min of 1 for non-zero to zero transition + assert new_replicas == 1 + + def test_respects_max_replicas(self, mock_ray_actor_methods): + """Test that scaling respects max_replicas bound.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 1000) + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=10, + max_replicas=10, + upscale_delay_s=0, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 10 + + def test_respects_min_replicas(self, mock_ray_actor_methods): + """Test that scaling respects min_replicas bound.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 5) + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=10, + min_replicas=3, + upscale_delay_s=0, + downscale_delay_s=0, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 3 + + def test_scale_from_zero(self, mock_ray_actor_methods): + """Test scaling up from zero replicas.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 50) + + ctx = create_autoscaling_context( + current_num_replicas=0, + target_num_replicas=0, + target_ongoing_requests=10, + min_replicas=0, + max_replicas=10, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 5 # ceil(50/10) + + def test_scale_from_zero_with_one_task(self, mock_ray_actor_methods): + """Test scaling from zero with a single task.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 1) + + ctx = create_autoscaling_context( + current_num_replicas=0, + target_num_replicas=0, + target_ongoing_requests=10, + min_replicas=0, + max_replicas=10, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 1 + + def test_stay_at_zero_with_empty_queue(self, mock_ray_actor_methods): + """Test staying at zero replicas when queue is empty.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 0) + + ctx = create_autoscaling_context( + current_num_replicas=0, + target_num_replicas=0, + target_ongoing_requests=10, + min_replicas=0, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 0 + + def test_correct_queue_monitor_actor_name(self, mock_ray_actor_methods): + """Test that correct QueueMonitor actor name is used.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + queue_monitor_actor_name = f"{QUEUE_MONITOR_ACTOR_PREFIX}my_task_consumer" + setup_queue_monitor_mocks( + mock_get_actor, mock_ray_get, 50, deployment_name="my_task_consumer" + ) + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + deployment_name="my_task_consumer", + app_name="my_app", + ) + + queue_based_autoscaling_policy(ctx) + + # Verify ray.get_actor was called with the correct actor name + mock_get_actor.assert_called_once_with( + queue_monitor_actor_name, + namespace="serve", + ) + + +class TestQueueBasedAutoscalingPolicyDelays: + """Tests for upscale and downscale delays in queue_based_autoscaling_policy.""" + + def test_upscale_delay(self, mock_ray_actor_methods): + """Test that upscale decisions require delay.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 200) + + upscale_delay_s = 30.0 + wait_periods = int(upscale_delay_s / CONTROL_LOOP_INTERVAL_S) + + policy_state = {} + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=10, + max_replicas=20, + upscale_delay_s=upscale_delay_s, + policy_state=policy_state, + ) + + # First wait_periods calls should not scale + for i in range(wait_periods): + new_replicas, policy_state = queue_based_autoscaling_policy(ctx) + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=10, + max_replicas=20, + upscale_delay_s=upscale_delay_s, + policy_state=policy_state, + ) + assert new_replicas == 5, f"Should not scale up at iteration {i}" + + # Next call should scale + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 20 + + def test_downscale_delay(self, mock_ray_actor_methods): + """Test that downscale decisions require delay.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 20) + + downscale_delay_s = 60.0 + wait_periods = int(downscale_delay_s / CONTROL_LOOP_INTERVAL_S) + + policy_state = {} + ctx = create_autoscaling_context( + current_num_replicas=10, + target_num_replicas=10, + target_ongoing_requests=10, + min_replicas=1, + downscale_delay_s=downscale_delay_s, + policy_state=policy_state, + ) + + # First wait_periods calls should not scale down + for i in range(wait_periods): + new_replicas, policy_state = queue_based_autoscaling_policy(ctx) + ctx = create_autoscaling_context( + current_num_replicas=10, + target_num_replicas=10, + target_ongoing_requests=10, + min_replicas=1, + downscale_delay_s=downscale_delay_s, + policy_state=policy_state, + ) + assert new_replicas == 10, f"Should not scale down at iteration {i}" + + # Next call should scale + new_replicas, _ = queue_based_autoscaling_policy(ctx) + assert new_replicas == 2 + + +class TestQueueBasedAutoscalingPolicyState: + """Tests for policy state management.""" + + def test_stores_queue_length_in_state(self, mock_ray_actor_methods): + """Test that queue_length is stored in policy state.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 42) + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + upscale_delay_s=0, + downscale_delay_s=0, + ) + + _, policy_state = queue_based_autoscaling_policy(ctx) + assert policy_state.get("last_queue_length") == 42 + + def test_preserves_decision_counter(self, mock_ray_actor_methods): + """Test that decision counter is preserved across calls.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + setup_queue_monitor_mocks(mock_get_actor, mock_ray_get, 200) + + policy_state = {} + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=10, + upscale_delay_s=30.0, + policy_state=policy_state, + ) + + _, policy_state = queue_based_autoscaling_policy(ctx) + assert policy_state.get("decision_counter", 0) == 1 + + # Call again + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + target_ongoing_requests=10, + upscale_delay_s=30.0, + policy_state=policy_state, + ) + _, policy_state = queue_based_autoscaling_policy(ctx) + assert policy_state.get("decision_counter", 0) == 2 + + +class TestQueueBasedAutoscalingPolicyActorRecovery: + """Tests for QueueMonitor actor recovery via policy_state.""" + + def test_stores_config_in_policy_state(self, mock_ray_actor_methods): + """Test that QueueMonitor config is stored in policy_state on first call.""" + mock_get_actor, mock_ray_get = mock_ray_actor_methods + config_dict = {"broker_url": "redis://myredis:6379", "queue_name": "myqueue"} + setup_queue_monitor_mocks( + mock_get_actor, mock_ray_get, 50, config_dict=config_dict + ) + + ctx = create_autoscaling_context( + current_num_replicas=5, + target_num_replicas=5, + upscale_delay_s=0, + downscale_delay_s=0, + ) + + _, policy_state = queue_based_autoscaling_policy(ctx) + + assert "queue_monitor_config" in policy_state + assert ( + policy_state["queue_monitor_config"]["broker_url"] == "redis://myredis:6379" + ) + assert policy_state["queue_monitor_config"]["queue_name"] == "myqueue" + + def test_recovers_actor_from_policy_state(self): + """Test that actor is recreated from policy_state when not found.""" + with patch( + "ray.serve.autoscaling_policy.ray.get_actor" + ) as mock_get_actor, patch( + "ray.serve.autoscaling_policy.ray.get" + ) as mock_ray_get, patch( + "ray.serve.autoscaling_policy.create_queue_monitor_actor" + ) as mock_create_actor: + + # First call: actor not found + mock_get_actor.side_effect = ValueError("Actor not found") + + # Mock the newly created actor + mock_new_actor = MagicMock() + mock_queue_length_ref = MagicMock() + mock_new_actor.get_queue_length.remote.return_value = mock_queue_length_ref + mock_create_actor.return_value = mock_new_actor + mock_ray_get.return_value = 100 # Queue length + + # Pass stored config in policy_state + stored_config = { + "broker_url": "redis://localhost:6379", + "queue_name": "tasks", + } + policy_state = {"queue_monitor_config": stored_config} + + ctx = create_autoscaling_context( + current_num_replicas=1, + target_num_replicas=1, + target_ongoing_requests=10, + max_replicas=20, + upscale_delay_s=0, + policy_state=policy_state, + ) + + new_replicas, _ = queue_based_autoscaling_policy(ctx) + + # Verify actor was recreated + mock_create_actor.assert_called_once() + call_kwargs = mock_create_actor.call_args.kwargs + assert call_kwargs["deployment_name"] == "test_deployment" + assert call_kwargs["config"].broker_url == "redis://localhost:6379" + assert call_kwargs["config"].queue_name == "tasks" + + # Verify scaling happened based on queue length + assert new_replicas == 10 # ceil(100/10) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_queue_monitor.py b/python/ray/serve/tests/unit/test_queue_monitor.py new file mode 100644 index 000000000000..e439b34767d6 --- /dev/null +++ b/python/ray/serve/tests/unit/test_queue_monitor.py @@ -0,0 +1,164 @@ +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from ray.serve._private.queue_monitor import ( + QueueMonitor, + QueueMonitorConfig, +) + + +class TestQueueMonitorConfig: + """Tests for QueueMonitorConfig class.""" + + def test_redis_broker_type(self): + """Test Redis broker type detection.""" + config = QueueMonitorConfig( + broker_url="redis://localhost:6379/0", + queue_name="test_queue", + ) + assert config.broker_type == "redis" + + def test_redis_with_password_broker_type(self): + """Test Redis with password broker type detection.""" + config = QueueMonitorConfig( + broker_url="rediss://user:password@localhost:6379/0", + queue_name="test_queue", + ) + assert config.broker_type == "redis" + + def test_rabbitmq_amqp_broker_type(self): + """Test RabbitMQ AMQP broker type detection.""" + config = QueueMonitorConfig( + broker_url="amqp://guest:guest@localhost:5672//", + queue_name="test_queue", + ) + assert config.broker_type == "rabbitmq" + + def test_rabbitmq_pyamqp_broker_type(self): + """Test RabbitMQ pyamqp broker type detection.""" + config = QueueMonitorConfig( + broker_url="pyamqp://guest:guest@localhost:5672//", + queue_name="test_queue", + ) + assert config.broker_type == "rabbitmq" + + def test_sqs_broker_type(self): + """Test SQS broker type detection.""" + config = QueueMonitorConfig( + broker_url="sqs://...", + queue_name="test_queue", + ) + assert config.broker_type == "sqs" + + def test_unknown_broker_type(self): + """Test unknown broker type detection.""" + config = QueueMonitorConfig( + broker_url="some://unknown/broker", + queue_name="test_queue", + ) + assert config.broker_type == "unknown" + + def test_config_stores_values(self): + """Test config stores provided values.""" + config = QueueMonitorConfig( + broker_url="redis://localhost:6379", + queue_name="my_queue", + ) + assert config.broker_url == "redis://localhost:6379" + assert config.queue_name == "my_queue" + + +class TestQueueMonitor: + """Tests for QueueMonitor class.""" + + @pytest.fixture + def redis_config(self): + """Provides a Redis QueueMonitorConfig.""" + return QueueMonitorConfig( + broker_url="redis://localhost:6379/0", + queue_name="test_queue", + ) + + @pytest.fixture + def rabbitmq_config(self): + """Provides a RabbitMQ QueueMonitorConfig.""" + return QueueMonitorConfig( + broker_url="amqp://guest:guest@localhost:5672//", + queue_name="test_queue", + ) + + @patch("ray.serve._private.queue_monitor.redis") + def test_get_redis_queue_length(self, mock_redis_module, redis_config): + """Test Redis queue length retrieval.""" + mock_client = MagicMock() + mock_client.llen.return_value = 42 + mock_redis_module.from_url.return_value = mock_client + + monitor = QueueMonitor(redis_config) + monitor.initialize() + length = monitor.get_queue_length() + + assert length == 42 + mock_client.llen.assert_called_with("test_queue") + + @patch("ray.serve._private.queue_monitor.pika") + def test_get_rabbitmq_queue_length(self, mock_pika, rabbitmq_config): + """Test RabbitMQ queue length retrieval.""" + mock_params = MagicMock() + mock_pika.URLParameters.return_value = mock_params + + # Mock for initialization + mock_init_connection = MagicMock() + # Mock for queue length query + mock_query_connection = MagicMock() + mock_channel = MagicMock() + mock_result = MagicMock() + mock_result.method.message_count = 25 + + mock_query_connection.channel.return_value = mock_channel + mock_channel.queue_declare.return_value = mock_result + + # First call is for initialization, second is for query + mock_pika.BlockingConnection.side_effect = [ + mock_init_connection, + mock_query_connection, + ] + + monitor = QueueMonitor(rabbitmq_config) + monitor.initialize() + length = monitor.get_queue_length() + + assert length == 25 + mock_channel.queue_declare.assert_called_with( + queue="test_queue", + passive=True, + ) + + @patch("ray.serve._private.queue_monitor.redis") + def test_get_queue_length_returns_cached_on_error( + self, mock_redis_module, redis_config + ): + """Test get_queue_length returns cached value on error.""" + mock_client = MagicMock() + mock_client.llen.return_value = 50 + mock_redis_module.from_url.return_value = mock_client + + monitor = QueueMonitor(redis_config) + monitor.initialize() + + # First successful query + length = monitor.get_queue_length() + assert length == 50 + + # Now make queries fail + mock_client.llen.side_effect = Exception("Connection lost") + + # Should return cached value + length = monitor.get_queue_length() + assert length == 50 + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_task_consumer.py b/python/ray/serve/tests/unit/test_task_consumer.py index 2068eb19924c..4a9b3cceb33a 100644 --- a/python/ray/serve/tests/unit/test_task_consumer.py +++ b/python/ray/serve/tests/unit/test_task_consumer.py @@ -300,6 +300,31 @@ def my_task(self): assert MyTaskConsumer.name == "MyTaskConsumer" +def test_queue_monitor_provides_queue_length(): + """Test that QueueMonitor provides queue length for autoscaling.""" + + queue_name = f"test_queue_{uuid.uuid4().hex}" + task_processor_config = TaskProcessorConfig( + queue_name=queue_name, + adapter_config=CeleryAdapterConfig( + broker_url="redis://localhost:6379/0", + backend_url="redis://localhost:6379/0", + ), + adapter=MockTaskProcessorAdapter, + ) + + @task_consumer(task_processor_config=task_processor_config) + class MyConsumer: + @task_handler + def process(self): + pass + + monitor_config = MyConsumer.get_queue_monitor_config() + + assert monitor_config.broker_url == "redis://localhost:6379/0" + assert monitor_config.queue_name == task_processor_config.queue_name + + def test_task_consumer_preserves_metadata(config): class OriginalConsumer: """Docstring for a task consumer.""" diff --git a/python/setup.py b/python/setup.py index 1f14f08ea669..318c4f95f2c9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -315,6 +315,7 @@ def get_packages(self): setup_spec.extras["serve"] + [ "celery", + "pika", ] ) )